204 lines
4.6 KiB
Go
204 lines
4.6 KiB
Go
package middleware
|
||
|
||
import (
|
||
"errors"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/golang-jwt/jwt/v5"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// JWTClaims JWT 声明
|
||
type JWTClaims struct {
|
||
UserID uint `json:"user_id"`
|
||
Username string `json:"username"`
|
||
TenantID uint `json:"tenant_id"`
|
||
Role string `json:"role"`
|
||
jwt.RegisteredClaims
|
||
}
|
||
|
||
// AuthMiddleware 认证中间件
|
||
func Auth(secretKey string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 获取 Authorization 头
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "缺少认证信息",
|
||
"message": "请提供 Authorization 头"
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 解析 Bearer Token
|
||
parts := strings.SplitN(authHeader, " ", 2)
|
||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "无效的认证格式",
|
||
"message": "请使用格式:Bearer {token}"
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
tokenString := parts[1]
|
||
claims := &JWTClaims{}
|
||
|
||
// 解析 Token
|
||
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
||
// 验证签名算法
|
||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
return nil, errors.New("不支持的签名算法")
|
||
}
|
||
return []byte(secretKey), nil
|
||
})
|
||
|
||
if err != nil || !token.Valid {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "无效的 Token",
|
||
"message": err.Error(),
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 将解析出的用户信息存入上下文
|
||
c.Set("user_id", claims.UserID)
|
||
c.Set("username", claims.Username)
|
||
c.Set("tenant_id", claims.TenantID)
|
||
c.Set("role", claims.Role)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// GenerateToken 生成 JWT Token
|
||
func GenerateToken(secretKey string, userID uint, username string, tenantID uint, role string) (string, error) {
|
||
expTime := time.Now().Add(24 * time.Hour) // 24 小时过期
|
||
|
||
claims := JWTClaims{
|
||
UserID: userID,
|
||
Username: username,
|
||
TenantID: tenantID,
|
||
Role: role,
|
||
RegisteredClaims: jwt.RegisteredClaims{
|
||
ExpiresAt: jwt.NewNumericDate(expTime),
|
||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||
},
|
||
}
|
||
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||
return token.SignedString([]byte(secretKey))
|
||
}
|
||
|
||
// AdminOnly 仅允许管理员访问
|
||
func AdminOnly() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
role, exists := c.Get("role")
|
||
if !exists {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"error": "未授权",
|
||
"message": "请先登录"
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
roleStr, ok := role.(string)
|
||
if !ok {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"error": "角色信息无效",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 检查是否为管理员角色
|
||
isAdminRole := roleStr == "admin" || roleStr == "super_admin" || roleStr == "system_admin"
|
||
if !isAdminRole {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"error": "权限不足",
|
||
"message": "需要管理员权限才能访问此资源"
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// TenantMiddleware 租户中间件(获取当前租户 ID)
|
||
func TenantMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 尝试从路径参数获取租户 ID
|
||
tenantIDStr := c.Param("tenant_id")
|
||
if tenantIDStr == "" {
|
||
// 如果没有路径参数,从用户信息中获取
|
||
if tenantID, exists := c.Get("tenant_id"); exists {
|
||
c.Set("current_tenant_id", tenantID)
|
||
}
|
||
} else {
|
||
// 尝试解析租户 ID
|
||
if tenantID, err := strconv.ParseUint(tenantIDStr, 10, 32); err == nil {
|
||
c.Set("current_tenant_id", uint(tenantID))
|
||
}
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// PermissionCheck 权限检查中间件
|
||
func PermissionCheck(requiredPermissions []string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
permissionSet, exists := c.Get("permissions")
|
||
if !exists {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"error": "未授权",
|
||
"message": "请先登录"
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
userPermissions, ok := permissionSet.([]string)
|
||
if !ok {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"error": "权限信息无效",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 检查是否拥有所有必需权限
|
||
for _, required := range requiredPermissions {
|
||
if !contains(userPermissions, required) {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"error": "权限不足",
|
||
"message": "需要权限:",
|
||
"required": required,
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// contains 检查切片是否包含元素
|
||
func contains(slice []string, item string) bool {
|
||
for _, s := range slice {
|
||
if s == item {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|