package client import ( "errors" "fmt" "git.bvbej.com/bvbej/base-golang/pkg/websocket/client/service" "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec" _ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/json" _ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/protobuf" "git.bvbej.com/bvbej/base-golang/tool/assist" "github.com/gorilla/websocket" "go.uber.org/zap" "net/http" "net/url" "reflect" "sync/atomic" "time" ) const ( writeWait = 20 * time.Second pongWait = 60 * time.Second pingPeriod = (pongWait * 9) / 10 maxFrameMessageLen = 16 * 1024 maxSendBuffer = 32 ) var ( _ Client = (*client)(nil) ErrBrokenPipe = errors.New("send to broken pipe") ErrBufferPoolExceed = errors.New("send buffer exceed") ) type Client interface { acceptLoop() sendLoop() onReceive(msg []byte) error onSend(msg []byte) (err error) Send(router string, data any) error Connect(requestHeader http.Header) error OnReceiveError() <-chan error RouterCodec() codec.Codec Disconnect() } type client struct { logger *zap.Logger session *service.Session isSurviving *atomic.Bool routerCodec codec.Codec url url.URL send chan []byte handlers map[string]*service.Handler // 注册的方法列表 onReceiveErr chan error pingTicker *time.Ticker } func New(logger *zap.Logger, url url.URL, decoder string, handlers interface{}) (Client, error) { if !assist.InArray(url.Scheme, []string{"ws", "wss"}) { return nil, errors.New(`param: scheme not supported`) } routerCodec := codec.GetCodec(decoder) if routerCodec == nil { return nil, errors.New(`param: codec not supported`) } components := service.RegisterHandler(handlers) if len(components) == 0 { return nil, errors.New(`param: handlers unqualified`) } isSurviving := new(atomic.Bool) isSurviving.Store(false) c := &client{ logger: logger, isSurviving: isSurviving, routerCodec: routerCodec, url: url, send: make(chan []byte, maxSendBuffer), handlers: components, onReceiveErr: make(chan error, maxSendBuffer), pingTicker: time.NewTicker(pingPeriod), } return c, nil } func (c *client) acceptLoop() { defer func() { c.isSurviving.Store(false) }() _ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait)) c.session.Conn.SetPongHandler(func(string) error { _ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for c.session.Conn != nil { _, data, err := c.session.Conn.ReadMessage() if err != nil { break } err = c.onReceive(data) if err != nil { c.onReceiveErr <- err } } } func (c *client) sendLoop() { defer func() { c.isSurviving.Store(false) }() for { select { case msg := <-c.send: _ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.session.Conn.WriteMessage(websocket.BinaryMessage, msg); err != nil { return } case <-c.pingTicker.C: _ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.session.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } func (c *client) onReceive(msg []byte) error { _, msgPack, err := c.routerCodec.Unmarshal(msg) if err != nil { return fmt.Errorf("onreceive: %v", err) } router, ok := msgPack.Router.(string) if !ok { return fmt.Errorf("onreceive: invalid router:%v", msgPack.Router) } s, ok := c.handlers[router] if !ok { return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err) } var args = []reflect.Value{s.Receiver, reflect.ValueOf(c.session), reflect.ValueOf(msgPack.DataPtr)} s.Method.Func.Call(args) return nil } func (c *client) onSend(msg []byte) (err error) { defer func() { if e := recover(); e != nil { err = ErrBrokenPipe } }() if !c.isSurviving.Load() { return ErrBrokenPipe } if len(c.send) >= maxSendBuffer { return ErrBufferPoolExceed } if len(msg) > maxFrameMessageLen { return } c.send <- msg return nil } func (c *client) Send(router string, data any) error { rb, err := c.routerCodec.Marshal(router, data, nil) if err != nil { return fmt.Errorf("service: %v", err) } return c.onSend(rb) } func (c *client) OnReceiveError() <-chan error { return c.onReceiveErr } func (c *client) RouterCodec() codec.Codec { return c.routerCodec } func (c *client) Connect(requestHeader http.Header) error { conn, _, err := websocket.DefaultDialer.Dial(c.url.String(), requestHeader) if err != nil { return fmt.Errorf("dial: %s", err) } c.session = service.NewSession(conn) go c.acceptLoop() go c.sendLoop() c.isSurviving.Store(true) return nil } func (c *client) Disconnect() { close(c.send) c.pingTicker.Stop() _ = c.session.Conn.Close() }