123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- package proxy
- import (
- "context"
- "golang.org/x/time/rate"
- "io"
- "net"
- "runtime/debug"
- "sync"
- "time"
- )
- const burstLimit = 1000 * 1000 * 1000
- type Reader struct {
- r io.Reader
- limiter *rate.Limiter
- ctx context.Context
- }
- func NewReader(r io.Reader) *Reader {
- return &Reader{
- r: r,
- ctx: context.Background(),
- }
- }
- func (s *Reader) SetRateLimit(bytesPerSec float64) {
- s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
- s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
- }
- func (s *Reader) Read(p []byte) (int, error) {
- if s.limiter == nil {
- return s.r.Read(p)
- }
- n, err := s.r.Read(p)
- if err != nil {
- return n, err
- }
- if err := s.limiter.WaitN(s.ctx, n); err != nil {
- return n, err
- }
- return n, nil
- }
- func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
- conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
- return
- }
- func CloseConn(conn net.Conn) {
- if conn != nil {
- _ = conn.SetDeadline(time.Now().Add(time.Millisecond))
- _ = conn.Close()
- }
- }
- func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
- var one = &sync.Once{}
- go func() {
- defer func() {
- if e := recover(); e != nil {
- logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
- }
- }()
- var err error
- var isSrcErr bool
- if bytesPreSec > 0 {
- newReader := NewReader(src)
- newReader.SetRateLimit(bytesPreSec)
- _, isSrcErr, err = IoCopy(dst, newReader, func(c int) {
- cfn(c, false)
- })
- } else {
- _, isSrcErr, err = IoCopy(dst, src, func(c int) {
- cfn(c, false)
- })
- }
- if err != nil {
- one.Do(func() {
- fn(isSrcErr, err)
- })
- }
- }()
- go func() {
- defer func() {
- if e := recover(); e != nil {
- logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
- }
- }()
- var err error
- var isSrcErr bool
- if bytesPreSec > 0 {
- newReader := NewReader(dst)
- newReader.SetRateLimit(bytesPreSec)
- _, isSrcErr, err = IoCopy(src, newReader, func(c int) {
- cfn(c, true)
- })
- } else {
- _, isSrcErr, err = IoCopy(src, dst, func(c int) {
- cfn(c, true)
- })
- }
- if err != nil {
- one.Do(func() {
- fn(isSrcErr, err)
- })
- }
- }()
- }
- func IoCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) {
- buf := make([]byte, 32*1024)
- for {
- nr, er := src.Read(buf)
- if nr > 0 {
- nw, ew := dst.Write(buf[0:nr])
- if nw > 0 {
- written += int64(nw)
- if len(fn) == 1 {
- fn[0](nw)
- }
- }
- if ew != nil {
- err = ew
- break
- }
- if nr != nw {
- err = io.ErrShortWrite
- break
- }
- }
- if er != nil {
- err = er
- isSrcErr = true
- break
- }
- }
- return written, isSrcErr, err
- }
|