Java封装netty websocket server

Java封装netty websocket server,第1张

Java封装netty websocket server

netty是一个高效的NIO框架,用netty封装的websocket服务器不仅稳定性能也非常优秀,由于netty会用线程绑定连接(通道),在游戏开发中可以利用这种特性进行单用户无锁开发,可以大量的减少对锁的使用,提高游戏的吞吐量,这里对netty进行简单封装,方便使用

git地址

https://gitee.com/ichiva/netty-websocket-server.git
介绍
封装netty用于快速创建websocket服务器
快速入门
创建FastNettyWebSocket实例,传入消息处理器,并开启监听
    public static void main(String[] args) {
        new FastNettyWebSocketServer(new WebSocket() {
            @Override
            public void onMessage(WebSocketSession session, String message) {
                System.out.println("收到:" + message);
                send(session,"收到,over");
            }
        }).listener(8080);
    }
实现细节

定义WebSocketServer接口

public interface WebSocketServer {
    
    default void onOpen(WebSocketSession session){

    }

    
    void onMessage(WebSocketSession session, String message);

    
    default void onMessage(WebSocketSession session, byte[] message) {

    }

    
    default void send(WebSocketSession session, String message){
        session.getChannel().writeAndFlush(
                new TextWebSocketframe(message)
        );
    }

    
    default void send(WebSocketSession session, byte[] message){
        session.getChannel().writeAndFlush(new BinaryWebSocketframe(
                Unpooled.buffer().writeBytes(message)
        ));
    }

    
    default void onClose(WebSocketSession session) {

    }

    
    default void onError(WebSocketSession session, Throwable e){

    }
}

定义配置文件

public interface NettyWebsocketServerConfig {

    
    default  NioEventLoopGroup getWorkerGroup(){
        return new NioEventLoopGroup();
    }

    
    default NioEventLoopGroup getBoosGroup(){
        return new NioEventLoopGroup(1);
    }

    ChannelHandler getChildHandler();

    
    default int getPort(){ return 8080; }

}

WebSocketServer (核心)实现,用于启动netty和关闭netty

public abstract class FastNettyWebSocketServer implements WebSocketServer {
    private Channel serverChannel;
    private NettyWebsocketServerConfig config;

    public void start(NettyWebsocketServerConfig config) {
        this.config = config;
        ServerBootstrap server = new ServerBootstrap();
        server.group(config.getBoosGroup(), config.getWorkerGroup());
        server.channel(NioServerSocketChannel.class);
        server.childHandler(config.getChildHandler());

        ChannelFuture future = server.bind(config.getPort());
        future.addListener(f -> {
            if (f.isDone() && f.isSuccess()) {
                this.serverChannel = future.channel();
                log.info("Start ws server success");
                log.info("boos group thread number {}", config.getBoosGroup().executorCount());
                log.info("worker group thread number {}", config.getWorkerGroup().executorCount());
            }
            if (f.isDone() && f.cause() != null) {
                log.error("Start ws server fail throw={}", f.cause().getMessage());
                future.channel().close();
            }
        });
    }

    public void start(final int port) {
        start(new NettyWebsocketServerConfig() {
            @Override
            public ChannelHandler getChildHandler() {
                return new WebSocketChannelInitializer(FastNettyWebSocketServer.this);
            }

            @Override
            public int getPort() {
                return port;
            }
        });
    }

    public void start() {
        start(8080);
    }

    public void stop() {
        if (serverChannel != null && serverChannel.isOpen()) {
            final int waitSec = 10;
            CountDownLatch latch = new CountDownLatch(1);
            serverChannel.close().addListener(f -> {
                config.getWorkerGroup().schedule(() -> {
                    log.info("Shutdown dispatcher success...");
                    config.getWorkerGroup().shutdownGracefully();
                    latch.countDown();
                }, waitSec - 2, TimeUnit.SECONDS);

                log.info("Close ws server socket success={}", f.isSuccess());
                config.getBoosGroup().shutdownGracefully();
            });

            try {
                boolean flag = latch.await(waitSec, TimeUnit.SECONDS);
                if(!flag){
                    log.warn("Shutdown ws server timeout");
                }
            } catch (InterruptedException e) {
                log.warn("Shutdown ws server interrupted exception={}", e.getMessage());
            }
        }
    }
}

默认的通道实现

public class  WebSocketChannelInitializer extends ChannelInitializer {

