package client import ( "errors" "fmt" "git.bvbej.com/bvbej/base-golang/pkg/websocket/client/service" "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec" _ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/json" _ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/protobuf" "git.bvbej.com/bvbej/base-golang/tool/assist" "github.com/gorilla/websocket" "go.uber.org/zap" "net/http" "net/url" "reflect" "sync/atomic" "time" ) const ( writeWait = 20 * time.Second pongWait = 60 * time.Second reconnectWait = 3 * time.Second pingPeriod = (pongWait * 9) / 10 maxFrameMessageLen = 16 * 1024 maxSendBuffer = 32 ) var ( _ Client = (*client)(nil) ErrBrokenPipe = errors.New("send to broken pipe") ErrBufferPoolExceed = errors.New("send buffer exceed") ) type Client interface { readLoop() writeLoop() onReceive(msg []byte) error onSend(msg []byte) (err error) Send(router string, data any) error Connect(requestHeader http.Header) error OnReceiveError() <-chan error Disconnect() } type client struct { url url.URL requestHeader http.Header logger *zap.Logger session *service.Session isConnected *atomic.Bool routerCodec codec.Codec send chan []byte handlers map[string]*service.Handler // 注册的方法列表 onReceiveChan chan error pingTicker *time.Ticker } func New(logger *zap.Logger, url url.URL, decoder string, handlers interface{}) (Client, error) { if !assist.InArray(url.Scheme, []string{"ws", "wss"}) { return nil, errors.New(`param: scheme not supported`) } routerCodec := codec.GetCodec(decoder) if routerCodec == nil { return nil, errors.New(`param: codec not supported`) } components := service.RegisterHandler(handlers) if len(components) == 0 { return nil, errors.New(`param: handlers unqualified`) } c := &client{ logger: logger, isConnected: new(atomic.Bool), routerCodec: routerCodec, url: url, send: make(chan []byte, maxSendBuffer), handlers: components, onReceiveChan: make(chan error, maxSendBuffer), pingTicker: time.NewTicker(pingPeriod), } return c, nil } func (c *client) readLoop() { defer func() { c.logger.Sugar().Info("readLoop closed") }() _ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait)) c.session.Conn.SetPongHandler(func(string) error { _ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) c.logger.Sugar().Info("readLoop running") for c.session.Conn != nil { _, data, err := c.session.Conn.ReadMessage() if err != nil { c.isConnected.Store(false) for c.connect() != nil { time.Sleep(reconnectWait) } continue } err = c.onReceive(data) if err != nil { c.onReceiveChan <- err } } } func (c *client) writeLoop() { defer func() { c.logger.Sugar().Info("writeLoop closed") }() c.logger.Sugar().Info("writeLoop running") for { select { case msg := <-c.send: _ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.session.Conn.WriteMessage(websocket.BinaryMessage, msg); err != nil { break } } } } func (c *client) ping() { for range c.pingTicker.C { _ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.session.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { for c.connect() != nil { time.Sleep(reconnectWait) } } } } func (c *client) connect() error { conn, _, err := websocket.DefaultDialer.Dial(c.url.String(), c.requestHeader) if err != nil { return fmt.Errorf("dial: %s", err) } c.session = service.NewSession(conn) return nil } func (c *client) onReceive(msg []byte) error { _, msgPack, err := c.routerCodec.Unmarshal(msg) if err != nil { return fmt.Errorf("onreceive: %v", err) } router, ok := msgPack.Router.(string) if !ok { return fmt.Errorf("onreceive: invalid router:%v", msgPack.Router) } s, ok := c.handlers[router] if !ok { return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err) } var args = []reflect.Value{s.Receiver, reflect.ValueOf(c.session), reflect.ValueOf(msgPack.DataPtr)} s.Method.Func.Call(args) return nil } func (c *client) onSend(msg []byte) (err error) { defer func() { if e := recover(); e != nil { err = ErrBrokenPipe } }() if !c.isConnected.Load() { return ErrBrokenPipe } if len(c.send) >= maxSendBuffer { return ErrBufferPoolExceed } if len(msg) > maxFrameMessageLen { return } c.send <- msg return nil } func (c *client) Send(router string, data any) error { rb, err := c.routerCodec.Marshal(router, data, nil) if err != nil { return fmt.Errorf("service: %v", err) } return c.onSend(rb) } func (c *client) OnReceiveError() <-chan error { return c.onReceiveChan } func (c *client) Connect(requestHeader http.Header) error { c.requestHeader = requestHeader err := c.connect() if err != nil { return err } go c.ping() go c.writeLoop() go c.readLoop() c.isConnected.Store(true) return nil } func (c *client) Disconnect() { c.pingTicker.Stop() _ = c.session.Conn.Close() close(c.send) close(c.onReceiveChan) c.session.Conn = nil }