client.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. package client
  2. import (
  3. "errors"
  4. "fmt"
  5. "git.bvbej.com/bvbej/base-golang/pkg/ticker"
  6. "git.bvbej.com/bvbej/base-golang/pkg/websocket/client/service"
  7. "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
  8. _ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/json"
  9. _ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/protobuf"
  10. "git.bvbej.com/bvbej/base-golang/tool"
  11. "github.com/gorilla/websocket"
  12. "go.uber.org/zap"
  13. "net/http"
  14. "net/url"
  15. "reflect"
  16. "sync/atomic"
  17. "time"
  18. )
  19. const (
  20. writeWait = 20 * time.Second
  21. pongWait = 60 * time.Second
  22. reconnectWait = 3 * time.Second
  23. pingPeriod = (pongWait * 9) / 10
  24. maxFrameMessageLen = 16 * 1024
  25. maxSendBuffer = 32
  26. )
  27. var (
  28. _ Client = (*client)(nil)
  29. ErrBrokenPipe = errors.New("send to broken pipe")
  30. ErrBufferPoolExceed = errors.New("send buffer exceed")
  31. )
  32. type Client interface {
  33. readLoop()
  34. writeLoop()
  35. ping()
  36. reconnect()
  37. onReceive(msg []byte) error
  38. onSend(msg []byte) error
  39. connect() error
  40. Send(router string, data any) error
  41. Connect(requestHeader http.Header) error
  42. OnReceiveError(f func(error))
  43. OnReconnected(f func(error))
  44. Close()
  45. }
  46. type client struct {
  47. url url.URL
  48. requestHeader http.Header
  49. logger *zap.Logger
  50. session *service.Session
  51. isConnected atomic.Bool
  52. routerCodec codec.Codec
  53. send chan []byte
  54. handlers map[string]*service.Handler // 注册的方法列表
  55. onReceiveErr func(error)
  56. pingTicker ticker.Ticker
  57. checkConnTicker ticker.Ticker
  58. onReconnect func(error)
  59. }
  60. func New(logger *zap.Logger, url url.URL, decoder string, handlers any) (Client, error) {
  61. if !tool.InArray(url.Scheme, []string{"ws", "wss"}) {
  62. return nil, errors.New(`param: scheme not supported`)
  63. }
  64. routerCodec := codec.GetCodec(decoder)
  65. if routerCodec == nil {
  66. return nil, errors.New(`param: codec not supported`)
  67. }
  68. components := service.RegisterHandler(handlers)
  69. if len(components) == 0 {
  70. return nil, errors.New(`param: handlers unqualified`)
  71. }
  72. c := &client{
  73. logger: logger,
  74. isConnected: atomic.Bool{},
  75. routerCodec: routerCodec,
  76. url: url,
  77. send: make(chan []byte, maxSendBuffer),
  78. handlers: components,
  79. pingTicker: ticker.New(pingPeriod),
  80. checkConnTicker: ticker.New(reconnectWait),
  81. }
  82. return c, nil
  83. }
  84. func (c *client) readLoop() {
  85. _ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait))
  86. c.session.Conn.SetPongHandler(func(string) error {
  87. _ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait))
  88. return nil
  89. })
  90. for {
  91. _, data, err := c.session.Conn.ReadMessage()
  92. if err != nil {
  93. c.isConnected.Store(false)
  94. break
  95. }
  96. err = c.onReceive(data)
  97. if err != nil && c.onReceiveErr != nil {
  98. c.onReceiveErr(err)
  99. }
  100. }
  101. }
  102. func (c *client) writeLoop() {
  103. for msg := range c.send {
  104. _ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait))
  105. err := c.session.Conn.WriteMessage(websocket.BinaryMessage, msg)
  106. if err != nil {
  107. c.logger.Sugar().Errorf("writeLoop err: %s", err)
  108. }
  109. }
  110. }
  111. func (c *client) ping() {
  112. c.pingTicker.Process(func() {
  113. _ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait))
  114. _ = c.session.Conn.WriteMessage(websocket.PingMessage, nil)
  115. })
  116. }
  117. func (c *client) reconnect() {
  118. c.checkConnTicker.Process(func() {
  119. if c.isConnected.Load() {
  120. return
  121. }
  122. err := c.connect()
  123. if c.onReconnect != nil {
  124. c.onReconnect(err)
  125. }
  126. })
  127. }
  128. func (c *client) connect() error {
  129. conn, _, err := websocket.DefaultDialer.Dial(c.url.String(), c.requestHeader)
  130. if err != nil {
  131. return fmt.Errorf("dial: %s", err)
  132. }
  133. c.session = service.NewSession(conn)
  134. c.isConnected.Store(true)
  135. go c.readLoop()
  136. return nil
  137. }
  138. func (c *client) onReceive(msg []byte) error {
  139. _, msgPack, err := c.routerCodec.Unmarshal(msg)
  140. if err != nil {
  141. return fmt.Errorf("onreceive: %v", err)
  142. }
  143. router, ok := msgPack.Router.(string)
  144. if !ok {
  145. return fmt.Errorf("onreceive: invalid router:%v", msgPack.Router)
  146. }
  147. s, ok := c.handlers[router]
  148. if !ok {
  149. return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err)
  150. }
  151. if msgPack.Err != nil {
  152. return fmt.Errorf("%s:%s", router, msgPack.Err)
  153. }
  154. var args = []reflect.Value{s.Receiver, reflect.ValueOf(c.session), reflect.ValueOf(msgPack.DataPtr)}
  155. s.Method.Func.Call(args)
  156. return nil
  157. }
  158. func (c *client) onSend(msg []byte) (err error) {
  159. defer func() {
  160. if e := recover(); e != nil {
  161. err = ErrBrokenPipe
  162. }
  163. }()
  164. if !c.isConnected.Load() {
  165. return ErrBrokenPipe
  166. }
  167. if len(c.send) >= maxSendBuffer {
  168. return ErrBufferPoolExceed
  169. }
  170. if len(msg) > maxFrameMessageLen {
  171. return
  172. }
  173. c.send <- msg
  174. return nil
  175. }
  176. func (c *client) Connect(requestHeader http.Header) error {
  177. c.requestHeader = requestHeader
  178. err := c.connect()
  179. if err != nil {
  180. return err
  181. }
  182. go c.ping()
  183. go c.writeLoop()
  184. go c.reconnect()
  185. return nil
  186. }
  187. func (c *client) Send(router string, data any) error {
  188. rb, err := c.routerCodec.Marshal(router, data, nil)
  189. if err != nil {
  190. return fmt.Errorf("service: %v", err)
  191. }
  192. return c.onSend(rb)
  193. }
  194. func (c *client) OnReceiveError(f func(error)) {
  195. c.onReceiveErr = f
  196. }
  197. func (c *client) OnReconnected(f func(error)) {
  198. c.onReconnect = f
  199. }
  200. func (c *client) Close() {
  201. close(c.send)
  202. c.pingTicker.Stop()
  203. c.checkConnTicker.Stop()
  204. _ = c.session.Conn.Close()
  205. }