feat: 优化web
This commit is contained in:
@@ -0,0 +1,265 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ErrNotFound 未找到记录。
|
||||
var ErrNotFound = errors.New("oauth2: not found")
|
||||
|
||||
func hashToken(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// Store OAuth 持久化。
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewStore 创建 Store。
|
||||
func NewStore(db *gorm.DB) *Store {
|
||||
return &Store{db: db}
|
||||
}
|
||||
|
||||
// GetClientByClientID 按 client_id 查客户端。
|
||||
func (st *Store) GetClientByClientID(ctx context.Context, clientID string) (*OAuthClient, error) {
|
||||
var row OAuthClient
|
||||
err := st.db.WithContext(ctx).Where("client_id = ?", clientID).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// ParseRedirectURIs 解析 redirect_uris JSON 数组。
|
||||
func ParseRedirectURIs(raw string) ([]string, error) {
|
||||
var uris []string
|
||||
if raw == "" {
|
||||
return nil, errors.New("empty redirect_uris")
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &uris); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return uris, nil
|
||||
}
|
||||
|
||||
// RedirectURIMatch OAuth 2.1 精确匹配。
|
||||
func RedirectURIMatch(allowed []string, u string) bool {
|
||||
for _, x := range allowed {
|
||||
if x == u {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateAuthorizationCode 写入授权码(code 明文仅返回给调用方,库存哈希)。
|
||||
func (st *Store) CreateAuthorizationCode(ctx context.Context, codePlain string, clientID, redirectURI, userID, tenantID, scope, challenge, method string, expiresAt time.Time) error {
|
||||
row := OAuthAuthorizationCode{
|
||||
ID: id.New(),
|
||||
CodeHash: hashToken(codePlain),
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Scope: scope,
|
||||
CodeChallenge: challenge,
|
||||
CodeChallengeMethod: method,
|
||||
ExpiresAt: expiresAt,
|
||||
Used: false,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
return st.db.WithContext(ctx).Create(&row).Error
|
||||
}
|
||||
|
||||
// ConsumeAuthorizationCode 校验并一次性消费授权码,返回行数据供发 token。
|
||||
func (st *Store) ConsumeAuthorizationCode(ctx context.Context, codePlain, clientID, redirectURI string) (*OAuthAuthorizationCode, error) {
|
||||
h := hashToken(codePlain)
|
||||
var out *OAuthAuthorizationCode
|
||||
err := st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var row OAuthAuthorizationCode
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("code_hash = ?", h).First(&row).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if row.Used {
|
||||
return ErrNotFound
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return ErrNotFound
|
||||
}
|
||||
if row.ClientID != clientID || row.RedirectURI != redirectURI {
|
||||
return ErrNotFound
|
||||
}
|
||||
if err := tx.Model(&OAuthAuthorizationCode{}).Where("id = ?", row.ID).Update("used", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
out = &row
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// TokenPrincipal opaque access token 解析结果。
|
||||
type TokenPrincipal struct {
|
||||
UserID string
|
||||
TenantID string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// LookupAccessToken 按明文 access token 查有效记录。
|
||||
func (st *Store) LookupAccessToken(ctx context.Context, raw string) (*TokenPrincipal, error) {
|
||||
h := hashToken(raw)
|
||||
var row OAuthAccessToken
|
||||
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return &TokenPrincipal{UserID: row.UserID, TenantID: row.TenantID, Scope: row.Scope}, nil
|
||||
}
|
||||
|
||||
// LookupAccessTokenRow 按明文查 access token 行(自省用)。
|
||||
func (st *Store) LookupAccessTokenRow(ctx context.Context, raw string) (*OAuthAccessToken, error) {
|
||||
h := hashToken(raw)
|
||||
var row OAuthAccessToken
|
||||
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// LookupRefreshTokenRow 按明文查 refresh token 行(自省用)。
|
||||
func (st *Store) LookupRefreshTokenRow(ctx context.Context, raw string) (*OAuthRefreshToken, error) {
|
||||
h := hashToken(raw)
|
||||
var row OAuthRefreshToken
|
||||
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// IssueAccessAndRefresh 写入 access + refresh(opaque 明文仅调用方返回给客户端)。
|
||||
func (st *Store) IssueAccessAndRefresh(ctx context.Context, accessPlain, refreshPlain, clientID, userID, tenantID, scope string, accessTTL, refreshTTL time.Duration) error {
|
||||
now := time.Now()
|
||||
accessID := id.New()
|
||||
refreshID := id.New()
|
||||
at := OAuthAccessToken{
|
||||
ID: accessID,
|
||||
TokenHash: hashToken(accessPlain),
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Scope: scope,
|
||||
ExpiresAt: now.Add(accessTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
rt := OAuthRefreshToken{
|
||||
ID: refreshID,
|
||||
TokenHash: hashToken(refreshPlain),
|
||||
AccessTokenID: accessID,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Scope: scope,
|
||||
ExpiresAt: now.Add(refreshTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(&at).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Create(&rt).Error
|
||||
})
|
||||
}
|
||||
|
||||
// RotateByRefreshToken 使用 refresh 换发新 access+refresh,旧令牌作废;client_id 须与注册一致。
|
||||
func (st *Store) RotateByRefreshToken(ctx context.Context, clientID, refreshPlain, newAccessPlain, newRefreshPlain string, accessTTL, refreshTTL time.Duration) error {
|
||||
h := hashToken(refreshPlain)
|
||||
return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var row OAuthRefreshToken
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if row.ClientID != clientID {
|
||||
return ErrNotFound
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return ErrNotFound
|
||||
}
|
||||
now := time.Now()
|
||||
if err := tx.Model(&OAuthRefreshToken{}).Where("id = ?", row.ID).Update("revoked_at", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&OAuthAccessToken{}).Where("id = ?", row.AccessTokenID).Update("revoked_at", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
newAID := id.New()
|
||||
newRID := id.New()
|
||||
at := OAuthAccessToken{
|
||||
ID: newAID,
|
||||
TokenHash: hashToken(newAccessPlain),
|
||||
ClientID: row.ClientID,
|
||||
UserID: row.UserID,
|
||||
TenantID: row.TenantID,
|
||||
Scope: row.Scope,
|
||||
ExpiresAt: now.Add(accessTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
rt := OAuthRefreshToken{
|
||||
ID: newRID,
|
||||
TokenHash: hashToken(newRefreshPlain),
|
||||
AccessTokenID: newAID,
|
||||
ClientID: row.ClientID,
|
||||
UserID: row.UserID,
|
||||
TenantID: row.TenantID,
|
||||
Scope: row.Scope,
|
||||
ExpiresAt: now.Add(refreshTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := tx.Create(&at).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Create(&rt).Error
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user