package proxy import ( "sync" "time" ) type ConnPool interface { initAutoFill(async bool) (err error) Get() (conn any, err error) Put(conn any) ReleaseAll() Len() (length int) } type netPool struct { connects chan any lock *sync.Mutex config poolConfig } type poolConfig struct { Factory func() (any, error) IsActive func(any) bool Release func(any) InitialCap int MaxCap int } func NewConnPool(poolConfig poolConfig) (pool ConnPool, err error) { p := netPool{ config: poolConfig, connects: make(chan any, poolConfig.MaxCap), lock: &sync.Mutex{}, } if poolConfig.MaxCap > 0 { err = p.initAutoFill(false) if err == nil { _ = p.initAutoFill(true) } } return &p, nil } func (p *netPool) initAutoFill(async bool) (err error) { var worker = func() (err error) { for { if p.Len() <= p.config.InitialCap/2 { p.lock.Lock() errN := 0 for i := 0; i < p.config.InitialCap; i++ { c, factoryErr := p.config.Factory() if factoryErr != nil { errN++ if async { continue } else { p.lock.Unlock() return factoryErr } } select { case p.connects <- c: default: p.config.Release(c) break } if p.Len() >= p.config.InitialCap { break } } if errN > 0 { logger.Sugar().Infof("fill conn pool fail , ERRN:%d", errN) } p.lock.Unlock() } if !async { return } time.Sleep(time.Second * 2) } } if async { go func() { _ = worker() }() } else { err = worker() } return } func (p *netPool) Get() (conn any, err error) { p.lock.Lock() defer p.lock.Unlock() select { case conn = <-p.connects: if p.config.IsActive(conn) { return } p.config.Release(conn) default: conn, err = p.config.Factory() if err != nil { return nil, err } return conn, nil } return } func (p *netPool) Put(conn any) { if conn == nil { return } p.lock.Lock() defer p.lock.Unlock() if !p.config.IsActive(conn) { p.config.Release(conn) } select { case p.connects <- conn: default: p.config.Release(conn) } } func (p *netPool) ReleaseAll() { p.lock.Lock() defer p.lock.Unlock() close(p.connects) for c := range p.connects { p.config.Release(c) } p.connects = make(chan any, p.config.InitialCap) } func (p *netPool) Len() (length int) { return len(p.connects) }