58 lines
1.3 KiB
Go
58 lines
1.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// PerIPMinute 按客户端 IP 的固定窗口速率(每分钟 perMinute 次,burst 取 perMinute 与 64 的较小值)。
|
|
// 进程内 map 可能随 IP 数增长,多实例部署请在网关侧限流。
|
|
func PerIPMinute(enabled bool, perMinute int) gin.HandlerFunc {
|
|
if !enabled || perMinute <= 0 {
|
|
return func(c *gin.Context) { c.Next() }
|
|
}
|
|
burst := perMinute
|
|
if burst > 64 {
|
|
burst = 64
|
|
}
|
|
if burst < 5 {
|
|
burst = 5
|
|
}
|
|
lim := rate.Limit(float64(perMinute) / 60.0)
|
|
var mu sync.Mutex
|
|
limiters := make(map[string]*rate.Limiter)
|
|
return func(c *gin.Context) {
|
|
ip := clientIP(c)
|
|
mu.Lock()
|
|
limiter, ok := limiters[ip]
|
|
if !ok {
|
|
limiter = rate.NewLimiter(lim, burst)
|
|
limiters[ip] = limiter
|
|
}
|
|
mu.Unlock()
|
|
if !limiter.Allow() {
|
|
c.AbortWithStatusJSON(429, gin.H{"error": "rate_limit_exceeded"})
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func clientIP(c *gin.Context) string {
|
|
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
|
parts := strings.Split(xff, ",")
|
|
if len(parts) > 0 {
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
}
|
|
host, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
|
|
if err != nil {
|
|
return c.Request.RemoteAddr
|
|
}
|
|
return host
|
|
}
|