342 lines
10 KiB
Go
342 lines
10 KiB
Go
package oauth2
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"log/slog"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"time"
|
||
|
||
"giter.top/smart/internal/auth/session"
|
||
"giter.top/smart/pkg/config"
|
||
"giter.top/smart/pkg/security"
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// Service OAuth2 授权码 + PKCE + opaque token。
|
||
type Service struct {
|
||
cfg *config.Config
|
||
store *Store
|
||
sess *session.Store
|
||
}
|
||
|
||
// NewService 构造。
|
||
func NewService(cfg *config.Config, store *Store, sess *session.Store) *Service {
|
||
return &Service{cfg: cfg, store: store, sess: sess}
|
||
}
|
||
|
||
func (s *Service) durations() (authCode, access, refresh time.Duration) {
|
||
authCode = s.cfg.Auth.OAuth2.AuthCodeTTL
|
||
if authCode == 0 {
|
||
authCode = 120 * time.Second
|
||
}
|
||
access = s.cfg.Auth.OAuth2.AccessTokenTTL
|
||
if access == 0 {
|
||
access = 15 * time.Minute
|
||
}
|
||
refresh = s.cfg.Auth.OAuth2.RefreshTokenTTL
|
||
if refresh == 0 {
|
||
refresh = 720 * time.Hour
|
||
}
|
||
return authCode, access, refresh
|
||
}
|
||
|
||
// IssueAuthorizationCodeAfterPasswordAuth 在已通过用户名密码校验的上下文中签发 PKCE 绑定授权码(与 Authorize 中 CreateAuthorizationCode 一致)。
|
||
func (s *Service) IssueAuthorizationCodeAfterPasswordAuth(ctx context.Context, clientID, redirectURI, userID, tenantID, scope, codeChallenge, challengeMethod string) (codePlain string, err error) {
|
||
if scope == "" {
|
||
scope = "openid"
|
||
}
|
||
method := NormalizeCodeChallengeMethod(challengeMethod)
|
||
if codeChallenge == "" || method != "s256" {
|
||
return "", ErrPKCERequired
|
||
}
|
||
cli, err := s.store.GetClientByClientID(ctx, clientID)
|
||
if err != nil {
|
||
if errors.Is(err, ErrNotFound) {
|
||
return "", ErrInvalidClient
|
||
}
|
||
return "", err
|
||
}
|
||
uris, err := ParseRedirectURIs(cli.RedirectURIsJSON)
|
||
if err != nil || !RedirectURIMatch(uris, redirectURI) {
|
||
return "", ErrInvalidRedirectURI
|
||
}
|
||
codePlain, err = security.RandomURLSafe(32)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
codeTTL, _, _ := s.durations()
|
||
exp := time.Now().Add(codeTTL)
|
||
if err := s.store.CreateAuthorizationCode(ctx, codePlain, clientID, redirectURI, userID, tenantID, scope, codeChallenge, "S256", exp); err != nil {
|
||
return "", err
|
||
}
|
||
return codePlain, nil
|
||
}
|
||
|
||
func (s *Service) publicAuthorizeURL(c *gin.Context) string {
|
||
base := strings.TrimRight(s.cfg.Auth.PublicBaseURL, "/")
|
||
if base == "" {
|
||
scheme := "http"
|
||
if c.Request.TLS != nil {
|
||
scheme = "https"
|
||
}
|
||
if xf := c.GetHeader("X-Forwarded-Proto"); xf == "https" {
|
||
scheme = "https"
|
||
}
|
||
base = scheme + "://" + c.Request.Host
|
||
}
|
||
return base + "/oauth/authorize?" + c.Request.URL.RawQuery
|
||
}
|
||
|
||
// Authorize GET /oauth/authorize
|
||
func (s *Service) Authorize(c *gin.Context) {
|
||
q := c.Request.URL.Query()
|
||
if q.Get("response_type") != "code" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_response_type"})
|
||
return
|
||
}
|
||
clientID := q.Get("client_id")
|
||
redirectURI := q.Get("redirect_uri")
|
||
state := q.Get("state")
|
||
scope := q.Get("scope")
|
||
if scope == "" {
|
||
scope = "openid"
|
||
}
|
||
challenge := q.Get("code_challenge")
|
||
method := NormalizeCodeChallengeMethod(q.Get("code_challenge_method"))
|
||
if challenge == "" || method != "s256" {
|
||
s.redirectOAuthError(c, redirectURI, state, "invalid_request", "code_challenge and code_challenge_method=S256 required")
|
||
return
|
||
}
|
||
|
||
cli, err := s.store.GetClientByClientID(c.Request.Context(), clientID)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_client"})
|
||
return
|
||
}
|
||
uris, err := ParseRedirectURIs(cli.RedirectURIsJSON)
|
||
if err != nil || !RedirectURIMatch(uris, redirectURI) {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_redirect_uri"})
|
||
return
|
||
}
|
||
|
||
sid, err := c.Cookie(s.cfg.Auth.Session.CookieName)
|
||
if err != nil || sid == "" {
|
||
login := strings.TrimRight(s.cfg.Auth.OAuth2.FrontendLoginURL, "?")
|
||
ret := s.publicAuthorizeURL(c)
|
||
u, _ := url.Parse(login)
|
||
q2 := u.Query()
|
||
q2.Set("return_to", ret)
|
||
u.RawQuery = q2.Encode()
|
||
c.Redirect(http.StatusFound, u.String())
|
||
return
|
||
}
|
||
userID, tenantID, err := s.sess.Get(c.Request.Context(), sid)
|
||
if err != nil {
|
||
login := strings.TrimRight(s.cfg.Auth.OAuth2.FrontendLoginURL, "?")
|
||
ret := s.publicAuthorizeURL(c)
|
||
u, _ := url.Parse(login)
|
||
q2 := u.Query()
|
||
q2.Set("return_to", ret)
|
||
u.RawQuery = q2.Encode()
|
||
c.Redirect(http.StatusFound, u.String())
|
||
return
|
||
}
|
||
|
||
codePlain, err := security.RandomURLSafe(32)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||
return
|
||
}
|
||
codeTTL, _, _ := s.durations()
|
||
exp := time.Now().Add(codeTTL)
|
||
if err := s.store.CreateAuthorizationCode(c.Request.Context(), codePlain, clientID, redirectURI, userID, tenantID, scope, challenge, "S256", exp); err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||
return
|
||
}
|
||
redir, err := url.Parse(redirectURI)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_redirect_uri"})
|
||
return
|
||
}
|
||
rq := redir.Query()
|
||
rq.Set("code", codePlain)
|
||
if state != "" {
|
||
rq.Set("state", state)
|
||
}
|
||
redir.RawQuery = rq.Encode()
|
||
c.Redirect(http.StatusFound, redir.String())
|
||
}
|
||
|
||
func (s *Service) redirectOAuthError(c *gin.Context, redirectURI, state, errCode, desc string) {
|
||
if redirectURI == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": errCode, "error_description": desc})
|
||
return
|
||
}
|
||
u, e := url.Parse(redirectURI)
|
||
if e != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": errCode})
|
||
return
|
||
}
|
||
q := u.Query()
|
||
q.Set("error", errCode)
|
||
q.Set("error_description", desc)
|
||
if state != "" {
|
||
q.Set("state", state)
|
||
}
|
||
u.RawQuery = q.Encode()
|
||
c.Redirect(http.StatusFound, u.String())
|
||
}
|
||
|
||
// Token POST /oauth/token
|
||
func (s *Service) Token(c *gin.Context) {
|
||
if err := c.Request.ParseForm(); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||
return
|
||
}
|
||
gt := c.PostForm("grant_type")
|
||
switch gt {
|
||
case "authorization_code":
|
||
s.tokenAuthorizationCode(c)
|
||
case "refresh_token":
|
||
s.tokenRefresh(c)
|
||
default:
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_grant_type"})
|
||
}
|
||
}
|
||
|
||
func (s *Service) tokenAuthorizationCode(c *gin.Context) {
|
||
code := c.PostForm("code")
|
||
redirectURI := c.PostForm("redirect_uri")
|
||
clientID := c.PostForm("client_id")
|
||
verifier := c.PostForm("code_verifier")
|
||
if code == "" || redirectURI == "" || clientID == "" || verifier == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||
return
|
||
}
|
||
row, err := s.store.ConsumeAuthorizationCode(c.Request.Context(), code, clientID, redirectURI)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
|
||
return
|
||
}
|
||
if !VerifyPKCES256(verifier, row.CodeChallenge) {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "pkce verification failed"})
|
||
return
|
||
}
|
||
_, accessTTL, refreshTTL := s.durations()
|
||
accessPlain, err := security.RandomURLSafe(32)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||
return
|
||
}
|
||
refreshPlain, err := security.RandomURLSafe(48)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||
return
|
||
}
|
||
if err := s.store.IssueAccessAndRefresh(c.Request.Context(), accessPlain, refreshPlain, clientID, row.UserID, row.TenantID, row.Scope, accessTTL, refreshTTL); err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||
return
|
||
}
|
||
slog.Info("oauth2_token_issued", "grant_type", "authorization_code", "client_id", clientID, "user_id", row.UserID, "tenant_id", row.TenantID, "client_ip", c.ClientIP())
|
||
s.jsonAccessToken(c, accessPlain, refreshPlain, accessTTL)
|
||
}
|
||
|
||
func (s *Service) tokenRefresh(c *gin.Context) {
|
||
refresh := c.PostForm("refresh_token")
|
||
clientID := c.PostForm("client_id")
|
||
if refresh == "" || clientID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||
return
|
||
}
|
||
_, accessTTL, refreshTTL := s.durations()
|
||
newAccess, err := security.RandomURLSafe(32)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||
return
|
||
}
|
||
newRefresh, err := security.RandomURLSafe(48)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||
return
|
||
}
|
||
if err := s.store.RotateByRefreshToken(c.Request.Context(), clientID, refresh, newAccess, newRefresh, accessTTL, refreshTTL); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
|
||
return
|
||
}
|
||
slog.Info("oauth2_token_issued", "grant_type", "refresh_token", "client_id", clientID, "client_ip", c.ClientIP())
|
||
s.jsonAccessToken(c, newAccess, newRefresh, accessTTL)
|
||
}
|
||
|
||
func (s *Service) jsonAccessToken(c *gin.Context, access, refresh string, accessTTL time.Duration) {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"access_token": access,
|
||
"token_type": "Bearer",
|
||
"expires_in": int(accessTTL.Seconds()),
|
||
"refresh_token": refresh,
|
||
})
|
||
}
|
||
|
||
// Introspect POST /oauth/introspect(RFC 7662),与 opaque 查表语义一致。
|
||
func (s *Service) Introspect(c *gin.Context) {
|
||
if err := c.Request.ParseForm(); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"active": false})
|
||
return
|
||
}
|
||
tok := c.PostForm("token")
|
||
hint := strings.TrimSpace(c.PostForm("token_type_hint"))
|
||
if tok == "" {
|
||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||
return
|
||
}
|
||
ctx := c.Request.Context()
|
||
|
||
tryRefreshFirst := hint == "refresh_token"
|
||
if tryRefreshFirst {
|
||
if row, err := s.store.LookupRefreshTokenRow(ctx, tok); err == nil {
|
||
slog.Info("oauth2_introspect", "active", true, "token_type", "refresh_token", "client_id", row.ClientID, "sub", row.UserID)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"active": true,
|
||
"scope": row.Scope,
|
||
"client_id": row.ClientID,
|
||
"token_type": "refresh_token",
|
||
"sub": row.UserID,
|
||
"exp": row.ExpiresAt.Unix(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||
return
|
||
}
|
||
|
||
if row, err := s.store.LookupAccessTokenRow(ctx, tok); err == nil {
|
||
slog.Info("oauth2_introspect", "active", true, "token_type", "access_token", "client_id", row.ClientID, "sub", row.UserID)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"active": true,
|
||
"scope": row.Scope,
|
||
"client_id": row.ClientID,
|
||
"token_type": "access_token",
|
||
"sub": row.UserID,
|
||
"exp": row.ExpiresAt.Unix(),
|
||
})
|
||
return
|
||
}
|
||
|
||
if row, err := s.store.LookupRefreshTokenRow(ctx, tok); err == nil {
|
||
slog.Info("oauth2_introspect", "active", true, "token_type", "refresh_token", "client_id", row.ClientID, "sub", row.UserID)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"active": true,
|
||
"scope": row.Scope,
|
||
"client_id": row.ClientID,
|
||
"token_type": "refresh_token",
|
||
"sub": row.UserID,
|
||
"exp": row.ExpiresAt.Unix(),
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||
}
|