Files
2026-04-23 18:58:13 +08:00

266 lines
7.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package oauth2
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"time"
"giter.top/smart/pkg/utils/id"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// ErrNotFound 未找到记录。
var ErrNotFound = errors.New("oauth2: not found")
func hashToken(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
// Store OAuth 持久化。
type Store struct {
db *gorm.DB
}
// NewStore 创建 Store。
func NewStore(db *gorm.DB) *Store {
return &Store{db: db}
}
// GetClientByClientID 按 client_id 查客户端。
func (st *Store) GetClientByClientID(ctx context.Context, clientID string) (*OAuthClient, error) {
var row OAuthClient
err := st.db.WithContext(ctx).Where("client_id = ?", clientID).First(&row).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &row, nil
}
// ParseRedirectURIs 解析 redirect_uris JSON 数组。
func ParseRedirectURIs(raw string) ([]string, error) {
var uris []string
if raw == "" {
return nil, errors.New("empty redirect_uris")
}
if err := json.Unmarshal([]byte(raw), &uris); err != nil {
return nil, err
}
return uris, nil
}
// RedirectURIMatch OAuth 2.1 精确匹配。
func RedirectURIMatch(allowed []string, u string) bool {
for _, x := range allowed {
if x == u {
return true
}
}
return false
}
// CreateAuthorizationCode 写入授权码(code 明文仅返回给调用方,库存哈希)。
func (st *Store) CreateAuthorizationCode(ctx context.Context, codePlain string, clientID, redirectURI, userID, tenantID, scope, challenge, method string, expiresAt time.Time) error {
row := OAuthAuthorizationCode{
ID: id.New(),
CodeHash: hashToken(codePlain),
ClientID: clientID,
RedirectURI: redirectURI,
UserID: userID,
TenantID: tenantID,
Scope: scope,
CodeChallenge: challenge,
CodeChallengeMethod: method,
ExpiresAt: expiresAt,
Used: false,
CreatedAt: time.Now(),
}
return st.db.WithContext(ctx).Create(&row).Error
}
// ConsumeAuthorizationCode 校验并一次性消费授权码,返回行数据供发 token。
func (st *Store) ConsumeAuthorizationCode(ctx context.Context, codePlain, clientID, redirectURI string) (*OAuthAuthorizationCode, error) {
h := hashToken(codePlain)
var out *OAuthAuthorizationCode
err := st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var row OAuthAuthorizationCode
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("code_hash = ?", h).First(&row).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
if row.Used {
return ErrNotFound
}
if time.Now().After(row.ExpiresAt) {
return ErrNotFound
}
if row.ClientID != clientID || row.RedirectURI != redirectURI {
return ErrNotFound
}
if err := tx.Model(&OAuthAuthorizationCode{}).Where("id = ?", row.ID).Update("used", true).Error; err != nil {
return err
}
out = &row
return nil
})
if err != nil {
return nil, err
}
return out, nil
}
// TokenPrincipal opaque access token 解析结果。
type TokenPrincipal struct {
UserID string
TenantID string
Scope string
}
// LookupAccessToken 按明文 access token 查有效记录。
func (st *Store) LookupAccessToken(ctx context.Context, raw string) (*TokenPrincipal, error) {
h := hashToken(raw)
var row OAuthAccessToken
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
if time.Now().After(row.ExpiresAt) {
return nil, ErrNotFound
}
return &TokenPrincipal{UserID: row.UserID, TenantID: row.TenantID, Scope: row.Scope}, nil
}
// LookupAccessTokenRow 按明文查 access token 行(自省用)。
func (st *Store) LookupAccessTokenRow(ctx context.Context, raw string) (*OAuthAccessToken, error) {
h := hashToken(raw)
var row OAuthAccessToken
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
if time.Now().After(row.ExpiresAt) {
return nil, ErrNotFound
}
return &row, nil
}
// LookupRefreshTokenRow 按明文查 refresh token 行(自省用)。
func (st *Store) LookupRefreshTokenRow(ctx context.Context, raw string) (*OAuthRefreshToken, error) {
h := hashToken(raw)
var row OAuthRefreshToken
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
if time.Now().After(row.ExpiresAt) {
return nil, ErrNotFound
}
return &row, nil
}
// IssueAccessAndRefresh 写入 access + refreshopaque 明文仅调用方返回给客户端)。
func (st *Store) IssueAccessAndRefresh(ctx context.Context, accessPlain, refreshPlain, clientID, userID, tenantID, scope string, accessTTL, refreshTTL time.Duration) error {
now := time.Now()
accessID := id.New()
refreshID := id.New()
at := OAuthAccessToken{
ID: accessID,
TokenHash: hashToken(accessPlain),
ClientID: clientID,
UserID: userID,
TenantID: tenantID,
Scope: scope,
ExpiresAt: now.Add(accessTTL),
CreatedAt: now,
}
rt := OAuthRefreshToken{
ID: refreshID,
TokenHash: hashToken(refreshPlain),
AccessTokenID: accessID,
ClientID: clientID,
UserID: userID,
TenantID: tenantID,
Scope: scope,
ExpiresAt: now.Add(refreshTTL),
CreatedAt: now,
}
return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&at).Error; err != nil {
return err
}
return tx.Create(&rt).Error
})
}
// RotateByRefreshToken 使用 refresh 换发新 access+refresh,旧令牌作废;client_id 须与注册一致。
func (st *Store) RotateByRefreshToken(ctx context.Context, clientID, refreshPlain, newAccessPlain, newRefreshPlain string, accessTTL, refreshTTL time.Duration) error {
h := hashToken(refreshPlain)
return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var row OAuthRefreshToken
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
if row.ClientID != clientID {
return ErrNotFound
}
if time.Now().After(row.ExpiresAt) {
return ErrNotFound
}
now := time.Now()
if err := tx.Model(&OAuthRefreshToken{}).Where("id = ?", row.ID).Update("revoked_at", now).Error; err != nil {
return err
}
if err := tx.Model(&OAuthAccessToken{}).Where("id = ?", row.AccessTokenID).Update("revoked_at", now).Error; err != nil {
return err
}
newAID := id.New()
newRID := id.New()
at := OAuthAccessToken{
ID: newAID,
TokenHash: hashToken(newAccessPlain),
ClientID: row.ClientID,
UserID: row.UserID,
TenantID: row.TenantID,
Scope: row.Scope,
ExpiresAt: now.Add(accessTTL),
CreatedAt: now,
}
rt := OAuthRefreshToken{
ID: newRID,
TokenHash: hashToken(newRefreshPlain),
AccessTokenID: newAID,
ClientID: row.ClientID,
UserID: row.UserID,
TenantID: row.TenantID,
Scope: row.Scope,
ExpiresAt: now.Add(refreshTTL),
CreatedAt: now,
}
if err := tx.Create(&at).Error; err != nil {
return err
}
return tx.Create(&rt).Error
})
}