Parcourir la source

[add] IPRateLimiter

bvbej il y a 2 ans
Parent
commit
4037baf560
1 fichiers modifiés avec 12 ajouts et 7 suppressions
  1. 12 7
      pkg/limiter/ip.go

+ 12 - 7
pkg/limiter/ip.go

@@ -3,26 +3,26 @@ package limiter
 import (
 	"golang.org/x/time/rate"
 	"sync"
-	"time"
 )
 
 var _ IPRateLimiter = (*ipRateLimiter)(nil)
 
 type IPRateLimiter interface {
 	addIP(ip string) *rate.Limiter
-	GetLimiter(ip string) *rate.Limiter
+	getLimiter(ip string) *rate.Limiter
+	Allow(ip string) bool
 }
 
 type ipRateLimiter struct {
-	ips   sync.Map
+	ips   *sync.Map
 	limit rate.Limit
 	burst int
 }
 
-func NewIPRateLimiter(limit time.Duration, burst int) IPRateLimiter {
+func NewIPRateLimiter(limit int, burst int) IPRateLimiter {
 	return &ipRateLimiter{
-		ips:   sync.Map{},
-		limit: rate.Every(limit),
+		ips:   new(sync.Map),
+		limit: rate.Limit(limit),
 		burst: burst,
 	}
 }
@@ -32,10 +32,15 @@ func (i *ipRateLimiter) addIP(ip string) *rate.Limiter {
 	return store.(*rate.Limiter)
 }
 
-func (i *ipRateLimiter) GetLimiter(ip string) *rate.Limiter {
+func (i *ipRateLimiter) getLimiter(ip string) *rate.Limiter {
 	limiter, exists := i.ips.Load(ip)
 	if !exists {
 		return i.addIP(ip)
 	}
 	return limiter.(*rate.Limiter)
 }
+
+func (i *ipRateLimiter) Allow(ip string) bool {
+	limiter := i.getLimiter(ip)
+	return limiter.Allow()
+}