feat: 优化web
This commit is contained in:
@@ -0,0 +1,189 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/internal/auth/oauth2"
|
||||
"giter.top/smart/internal/auth/session"
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
iamrepo "giter.top/smart/internal/iam/repository"
|
||||
"giter.top/smart/pkg/config"
|
||||
"giter.top/smart/pkg/utils/codec"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// LoginHandler JSON 登录:校验密码后签发 OAuth2 授权码(PKCE),并可选下发会话 Cookie(与 /oauth/authorize 兼容)。
|
||||
type LoginHandler struct {
|
||||
cfg *config.Config
|
||||
users iamrepo.UserRepository
|
||||
sess *session.Store
|
||||
oauth *oauth2.Service
|
||||
}
|
||||
|
||||
// NewLoginHandler 构造。
|
||||
func NewLoginHandler(cfg *config.Config, users iamrepo.UserRepository, sess *session.Store, oauth *oauth2.Service) *LoginHandler {
|
||||
return &LoginHandler{cfg: cfg, users: users, sess: sess, oauth: oauth}
|
||||
}
|
||||
|
||||
type loginBody struct {
|
||||
TenantID string `json:"tenant_id"`
|
||||
UserName string `json:"user_name"`
|
||||
Password string `json:"password"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
State string `json:"state"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type apiEnvelope struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Login POST /api/v1/auth/login
|
||||
func (h *LoginHandler) Login(c *gin.Context) {
|
||||
var req loginBody
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "请求参数无效: " + err.Error(), Data: nil})
|
||||
return
|
||||
}
|
||||
if req.UserName == "" || req.Password == "" {
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "缺少 user_name 或 password", Data: nil})
|
||||
return
|
||||
}
|
||||
if req.ClientID == "" || req.RedirectURI == "" || req.CodeChallenge == "" {
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "缺少 client_id、redirect_uri 或 code_challenge", Data: nil})
|
||||
return
|
||||
}
|
||||
if req.CodeChallengeMethod == "" {
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "缺少 code_challenge_method", Data: nil})
|
||||
return
|
||||
}
|
||||
|
||||
tid := req.TenantID
|
||||
if tid == "" {
|
||||
tid = entity.PlatformTenantID
|
||||
}
|
||||
|
||||
u, err := h.users.GetByUserName(c.Request.Context(), tid, req.UserName)
|
||||
if err != nil {
|
||||
slog.Warn("auth_login_failed", "reason", "user_not_found", "tenant_id", tid, "user_name", req.UserName, "client_ip", c.ClientIP())
|
||||
c.JSON(http.StatusUnauthorized, apiEnvelope{Code: 401, Msg: "用户名或密码错误", Data: nil})
|
||||
return
|
||||
}
|
||||
if err := codec.VerifyPassword(req.Password, u.PasswordHash); err != nil {
|
||||
slog.Warn("auth_login_failed", "reason", "bad_password", "tenant_id", tid, "user_name", req.UserName, "client_ip", c.ClientIP())
|
||||
c.JSON(http.StatusUnauthorized, apiEnvelope{Code: 401, Msg: "用户名或密码错误", Data: nil})
|
||||
return
|
||||
}
|
||||
if u.Status != 1 {
|
||||
slog.Warn("auth_login_failed", "reason", "user_disabled", "tenant_id", tid, "user_id", u.ID, "client_ip", c.ClientIP())
|
||||
c.JSON(http.StatusForbidden, apiEnvelope{Code: 403, Msg: "用户已禁用", Data: nil})
|
||||
return
|
||||
}
|
||||
|
||||
codePlain, err := h.oauth.IssueAuthorizationCodeAfterPasswordAuth(
|
||||
c.Request.Context(),
|
||||
req.ClientID,
|
||||
req.RedirectURI,
|
||||
u.ID,
|
||||
u.TenantID,
|
||||
req.Scope,
|
||||
req.CodeChallenge,
|
||||
req.CodeChallengeMethod,
|
||||
)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, oauth2.ErrInvalidClient):
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "无效的 client_id", Data: nil})
|
||||
return
|
||||
case errors.Is(err, oauth2.ErrInvalidRedirectURI):
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "redirect_uri 与客户端登记不一致", Data: nil})
|
||||
return
|
||||
case errors.Is(err, oauth2.ErrPKCERequired):
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "code_challenge 或 code_challenge_method 无效(需 S256)", Data: nil})
|
||||
return
|
||||
default:
|
||||
slog.Error("auth_login_issue_code", "err", err)
|
||||
c.JSON(http.StatusInternalServerError, apiEnvelope{Code: 500, Msg: "服务器错误", Data: nil})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sid, err := h.sess.Create(c.Request.Context(), u.ID, u.TenantID)
|
||||
if err != nil {
|
||||
slog.Error("auth_login_session", "err", err)
|
||||
c.JSON(http.StatusInternalServerError, apiEnvelope{Code: 500, Msg: "会话创建失败", Data: nil})
|
||||
return
|
||||
}
|
||||
h.setSessionCookie(c, sid)
|
||||
|
||||
data := gin.H{
|
||||
"authorization_code": codePlain,
|
||||
}
|
||||
if req.State != "" {
|
||||
data["state"] = req.State
|
||||
}
|
||||
|
||||
slog.Info("auth_login_ok", "tenant_id", u.TenantID, "user_id", u.ID, "user_name", req.UserName, "client_ip", c.ClientIP())
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 200, Msg: "操作成功", Data: data})
|
||||
}
|
||||
|
||||
// Logout POST /api/v1/auth/logout
|
||||
func (h *LoginHandler) Logout(c *gin.Context) {
|
||||
sid, err := c.Cookie(h.cfg.Auth.Session.CookieName)
|
||||
if err == nil && sid != "" {
|
||||
_ = h.sess.Delete(c.Request.Context(), sid)
|
||||
}
|
||||
h.clearSessionCookie(c)
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 200, Msg: "操作成功", Data: nil})
|
||||
}
|
||||
|
||||
func (h *LoginHandler) setSessionCookie(c *gin.Context, sid string) {
|
||||
same := sameSite(h.cfg.Auth.Session.SameSite)
|
||||
ttl := h.cfg.Auth.Session.TTL
|
||||
if ttl == 0 {
|
||||
ttl = 168 * time.Hour
|
||||
}
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: h.cfg.Auth.Session.CookieName,
|
||||
Value: sid,
|
||||
Path: "/",
|
||||
Domain: h.cfg.Auth.Session.CookieDomain,
|
||||
MaxAge: int(ttl.Seconds()),
|
||||
Secure: h.cfg.Auth.Session.CookieSecure,
|
||||
HttpOnly: true,
|
||||
SameSite: same,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *LoginHandler) clearSessionCookie(c *gin.Context) {
|
||||
same := sameSite(h.cfg.Auth.Session.SameSite)
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: h.cfg.Auth.Session.CookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: h.cfg.Auth.Session.CookieDomain,
|
||||
MaxAge: -1,
|
||||
Secure: h.cfg.Auth.Session.CookieSecure,
|
||||
HttpOnly: true,
|
||||
SameSite: same,
|
||||
})
|
||||
}
|
||||
|
||||
func sameSite(s string) http.SameSite {
|
||||
switch s {
|
||||
case "strict":
|
||||
return http.SameSiteStrictMode
|
||||
case "none":
|
||||
return http.SameSiteNoneMode
|
||||
default:
|
||||
return http.SameSiteLaxMode
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"giter.top/smart/internal/auth/handler"
|
||||
"giter.top/smart/internal/auth/oauth2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthRoutes 认证相关 HTTP(OAuth2、登录)。
|
||||
type AuthRoutes struct {
|
||||
bearer gin.HandlerFunc
|
||||
loginRL gin.HandlerFunc
|
||||
tokenRL gin.HandlerFunc
|
||||
oauthH *oauth2.Handler
|
||||
loginH *handler.LoginHandler
|
||||
}
|
||||
|
||||
// NewAuthRoutes 构造(loginRL/tokenRL 使用 Wire 专用类型,见 wire_provider.go)。
|
||||
func NewAuthRoutes(bearer gin.HandlerFunc, loginRL LoginRateLimitWire, tokenRL TokenRateLimitWire, oauthH *oauth2.Handler, loginH *handler.LoginHandler) *AuthRoutes {
|
||||
return &AuthRoutes{
|
||||
bearer: bearer,
|
||||
loginRL: gin.HandlerFunc(loginRL),
|
||||
tokenRL: gin.HandlerFunc(tokenRL),
|
||||
oauthH: oauthH,
|
||||
loginH: loginH,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 实现 server.HttpRoutes:OAuth 在根路径,/api/v1 挂 Bearer 与登录。
|
||||
func (r *AuthRoutes) Register(engine *gin.Engine, apiGroup *gin.RouterGroup) {
|
||||
apiGroup.Use(r.bearer)
|
||||
apiGroup.POST("/auth/login", r.loginRL, r.loginH.Login)
|
||||
apiGroup.POST("/auth/logout", r.loginH.Logout)
|
||||
|
||||
engine.GET("/oauth/authorize", r.oauthH.Authorize)
|
||||
engine.POST("/oauth/token", r.tokenRL, r.oauthH.Token)
|
||||
engine.POST("/oauth/introspect", r.tokenRL, r.oauthH.Introspect)
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"giter.top/smart/internal/auth/oauth2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Context keys for auth principal
|
||||
const (
|
||||
CtxUserID = "auth_user_id"
|
||||
CtxTenantID = "auth_tenant_id"
|
||||
CtxScope = "auth_scope"
|
||||
)
|
||||
|
||||
// NewBearer 解析 opaque Bearer access_token,写入上下文;无 Bearer 或无效时继续放行(兼容未迁移接口)。
|
||||
func NewBearer(store *oauth2.Store) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
h := c.GetHeader("Authorization")
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(h, prefix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(h, prefix))
|
||||
if raw == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
p, err := store.LookupAccessToken(c.Request.Context(), raw)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
c.Set(CtxUserID, p.UserID)
|
||||
c.Set(CtxTenantID, p.TenantID)
|
||||
c.Set(CtxScope, p.Scope)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// PerIPMinute 按客户端 IP 的固定窗口速率(每分钟 perMinute 次,burst 取 perMinute 与 64 的较小值)。
|
||||
// 进程内 map 可能随 IP 数增长,多实例部署请在网关侧限流。
|
||||
func PerIPMinute(enabled bool, perMinute int) gin.HandlerFunc {
|
||||
if !enabled || perMinute <= 0 {
|
||||
return func(c *gin.Context) { c.Next() }
|
||||
}
|
||||
burst := perMinute
|
||||
if burst > 64 {
|
||||
burst = 64
|
||||
}
|
||||
if burst < 5 {
|
||||
burst = 5
|
||||
}
|
||||
lim := rate.Limit(float64(perMinute) / 60.0)
|
||||
var mu sync.Mutex
|
||||
limiters := make(map[string]*rate.Limiter)
|
||||
return func(c *gin.Context) {
|
||||
ip := clientIP(c)
|
||||
mu.Lock()
|
||||
limiter, ok := limiters[ip]
|
||||
if !ok {
|
||||
limiter = rate.NewLimiter(lim, burst)
|
||||
limiters[ip] = limiter
|
||||
}
|
||||
mu.Unlock()
|
||||
if !limiter.Allow() {
|
||||
c.AbortWithStatusJSON(429, gin.H{"error": "rate_limit_exceeded"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func clientIP(c *gin.Context) string {
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
if len(parts) > 0 {
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
host, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
|
||||
if err != nil {
|
||||
return c.Request.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package oauth2
|
||||
|
||||
import "errors"
|
||||
|
||||
// JSON 登录签发授权码时与 Authorize 对齐校验。
|
||||
var (
|
||||
ErrInvalidClient = errors.New("oauth2: invalid client_id")
|
||||
ErrInvalidRedirectURI = errors.New("oauth2: invalid redirect_uri")
|
||||
ErrPKCERequired = errors.New("oauth2: invalid code_challenge or code_challenge_method")
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
package oauth2
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// Handler 绑定 Gin 与 Service。
|
||||
type Handler struct {
|
||||
svc *Service
|
||||
}
|
||||
|
||||
// NewHandler 构造。
|
||||
func NewHandler(svc *Service) *Handler {
|
||||
return &Handler{svc: svc}
|
||||
}
|
||||
|
||||
// Authorize GET /oauth/authorize
|
||||
func (h *Handler) Authorize(c *gin.Context) {
|
||||
h.svc.Authorize(c)
|
||||
}
|
||||
|
||||
// Token POST /oauth/token
|
||||
func (h *Handler) Token(c *gin.Context) {
|
||||
h.svc.Token(c)
|
||||
}
|
||||
|
||||
// Introspect POST /oauth/introspect
|
||||
func (h *Handler) Introspect(c *gin.Context) {
|
||||
h.svc.Introspect(c)
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package oauth2
|
||||
|
||||
import "time"
|
||||
|
||||
// OAuthClient oauth_client
|
||||
type OAuthClient struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)"`
|
||||
ClientID string `gorm:"size:64;not null;uniqueIndex"`
|
||||
ClientSecretHash *string `gorm:"size:255"`
|
||||
RedirectURIsJSON string `gorm:"column:redirect_uris;type:text;not null"`
|
||||
IsPublic bool `gorm:"not null;default:true"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (OAuthClient) TableName() string { return "oauth_client" }
|
||||
|
||||
// OAuthAuthorizationCode oauth_authorization_code
|
||||
type OAuthAuthorizationCode struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)"`
|
||||
CodeHash string `gorm:"size:64;not null;uniqueIndex"`
|
||||
ClientID string `gorm:"size:64;not null"`
|
||||
RedirectURI string `gorm:"type:text;not null"`
|
||||
UserID string `gorm:"size:36;not null"`
|
||||
TenantID string `gorm:"size:36;not null"`
|
||||
Scope string `gorm:"type:text;not null"`
|
||||
CodeChallenge string `gorm:"size:128;not null"`
|
||||
CodeChallengeMethod string `gorm:"size:16;not null"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
Used bool `gorm:"not null;default:false"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (OAuthAuthorizationCode) TableName() string { return "oauth_authorization_code" }
|
||||
|
||||
// OAuthAccessToken oauth_access_token
|
||||
type OAuthAccessToken struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)"`
|
||||
TokenHash string `gorm:"size:64;not null;uniqueIndex"`
|
||||
ClientID string `gorm:"size:64;not null"`
|
||||
UserID string `gorm:"size:36;not null"`
|
||||
TenantID string `gorm:"size:36;not null"`
|
||||
Scope string `gorm:"type:text;not null"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
RevokedAt *time.Time `gorm:""`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (OAuthAccessToken) TableName() string { return "oauth_access_token" }
|
||||
|
||||
// OAuthRefreshToken oauth_refresh_token
|
||||
type OAuthRefreshToken struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)"`
|
||||
TokenHash string `gorm:"size:64;not null;uniqueIndex"`
|
||||
AccessTokenID string `gorm:"size:36;not null;index"`
|
||||
ClientID string `gorm:"size:64;not null"`
|
||||
UserID string `gorm:"size:36;not null"`
|
||||
TenantID string `gorm:"size:36;not null"`
|
||||
Scope string `gorm:"type:text;not null"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
RevokedAt *time.Time `gorm:""`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (OAuthRefreshToken) TableName() string { return "oauth_refresh_token" }
|
||||
@@ -0,0 +1,23 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// VerifyPKCES256 校验 code_verifier 是否与 code_challenge(S256)一致。
|
||||
func VerifyPKCES256(codeVerifier, codeChallenge string) bool {
|
||||
if codeVerifier == "" || codeChallenge == "" {
|
||||
return false
|
||||
}
|
||||
sum := sha256.Sum256([]byte(codeVerifier))
|
||||
expected := base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return subtle.ConstantTimeCompare([]byte(expected), []byte(codeChallenge)) == 1
|
||||
}
|
||||
|
||||
// NormalizeCodeChallengeMethod 返回小写方法名;仅支持 S256(OAuth 2.1 推荐)。
|
||||
func NormalizeCodeChallengeMethod(m string) string {
|
||||
return strings.TrimSpace(strings.ToLower(m))
|
||||
}
|
||||
@@ -0,0 +1,341 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/internal/auth/session"
|
||||
"giter.top/smart/pkg/config"
|
||||
"giter.top/smart/pkg/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Service OAuth2 授权码 + PKCE + opaque token。
|
||||
type Service struct {
|
||||
cfg *config.Config
|
||||
store *Store
|
||||
sess *session.Store
|
||||
}
|
||||
|
||||
// NewService 构造。
|
||||
func NewService(cfg *config.Config, store *Store, sess *session.Store) *Service {
|
||||
return &Service{cfg: cfg, store: store, sess: sess}
|
||||
}
|
||||
|
||||
func (s *Service) durations() (authCode, access, refresh time.Duration) {
|
||||
authCode = s.cfg.Auth.OAuth2.AuthCodeTTL
|
||||
if authCode == 0 {
|
||||
authCode = 120 * time.Second
|
||||
}
|
||||
access = s.cfg.Auth.OAuth2.AccessTokenTTL
|
||||
if access == 0 {
|
||||
access = 15 * time.Minute
|
||||
}
|
||||
refresh = s.cfg.Auth.OAuth2.RefreshTokenTTL
|
||||
if refresh == 0 {
|
||||
refresh = 720 * time.Hour
|
||||
}
|
||||
return authCode, access, refresh
|
||||
}
|
||||
|
||||
// IssueAuthorizationCodeAfterPasswordAuth 在已通过用户名密码校验的上下文中签发 PKCE 绑定授权码(与 Authorize 中 CreateAuthorizationCode 一致)。
|
||||
func (s *Service) IssueAuthorizationCodeAfterPasswordAuth(ctx context.Context, clientID, redirectURI, userID, tenantID, scope, codeChallenge, challengeMethod string) (codePlain string, err error) {
|
||||
if scope == "" {
|
||||
scope = "openid"
|
||||
}
|
||||
method := NormalizeCodeChallengeMethod(challengeMethod)
|
||||
if codeChallenge == "" || method != "s256" {
|
||||
return "", ErrPKCERequired
|
||||
}
|
||||
cli, err := s.store.GetClientByClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
return "", ErrInvalidClient
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
uris, err := ParseRedirectURIs(cli.RedirectURIsJSON)
|
||||
if err != nil || !RedirectURIMatch(uris, redirectURI) {
|
||||
return "", ErrInvalidRedirectURI
|
||||
}
|
||||
codePlain, err = security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
codeTTL, _, _ := s.durations()
|
||||
exp := time.Now().Add(codeTTL)
|
||||
if err := s.store.CreateAuthorizationCode(ctx, codePlain, clientID, redirectURI, userID, tenantID, scope, codeChallenge, "S256", exp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return codePlain, nil
|
||||
}
|
||||
|
||||
func (s *Service) publicAuthorizeURL(c *gin.Context) string {
|
||||
base := strings.TrimRight(s.cfg.Auth.PublicBaseURL, "/")
|
||||
if base == "" {
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if xf := c.GetHeader("X-Forwarded-Proto"); xf == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
base = scheme + "://" + c.Request.Host
|
||||
}
|
||||
return base + "/oauth/authorize?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// Authorize GET /oauth/authorize
|
||||
func (s *Service) Authorize(c *gin.Context) {
|
||||
q := c.Request.URL.Query()
|
||||
if q.Get("response_type") != "code" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_response_type"})
|
||||
return
|
||||
}
|
||||
clientID := q.Get("client_id")
|
||||
redirectURI := q.Get("redirect_uri")
|
||||
state := q.Get("state")
|
||||
scope := q.Get("scope")
|
||||
if scope == "" {
|
||||
scope = "openid"
|
||||
}
|
||||
challenge := q.Get("code_challenge")
|
||||
method := NormalizeCodeChallengeMethod(q.Get("code_challenge_method"))
|
||||
if challenge == "" || method != "s256" {
|
||||
s.redirectOAuthError(c, redirectURI, state, "invalid_request", "code_challenge and code_challenge_method=S256 required")
|
||||
return
|
||||
}
|
||||
|
||||
cli, err := s.store.GetClientByClientID(c.Request.Context(), clientID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_client"})
|
||||
return
|
||||
}
|
||||
uris, err := ParseRedirectURIs(cli.RedirectURIsJSON)
|
||||
if err != nil || !RedirectURIMatch(uris, redirectURI) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_redirect_uri"})
|
||||
return
|
||||
}
|
||||
|
||||
sid, err := c.Cookie(s.cfg.Auth.Session.CookieName)
|
||||
if err != nil || sid == "" {
|
||||
login := strings.TrimRight(s.cfg.Auth.OAuth2.FrontendLoginURL, "?")
|
||||
ret := s.publicAuthorizeURL(c)
|
||||
u, _ := url.Parse(login)
|
||||
q2 := u.Query()
|
||||
q2.Set("return_to", ret)
|
||||
u.RawQuery = q2.Encode()
|
||||
c.Redirect(http.StatusFound, u.String())
|
||||
return
|
||||
}
|
||||
userID, tenantID, err := s.sess.Get(c.Request.Context(), sid)
|
||||
if err != nil {
|
||||
login := strings.TrimRight(s.cfg.Auth.OAuth2.FrontendLoginURL, "?")
|
||||
ret := s.publicAuthorizeURL(c)
|
||||
u, _ := url.Parse(login)
|
||||
q2 := u.Query()
|
||||
q2.Set("return_to", ret)
|
||||
u.RawQuery = q2.Encode()
|
||||
c.Redirect(http.StatusFound, u.String())
|
||||
return
|
||||
}
|
||||
|
||||
codePlain, err := security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
codeTTL, _, _ := s.durations()
|
||||
exp := time.Now().Add(codeTTL)
|
||||
if err := s.store.CreateAuthorizationCode(c.Request.Context(), codePlain, clientID, redirectURI, userID, tenantID, scope, challenge, "S256", exp); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
redir, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_redirect_uri"})
|
||||
return
|
||||
}
|
||||
rq := redir.Query()
|
||||
rq.Set("code", codePlain)
|
||||
if state != "" {
|
||||
rq.Set("state", state)
|
||||
}
|
||||
redir.RawQuery = rq.Encode()
|
||||
c.Redirect(http.StatusFound, redir.String())
|
||||
}
|
||||
|
||||
func (s *Service) redirectOAuthError(c *gin.Context, redirectURI, state, errCode, desc string) {
|
||||
if redirectURI == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errCode, "error_description": desc})
|
||||
return
|
||||
}
|
||||
u, e := url.Parse(redirectURI)
|
||||
if e != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errCode})
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("error", errCode)
|
||||
q.Set("error_description", desc)
|
||||
if state != "" {
|
||||
q.Set("state", state)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
c.Redirect(http.StatusFound, u.String())
|
||||
}
|
||||
|
||||
// Token POST /oauth/token
|
||||
func (s *Service) Token(c *gin.Context) {
|
||||
if err := c.Request.ParseForm(); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||||
return
|
||||
}
|
||||
gt := c.PostForm("grant_type")
|
||||
switch gt {
|
||||
case "authorization_code":
|
||||
s.tokenAuthorizationCode(c)
|
||||
case "refresh_token":
|
||||
s.tokenRefresh(c)
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_grant_type"})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) tokenAuthorizationCode(c *gin.Context) {
|
||||
code := c.PostForm("code")
|
||||
redirectURI := c.PostForm("redirect_uri")
|
||||
clientID := c.PostForm("client_id")
|
||||
verifier := c.PostForm("code_verifier")
|
||||
if code == "" || redirectURI == "" || clientID == "" || verifier == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||||
return
|
||||
}
|
||||
row, err := s.store.ConsumeAuthorizationCode(c.Request.Context(), code, clientID, redirectURI)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
|
||||
return
|
||||
}
|
||||
if !VerifyPKCES256(verifier, row.CodeChallenge) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "pkce verification failed"})
|
||||
return
|
||||
}
|
||||
_, accessTTL, refreshTTL := s.durations()
|
||||
accessPlain, err := security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
refreshPlain, err := security.RandomURLSafe(48)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
if err := s.store.IssueAccessAndRefresh(c.Request.Context(), accessPlain, refreshPlain, clientID, row.UserID, row.TenantID, row.Scope, accessTTL, refreshTTL); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
slog.Info("oauth2_token_issued", "grant_type", "authorization_code", "client_id", clientID, "user_id", row.UserID, "tenant_id", row.TenantID, "client_ip", c.ClientIP())
|
||||
s.jsonAccessToken(c, accessPlain, refreshPlain, accessTTL)
|
||||
}
|
||||
|
||||
func (s *Service) tokenRefresh(c *gin.Context) {
|
||||
refresh := c.PostForm("refresh_token")
|
||||
clientID := c.PostForm("client_id")
|
||||
if refresh == "" || clientID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||||
return
|
||||
}
|
||||
_, accessTTL, refreshTTL := s.durations()
|
||||
newAccess, err := security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
newRefresh, err := security.RandomURLSafe(48)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
if err := s.store.RotateByRefreshToken(c.Request.Context(), clientID, refresh, newAccess, newRefresh, accessTTL, refreshTTL); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
|
||||
return
|
||||
}
|
||||
slog.Info("oauth2_token_issued", "grant_type", "refresh_token", "client_id", clientID, "client_ip", c.ClientIP())
|
||||
s.jsonAccessToken(c, newAccess, newRefresh, accessTTL)
|
||||
}
|
||||
|
||||
func (s *Service) jsonAccessToken(c *gin.Context, access, refresh string, accessTTL time.Duration) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": access,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int(accessTTL.Seconds()),
|
||||
"refresh_token": refresh,
|
||||
})
|
||||
}
|
||||
|
||||
// Introspect POST /oauth/introspect(RFC 7662),与 opaque 查表语义一致。
|
||||
func (s *Service) Introspect(c *gin.Context) {
|
||||
if err := c.Request.ParseForm(); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
tok := c.PostForm("token")
|
||||
hint := strings.TrimSpace(c.PostForm("token_type_hint"))
|
||||
if tok == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tryRefreshFirst := hint == "refresh_token"
|
||||
if tryRefreshFirst {
|
||||
if row, err := s.store.LookupRefreshTokenRow(ctx, tok); err == nil {
|
||||
slog.Info("oauth2_introspect", "active", true, "token_type", "refresh_token", "client_id", row.ClientID, "sub", row.UserID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"active": true,
|
||||
"scope": row.Scope,
|
||||
"client_id": row.ClientID,
|
||||
"token_type": "refresh_token",
|
||||
"sub": row.UserID,
|
||||
"exp": row.ExpiresAt.Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
|
||||
if row, err := s.store.LookupAccessTokenRow(ctx, tok); err == nil {
|
||||
slog.Info("oauth2_introspect", "active", true, "token_type", "access_token", "client_id", row.ClientID, "sub", row.UserID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"active": true,
|
||||
"scope": row.Scope,
|
||||
"client_id": row.ClientID,
|
||||
"token_type": "access_token",
|
||||
"sub": row.UserID,
|
||||
"exp": row.ExpiresAt.Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if row, err := s.store.LookupRefreshTokenRow(ctx, tok); err == nil {
|
||||
slog.Info("oauth2_introspect", "active", true, "token_type", "refresh_token", "client_id", row.ClientID, "sub", row.UserID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"active": true,
|
||||
"scope": row.Scope,
|
||||
"client_id": row.ClientID,
|
||||
"token_type": "refresh_token",
|
||||
"sub": row.UserID,
|
||||
"exp": row.ExpiresAt.Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ErrNotFound 未找到记录。
|
||||
var ErrNotFound = errors.New("oauth2: not found")
|
||||
|
||||
func hashToken(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// Store OAuth 持久化。
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewStore 创建 Store。
|
||||
func NewStore(db *gorm.DB) *Store {
|
||||
return &Store{db: db}
|
||||
}
|
||||
|
||||
// GetClientByClientID 按 client_id 查客户端。
|
||||
func (st *Store) GetClientByClientID(ctx context.Context, clientID string) (*OAuthClient, error) {
|
||||
var row OAuthClient
|
||||
err := st.db.WithContext(ctx).Where("client_id = ?", clientID).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// ParseRedirectURIs 解析 redirect_uris JSON 数组。
|
||||
func ParseRedirectURIs(raw string) ([]string, error) {
|
||||
var uris []string
|
||||
if raw == "" {
|
||||
return nil, errors.New("empty redirect_uris")
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &uris); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return uris, nil
|
||||
}
|
||||
|
||||
// RedirectURIMatch OAuth 2.1 精确匹配。
|
||||
func RedirectURIMatch(allowed []string, u string) bool {
|
||||
for _, x := range allowed {
|
||||
if x == u {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateAuthorizationCode 写入授权码(code 明文仅返回给调用方,库存哈希)。
|
||||
func (st *Store) CreateAuthorizationCode(ctx context.Context, codePlain string, clientID, redirectURI, userID, tenantID, scope, challenge, method string, expiresAt time.Time) error {
|
||||
row := OAuthAuthorizationCode{
|
||||
ID: id.New(),
|
||||
CodeHash: hashToken(codePlain),
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Scope: scope,
|
||||
CodeChallenge: challenge,
|
||||
CodeChallengeMethod: method,
|
||||
ExpiresAt: expiresAt,
|
||||
Used: false,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
return st.db.WithContext(ctx).Create(&row).Error
|
||||
}
|
||||
|
||||
// ConsumeAuthorizationCode 校验并一次性消费授权码,返回行数据供发 token。
|
||||
func (st *Store) ConsumeAuthorizationCode(ctx context.Context, codePlain, clientID, redirectURI string) (*OAuthAuthorizationCode, error) {
|
||||
h := hashToken(codePlain)
|
||||
var out *OAuthAuthorizationCode
|
||||
err := st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var row OAuthAuthorizationCode
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("code_hash = ?", h).First(&row).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if row.Used {
|
||||
return ErrNotFound
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return ErrNotFound
|
||||
}
|
||||
if row.ClientID != clientID || row.RedirectURI != redirectURI {
|
||||
return ErrNotFound
|
||||
}
|
||||
if err := tx.Model(&OAuthAuthorizationCode{}).Where("id = ?", row.ID).Update("used", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
out = &row
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// TokenPrincipal opaque access token 解析结果。
|
||||
type TokenPrincipal struct {
|
||||
UserID string
|
||||
TenantID string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// LookupAccessToken 按明文 access token 查有效记录。
|
||||
func (st *Store) LookupAccessToken(ctx context.Context, raw string) (*TokenPrincipal, error) {
|
||||
h := hashToken(raw)
|
||||
var row OAuthAccessToken
|
||||
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return &TokenPrincipal{UserID: row.UserID, TenantID: row.TenantID, Scope: row.Scope}, nil
|
||||
}
|
||||
|
||||
// LookupAccessTokenRow 按明文查 access token 行(自省用)。
|
||||
func (st *Store) LookupAccessTokenRow(ctx context.Context, raw string) (*OAuthAccessToken, error) {
|
||||
h := hashToken(raw)
|
||||
var row OAuthAccessToken
|
||||
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// LookupRefreshTokenRow 按明文查 refresh token 行(自省用)。
|
||||
func (st *Store) LookupRefreshTokenRow(ctx context.Context, raw string) (*OAuthRefreshToken, error) {
|
||||
h := hashToken(raw)
|
||||
var row OAuthRefreshToken
|
||||
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// IssueAccessAndRefresh 写入 access + refresh(opaque 明文仅调用方返回给客户端)。
|
||||
func (st *Store) IssueAccessAndRefresh(ctx context.Context, accessPlain, refreshPlain, clientID, userID, tenantID, scope string, accessTTL, refreshTTL time.Duration) error {
|
||||
now := time.Now()
|
||||
accessID := id.New()
|
||||
refreshID := id.New()
|
||||
at := OAuthAccessToken{
|
||||
ID: accessID,
|
||||
TokenHash: hashToken(accessPlain),
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Scope: scope,
|
||||
ExpiresAt: now.Add(accessTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
rt := OAuthRefreshToken{
|
||||
ID: refreshID,
|
||||
TokenHash: hashToken(refreshPlain),
|
||||
AccessTokenID: accessID,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Scope: scope,
|
||||
ExpiresAt: now.Add(refreshTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(&at).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Create(&rt).Error
|
||||
})
|
||||
}
|
||||
|
||||
// RotateByRefreshToken 使用 refresh 换发新 access+refresh,旧令牌作废;client_id 须与注册一致。
|
||||
func (st *Store) RotateByRefreshToken(ctx context.Context, clientID, refreshPlain, newAccessPlain, newRefreshPlain string, accessTTL, refreshTTL time.Duration) error {
|
||||
h := hashToken(refreshPlain)
|
||||
return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var row OAuthRefreshToken
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if row.ClientID != clientID {
|
||||
return ErrNotFound
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return ErrNotFound
|
||||
}
|
||||
now := time.Now()
|
||||
if err := tx.Model(&OAuthRefreshToken{}).Where("id = ?", row.ID).Update("revoked_at", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&OAuthAccessToken{}).Where("id = ?", row.AccessTokenID).Update("revoked_at", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
newAID := id.New()
|
||||
newRID := id.New()
|
||||
at := OAuthAccessToken{
|
||||
ID: newAID,
|
||||
TokenHash: hashToken(newAccessPlain),
|
||||
ClientID: row.ClientID,
|
||||
UserID: row.UserID,
|
||||
TenantID: row.TenantID,
|
||||
Scope: row.Scope,
|
||||
ExpiresAt: now.Add(accessTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
rt := OAuthRefreshToken{
|
||||
ID: newRID,
|
||||
TokenHash: hashToken(newRefreshPlain),
|
||||
AccessTokenID: newAID,
|
||||
ClientID: row.ClientID,
|
||||
UserID: row.UserID,
|
||||
TenantID: row.TenantID,
|
||||
Scope: row.Scope,
|
||||
ExpiresAt: now.Add(refreshTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := tx.Create(&at).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Create(&rt).Error
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package scope
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Split 将空格分隔的 scope 拆成列表。
|
||||
func Split(scope string) []string {
|
||||
if strings.TrimSpace(scope) == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Fields(scope)
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Contains 判断 scope 字符串是否包含指定权限标记。
|
||||
func Contains(scope, want string) bool {
|
||||
if want == "" {
|
||||
return true
|
||||
}
|
||||
for _, s := range Split(scope) {
|
||||
if s == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAPIAccess 约定含 `api` 或 `api.*` 前缀即表示可访问业务 API(可与 IAM 菜单权限组合使用)。
|
||||
func HasAPIAccess(scope string) bool {
|
||||
for _, s := range Split(scope) {
|
||||
if s == "api" || strings.HasPrefix(s, "api.") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/pkg/config"
|
||||
"giter.top/smart/pkg/security"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ErrInvalidSession 会话不存在或已过期。
|
||||
var ErrInvalidSession = errors.New("session: invalid or expired")
|
||||
|
||||
const redisKeyPrefix = "auth:sess:"
|
||||
|
||||
type payload struct {
|
||||
UserID string `json:"user_id"`
|
||||
TenantID string `json:"tenant_id"`
|
||||
}
|
||||
|
||||
// Store Redis 会话(供 OAuth authorize 与登出)。
|
||||
type Store struct {
|
||||
rdb redis.UniversalClient
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewStore 创建会话存储。
|
||||
func NewStore(rdb redis.UniversalClient, cfg *config.Config) *Store {
|
||||
return &Store{rdb: rdb, cfg: cfg}
|
||||
}
|
||||
|
||||
func (s *Store) ttl() time.Duration {
|
||||
t := s.cfg.Auth.Session.TTL
|
||||
if t == 0 {
|
||||
return 168 * time.Hour
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// Create 创建会话并返回 session id(写入 Cookie 用)。
|
||||
func (s *Store) Create(ctx context.Context, userID, tenantID string) (sid string, err error) {
|
||||
sid, err = security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b, err := json.Marshal(payload{UserID: userID, TenantID: tenantID})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sid, s.rdb.Set(ctx, redisKeyPrefix+sid, b, s.ttl()).Err()
|
||||
}
|
||||
|
||||
// Get 解析会话。
|
||||
func (s *Store) Get(ctx context.Context, sid string) (userID, tenantID string, err error) {
|
||||
b, err := s.rdb.Get(ctx, redisKeyPrefix+sid).Bytes()
|
||||
if err == redis.Nil {
|
||||
return "", "", ErrInvalidSession
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
var p payload
|
||||
if err := json.Unmarshal(b, &p); err != nil {
|
||||
return "", "", ErrInvalidSession
|
||||
}
|
||||
return p.UserID, p.TenantID, nil
|
||||
}
|
||||
|
||||
// Delete 登出时删除。
|
||||
func (s *Store) Delete(ctx context.Context, sid string) error {
|
||||
return s.rdb.Del(ctx, redisKeyPrefix+sid).Err()
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"giter.top/smart/internal/auth/handler"
|
||||
"giter.top/smart/internal/auth/middleware"
|
||||
"giter.top/smart/internal/auth/oauth2"
|
||||
"giter.top/smart/internal/auth/session"
|
||||
"giter.top/smart/pkg/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// ProviderSet Wire 注入。
|
||||
var ProviderSet = wire.NewSet(
|
||||
session.NewStore,
|
||||
oauth2.NewStore,
|
||||
oauth2.NewService,
|
||||
oauth2.NewHandler,
|
||||
handler.NewLoginHandler,
|
||||
ProvideBearer,
|
||||
ProvideLoginRLimitWire,
|
||||
ProvideTokenRLimitWire,
|
||||
NewAuthRoutes,
|
||||
)
|
||||
|
||||
// ProvideBearer 提供 Gin 中间件。
|
||||
func ProvideBearer(store *oauth2.Store) gin.HandlerFunc {
|
||||
return middleware.NewBearer(store)
|
||||
}
|
||||
|
||||
// LoginRateLimitWire、TokenRateLimitWire 用于 Wire 区分多个 gin.HandlerFunc 形参。
|
||||
type LoginRateLimitWire gin.HandlerFunc
|
||||
type TokenRateLimitWire gin.HandlerFunc
|
||||
|
||||
// ProvideLoginRLimitWire 登录接口限流。
|
||||
func ProvideLoginRLimitWire(cfg *config.Config) LoginRateLimitWire {
|
||||
return LoginRateLimitWire(middleware.PerIPMinute(cfg.Auth.RateLimit.Enabled, cfg.Auth.RateLimit.LoginPerMinute))
|
||||
}
|
||||
|
||||
// ProvideTokenRLimitWire 令牌与自省端点限流。
|
||||
func ProvideTokenRLimitWire(cfg *config.Config) TokenRateLimitWire {
|
||||
return TokenRateLimitWire(middleware.PerIPMinute(cfg.Auth.RateLimit.Enabled, cfg.Auth.RateLimit.TokenPerMinute))
|
||||
}
|
||||
Reference in New Issue
Block a user