package oauth2 import ( "context" "errors" "log/slog" "net/http" "net/url" "strings" "time" "giter.top/smart/internal/auth/session" "giter.top/smart/pkg/config" "giter.top/smart/pkg/security" "github.com/gin-gonic/gin" ) // Service OAuth2 授权码 + PKCE + opaque token。 type Service struct { cfg *config.Config store *Store sess *session.Store } // NewService 构造。 func NewService(cfg *config.Config, store *Store, sess *session.Store) *Service { return &Service{cfg: cfg, store: store, sess: sess} } func (s *Service) durations() (authCode, access, refresh time.Duration) { authCode = s.cfg.Auth.OAuth2.AuthCodeTTL if authCode == 0 { authCode = 120 * time.Second } access = s.cfg.Auth.OAuth2.AccessTokenTTL if access == 0 { access = 15 * time.Minute } refresh = s.cfg.Auth.OAuth2.RefreshTokenTTL if refresh == 0 { refresh = 720 * time.Hour } return authCode, access, refresh } // IssueAuthorizationCodeAfterPasswordAuth 在已通过用户名密码校验的上下文中签发 PKCE 绑定授权码(与 Authorize 中 CreateAuthorizationCode 一致)。 func (s *Service) IssueAuthorizationCodeAfterPasswordAuth(ctx context.Context, clientID, redirectURI, userID, tenantID, scope, codeChallenge, challengeMethod string) (codePlain string, err error) { if scope == "" { scope = "openid" } method := NormalizeCodeChallengeMethod(challengeMethod) if codeChallenge == "" || method != "s256" { return "", ErrPKCERequired } cli, err := s.store.GetClientByClientID(ctx, clientID) if err != nil { if errors.Is(err, ErrNotFound) { return "", ErrInvalidClient } return "", err } uris, err := ParseRedirectURIs(cli.RedirectURIsJSON) if err != nil || !RedirectURIMatch(uris, redirectURI) { return "", ErrInvalidRedirectURI } codePlain, err = security.RandomURLSafe(32) if err != nil { return "", err } codeTTL, _, _ := s.durations() exp := time.Now().Add(codeTTL) if err := s.store.CreateAuthorizationCode(ctx, codePlain, clientID, redirectURI, userID, tenantID, scope, codeChallenge, "S256", exp); err != nil { return "", err } return codePlain, nil } func (s *Service) publicAuthorizeURL(c *gin.Context) string { base := strings.TrimRight(s.cfg.Auth.PublicBaseURL, "/") if base == "" { scheme := "http" if c.Request.TLS != nil { scheme = "https" } if xf := c.GetHeader("X-Forwarded-Proto"); xf == "https" { scheme = "https" } base = scheme + "://" + c.Request.Host } return base + "/oauth/authorize?" + c.Request.URL.RawQuery } // Authorize GET /oauth/authorize func (s *Service) Authorize(c *gin.Context) { q := c.Request.URL.Query() if q.Get("response_type") != "code" { c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_response_type"}) return } clientID := q.Get("client_id") redirectURI := q.Get("redirect_uri") state := q.Get("state") scope := q.Get("scope") if scope == "" { scope = "openid" } challenge := q.Get("code_challenge") method := NormalizeCodeChallengeMethod(q.Get("code_challenge_method")) if challenge == "" || method != "s256" { s.redirectOAuthError(c, redirectURI, state, "invalid_request", "code_challenge and code_challenge_method=S256 required") return } cli, err := s.store.GetClientByClientID(c.Request.Context(), clientID) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_client"}) return } uris, err := ParseRedirectURIs(cli.RedirectURIsJSON) if err != nil || !RedirectURIMatch(uris, redirectURI) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_redirect_uri"}) return } sid, err := c.Cookie(s.cfg.Auth.Session.CookieName) if err != nil || sid == "" { login := strings.TrimRight(s.cfg.Auth.OAuth2.FrontendLoginURL, "?") ret := s.publicAuthorizeURL(c) u, _ := url.Parse(login) q2 := u.Query() q2.Set("return_to", ret) u.RawQuery = q2.Encode() c.Redirect(http.StatusFound, u.String()) return } userID, tenantID, err := s.sess.Get(c.Request.Context(), sid) if err != nil { login := strings.TrimRight(s.cfg.Auth.OAuth2.FrontendLoginURL, "?") ret := s.publicAuthorizeURL(c) u, _ := url.Parse(login) q2 := u.Query() q2.Set("return_to", ret) u.RawQuery = q2.Encode() c.Redirect(http.StatusFound, u.String()) return } codePlain, err := security.RandomURLSafe(32) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"}) return } codeTTL, _, _ := s.durations() exp := time.Now().Add(codeTTL) if err := s.store.CreateAuthorizationCode(c.Request.Context(), codePlain, clientID, redirectURI, userID, tenantID, scope, challenge, "S256", exp); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"}) return } redir, err := url.Parse(redirectURI) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_redirect_uri"}) return } rq := redir.Query() rq.Set("code", codePlain) if state != "" { rq.Set("state", state) } redir.RawQuery = rq.Encode() c.Redirect(http.StatusFound, redir.String()) } func (s *Service) redirectOAuthError(c *gin.Context, redirectURI, state, errCode, desc string) { if redirectURI == "" { c.JSON(http.StatusBadRequest, gin.H{"error": errCode, "error_description": desc}) return } u, e := url.Parse(redirectURI) if e != nil { c.JSON(http.StatusBadRequest, gin.H{"error": errCode}) return } q := u.Query() q.Set("error", errCode) q.Set("error_description", desc) if state != "" { q.Set("state", state) } u.RawQuery = q.Encode() c.Redirect(http.StatusFound, u.String()) } // Token POST /oauth/token func (s *Service) Token(c *gin.Context) { if err := c.Request.ParseForm(); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"}) return } gt := c.PostForm("grant_type") switch gt { case "authorization_code": s.tokenAuthorizationCode(c) case "refresh_token": s.tokenRefresh(c) default: c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_grant_type"}) } } func (s *Service) tokenAuthorizationCode(c *gin.Context) { code := c.PostForm("code") redirectURI := c.PostForm("redirect_uri") clientID := c.PostForm("client_id") verifier := c.PostForm("code_verifier") if code == "" || redirectURI == "" || clientID == "" || verifier == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"}) return } row, err := s.store.ConsumeAuthorizationCode(c.Request.Context(), code, clientID, redirectURI) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"}) return } if !VerifyPKCES256(verifier, row.CodeChallenge) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "pkce verification failed"}) return } _, accessTTL, refreshTTL := s.durations() accessPlain, err := security.RandomURLSafe(32) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"}) return } refreshPlain, err := security.RandomURLSafe(48) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"}) return } if err := s.store.IssueAccessAndRefresh(c.Request.Context(), accessPlain, refreshPlain, clientID, row.UserID, row.TenantID, row.Scope, accessTTL, refreshTTL); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"}) return } slog.Info("oauth2_token_issued", "grant_type", "authorization_code", "client_id", clientID, "user_id", row.UserID, "tenant_id", row.TenantID, "client_ip", c.ClientIP()) s.jsonAccessToken(c, accessPlain, refreshPlain, accessTTL) } func (s *Service) tokenRefresh(c *gin.Context) { refresh := c.PostForm("refresh_token") clientID := c.PostForm("client_id") if refresh == "" || clientID == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"}) return } _, accessTTL, refreshTTL := s.durations() newAccess, err := security.RandomURLSafe(32) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"}) return } newRefresh, err := security.RandomURLSafe(48) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"}) return } if err := s.store.RotateByRefreshToken(c.Request.Context(), clientID, refresh, newAccess, newRefresh, accessTTL, refreshTTL); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"}) return } slog.Info("oauth2_token_issued", "grant_type", "refresh_token", "client_id", clientID, "client_ip", c.ClientIP()) s.jsonAccessToken(c, newAccess, newRefresh, accessTTL) } func (s *Service) jsonAccessToken(c *gin.Context, access, refresh string, accessTTL time.Duration) { c.JSON(http.StatusOK, gin.H{ "access_token": access, "token_type": "Bearer", "expires_in": int(accessTTL.Seconds()), "refresh_token": refresh, }) } // Introspect POST /oauth/introspect(RFC 7662),与 opaque 查表语义一致。 func (s *Service) Introspect(c *gin.Context) { if err := c.Request.ParseForm(); err != nil { c.JSON(http.StatusBadRequest, gin.H{"active": false}) return } tok := c.PostForm("token") hint := strings.TrimSpace(c.PostForm("token_type_hint")) if tok == "" { c.JSON(http.StatusOK, gin.H{"active": false}) return } ctx := c.Request.Context() tryRefreshFirst := hint == "refresh_token" if tryRefreshFirst { if row, err := s.store.LookupRefreshTokenRow(ctx, tok); err == nil { slog.Info("oauth2_introspect", "active", true, "token_type", "refresh_token", "client_id", row.ClientID, "sub", row.UserID) c.JSON(http.StatusOK, gin.H{ "active": true, "scope": row.Scope, "client_id": row.ClientID, "token_type": "refresh_token", "sub": row.UserID, "exp": row.ExpiresAt.Unix(), }) return } c.JSON(http.StatusOK, gin.H{"active": false}) return } if row, err := s.store.LookupAccessTokenRow(ctx, tok); err == nil { slog.Info("oauth2_introspect", "active", true, "token_type", "access_token", "client_id", row.ClientID, "sub", row.UserID) c.JSON(http.StatusOK, gin.H{ "active": true, "scope": row.Scope, "client_id": row.ClientID, "token_type": "access_token", "sub": row.UserID, "exp": row.ExpiresAt.Unix(), }) return } if row, err := s.store.LookupRefreshTokenRow(ctx, tok); err == nil { slog.Info("oauth2_introspect", "active", true, "token_type", "refresh_token", "client_id", row.ClientID, "sub", row.UserID) c.JSON(http.StatusOK, gin.H{ "active": true, "scope": row.Scope, "client_id": row.ClientID, "token_type": "refresh_token", "sub": row.UserID, "exp": row.ExpiresAt.Unix(), }) return } c.JSON(http.StatusOK, gin.H{"active": false}) }