Files
smart-customer-service/backend/internal/middleware/auth.go

204 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"net/http"
"strconv"
"strings"
"time"
)
// JWTClaims JWT 声明
type JWTClaims struct {
UserID uint `json:"user_id"`
Username string `json:"username"`
TenantID uint `json:"tenant_id"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// AuthMiddleware 认证中间件
func Auth(secretKey string) gin.HandlerFunc {
return func(c *gin.Context) {
// 获取 Authorization 头
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "缺少认证信息",
"message": "请提供 Authorization 头"
})
c.Abort()
return
}
// 解析 Bearer Token
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "无效的认证格式",
"message": "请使用格式Bearer {token}"
})
c.Abort()
return
}
tokenString := parts[1]
claims := &JWTClaims{}
// 解析 Token
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
// 验证签名算法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("不支持的签名算法")
}
return []byte(secretKey), nil
})
if err != nil || !token.Valid {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "无效的 Token",
"message": err.Error(),
})
c.Abort()
return
}
// 将解析出的用户信息存入上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("tenant_id", claims.TenantID)
c.Set("role", claims.Role)
c.Next()
}
}
// GenerateToken 生成 JWT Token
func GenerateToken(secretKey string, userID uint, username string, tenantID uint, role string) (string, error) {
expTime := time.Now().Add(24 * time.Hour) // 24 小时过期
claims := JWTClaims{
UserID: userID,
Username: username,
TenantID: tenantID,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(secretKey))
}
// AdminOnly 仅允许管理员访问
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists {
c.JSON(http.StatusForbidden, gin.H{
"error": "未授权",
"message": "请先登录"
})
c.Abort()
return
}
roleStr, ok := role.(string)
if !ok {
c.JSON(http.StatusForbidden, gin.H{
"error": "角色信息无效",
})
c.Abort()
return
}
// 检查是否为管理员角色
isAdminRole := roleStr == "admin" || roleStr == "super_admin" || roleStr == "system_admin"
if !isAdminRole {
c.JSON(http.StatusForbidden, gin.H{
"error": "权限不足",
"message": "需要管理员权限才能访问此资源"
})
c.Abort()
return
}
c.Next()
}
}
// TenantMiddleware 租户中间件(获取当前租户 ID
func TenantMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试从路径参数获取租户 ID
tenantIDStr := c.Param("tenant_id")
if tenantIDStr == "" {
// 如果没有路径参数,从用户信息中获取
if tenantID, exists := c.Get("tenant_id"); exists {
c.Set("current_tenant_id", tenantID)
}
} else {
// 尝试解析租户 ID
if tenantID, err := strconv.ParseUint(tenantIDStr, 10, 32); err == nil {
c.Set("current_tenant_id", uint(tenantID))
}
}
c.Next()
}
}
// PermissionCheck 权限检查中间件
func PermissionCheck(requiredPermissions []string) gin.HandlerFunc {
return func(c *gin.Context) {
permissionSet, exists := c.Get("permissions")
if !exists {
c.JSON(http.StatusForbidden, gin.H{
"error": "未授权",
"message": "请先登录"
})
c.Abort()
return
}
userPermissions, ok := permissionSet.([]string)
if !ok {
c.JSON(http.StatusForbidden, gin.H{
"error": "权限信息无效",
})
c.Abort()
return
}
// 检查是否拥有所有必需权限
for _, required := range requiredPermissions {
if !contains(userPermissions, required) {
c.JSON(http.StatusForbidden, gin.H{
"error": "权限不足",
"message": "需要权限:",
"required": required,
})
c.Abort()
return
}
}
c.Next()
}
}
// contains 检查切片是否包含元素
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}