|
@@ -3,26 +3,26 @@ package limiter
|
|
|
import (
|
|
|
"golang.org/x/time/rate"
|
|
|
"sync"
|
|
|
- "time"
|
|
|
)
|
|
|
|
|
|
var _ IPRateLimiter = (*ipRateLimiter)(nil)
|
|
|
|
|
|
type IPRateLimiter interface {
|
|
|
addIP(ip string) *rate.Limiter
|
|
|
- GetLimiter(ip string) *rate.Limiter
|
|
|
+ getLimiter(ip string) *rate.Limiter
|
|
|
+ Allow(ip string) bool
|
|
|
}
|
|
|
|
|
|
type ipRateLimiter struct {
|
|
|
- ips sync.Map
|
|
|
+ ips *sync.Map
|
|
|
limit rate.Limit
|
|
|
burst int
|
|
|
}
|
|
|
|
|
|
-func NewIPRateLimiter(limit time.Duration, burst int) IPRateLimiter {
|
|
|
+func NewIPRateLimiter(limit int, burst int) IPRateLimiter {
|
|
|
return &ipRateLimiter{
|
|
|
- ips: sync.Map{},
|
|
|
- limit: rate.Every(limit),
|
|
|
+ ips: new(sync.Map),
|
|
|
+ limit: rate.Limit(limit),
|
|
|
burst: burst,
|
|
|
}
|
|
|
}
|
|
@@ -32,10 +32,15 @@ func (i *ipRateLimiter) addIP(ip string) *rate.Limiter {
|
|
|
return store.(*rate.Limiter)
|
|
|
}
|
|
|
|
|
|
-func (i *ipRateLimiter) GetLimiter(ip string) *rate.Limiter {
|
|
|
+func (i *ipRateLimiter) getLimiter(ip string) *rate.Limiter {
|
|
|
limiter, exists := i.ips.Load(ip)
|
|
|
if !exists {
|
|
|
return i.addIP(ip)
|
|
|
}
|
|
|
return limiter.(*rate.Limiter)
|
|
|
}
|
|
|
+
|
|
|
+func (i *ipRateLimiter) Allow(ip string) bool {
|
|
|
+ limiter := i.getLimiter(ip)
|
|
|
+ return limiter.Allow()
|
|
|
+}
|