io.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package proxy
  2. import (
  3. "context"
  4. "golang.org/x/time/rate"
  5. "io"
  6. "net"
  7. "runtime/debug"
  8. "sync"
  9. "time"
  10. )
  11. const burstLimit = 1000 * 1000 * 1000
  12. type Reader struct {
  13. r io.Reader
  14. limiter *rate.Limiter
  15. ctx context.Context
  16. }
  17. func NewReader(r io.Reader) *Reader {
  18. return &Reader{
  19. r: r,
  20. ctx: context.Background(),
  21. }
  22. }
  23. func (s *Reader) SetRateLimit(bytesPerSec float64) {
  24. s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
  25. s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
  26. }
  27. func (s *Reader) Read(p []byte) (int, error) {
  28. if s.limiter == nil {
  29. return s.r.Read(p)
  30. }
  31. n, err := s.r.Read(p)
  32. if err != nil {
  33. return n, err
  34. }
  35. if err := s.limiter.WaitN(s.ctx, n); err != nil {
  36. return n, err
  37. }
  38. return n, nil
  39. }
  40. func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
  41. conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
  42. return
  43. }
  44. func CloseConn(conn net.Conn) {
  45. if conn != nil {
  46. _ = conn.SetDeadline(time.Now().Add(time.Millisecond))
  47. _ = conn.Close()
  48. }
  49. }
  50. func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
  51. var one = &sync.Once{}
  52. go func() {
  53. defer func() {
  54. if e := recover(); e != nil {
  55. logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
  56. }
  57. }()
  58. var err error
  59. var isSrcErr bool
  60. if bytesPreSec > 0 {
  61. newReader := NewReader(src)
  62. newReader.SetRateLimit(bytesPreSec)
  63. _, isSrcErr, err = IoCopy(dst, newReader, func(c int) {
  64. cfn(c, false)
  65. })
  66. } else {
  67. _, isSrcErr, err = IoCopy(dst, src, func(c int) {
  68. cfn(c, false)
  69. })
  70. }
  71. if err != nil {
  72. one.Do(func() {
  73. fn(isSrcErr, err)
  74. })
  75. }
  76. }()
  77. go func() {
  78. defer func() {
  79. if e := recover(); e != nil {
  80. logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
  81. }
  82. }()
  83. var err error
  84. var isSrcErr bool
  85. if bytesPreSec > 0 {
  86. newReader := NewReader(dst)
  87. newReader.SetRateLimit(bytesPreSec)
  88. _, isSrcErr, err = IoCopy(src, newReader, func(c int) {
  89. cfn(c, true)
  90. })
  91. } else {
  92. _, isSrcErr, err = IoCopy(src, dst, func(c int) {
  93. cfn(c, true)
  94. })
  95. }
  96. if err != nil {
  97. one.Do(func() {
  98. fn(isSrcErr, err)
  99. })
  100. }
  101. }()
  102. }
  103. func IoCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) {
  104. buf := make([]byte, 32*1024)
  105. for {
  106. nr, er := src.Read(buf)
  107. if nr > 0 {
  108. nw, ew := dst.Write(buf[0:nr])
  109. if nw > 0 {
  110. written += int64(nw)
  111. if len(fn) == 1 {
  112. fn[0](nw)
  113. }
  114. }
  115. if ew != nil {
  116. err = ew
  117. break
  118. }
  119. if nr != nw {
  120. err = io.ErrShortWrite
  121. break
  122. }
  123. }
  124. if er != nil {
  125. err = er
  126. isSrcErr = true
  127. break
  128. }
  129. }
  130. return written, isSrcErr, err
  131. }