    private final WebSocketServer server;

    public WebSocketChannelInitializer(WebSocketServer server){
        this.server = server;
    }

    @Override
    protected void initChannel(SocketChannel ch) {
        //二进制流在通道中被处理
        ChannelPipeline pipeline = ch.pipeline();
        // HttpRequestDecoder和HttpResponseEncoder的一个组合,针对http协议进行编解码
        pipeline.addLast("httpServerCodec", new HttpServerCodec());//设置解码器
        //分块向客户端写数据,防止发送大文件时导致内存溢出, channel.write(new ChunkedFile(new File("bigFile.mkv")))
        pipeline.addLast(new ChunkedWriteHandler());//用于大数据的分区传输
        // 将HttpMessage和HttpContents聚合到一个完成的 FullHttpRequest或FullHttpResponse中
        // 具体是FullHttpRequest对象还是FullHttpResponse对象取决于是请求还是响应
        // 需要放到HttpServerCodec这个处理器后面
        pipeline.addLast(new HttpObjectAggregator(1024 * 2));//聚合器,使用websocket会用到
        // webSocket 数据压缩扩展,当添加这个的时候WebSocketServerProtocolHandler的第三个参数需要设置成true
        pipeline.addLast(new WebSocketServerCompressionHandler());
        // 服务器端向外暴露的 web socket 端点,当客户端传递比较大的对象时,maxframeSize参数的值需要调大
        pipeline.addLast(new WebSocketServerAuthProtocolHandler("/", null, true, 65536,server));
        pipeline.addLast(new LengthFieldPrepender(4));
        // 业务代码
        pipeline.addLast(new WebSocketServerChannelInboundHandler(server));
    }
}

提供session支持

public class WebSocketServerChannelInboundHandler extends SimpleChannelInboundHandler {

    private final WebSocketServer webSocketServer;

    public WebSocketServerChannelInboundHandler(WebSocketServer webSocketServer){
        this.webSocketServer = webSocketServer;
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
        WebSocketSession session = Sessions.getSession(ctx);

        if(msg instanceof TextWebSocketframe){
            String message = ((TextWebSocketframe) msg).text();

            try {
                webSocketServer.onMessage(session,message);
            }catch (Throwable e){
                webSocketServer.onError(session,e);
            }
        }else if(msg instanceof BinaryWebSocketframe){
            byte[] bytes = ((BinaryWebSocketframe) msg).content().array();

            try {
                webSocketServer.onMessage(session,bytes);
            }catch (Throwable e){
                webSocketServer.onError(session,e);
            }
        }else {
            System.out.println("未知消息类型:" + msg.getClass().getName());
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        super.exceptionCaught(ctx, cause);
        webSocketServer.onError(Sessions.getSession(ctx),cause);
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        super.channelInactive(ctx);
        WebSocketSession destroy = Sessions.destroy(ctx);
        webSocketServer.onClose(destroy);
    }
}
 

扩展uri支持

public class WebSocketServerAuthProtocolHandler extends WebSocketServerProtocolHandler {

    private final WebSocketServer webSocketServer;

    public WebSocketServerAuthProtocolHandler(String websocketPath, WebSocketServer webSocketServer) {
        this(websocketPath, null, false,webSocketServer);
    }

    public WebSocketServerAuthProtocolHandler(String websocketPath, String subprotocols, WebSocketServer webSocketServer) {
        this(websocketPath, subprotocols, false,webSocketServer);
    }

    public WebSocketServerAuthProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, WebSocketServer webSocketServer) {
        this(websocketPath, subprotocols, allowExtensions, 65536,webSocketServer);
    }

    public WebSocketServerAuthProtocolHandler(String websocketPath, String subprotocols,
                                              boolean allowExtensions, int maxframeSize, WebSocketServer webSocketServer) {
        this(websocketPath, subprotocols, allowExtensions, maxframeSize, false,webSocketServer);
    }

