package oauth2 import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "time" "giter.top/smart/pkg/utils/id" "gorm.io/gorm" "gorm.io/gorm/clause" ) // ErrNotFound 未找到记录。 var ErrNotFound = errors.New("oauth2: not found") func hashToken(raw string) string { sum := sha256.Sum256([]byte(raw)) return hex.EncodeToString(sum[:]) } // Store OAuth 持久化。 type Store struct { db *gorm.DB } // NewStore 创建 Store。 func NewStore(db *gorm.DB) *Store { return &Store{db: db} } // GetClientByClientID 按 client_id 查客户端。 func (st *Store) GetClientByClientID(ctx context.Context, clientID string) (*OAuthClient, error) { var row OAuthClient err := st.db.WithContext(ctx).Where("client_id = ?", clientID).First(&row).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &row, nil } // ParseRedirectURIs 解析 redirect_uris JSON 数组。 func ParseRedirectURIs(raw string) ([]string, error) { var uris []string if raw == "" { return nil, errors.New("empty redirect_uris") } if err := json.Unmarshal([]byte(raw), &uris); err != nil { return nil, err } return uris, nil } // RedirectURIMatch OAuth 2.1 精确匹配。 func RedirectURIMatch(allowed []string, u string) bool { for _, x := range allowed { if x == u { return true } } return false } // CreateAuthorizationCode 写入授权码(code 明文仅返回给调用方,库存哈希)。 func (st *Store) CreateAuthorizationCode(ctx context.Context, codePlain string, clientID, redirectURI, userID, tenantID, scope, challenge, method string, expiresAt time.Time) error { row := OAuthAuthorizationCode{ ID: id.New(), CodeHash: hashToken(codePlain), ClientID: clientID, RedirectURI: redirectURI, UserID: userID, TenantID: tenantID, Scope: scope, CodeChallenge: challenge, CodeChallengeMethod: method, ExpiresAt: expiresAt, Used: false, CreatedAt: time.Now(), } return st.db.WithContext(ctx).Create(&row).Error } // ConsumeAuthorizationCode 校验并一次性消费授权码,返回行数据供发 token。 func (st *Store) ConsumeAuthorizationCode(ctx context.Context, codePlain, clientID, redirectURI string) (*OAuthAuthorizationCode, error) { h := hashToken(codePlain) var out *OAuthAuthorizationCode err := st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var row OAuthAuthorizationCode if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("code_hash = ?", h).First(&row).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } if row.Used { return ErrNotFound } if time.Now().After(row.ExpiresAt) { return ErrNotFound } if row.ClientID != clientID || row.RedirectURI != redirectURI { return ErrNotFound } if err := tx.Model(&OAuthAuthorizationCode{}).Where("id = ?", row.ID).Update("used", true).Error; err != nil { return err } out = &row return nil }) if err != nil { return nil, err } return out, nil } // TokenPrincipal opaque access token 解析结果。 type TokenPrincipal struct { UserID string TenantID string Scope string } // LookupAccessToken 按明文 access token 查有效记录。 func (st *Store) LookupAccessToken(ctx context.Context, raw string) (*TokenPrincipal, error) { h := hashToken(raw) var row OAuthAccessToken err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } if time.Now().After(row.ExpiresAt) { return nil, ErrNotFound } return &TokenPrincipal{UserID: row.UserID, TenantID: row.TenantID, Scope: row.Scope}, nil } // LookupAccessTokenRow 按明文查 access token 行(自省用)。 func (st *Store) LookupAccessTokenRow(ctx context.Context, raw string) (*OAuthAccessToken, error) { h := hashToken(raw) var row OAuthAccessToken err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } if time.Now().After(row.ExpiresAt) { return nil, ErrNotFound } return &row, nil } // LookupRefreshTokenRow 按明文查 refresh token 行(自省用)。 func (st *Store) LookupRefreshTokenRow(ctx context.Context, raw string) (*OAuthRefreshToken, error) { h := hashToken(raw) var row OAuthRefreshToken err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } if time.Now().After(row.ExpiresAt) { return nil, ErrNotFound } return &row, nil } // IssueAccessAndRefresh 写入 access + refresh(opaque 明文仅调用方返回给客户端)。 func (st *Store) IssueAccessAndRefresh(ctx context.Context, accessPlain, refreshPlain, clientID, userID, tenantID, scope string, accessTTL, refreshTTL time.Duration) error { now := time.Now() accessID := id.New() refreshID := id.New() at := OAuthAccessToken{ ID: accessID, TokenHash: hashToken(accessPlain), ClientID: clientID, UserID: userID, TenantID: tenantID, Scope: scope, ExpiresAt: now.Add(accessTTL), CreatedAt: now, } rt := OAuthRefreshToken{ ID: refreshID, TokenHash: hashToken(refreshPlain), AccessTokenID: accessID, ClientID: clientID, UserID: userID, TenantID: tenantID, Scope: scope, ExpiresAt: now.Add(refreshTTL), CreatedAt: now, } return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Create(&at).Error; err != nil { return err } return tx.Create(&rt).Error }) } // RotateByRefreshToken 使用 refresh 换发新 access+refresh,旧令牌作废;client_id 须与注册一致。 func (st *Store) RotateByRefreshToken(ctx context.Context, clientID, refreshPlain, newAccessPlain, newRefreshPlain string, accessTTL, refreshTTL time.Duration) error { h := hashToken(refreshPlain) return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var row OAuthRefreshToken if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } if row.ClientID != clientID { return ErrNotFound } if time.Now().After(row.ExpiresAt) { return ErrNotFound } now := time.Now() if err := tx.Model(&OAuthRefreshToken{}).Where("id = ?", row.ID).Update("revoked_at", now).Error; err != nil { return err } if err := tx.Model(&OAuthAccessToken{}).Where("id = ?", row.AccessTokenID).Update("revoked_at", now).Error; err != nil { return err } newAID := id.New() newRID := id.New() at := OAuthAccessToken{ ID: newAID, TokenHash: hashToken(newAccessPlain), ClientID: row.ClientID, UserID: row.UserID, TenantID: row.TenantID, Scope: row.Scope, ExpiresAt: now.Add(accessTTL), CreatedAt: now, } rt := OAuthRefreshToken{ ID: newRID, TokenHash: hashToken(newRefreshPlain), AccessTokenID: newAID, ClientID: row.ClientID, UserID: row.UserID, TenantID: row.TenantID, Scope: row.Scope, ExpiresAt: now.Add(refreshTTL), CreatedAt: now, } if err := tx.Create(&at).Error; err != nil { return err } return tx.Create(&rt).Error }) }