package proxy import ( "context" "golang.org/x/time/rate" "io" "net" "runtime/debug" "sync" "time" ) const burstLimit = 1000 * 1000 * 1000 type Reader struct { r io.Reader limiter *rate.Limiter ctx context.Context } func NewReader(r io.Reader) *Reader { return &Reader{ r: r, ctx: context.Background(), } } func (s *Reader) SetRateLimit(bytesPerSec float64) { s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit) s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst } func (s *Reader) Read(p []byte) (int, error) { if s.limiter == nil { return s.r.Read(p) } n, err := s.r.Read(p) if err != nil { return n, err } if err := s.limiter.WaitN(s.ctx, n); err != nil { return n, err } return n, nil } func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) { conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond) return } func CloseConn(conn net.Conn) { if conn != nil { _ = conn.SetDeadline(time.Now().Add(time.Millisecond)) _ = conn.Close() } } func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) { var one = &sync.Once{} go func() { defer func() { if e := recover(); e != nil { logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() var err error var isSrcErr bool if bytesPreSec > 0 { newReader := NewReader(src) newReader.SetRateLimit(bytesPreSec) _, isSrcErr, err = IoCopy(dst, newReader, func(c int) { cfn(c, false) }) } else { _, isSrcErr, err = IoCopy(dst, src, func(c int) { cfn(c, false) }) } if err != nil { one.Do(func() { fn(isSrcErr, err) }) } }() go func() { defer func() { if e := recover(); e != nil { logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() var err error var isSrcErr bool if bytesPreSec > 0 { newReader := NewReader(dst) newReader.SetRateLimit(bytesPreSec) _, isSrcErr, err = IoCopy(src, newReader, func(c int) { cfn(c, true) }) } else { _, isSrcErr, err = IoCopy(src, dst, func(c int) { cfn(c, true) }) } if err != nil { one.Do(func() { fn(isSrcErr, err) }) } }() } func IoCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) { buf := make([]byte, 32*1024) for { nr, er := src.Read(buf) if nr > 0 { nw, ew := dst.Write(buf[0:nr]) if nw > 0 { written += int64(nw) if len(fn) == 1 { fn[0](nw) } } if ew != nil { err = ew break } if nr != nw { err = io.ErrShortWrite break } } if er != nil { err = er isSrcErr = true break } } return written, isSrcErr, err }