    public WebSocketServerAuthProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxframeSize, boolean allowMaskMismatch, WebSocketServer webSocketServer) {
        super(websocketPath, subprotocols, allowExtensions, maxframeSize, allowMaskMismatch);
        this._webSocketPathPrefix = websocketPath;
        this._subprotocols =subprotocols;
        this._allowExtensions = allowExtensions;
        this._maxframeSize = maxframeSize;
        this._allowMaskMismatch = allowMaskMismatch;
        this.webSocketServer =webSocketServer;
    }

    String _webSocketPathPrefix;
    String _subprotocols;
    boolean _allowExtensions;
    int _maxframeSize;
    boolean _allowMaskMismatch;

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) {
        ChannelPipeline cp = ctx.pipeline();
        if (cp.get(WebSocketServerAuthHandshakeHandler.class) == null) {
            // Add the WebSocketHandshakeHandler before this one.
            ctx.pipeline().addBefore(ctx.name(), WebSocketServerAuthHandshakeHandler.class.getName(),
                    new WebSocketServerAuthHandshakeHandler(_webSocketPathPrefix, _subprotocols,
                            _allowExtensions, _maxframeSize, _allowMaskMismatch, webSocketServer));
        }
        if (cp.get(Utf8framevalidator.class) == null) {
            // Add the UFT8 checking before this one.
            ctx.pipeline().addBefore(ctx.name(), Utf8framevalidator.class.getName(),
                    new Utf8framevalidator());
        }
    }

}

public class WebSocketServerAuthHandshakeHandler extends ChannelInboundHandlerAdapter {

    private final String websocketPath;
    private final String subprotocols;
    private final boolean allowExtensions;
    private final int maxframePayloadSize;
    private final boolean allowMaskMismatch;
    private final WebSocketServer webSocketServer;

    public WebSocketServerAuthHandshakeHandler(String websocketPath, String subprotocols,
                                               boolean allowExtensions, int maxframeSize, boolean allowMaskMismatch, WebSocketServer webSocketServer) {
        this.websocketPath = websocketPath;
        this.subprotocols = subprotocols;
        this.allowExtensions = allowExtensions;
        this.maxframePayloadSize = maxframeSize;
        this.allowMaskMismatch = allowMaskMismatch;
        this.webSocketServer = webSocketServer;
    }

    @Override
    public void channelRead(final ChannelHandlerContext ctx, Object msg) {
        FullHttpRequest req = (FullHttpRequest) msg;
        if (req.uri().indexOf(websocketPath) != 0) {
            ctx.fireChannelRead(msg);
            return;
        }

        try {
            if (req.method() != GET) {
                sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
                return;
            }

            WebSocketSession session = Sessions.createSession(ctx);
            session.setChannel(ctx.channel());
            session.setId(ctx.channel().hashCode());
            session.setUri(req.uri());
            UrlEntity entity = UrlEntity.parse(req.uri());
            session.setUribase(entity.getbaseUrl());
            session.setParams(entity.getParams());

            webSocketServer.onOpen(session);
            final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                    getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
                            allowExtensions, maxframePayloadSize, allowMaskMismatch);
            final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
            if (handshaker == null) {
                WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
            } else {
                final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
                handshakeFuture.addListener((ChannelFutureListener) future -> {
                    if (!future.isSuccess()) {
                        ctx.fireExceptionCaught(future.cause());
                    } else {
                        ctx.fireUserEventTriggered(
                                WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
                    }
                });
                setHandshaker(ctx.channel(), handshaker);
                ctx.pipeline().replace(this, "WS403Responder",
                        forbiddenHttpRequestResponder());
            }
        } finally {
            req.release();
        }
    }

    private static final AttributeKey HANDSHAKER_ATTR_KEY =
            AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");

    static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
        channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
    }

    static ChannelHandler forbiddenHttpRequestResponder() {
        return new ChannelInboundHandlerAdapter() {
            @Override
            public void channelRead(ChannelHandlerContext ctx, Object msg) {
                if (msg instanceof FullHttpRequest) {
                    ((FullHttpRequest) msg).release();
                    FullHttpResponse response =
                            new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN);
                    ctx.channel().writeAndFlush(response);
                } else {
                    ctx.fireChannelRead(msg);
                }
            }
        };
    }

    private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
        ChannelFuture f = ctx.channel().writeAndFlush(res);
        if (!isKeepAlive(req) || res.status().code() != 200) {
            f.addListener(ChannelFutureListener.CLOSE);
        }
    }

    private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
        String protocol = "ws";
        if (cp.get(SslHandler.class) != null) {
            // SSL in use so use Secure WebSockets
            protocol = "wss";
        }
        return protocol + "://" + req.headers().get(HttpHeaderNames.HOST) + path;
    }
}

git地址

https://gitee.com/ichiva/netty-websocket-server.git

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5707589.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)