feat: 优化web
This commit is contained in:
@@ -0,0 +1,189 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/internal/auth/oauth2"
|
||||
"giter.top/smart/internal/auth/session"
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
iamrepo "giter.top/smart/internal/iam/repository"
|
||||
"giter.top/smart/pkg/config"
|
||||
"giter.top/smart/pkg/utils/codec"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// LoginHandler JSON 登录:校验密码后签发 OAuth2 授权码(PKCE),并可选下发会话 Cookie(与 /oauth/authorize 兼容)。
|
||||
type LoginHandler struct {
|
||||
cfg *config.Config
|
||||
users iamrepo.UserRepository
|
||||
sess *session.Store
|
||||
oauth *oauth2.Service
|
||||
}
|
||||
|
||||
// NewLoginHandler 构造。
|
||||
func NewLoginHandler(cfg *config.Config, users iamrepo.UserRepository, sess *session.Store, oauth *oauth2.Service) *LoginHandler {
|
||||
return &LoginHandler{cfg: cfg, users: users, sess: sess, oauth: oauth}
|
||||
}
|
||||
|
||||
type loginBody struct {
|
||||
TenantID string `json:"tenant_id"`
|
||||
UserName string `json:"user_name"`
|
||||
Password string `json:"password"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
State string `json:"state"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type apiEnvelope struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Login POST /api/v1/auth/login
|
||||
func (h *LoginHandler) Login(c *gin.Context) {
|
||||
var req loginBody
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "请求参数无效: " + err.Error(), Data: nil})
|
||||
return
|
||||
}
|
||||
if req.UserName == "" || req.Password == "" {
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "缺少 user_name 或 password", Data: nil})
|
||||
return
|
||||
}
|
||||
if req.ClientID == "" || req.RedirectURI == "" || req.CodeChallenge == "" {
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "缺少 client_id、redirect_uri 或 code_challenge", Data: nil})
|
||||
return
|
||||
}
|
||||
if req.CodeChallengeMethod == "" {
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "缺少 code_challenge_method", Data: nil})
|
||||
return
|
||||
}
|
||||
|
||||
tid := req.TenantID
|
||||
if tid == "" {
|
||||
tid = entity.PlatformTenantID
|
||||
}
|
||||
|
||||
u, err := h.users.GetByUserName(c.Request.Context(), tid, req.UserName)
|
||||
if err != nil {
|
||||
slog.Warn("auth_login_failed", "reason", "user_not_found", "tenant_id", tid, "user_name", req.UserName, "client_ip", c.ClientIP())
|
||||
c.JSON(http.StatusUnauthorized, apiEnvelope{Code: 401, Msg: "用户名或密码错误", Data: nil})
|
||||
return
|
||||
}
|
||||
if err := codec.VerifyPassword(req.Password, u.PasswordHash); err != nil {
|
||||
slog.Warn("auth_login_failed", "reason", "bad_password", "tenant_id", tid, "user_name", req.UserName, "client_ip", c.ClientIP())
|
||||
c.JSON(http.StatusUnauthorized, apiEnvelope{Code: 401, Msg: "用户名或密码错误", Data: nil})
|
||||
return
|
||||
}
|
||||
if u.Status != 1 {
|
||||
slog.Warn("auth_login_failed", "reason", "user_disabled", "tenant_id", tid, "user_id", u.ID, "client_ip", c.ClientIP())
|
||||
c.JSON(http.StatusForbidden, apiEnvelope{Code: 403, Msg: "用户已禁用", Data: nil})
|
||||
return
|
||||
}
|
||||
|
||||
codePlain, err := h.oauth.IssueAuthorizationCodeAfterPasswordAuth(
|
||||
c.Request.Context(),
|
||||
req.ClientID,
|
||||
req.RedirectURI,
|
||||
u.ID,
|
||||
u.TenantID,
|
||||
req.Scope,
|
||||
req.CodeChallenge,
|
||||
req.CodeChallengeMethod,
|
||||
)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, oauth2.ErrInvalidClient):
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "无效的 client_id", Data: nil})
|
||||
return
|
||||
case errors.Is(err, oauth2.ErrInvalidRedirectURI):
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "redirect_uri 与客户端登记不一致", Data: nil})
|
||||
return
|
||||
case errors.Is(err, oauth2.ErrPKCERequired):
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 400, Msg: "code_challenge 或 code_challenge_method 无效(需 S256)", Data: nil})
|
||||
return
|
||||
default:
|
||||
slog.Error("auth_login_issue_code", "err", err)
|
||||
c.JSON(http.StatusInternalServerError, apiEnvelope{Code: 500, Msg: "服务器错误", Data: nil})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sid, err := h.sess.Create(c.Request.Context(), u.ID, u.TenantID)
|
||||
if err != nil {
|
||||
slog.Error("auth_login_session", "err", err)
|
||||
c.JSON(http.StatusInternalServerError, apiEnvelope{Code: 500, Msg: "会话创建失败", Data: nil})
|
||||
return
|
||||
}
|
||||
h.setSessionCookie(c, sid)
|
||||
|
||||
data := gin.H{
|
||||
"authorization_code": codePlain,
|
||||
}
|
||||
if req.State != "" {
|
||||
data["state"] = req.State
|
||||
}
|
||||
|
||||
slog.Info("auth_login_ok", "tenant_id", u.TenantID, "user_id", u.ID, "user_name", req.UserName, "client_ip", c.ClientIP())
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 200, Msg: "操作成功", Data: data})
|
||||
}
|
||||
|
||||
// Logout POST /api/v1/auth/logout
|
||||
func (h *LoginHandler) Logout(c *gin.Context) {
|
||||
sid, err := c.Cookie(h.cfg.Auth.Session.CookieName)
|
||||
if err == nil && sid != "" {
|
||||
_ = h.sess.Delete(c.Request.Context(), sid)
|
||||
}
|
||||
h.clearSessionCookie(c)
|
||||
c.JSON(http.StatusOK, apiEnvelope{Code: 200, Msg: "操作成功", Data: nil})
|
||||
}
|
||||
|
||||
func (h *LoginHandler) setSessionCookie(c *gin.Context, sid string) {
|
||||
same := sameSite(h.cfg.Auth.Session.SameSite)
|
||||
ttl := h.cfg.Auth.Session.TTL
|
||||
if ttl == 0 {
|
||||
ttl = 168 * time.Hour
|
||||
}
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: h.cfg.Auth.Session.CookieName,
|
||||
Value: sid,
|
||||
Path: "/",
|
||||
Domain: h.cfg.Auth.Session.CookieDomain,
|
||||
MaxAge: int(ttl.Seconds()),
|
||||
Secure: h.cfg.Auth.Session.CookieSecure,
|
||||
HttpOnly: true,
|
||||
SameSite: same,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *LoginHandler) clearSessionCookie(c *gin.Context) {
|
||||
same := sameSite(h.cfg.Auth.Session.SameSite)
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: h.cfg.Auth.Session.CookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: h.cfg.Auth.Session.CookieDomain,
|
||||
MaxAge: -1,
|
||||
Secure: h.cfg.Auth.Session.CookieSecure,
|
||||
HttpOnly: true,
|
||||
SameSite: same,
|
||||
})
|
||||
}
|
||||
|
||||
func sameSite(s string) http.SameSite {
|
||||
switch s {
|
||||
case "strict":
|
||||
return http.SameSiteStrictMode
|
||||
case "none":
|
||||
return http.SameSiteNoneMode
|
||||
default:
|
||||
return http.SameSiteLaxMode
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"giter.top/smart/internal/auth/handler"
|
||||
"giter.top/smart/internal/auth/oauth2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthRoutes 认证相关 HTTP(OAuth2、登录)。
|
||||
type AuthRoutes struct {
|
||||
bearer gin.HandlerFunc
|
||||
loginRL gin.HandlerFunc
|
||||
tokenRL gin.HandlerFunc
|
||||
oauthH *oauth2.Handler
|
||||
loginH *handler.LoginHandler
|
||||
}
|
||||
|
||||
// NewAuthRoutes 构造(loginRL/tokenRL 使用 Wire 专用类型,见 wire_provider.go)。
|
||||
func NewAuthRoutes(bearer gin.HandlerFunc, loginRL LoginRateLimitWire, tokenRL TokenRateLimitWire, oauthH *oauth2.Handler, loginH *handler.LoginHandler) *AuthRoutes {
|
||||
return &AuthRoutes{
|
||||
bearer: bearer,
|
||||
loginRL: gin.HandlerFunc(loginRL),
|
||||
tokenRL: gin.HandlerFunc(tokenRL),
|
||||
oauthH: oauthH,
|
||||
loginH: loginH,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 实现 server.HttpRoutes:OAuth 在根路径,/api/v1 挂 Bearer 与登录。
|
||||
func (r *AuthRoutes) Register(engine *gin.Engine, apiGroup *gin.RouterGroup) {
|
||||
apiGroup.Use(r.bearer)
|
||||
apiGroup.POST("/auth/login", r.loginRL, r.loginH.Login)
|
||||
apiGroup.POST("/auth/logout", r.loginH.Logout)
|
||||
|
||||
engine.GET("/oauth/authorize", r.oauthH.Authorize)
|
||||
engine.POST("/oauth/token", r.tokenRL, r.oauthH.Token)
|
||||
engine.POST("/oauth/introspect", r.tokenRL, r.oauthH.Introspect)
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"giter.top/smart/internal/auth/oauth2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Context keys for auth principal
|
||||
const (
|
||||
CtxUserID = "auth_user_id"
|
||||
CtxTenantID = "auth_tenant_id"
|
||||
CtxScope = "auth_scope"
|
||||
)
|
||||
|
||||
// NewBearer 解析 opaque Bearer access_token,写入上下文;无 Bearer 或无效时继续放行(兼容未迁移接口)。
|
||||
func NewBearer(store *oauth2.Store) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
h := c.GetHeader("Authorization")
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(h, prefix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(h, prefix))
|
||||
if raw == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
p, err := store.LookupAccessToken(c.Request.Context(), raw)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
c.Set(CtxUserID, p.UserID)
|
||||
c.Set(CtxTenantID, p.TenantID)
|
||||
c.Set(CtxScope, p.Scope)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// PerIPMinute 按客户端 IP 的固定窗口速率(每分钟 perMinute 次,burst 取 perMinute 与 64 的较小值)。
|
||||
// 进程内 map 可能随 IP 数增长,多实例部署请在网关侧限流。
|
||||
func PerIPMinute(enabled bool, perMinute int) gin.HandlerFunc {
|
||||
if !enabled || perMinute <= 0 {
|
||||
return func(c *gin.Context) { c.Next() }
|
||||
}
|
||||
burst := perMinute
|
||||
if burst > 64 {
|
||||
burst = 64
|
||||
}
|
||||
if burst < 5 {
|
||||
burst = 5
|
||||
}
|
||||
lim := rate.Limit(float64(perMinute) / 60.0)
|
||||
var mu sync.Mutex
|
||||
limiters := make(map[string]*rate.Limiter)
|
||||
return func(c *gin.Context) {
|
||||
ip := clientIP(c)
|
||||
mu.Lock()
|
||||
limiter, ok := limiters[ip]
|
||||
if !ok {
|
||||
limiter = rate.NewLimiter(lim, burst)
|
||||
limiters[ip] = limiter
|
||||
}
|
||||
mu.Unlock()
|
||||
if !limiter.Allow() {
|
||||
c.AbortWithStatusJSON(429, gin.H{"error": "rate_limit_exceeded"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func clientIP(c *gin.Context) string {
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
if len(parts) > 0 {
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
host, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
|
||||
if err != nil {
|
||||
return c.Request.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package oauth2
|
||||
|
||||
import "errors"
|
||||
|
||||
// JSON 登录签发授权码时与 Authorize 对齐校验。
|
||||
var (
|
||||
ErrInvalidClient = errors.New("oauth2: invalid client_id")
|
||||
ErrInvalidRedirectURI = errors.New("oauth2: invalid redirect_uri")
|
||||
ErrPKCERequired = errors.New("oauth2: invalid code_challenge or code_challenge_method")
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
package oauth2
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// Handler 绑定 Gin 与 Service。
|
||||
type Handler struct {
|
||||
svc *Service
|
||||
}
|
||||
|
||||
// NewHandler 构造。
|
||||
func NewHandler(svc *Service) *Handler {
|
||||
return &Handler{svc: svc}
|
||||
}
|
||||
|
||||
// Authorize GET /oauth/authorize
|
||||
func (h *Handler) Authorize(c *gin.Context) {
|
||||
h.svc.Authorize(c)
|
||||
}
|
||||
|
||||
// Token POST /oauth/token
|
||||
func (h *Handler) Token(c *gin.Context) {
|
||||
h.svc.Token(c)
|
||||
}
|
||||
|
||||
// Introspect POST /oauth/introspect
|
||||
func (h *Handler) Introspect(c *gin.Context) {
|
||||
h.svc.Introspect(c)
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package oauth2
|
||||
|
||||
import "time"
|
||||
|
||||
// OAuthClient oauth_client
|
||||
type OAuthClient struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)"`
|
||||
ClientID string `gorm:"size:64;not null;uniqueIndex"`
|
||||
ClientSecretHash *string `gorm:"size:255"`
|
||||
RedirectURIsJSON string `gorm:"column:redirect_uris;type:text;not null"`
|
||||
IsPublic bool `gorm:"not null;default:true"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (OAuthClient) TableName() string { return "oauth_client" }
|
||||
|
||||
// OAuthAuthorizationCode oauth_authorization_code
|
||||
type OAuthAuthorizationCode struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)"`
|
||||
CodeHash string `gorm:"size:64;not null;uniqueIndex"`
|
||||
ClientID string `gorm:"size:64;not null"`
|
||||
RedirectURI string `gorm:"type:text;not null"`
|
||||
UserID string `gorm:"size:36;not null"`
|
||||
TenantID string `gorm:"size:36;not null"`
|
||||
Scope string `gorm:"type:text;not null"`
|
||||
CodeChallenge string `gorm:"size:128;not null"`
|
||||
CodeChallengeMethod string `gorm:"size:16;not null"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
Used bool `gorm:"not null;default:false"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (OAuthAuthorizationCode) TableName() string { return "oauth_authorization_code" }
|
||||
|
||||
// OAuthAccessToken oauth_access_token
|
||||
type OAuthAccessToken struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)"`
|
||||
TokenHash string `gorm:"size:64;not null;uniqueIndex"`
|
||||
ClientID string `gorm:"size:64;not null"`
|
||||
UserID string `gorm:"size:36;not null"`
|
||||
TenantID string `gorm:"size:36;not null"`
|
||||
Scope string `gorm:"type:text;not null"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
RevokedAt *time.Time `gorm:""`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (OAuthAccessToken) TableName() string { return "oauth_access_token" }
|
||||
|
||||
// OAuthRefreshToken oauth_refresh_token
|
||||
type OAuthRefreshToken struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)"`
|
||||
TokenHash string `gorm:"size:64;not null;uniqueIndex"`
|
||||
AccessTokenID string `gorm:"size:36;not null;index"`
|
||||
ClientID string `gorm:"size:64;not null"`
|
||||
UserID string `gorm:"size:36;not null"`
|
||||
TenantID string `gorm:"size:36;not null"`
|
||||
Scope string `gorm:"type:text;not null"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
RevokedAt *time.Time `gorm:""`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (OAuthRefreshToken) TableName() string { return "oauth_refresh_token" }
|
||||
@@ -0,0 +1,23 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// VerifyPKCES256 校验 code_verifier 是否与 code_challenge(S256)一致。
|
||||
func VerifyPKCES256(codeVerifier, codeChallenge string) bool {
|
||||
if codeVerifier == "" || codeChallenge == "" {
|
||||
return false
|
||||
}
|
||||
sum := sha256.Sum256([]byte(codeVerifier))
|
||||
expected := base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return subtle.ConstantTimeCompare([]byte(expected), []byte(codeChallenge)) == 1
|
||||
}
|
||||
|
||||
// NormalizeCodeChallengeMethod 返回小写方法名;仅支持 S256(OAuth 2.1 推荐)。
|
||||
func NormalizeCodeChallengeMethod(m string) string {
|
||||
return strings.TrimSpace(strings.ToLower(m))
|
||||
}
|
||||
@@ -0,0 +1,341 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/internal/auth/session"
|
||||
"giter.top/smart/pkg/config"
|
||||
"giter.top/smart/pkg/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Service OAuth2 授权码 + PKCE + opaque token。
|
||||
type Service struct {
|
||||
cfg *config.Config
|
||||
store *Store
|
||||
sess *session.Store
|
||||
}
|
||||
|
||||
// NewService 构造。
|
||||
func NewService(cfg *config.Config, store *Store, sess *session.Store) *Service {
|
||||
return &Service{cfg: cfg, store: store, sess: sess}
|
||||
}
|
||||
|
||||
func (s *Service) durations() (authCode, access, refresh time.Duration) {
|
||||
authCode = s.cfg.Auth.OAuth2.AuthCodeTTL
|
||||
if authCode == 0 {
|
||||
authCode = 120 * time.Second
|
||||
}
|
||||
access = s.cfg.Auth.OAuth2.AccessTokenTTL
|
||||
if access == 0 {
|
||||
access = 15 * time.Minute
|
||||
}
|
||||
refresh = s.cfg.Auth.OAuth2.RefreshTokenTTL
|
||||
if refresh == 0 {
|
||||
refresh = 720 * time.Hour
|
||||
}
|
||||
return authCode, access, refresh
|
||||
}
|
||||
|
||||
// IssueAuthorizationCodeAfterPasswordAuth 在已通过用户名密码校验的上下文中签发 PKCE 绑定授权码(与 Authorize 中 CreateAuthorizationCode 一致)。
|
||||
func (s *Service) IssueAuthorizationCodeAfterPasswordAuth(ctx context.Context, clientID, redirectURI, userID, tenantID, scope, codeChallenge, challengeMethod string) (codePlain string, err error) {
|
||||
if scope == "" {
|
||||
scope = "openid"
|
||||
}
|
||||
method := NormalizeCodeChallengeMethod(challengeMethod)
|
||||
if codeChallenge == "" || method != "s256" {
|
||||
return "", ErrPKCERequired
|
||||
}
|
||||
cli, err := s.store.GetClientByClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
return "", ErrInvalidClient
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
uris, err := ParseRedirectURIs(cli.RedirectURIsJSON)
|
||||
if err != nil || !RedirectURIMatch(uris, redirectURI) {
|
||||
return "", ErrInvalidRedirectURI
|
||||
}
|
||||
codePlain, err = security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
codeTTL, _, _ := s.durations()
|
||||
exp := time.Now().Add(codeTTL)
|
||||
if err := s.store.CreateAuthorizationCode(ctx, codePlain, clientID, redirectURI, userID, tenantID, scope, codeChallenge, "S256", exp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return codePlain, nil
|
||||
}
|
||||
|
||||
func (s *Service) publicAuthorizeURL(c *gin.Context) string {
|
||||
base := strings.TrimRight(s.cfg.Auth.PublicBaseURL, "/")
|
||||
if base == "" {
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if xf := c.GetHeader("X-Forwarded-Proto"); xf == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
base = scheme + "://" + c.Request.Host
|
||||
}
|
||||
return base + "/oauth/authorize?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// Authorize GET /oauth/authorize
|
||||
func (s *Service) Authorize(c *gin.Context) {
|
||||
q := c.Request.URL.Query()
|
||||
if q.Get("response_type") != "code" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_response_type"})
|
||||
return
|
||||
}
|
||||
clientID := q.Get("client_id")
|
||||
redirectURI := q.Get("redirect_uri")
|
||||
state := q.Get("state")
|
||||
scope := q.Get("scope")
|
||||
if scope == "" {
|
||||
scope = "openid"
|
||||
}
|
||||
challenge := q.Get("code_challenge")
|
||||
method := NormalizeCodeChallengeMethod(q.Get("code_challenge_method"))
|
||||
if challenge == "" || method != "s256" {
|
||||
s.redirectOAuthError(c, redirectURI, state, "invalid_request", "code_challenge and code_challenge_method=S256 required")
|
||||
return
|
||||
}
|
||||
|
||||
cli, err := s.store.GetClientByClientID(c.Request.Context(), clientID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_client"})
|
||||
return
|
||||
}
|
||||
uris, err := ParseRedirectURIs(cli.RedirectURIsJSON)
|
||||
if err != nil || !RedirectURIMatch(uris, redirectURI) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_redirect_uri"})
|
||||
return
|
||||
}
|
||||
|
||||
sid, err := c.Cookie(s.cfg.Auth.Session.CookieName)
|
||||
if err != nil || sid == "" {
|
||||
login := strings.TrimRight(s.cfg.Auth.OAuth2.FrontendLoginURL, "?")
|
||||
ret := s.publicAuthorizeURL(c)
|
||||
u, _ := url.Parse(login)
|
||||
q2 := u.Query()
|
||||
q2.Set("return_to", ret)
|
||||
u.RawQuery = q2.Encode()
|
||||
c.Redirect(http.StatusFound, u.String())
|
||||
return
|
||||
}
|
||||
userID, tenantID, err := s.sess.Get(c.Request.Context(), sid)
|
||||
if err != nil {
|
||||
login := strings.TrimRight(s.cfg.Auth.OAuth2.FrontendLoginURL, "?")
|
||||
ret := s.publicAuthorizeURL(c)
|
||||
u, _ := url.Parse(login)
|
||||
q2 := u.Query()
|
||||
q2.Set("return_to", ret)
|
||||
u.RawQuery = q2.Encode()
|
||||
c.Redirect(http.StatusFound, u.String())
|
||||
return
|
||||
}
|
||||
|
||||
codePlain, err := security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
codeTTL, _, _ := s.durations()
|
||||
exp := time.Now().Add(codeTTL)
|
||||
if err := s.store.CreateAuthorizationCode(c.Request.Context(), codePlain, clientID, redirectURI, userID, tenantID, scope, challenge, "S256", exp); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
redir, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_redirect_uri"})
|
||||
return
|
||||
}
|
||||
rq := redir.Query()
|
||||
rq.Set("code", codePlain)
|
||||
if state != "" {
|
||||
rq.Set("state", state)
|
||||
}
|
||||
redir.RawQuery = rq.Encode()
|
||||
c.Redirect(http.StatusFound, redir.String())
|
||||
}
|
||||
|
||||
func (s *Service) redirectOAuthError(c *gin.Context, redirectURI, state, errCode, desc string) {
|
||||
if redirectURI == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errCode, "error_description": desc})
|
||||
return
|
||||
}
|
||||
u, e := url.Parse(redirectURI)
|
||||
if e != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errCode})
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("error", errCode)
|
||||
q.Set("error_description", desc)
|
||||
if state != "" {
|
||||
q.Set("state", state)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
c.Redirect(http.StatusFound, u.String())
|
||||
}
|
||||
|
||||
// Token POST /oauth/token
|
||||
func (s *Service) Token(c *gin.Context) {
|
||||
if err := c.Request.ParseForm(); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||||
return
|
||||
}
|
||||
gt := c.PostForm("grant_type")
|
||||
switch gt {
|
||||
case "authorization_code":
|
||||
s.tokenAuthorizationCode(c)
|
||||
case "refresh_token":
|
||||
s.tokenRefresh(c)
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_grant_type"})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) tokenAuthorizationCode(c *gin.Context) {
|
||||
code := c.PostForm("code")
|
||||
redirectURI := c.PostForm("redirect_uri")
|
||||
clientID := c.PostForm("client_id")
|
||||
verifier := c.PostForm("code_verifier")
|
||||
if code == "" || redirectURI == "" || clientID == "" || verifier == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||||
return
|
||||
}
|
||||
row, err := s.store.ConsumeAuthorizationCode(c.Request.Context(), code, clientID, redirectURI)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
|
||||
return
|
||||
}
|
||||
if !VerifyPKCES256(verifier, row.CodeChallenge) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "pkce verification failed"})
|
||||
return
|
||||
}
|
||||
_, accessTTL, refreshTTL := s.durations()
|
||||
accessPlain, err := security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
refreshPlain, err := security.RandomURLSafe(48)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
if err := s.store.IssueAccessAndRefresh(c.Request.Context(), accessPlain, refreshPlain, clientID, row.UserID, row.TenantID, row.Scope, accessTTL, refreshTTL); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
slog.Info("oauth2_token_issued", "grant_type", "authorization_code", "client_id", clientID, "user_id", row.UserID, "tenant_id", row.TenantID, "client_ip", c.ClientIP())
|
||||
s.jsonAccessToken(c, accessPlain, refreshPlain, accessTTL)
|
||||
}
|
||||
|
||||
func (s *Service) tokenRefresh(c *gin.Context) {
|
||||
refresh := c.PostForm("refresh_token")
|
||||
clientID := c.PostForm("client_id")
|
||||
if refresh == "" || clientID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||||
return
|
||||
}
|
||||
_, accessTTL, refreshTTL := s.durations()
|
||||
newAccess, err := security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
newRefresh, err := security.RandomURLSafe(48)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
if err := s.store.RotateByRefreshToken(c.Request.Context(), clientID, refresh, newAccess, newRefresh, accessTTL, refreshTTL); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
|
||||
return
|
||||
}
|
||||
slog.Info("oauth2_token_issued", "grant_type", "refresh_token", "client_id", clientID, "client_ip", c.ClientIP())
|
||||
s.jsonAccessToken(c, newAccess, newRefresh, accessTTL)
|
||||
}
|
||||
|
||||
func (s *Service) jsonAccessToken(c *gin.Context, access, refresh string, accessTTL time.Duration) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": access,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int(accessTTL.Seconds()),
|
||||
"refresh_token": refresh,
|
||||
})
|
||||
}
|
||||
|
||||
// Introspect POST /oauth/introspect(RFC 7662),与 opaque 查表语义一致。
|
||||
func (s *Service) Introspect(c *gin.Context) {
|
||||
if err := c.Request.ParseForm(); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
tok := c.PostForm("token")
|
||||
hint := strings.TrimSpace(c.PostForm("token_type_hint"))
|
||||
if tok == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tryRefreshFirst := hint == "refresh_token"
|
||||
if tryRefreshFirst {
|
||||
if row, err := s.store.LookupRefreshTokenRow(ctx, tok); err == nil {
|
||||
slog.Info("oauth2_introspect", "active", true, "token_type", "refresh_token", "client_id", row.ClientID, "sub", row.UserID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"active": true,
|
||||
"scope": row.Scope,
|
||||
"client_id": row.ClientID,
|
||||
"token_type": "refresh_token",
|
||||
"sub": row.UserID,
|
||||
"exp": row.ExpiresAt.Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
|
||||
if row, err := s.store.LookupAccessTokenRow(ctx, tok); err == nil {
|
||||
slog.Info("oauth2_introspect", "active", true, "token_type", "access_token", "client_id", row.ClientID, "sub", row.UserID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"active": true,
|
||||
"scope": row.Scope,
|
||||
"client_id": row.ClientID,
|
||||
"token_type": "access_token",
|
||||
"sub": row.UserID,
|
||||
"exp": row.ExpiresAt.Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if row, err := s.store.LookupRefreshTokenRow(ctx, tok); err == nil {
|
||||
slog.Info("oauth2_introspect", "active", true, "token_type", "refresh_token", "client_id", row.ClientID, "sub", row.UserID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"active": true,
|
||||
"scope": row.Scope,
|
||||
"client_id": row.ClientID,
|
||||
"token_type": "refresh_token",
|
||||
"sub": row.UserID,
|
||||
"exp": row.ExpiresAt.Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ErrNotFound 未找到记录。
|
||||
var ErrNotFound = errors.New("oauth2: not found")
|
||||
|
||||
func hashToken(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// Store OAuth 持久化。
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewStore 创建 Store。
|
||||
func NewStore(db *gorm.DB) *Store {
|
||||
return &Store{db: db}
|
||||
}
|
||||
|
||||
// GetClientByClientID 按 client_id 查客户端。
|
||||
func (st *Store) GetClientByClientID(ctx context.Context, clientID string) (*OAuthClient, error) {
|
||||
var row OAuthClient
|
||||
err := st.db.WithContext(ctx).Where("client_id = ?", clientID).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// ParseRedirectURIs 解析 redirect_uris JSON 数组。
|
||||
func ParseRedirectURIs(raw string) ([]string, error) {
|
||||
var uris []string
|
||||
if raw == "" {
|
||||
return nil, errors.New("empty redirect_uris")
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &uris); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return uris, nil
|
||||
}
|
||||
|
||||
// RedirectURIMatch OAuth 2.1 精确匹配。
|
||||
func RedirectURIMatch(allowed []string, u string) bool {
|
||||
for _, x := range allowed {
|
||||
if x == u {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateAuthorizationCode 写入授权码(code 明文仅返回给调用方,库存哈希)。
|
||||
func (st *Store) CreateAuthorizationCode(ctx context.Context, codePlain string, clientID, redirectURI, userID, tenantID, scope, challenge, method string, expiresAt time.Time) error {
|
||||
row := OAuthAuthorizationCode{
|
||||
ID: id.New(),
|
||||
CodeHash: hashToken(codePlain),
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Scope: scope,
|
||||
CodeChallenge: challenge,
|
||||
CodeChallengeMethod: method,
|
||||
ExpiresAt: expiresAt,
|
||||
Used: false,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
return st.db.WithContext(ctx).Create(&row).Error
|
||||
}
|
||||
|
||||
// ConsumeAuthorizationCode 校验并一次性消费授权码,返回行数据供发 token。
|
||||
func (st *Store) ConsumeAuthorizationCode(ctx context.Context, codePlain, clientID, redirectURI string) (*OAuthAuthorizationCode, error) {
|
||||
h := hashToken(codePlain)
|
||||
var out *OAuthAuthorizationCode
|
||||
err := st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var row OAuthAuthorizationCode
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("code_hash = ?", h).First(&row).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if row.Used {
|
||||
return ErrNotFound
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return ErrNotFound
|
||||
}
|
||||
if row.ClientID != clientID || row.RedirectURI != redirectURI {
|
||||
return ErrNotFound
|
||||
}
|
||||
if err := tx.Model(&OAuthAuthorizationCode{}).Where("id = ?", row.ID).Update("used", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
out = &row
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// TokenPrincipal opaque access token 解析结果。
|
||||
type TokenPrincipal struct {
|
||||
UserID string
|
||||
TenantID string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// LookupAccessToken 按明文 access token 查有效记录。
|
||||
func (st *Store) LookupAccessToken(ctx context.Context, raw string) (*TokenPrincipal, error) {
|
||||
h := hashToken(raw)
|
||||
var row OAuthAccessToken
|
||||
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return &TokenPrincipal{UserID: row.UserID, TenantID: row.TenantID, Scope: row.Scope}, nil
|
||||
}
|
||||
|
||||
// LookupAccessTokenRow 按明文查 access token 行(自省用)。
|
||||
func (st *Store) LookupAccessTokenRow(ctx context.Context, raw string) (*OAuthAccessToken, error) {
|
||||
h := hashToken(raw)
|
||||
var row OAuthAccessToken
|
||||
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// LookupRefreshTokenRow 按明文查 refresh token 行(自省用)。
|
||||
func (st *Store) LookupRefreshTokenRow(ctx context.Context, raw string) (*OAuthRefreshToken, error) {
|
||||
h := hashToken(raw)
|
||||
var row OAuthRefreshToken
|
||||
err := st.db.WithContext(ctx).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// IssueAccessAndRefresh 写入 access + refresh(opaque 明文仅调用方返回给客户端)。
|
||||
func (st *Store) IssueAccessAndRefresh(ctx context.Context, accessPlain, refreshPlain, clientID, userID, tenantID, scope string, accessTTL, refreshTTL time.Duration) error {
|
||||
now := time.Now()
|
||||
accessID := id.New()
|
||||
refreshID := id.New()
|
||||
at := OAuthAccessToken{
|
||||
ID: accessID,
|
||||
TokenHash: hashToken(accessPlain),
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Scope: scope,
|
||||
ExpiresAt: now.Add(accessTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
rt := OAuthRefreshToken{
|
||||
ID: refreshID,
|
||||
TokenHash: hashToken(refreshPlain),
|
||||
AccessTokenID: accessID,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Scope: scope,
|
||||
ExpiresAt: now.Add(refreshTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(&at).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Create(&rt).Error
|
||||
})
|
||||
}
|
||||
|
||||
// RotateByRefreshToken 使用 refresh 换发新 access+refresh,旧令牌作废;client_id 须与注册一致。
|
||||
func (st *Store) RotateByRefreshToken(ctx context.Context, clientID, refreshPlain, newAccessPlain, newRefreshPlain string, accessTTL, refreshTTL time.Duration) error {
|
||||
h := hashToken(refreshPlain)
|
||||
return st.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var row OAuthRefreshToken
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("token_hash = ? AND revoked_at IS NULL", h).First(&row).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if row.ClientID != clientID {
|
||||
return ErrNotFound
|
||||
}
|
||||
if time.Now().After(row.ExpiresAt) {
|
||||
return ErrNotFound
|
||||
}
|
||||
now := time.Now()
|
||||
if err := tx.Model(&OAuthRefreshToken{}).Where("id = ?", row.ID).Update("revoked_at", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&OAuthAccessToken{}).Where("id = ?", row.AccessTokenID).Update("revoked_at", now).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
newAID := id.New()
|
||||
newRID := id.New()
|
||||
at := OAuthAccessToken{
|
||||
ID: newAID,
|
||||
TokenHash: hashToken(newAccessPlain),
|
||||
ClientID: row.ClientID,
|
||||
UserID: row.UserID,
|
||||
TenantID: row.TenantID,
|
||||
Scope: row.Scope,
|
||||
ExpiresAt: now.Add(accessTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
rt := OAuthRefreshToken{
|
||||
ID: newRID,
|
||||
TokenHash: hashToken(newRefreshPlain),
|
||||
AccessTokenID: newAID,
|
||||
ClientID: row.ClientID,
|
||||
UserID: row.UserID,
|
||||
TenantID: row.TenantID,
|
||||
Scope: row.Scope,
|
||||
ExpiresAt: now.Add(refreshTTL),
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := tx.Create(&at).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Create(&rt).Error
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package scope
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Split 将空格分隔的 scope 拆成列表。
|
||||
func Split(scope string) []string {
|
||||
if strings.TrimSpace(scope) == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Fields(scope)
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Contains 判断 scope 字符串是否包含指定权限标记。
|
||||
func Contains(scope, want string) bool {
|
||||
if want == "" {
|
||||
return true
|
||||
}
|
||||
for _, s := range Split(scope) {
|
||||
if s == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAPIAccess 约定含 `api` 或 `api.*` 前缀即表示可访问业务 API(可与 IAM 菜单权限组合使用)。
|
||||
func HasAPIAccess(scope string) bool {
|
||||
for _, s := range Split(scope) {
|
||||
if s == "api" || strings.HasPrefix(s, "api.") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/pkg/config"
|
||||
"giter.top/smart/pkg/security"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ErrInvalidSession 会话不存在或已过期。
|
||||
var ErrInvalidSession = errors.New("session: invalid or expired")
|
||||
|
||||
const redisKeyPrefix = "auth:sess:"
|
||||
|
||||
type payload struct {
|
||||
UserID string `json:"user_id"`
|
||||
TenantID string `json:"tenant_id"`
|
||||
}
|
||||
|
||||
// Store Redis 会话(供 OAuth authorize 与登出)。
|
||||
type Store struct {
|
||||
rdb redis.UniversalClient
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewStore 创建会话存储。
|
||||
func NewStore(rdb redis.UniversalClient, cfg *config.Config) *Store {
|
||||
return &Store{rdb: rdb, cfg: cfg}
|
||||
}
|
||||
|
||||
func (s *Store) ttl() time.Duration {
|
||||
t := s.cfg.Auth.Session.TTL
|
||||
if t == 0 {
|
||||
return 168 * time.Hour
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// Create 创建会话并返回 session id(写入 Cookie 用)。
|
||||
func (s *Store) Create(ctx context.Context, userID, tenantID string) (sid string, err error) {
|
||||
sid, err = security.RandomURLSafe(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b, err := json.Marshal(payload{UserID: userID, TenantID: tenantID})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sid, s.rdb.Set(ctx, redisKeyPrefix+sid, b, s.ttl()).Err()
|
||||
}
|
||||
|
||||
// Get 解析会话。
|
||||
func (s *Store) Get(ctx context.Context, sid string) (userID, tenantID string, err error) {
|
||||
b, err := s.rdb.Get(ctx, redisKeyPrefix+sid).Bytes()
|
||||
if err == redis.Nil {
|
||||
return "", "", ErrInvalidSession
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
var p payload
|
||||
if err := json.Unmarshal(b, &p); err != nil {
|
||||
return "", "", ErrInvalidSession
|
||||
}
|
||||
return p.UserID, p.TenantID, nil
|
||||
}
|
||||
|
||||
// Delete 登出时删除。
|
||||
func (s *Store) Delete(ctx context.Context, sid string) error {
|
||||
return s.rdb.Del(ctx, redisKeyPrefix+sid).Err()
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"giter.top/smart/internal/auth/handler"
|
||||
"giter.top/smart/internal/auth/middleware"
|
||||
"giter.top/smart/internal/auth/oauth2"
|
||||
"giter.top/smart/internal/auth/session"
|
||||
"giter.top/smart/pkg/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// ProviderSet Wire 注入。
|
||||
var ProviderSet = wire.NewSet(
|
||||
session.NewStore,
|
||||
oauth2.NewStore,
|
||||
oauth2.NewService,
|
||||
oauth2.NewHandler,
|
||||
handler.NewLoginHandler,
|
||||
ProvideBearer,
|
||||
ProvideLoginRLimitWire,
|
||||
ProvideTokenRLimitWire,
|
||||
NewAuthRoutes,
|
||||
)
|
||||
|
||||
// ProvideBearer 提供 Gin 中间件。
|
||||
func ProvideBearer(store *oauth2.Store) gin.HandlerFunc {
|
||||
return middleware.NewBearer(store)
|
||||
}
|
||||
|
||||
// LoginRateLimitWire、TokenRateLimitWire 用于 Wire 区分多个 gin.HandlerFunc 形参。
|
||||
type LoginRateLimitWire gin.HandlerFunc
|
||||
type TokenRateLimitWire gin.HandlerFunc
|
||||
|
||||
// ProvideLoginRLimitWire 登录接口限流。
|
||||
func ProvideLoginRLimitWire(cfg *config.Config) LoginRateLimitWire {
|
||||
return LoginRateLimitWire(middleware.PerIPMinute(cfg.Auth.RateLimit.Enabled, cfg.Auth.RateLimit.LoginPerMinute))
|
||||
}
|
||||
|
||||
// ProvideTokenRLimitWire 令牌与自省端点限流。
|
||||
func ProvideTokenRLimitWire(cfg *config.Config) TokenRateLimitWire {
|
||||
return TokenRateLimitWire(middleware.PerIPMinute(cfg.Auth.RateLimit.Enabled, cfg.Auth.RateLimit.TokenPerMinute))
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"giter.top/smart/pkg/cache"
|
||||
"giter.top/smart/pkg/db"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSet = wire.NewSet(db.NewDB, cache.NewRedis)
|
||||
@@ -0,0 +1,24 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Dept 部门 iam_dept(根部门 parent_id 为空字符串)
|
||||
type Dept struct {
|
||||
ID string `json:"id" gorm:"primaryKey;type:varchar(36);not null"`
|
||||
TenantID string `json:"tenant_id" gorm:"size:36;not null;index:idx_dept_tenant"`
|
||||
ParentID string `json:"parent_id" gorm:"size:36;default:'';index:idx_dept_parent"`
|
||||
DeptName string `json:"dept_name" gorm:"size:128;not null"`
|
||||
DeptPath string `json:"dept_path" gorm:"type:text"`
|
||||
LeaderID *string `json:"leader_id" gorm:"size:36"`
|
||||
SortOrder int `json:"sort_order" gorm:"default:0"`
|
||||
Status int16 `json:"status" gorm:"default:1"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
func (Dept) TableName() string { return "iam_dept" }
|
||||
@@ -0,0 +1,32 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PublicOverviewPerms 动态导航中「概览页」类公开权限标识(PRD:所有用户默认可见,需在菜单中配置同名 perms)
|
||||
const PublicOverviewPerms = "public:overview"
|
||||
|
||||
// Menu 菜单 iam_menu(全局,不按租户分表;根节点 parent_id 为空字符串)
|
||||
type Menu struct {
|
||||
ID string `json:"id" gorm:"primaryKey;type:varchar(36);not null"`
|
||||
ParentID string `json:"parent_id" gorm:"size:36;default:'';index:idx_menu_parent"`
|
||||
MenuName string `json:"menu_name" gorm:"size:128;not null"`
|
||||
MenuType int16 `json:"menu_type" gorm:"not null"` // 1目录 2菜单 3按钮
|
||||
Perms string `json:"perms" gorm:"size:128;uniqueIndex"`
|
||||
Path string `json:"path" gorm:"size:255"`
|
||||
Component string `json:"component" gorm:"size:255"`
|
||||
Icon string `json:"icon" gorm:"size:64"`
|
||||
SortOrder int `json:"sort_order" gorm:"default:0"`
|
||||
IsVisible bool `json:"is_visible" gorm:"default:true"`
|
||||
IsBuiltin bool `json:"is_builtin" gorm:"default:false"`
|
||||
ExternalLink string `json:"external_link" gorm:"size:512"`
|
||||
Status int16 `json:"status" gorm:"default:1"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
func (Menu) TableName() string { return "iam_menu" }
|
||||
@@ -0,0 +1,43 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 数据范围并集优先级(数值越大权限越大)
|
||||
const (
|
||||
DataScopeSelf int16 = 1
|
||||
DataScopeDept int16 = 2
|
||||
DataScopeDeptTree int16 = 3
|
||||
DataScopeAll int16 = 4
|
||||
)
|
||||
|
||||
// Role 角色 iam_role
|
||||
type Role struct {
|
||||
ID string `json:"id" gorm:"primaryKey;type:varchar(36);not null"`
|
||||
TenantID string `json:"tenant_id" gorm:"size:36;not null;index:idx_role_tenant"`
|
||||
RoleCode string `json:"role_code" gorm:"size:64;not null"`
|
||||
RoleName string `json:"role_name" gorm:"size:128;not null"`
|
||||
DataScope int16 `json:"data_scope" gorm:"default:4"` // 1本人 2本部门 3本部门及子部门 4全部
|
||||
Description string `json:"description" gorm:"size:512"`
|
||||
IsBuiltin bool `json:"is_builtin" gorm:"default:false"`
|
||||
Status int16 `json:"status" gorm:"default:1"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
func (Role) TableName() string { return "iam_role" }
|
||||
|
||||
|
||||
// RoleMenu 角色菜单 iam_role_menu
|
||||
type RoleMenu struct {
|
||||
ID string `json:"id" gorm:"primaryKey;type:varchar(36);not null"`
|
||||
RoleID string `json:"role_id" gorm:"size:36;not null;uniqueIndex:uk_role_menu"`
|
||||
MenuID string `json:"menu_id" gorm:"size:36;not null;uniqueIndex:uk_role_menu"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func (RoleMenu) TableName() string { return "iam_role_menu" }
|
||||
@@ -0,0 +1,25 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PlatformTenantID 平台租户主键(与初始化数据一致;菜单维护等仅平台租户可操作)
|
||||
const PlatformTenantID = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
// Tenant 租户 iam_tenant
|
||||
type Tenant struct {
|
||||
ID string `json:"id" gorm:"primaryKey;type:varchar(36);not null"`
|
||||
TenantCode string `json:"tenant_code" gorm:"size:64;uniqueIndex;not null"`
|
||||
TenantName string `json:"tenant_name" gorm:"size:128;not null"`
|
||||
AdminUserID *string `json:"admin_user_id" gorm:"size:36"`
|
||||
Status int16 `json:"status" gorm:"default:1"` // 1 正常 0 冻结 -1 删除(逻辑)
|
||||
ExpireTime *time.Time `json:"expire_time"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
func (Tenant) TableName() string { return "iam_tenant" }
|
||||
@@ -0,0 +1,53 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User 用户 iam_user
|
||||
type User struct {
|
||||
ID string `json:"id" gorm:"primaryKey;type:varchar(36);not null"`
|
||||
TenantID string `json:"tenant_id" gorm:"size:36;not null;index:idx_user_tenant"`
|
||||
DeptID *string `json:"dept_id" gorm:"size:36;index:idx_user_dept"`
|
||||
UserName string `json:"user_name" gorm:"size:64;not null"`
|
||||
RealName string `json:"real_name" gorm:"size:64"`
|
||||
PasswordHash string `json:"-" gorm:"size:255;not null"`
|
||||
Phone string `json:"phone" gorm:"size:20"`
|
||||
Email string `json:"email" gorm:"size:128"`
|
||||
Avatar string `json:"avatar" gorm:"size:512"`
|
||||
Gender int16 `json:"gender" gorm:"default:0"`
|
||||
Status int16 `json:"status" gorm:"default:1"`
|
||||
LoginAttempts int `json:"login_attempts" gorm:"default:0"`
|
||||
LockedUntil *time.Time `json:"locked_until"`
|
||||
LastLoginAt *time.Time `json:"last_login_at"`
|
||||
LastLoginIP string `json:"last_login_ip" gorm:"size:45"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
func (User) TableName() string { return "iam_user" }
|
||||
|
||||
|
||||
// UserDept 用户部门关联 iam_user_dept
|
||||
type UserDept struct {
|
||||
ID string `json:"id" gorm:"primaryKey;type:varchar(36);not null"`
|
||||
UserID string `json:"user_id" gorm:"size:36;not null;uniqueIndex:uk_user_dept"`
|
||||
DeptID string `json:"dept_id" gorm:"size:36;not null;uniqueIndex:uk_user_dept"`
|
||||
IsPrimary bool `json:"is_primary" gorm:"default:false"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func (UserDept) TableName() string { return "iam_user_dept" }
|
||||
|
||||
// UserRole 用户角色 iam_user_role
|
||||
type UserRole struct {
|
||||
ID string `json:"id" gorm:"primaryKey;type:varchar(36);not null"`
|
||||
UserID string `json:"user_id" gorm:"size:36;not null;uniqueIndex:uk_user_role"`
|
||||
RoleID string `json:"role_id" gorm:"size:36;not null;uniqueIndex:uk_user_role"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func (UserRole) TableName() string { return "iam_user_role" }
|
||||
@@ -0,0 +1,95 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"giter.top/smart/internal/iam/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DeptHandler struct {
|
||||
svc service.DeptService
|
||||
}
|
||||
|
||||
func NewDeptHandler(svc service.DeptService) *DeptHandler {
|
||||
return &DeptHandler{svc: svc}
|
||||
}
|
||||
|
||||
func (h *DeptHandler) Tree(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
keyword := c.Query("keyword")
|
||||
var leaderID *string
|
||||
if s := c.Query("leader_id"); s != "" {
|
||||
leaderID = &s
|
||||
}
|
||||
tree, err := h.svc.Tree(c.Request.Context(), tid, keyword, leaderID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, tree)
|
||||
}
|
||||
|
||||
func (h *DeptHandler) Create(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
var req service.CreateDeptRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
d, err := h.svc.Create(c.Request.Context(), tid, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, d)
|
||||
}
|
||||
|
||||
func (h *DeptHandler) Update(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
var req service.UpdateDeptRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
d, err := h.svc.Update(c.Request.Context(), tid, id, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, d)
|
||||
}
|
||||
|
||||
func (h *DeptHandler) Delete(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
var ids []string
|
||||
if err := c.ShouldBindJSON(&ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.svc.Delete(c.Request.Context(), tid, ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *DeptHandler) Get(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
d, err := h.svc.Get(c.Request.Context(), tid, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, d)
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
authmw "giter.top/smart/internal/auth/middleware"
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func atoiDef(s string, def int) int {
|
||||
if s == "" {
|
||||
return def
|
||||
}
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// headerTenantID 当前租户:优先 OAuth2 Bearer 解析结果,其次 X-Tenant-ID,缺省平台租户。
|
||||
func headerTenantID(c *gin.Context) string {
|
||||
if v, ok := c.Get(authmw.CtxTenantID); ok {
|
||||
if s, ok2 := v.(string); ok2 && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
s := c.GetHeader("X-Tenant-ID")
|
||||
if s == "" {
|
||||
return entity.PlatformTenantID
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// headerUserID 当前用户:优先 OAuth2 opaque access_token 对应用户,其次 X-User-ID。
|
||||
func headerUserID(c *gin.Context) string {
|
||||
if v, ok := c.Get(authmw.CtxUserID); ok {
|
||||
if s, ok2 := v.(string); ok2 && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return c.GetHeader("X-User-ID")
|
||||
}
|
||||
|
||||
// headerGrantorUserID 请求头 X-Grantor-User-ID(授权人,用于防越权校验)
|
||||
func headerGrantorUserID(c *gin.Context) *string {
|
||||
s := c.GetHeader("X-Grantor-User-ID")
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"giter.top/smart/internal/iam/repository"
|
||||
"giter.top/smart/internal/iam/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type MenuHandler struct {
|
||||
svc service.MenuService
|
||||
}
|
||||
|
||||
func NewMenuHandler(svc service.MenuService) *MenuHandler {
|
||||
return &MenuHandler{svc: svc}
|
||||
}
|
||||
|
||||
func isPlatformAdmin(c *gin.Context) bool {
|
||||
return headerTenantID(c) == entity.PlatformTenantID
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Create(c *gin.Context) {
|
||||
var req service.CreateMenuRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
m, err := h.svc.Create(c.Request.Context(), &req, isPlatformAdmin(c))
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrForbidden) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "仅平台管理员可维护菜单"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, m)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Update(c *gin.Context) {
|
||||
mid := c.Param("id")
|
||||
if mid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
var req service.UpdateMenuRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
m, err := h.svc.Update(c.Request.Context(), mid, &req, isPlatformAdmin(c))
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrForbidden) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "仅平台管理员可维护菜单"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, m)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Delete(c *gin.Context) {
|
||||
var ids []string
|
||||
if err := c.ShouldBindJSON(&ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.svc.Delete(c.Request.Context(), ids, isPlatformAdmin(c)); err != nil {
|
||||
if errors.Is(err, repository.ErrForbidden) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "仅平台管理员可维护菜单"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Get(c *gin.Context) {
|
||||
mid := c.Param("id")
|
||||
if mid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
m, err := h.svc.Get(c.Request.Context(), mid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, m)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Tree(c *gin.Context) {
|
||||
var mt *int16
|
||||
if s := c.Query("menu_type"); s != "" {
|
||||
v64, err := strconv.ParseInt(s, 10, 16)
|
||||
if err == nil {
|
||||
v := int16(v64)
|
||||
mt = &v
|
||||
}
|
||||
}
|
||||
tree, err := h.svc.Tree(c.Request.Context(), mt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, tree)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Nav(c *gin.Context) {
|
||||
uid := headerUserID(c)
|
||||
if uid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 X-User-ID"})
|
||||
return
|
||||
}
|
||||
tree, err := h.svc.NavForUser(c.Request.Context(), uid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, tree)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Perms(c *gin.Context) {
|
||||
uid := headerUserID(c)
|
||||
if uid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 X-User-ID"})
|
||||
return
|
||||
}
|
||||
perms, err := h.svc.PermsForUser(c.Request.Context(), uid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"perms": perms})
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"giter.top/smart/internal/iam/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type RoleHandler struct {
|
||||
svc service.RoleService
|
||||
}
|
||||
|
||||
func NewRoleHandler(svc service.RoleService) *RoleHandler {
|
||||
return &RoleHandler{svc: svc}
|
||||
}
|
||||
|
||||
func (h *RoleHandler) Create(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
var req service.CreateRoleRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
role, err := h.svc.Create(c.Request.Context(), tid, &req, headerGrantorUserID(c))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, role)
|
||||
}
|
||||
|
||||
func (h *RoleHandler) Update(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
rid := c.Param("id")
|
||||
if rid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
var req service.UpdateRoleRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
r, err := h.svc.Update(c.Request.Context(), tid, rid, &req, headerGrantorUserID(c))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, r)
|
||||
}
|
||||
|
||||
func (h *RoleHandler) Delete(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
var ids []string
|
||||
if err := c.ShouldBindJSON(&ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.svc.Delete(c.Request.Context(), tid, ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *RoleHandler) Get(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
rid := c.Param("id")
|
||||
if rid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
r, err := h.svc.Get(c.Request.Context(), tid, rid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, r)
|
||||
}
|
||||
|
||||
func (h *RoleHandler) List(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
name := c.Query("name")
|
||||
code := c.Query("code")
|
||||
page := atoiDef(c.Query("page"), 1)
|
||||
pageSize := atoiDef(c.Query("page_size"), 10)
|
||||
resp, err := h.svc.List(c.Request.Context(), tid, name, code, page, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
type assignMenusBody struct {
|
||||
MenuIDs []string `json:"menu_ids"`
|
||||
}
|
||||
|
||||
func (h *RoleHandler) AssignMenus(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
rid := c.Param("id")
|
||||
if rid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
var body assignMenusBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.svc.AssignMenus(c.Request.Context(), tid, rid, body.MenuIDs, headerGrantorUserID(c)); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"giter.top/smart/internal/iam/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TenantHandler struct {
|
||||
svc service.TenantService
|
||||
}
|
||||
|
||||
func NewTenantHandler(svc service.TenantService) *TenantHandler {
|
||||
return &TenantHandler{svc: svc}
|
||||
}
|
||||
|
||||
func (h *TenantHandler) Create(c *gin.Context) {
|
||||
var req service.CreateTenantRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
t, err := h.svc.Create(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, t)
|
||||
}
|
||||
|
||||
func (h *TenantHandler) Update(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
var req service.UpdateTenantRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
t, err := h.svc.Update(c.Request.Context(), id, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, t)
|
||||
}
|
||||
|
||||
func (h *TenantHandler) Delete(c *gin.Context) {
|
||||
var ids []string
|
||||
if err := c.ShouldBindJSON(&ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.svc.Delete(c.Request.Context(), ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *TenantHandler) Get(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
t, err := h.svc.Get(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, t)
|
||||
}
|
||||
|
||||
func (h *TenantHandler) List(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
code := c.Query("code")
|
||||
var status *int16
|
||||
if s := c.Query("status"); s != "" {
|
||||
v64, err := strconv.ParseInt(s, 10, 16)
|
||||
if err == nil {
|
||||
v := int16(v64)
|
||||
status = &v
|
||||
}
|
||||
}
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
|
||||
resp, err := h.svc.List(c.Request.Context(), name, code, status, page, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"giter.top/smart/internal/iam/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type UserHandler struct {
|
||||
svc service.UserService
|
||||
}
|
||||
|
||||
func NewUserHandler(svc service.UserService) *UserHandler {
|
||||
return &UserHandler{svc: svc}
|
||||
}
|
||||
|
||||
func (h *UserHandler) Create(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
var req service.CreateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
u, err := h.svc.Create(c.Request.Context(), tid, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, u)
|
||||
}
|
||||
|
||||
func (h *UserHandler) Update(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
uid := c.Param("id")
|
||||
if uid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
var req service.UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
u, err := h.svc.Update(c.Request.Context(), tid, uid, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, u)
|
||||
}
|
||||
|
||||
func (h *UserHandler) Delete(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
var ids []string
|
||||
if err := c.ShouldBindJSON(&ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.svc.Delete(c.Request.Context(), tid, ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *UserHandler) Get(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
uid := c.Param("id")
|
||||
if uid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||
return
|
||||
}
|
||||
u, err := h.svc.Get(c.Request.Context(), tid, uid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, u)
|
||||
}
|
||||
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
tid := headerTenantID(c)
|
||||
q := &service.UserListQuery{
|
||||
Keyword: c.Query("keyword"),
|
||||
Page: atoiDef(c.Query("page"), 1),
|
||||
PageSize: atoiDef(c.Query("page_size"), 10),
|
||||
}
|
||||
if s := c.Query("dept_id"); s != "" {
|
||||
q.DeptID = &s
|
||||
}
|
||||
if s := c.Query("role_id"); s != "" {
|
||||
q.RoleID = &s
|
||||
}
|
||||
if s := c.Query("status"); s != "" {
|
||||
v64, err := strconv.ParseInt(s, 10, 16)
|
||||
if err == nil {
|
||||
v := int16(v64)
|
||||
q.Status = &v
|
||||
}
|
||||
}
|
||||
resp, err := h.svc.List(c.Request.Context(), tid, q)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (h *UserHandler) DataScope(c *gin.Context) {
|
||||
uid := headerUserID(c)
|
||||
if uid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 X-User-ID"})
|
||||
return
|
||||
}
|
||||
ds, err := h.svc.DataScopeForUser(c.Request.Context(), uid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"data_scope": ds})
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package iam
|
||||
|
||||
import (
|
||||
"giter.top/smart/internal/iam/handler"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type IamRoutes struct {
|
||||
tenantHandler *handler.TenantHandler
|
||||
deptHandler *handler.DeptHandler
|
||||
roleHandler *handler.RoleHandler
|
||||
userHandler *handler.UserHandler
|
||||
menuHandler *handler.MenuHandler
|
||||
}
|
||||
|
||||
func NewIamRoutes(tenantHandler *handler.TenantHandler, deptHandler *handler.DeptHandler, roleHandler *handler.RoleHandler, userHandler *handler.UserHandler, menuHandler *handler.MenuHandler) *IamRoutes {
|
||||
return &IamRoutes{
|
||||
tenantHandler: tenantHandler,
|
||||
deptHandler: deptHandler,
|
||||
roleHandler: roleHandler,
|
||||
userHandler: userHandler,
|
||||
menuHandler: menuHandler,
|
||||
}
|
||||
}
|
||||
// TODO 添加注册信息
|
||||
func (s *IamRoutes) Register(engine *gin.Engine, apiGroup *gin.RouterGroup) {
|
||||
// group :=engine.Group("/iam")
|
||||
group := apiGroup.Group("/iam")
|
||||
s.registerTenantRoutes(group)
|
||||
s.registerDeptRoutes(group)
|
||||
s.registerRoleRoutes(group)
|
||||
s.registerUserRoutes(group)
|
||||
s.registerMenuRoutes(group)
|
||||
}
|
||||
|
||||
func (s *IamRoutes) registerTenantRoutes(group *gin.RouterGroup) {
|
||||
tg := group.Group("/tenant")
|
||||
{
|
||||
tg.POST("/create", s.tenantHandler.Create)
|
||||
tg.PUT("/update/:id", s.tenantHandler.Update)
|
||||
tg.DELETE("/delete-batch", s.tenantHandler.Delete)
|
||||
tg.GET("/get/:id", s.tenantHandler.Get)
|
||||
tg.GET("/list", s.tenantHandler.List)
|
||||
}
|
||||
}
|
||||
func (s *IamRoutes) registerDeptRoutes(group *gin.RouterGroup) {
|
||||
dg := group.Group("/dept")
|
||||
{
|
||||
dg.POST("/create", s.deptHandler.Create)
|
||||
dg.PUT("/update/:id", s.deptHandler.Update)
|
||||
dg.DELETE("/delete-batch", s.deptHandler.Delete)
|
||||
dg.GET("/get/:id", s.deptHandler.Get)
|
||||
dg.GET("/tree", s.deptHandler.Tree)
|
||||
}
|
||||
}
|
||||
func (s *IamRoutes) registerRoleRoutes(group *gin.RouterGroup) {
|
||||
rg := group.Group("/role")
|
||||
{
|
||||
rg.POST("/create", s.roleHandler.Create)
|
||||
rg.PUT("/update/:id", s.roleHandler.Update)
|
||||
rg.DELETE("/delete-batch", s.roleHandler.Delete)
|
||||
rg.GET("/get/:id", s.roleHandler.Get)
|
||||
rg.GET("/list", s.roleHandler.List)
|
||||
}
|
||||
}
|
||||
func (s *IamRoutes) registerUserRoutes(group *gin.RouterGroup) {
|
||||
ug := group.Group("/user")
|
||||
{
|
||||
ug.POST("/create", s.userHandler.Create)
|
||||
ug.PUT("/update/:id", s.userHandler.Update)
|
||||
ug.DELETE("/delete-batch", s.userHandler.Delete)
|
||||
ug.GET("/get/:id", s.userHandler.Get)
|
||||
ug.GET("/list", s.userHandler.List)
|
||||
}
|
||||
}
|
||||
func (s *IamRoutes) registerMenuRoutes(group *gin.RouterGroup) {
|
||||
mg := group.Group("/menu")
|
||||
{
|
||||
mg.POST("/create", s.menuHandler.Create)
|
||||
mg.PUT("/update/:id", s.menuHandler.Update)
|
||||
mg.DELETE("/delete-batch", s.menuHandler.Delete)
|
||||
mg.GET("/get/:id", s.menuHandler.Get)
|
||||
mg.GET("/tree", s.menuHandler.Tree)
|
||||
mg.GET("/nav", s.menuHandler.Nav)
|
||||
mg.GET("/perms", s.menuHandler.Perms)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DeptRepository 部门数据访问
|
||||
type DeptRepository interface {
|
||||
Create(ctx context.Context, d *entity.Dept) error
|
||||
Update(ctx context.Context, d *entity.Dept) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
GetByID(ctx context.Context, id string) (*entity.Dept, error)
|
||||
ListByTenant(ctx context.Context, tenantID string) ([]entity.Dept, error)
|
||||
CountChildren(ctx context.Context, id string) (int64, error)
|
||||
ExistsSiblingName(ctx context.Context, tenantID, parentID, name string, excludeID string) (bool, error)
|
||||
FindRoot(ctx context.Context, tenantID string) (*entity.Dept, error)
|
||||
UpdatePath(ctx context.Context, id string, path string) error
|
||||
}
|
||||
|
||||
type deptRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewDeptRepository(db *gorm.DB) DeptRepository {
|
||||
return &deptRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *deptRepository) Create(ctx context.Context, d *entity.Dept) error {
|
||||
return r.db.WithContext(ctx).Create(d).Error
|
||||
}
|
||||
|
||||
func (r *deptRepository) Update(ctx context.Context, d *entity.Dept) error {
|
||||
return r.db.WithContext(ctx).Save(d).Error
|
||||
}
|
||||
|
||||
func (r *deptRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entity.Dept{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *deptRepository) GetByID(ctx context.Context, id string) (*entity.Dept, error) {
|
||||
var out entity.Dept
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&out).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *deptRepository) ListByTenant(ctx context.Context, tenantID string) ([]entity.Dept, error) {
|
||||
var rows []entity.Dept
|
||||
err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Order("sort_order ASC, created_at ASC").Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (r *deptRepository) CountChildren(ctx context.Context, id string) (int64, error) {
|
||||
var n int64
|
||||
err := r.db.WithContext(ctx).Model(&entity.Dept{}).Where("parent_id = ?", id).Count(&n).Error
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *deptRepository) ExistsSiblingName(ctx context.Context, tenantID, parentID, name string, excludeID string) (bool, error) {
|
||||
q := r.db.WithContext(ctx).Model(&entity.Dept{}).Where("tenant_id = ? AND parent_id = ? AND dept_name = ?", tenantID, parentID, name)
|
||||
if excludeID != "" {
|
||||
q = q.Where("id <> ?", excludeID)
|
||||
}
|
||||
var n int64
|
||||
err := q.Count(&n).Error
|
||||
return n > 0, err
|
||||
}
|
||||
|
||||
func (r *deptRepository) FindRoot(ctx context.Context, tenantID string) (*entity.Dept, error) {
|
||||
var out entity.Dept
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("tenant_id = ? AND (parent_id = '' OR parent_id = '0')", tenantID).
|
||||
First(&out).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *deptRepository) UpdatePath(ctx context.Context, id string, path string) error {
|
||||
return r.db.WithContext(ctx).Model(&entity.Dept{}).Where("id = ?", id).Update("dept_path", path).Error
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package repository
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNotFound = errors.New("not found")
|
||||
ErrConflict = errors.New("conflict")
|
||||
ErrInvalidState = errors.New("invalid state")
|
||||
ErrForbidden = errors.New("forbidden")
|
||||
)
|
||||
@@ -0,0 +1,111 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// MenuRepository 菜单
|
||||
type MenuRepository interface {
|
||||
Create(ctx context.Context, m *entity.Menu) error
|
||||
Update(ctx context.Context, m *entity.Menu) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
GetByID(ctx context.Context, id string) (*entity.Menu, error)
|
||||
ListAll(ctx context.Context) ([]entity.Menu, error)
|
||||
ListByType(ctx context.Context, menuType *int16) ([]entity.Menu, error)
|
||||
ExistsPerms(ctx context.Context, perms string, excludeID string) (bool, error)
|
||||
CountChildren(ctx context.Context, parentID string) (int64, error)
|
||||
CountRoleRefs(ctx context.Context, menuID string) (int64, error)
|
||||
ListByPerms(ctx context.Context, perms string) ([]entity.Menu, error)
|
||||
ListIDsByPermsIn(ctx context.Context, perms []string) ([]string, error)
|
||||
}
|
||||
|
||||
type menuRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewMenuRepository(db *gorm.DB) MenuRepository {
|
||||
return &menuRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *menuRepository) Create(ctx context.Context, m *entity.Menu) error {
|
||||
return r.db.WithContext(ctx).Create(m).Error
|
||||
}
|
||||
|
||||
func (r *menuRepository) Update(ctx context.Context, m *entity.Menu) error {
|
||||
return r.db.WithContext(ctx).Save(m).Error
|
||||
}
|
||||
|
||||
func (r *menuRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entity.Menu{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *menuRepository) GetByID(ctx context.Context, id string) (*entity.Menu, error) {
|
||||
var out entity.Menu
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&out).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *menuRepository) ListAll(ctx context.Context) ([]entity.Menu, error) {
|
||||
var rows []entity.Menu
|
||||
err := r.db.WithContext(ctx).Order("sort_order ASC, created_at ASC").Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (r *menuRepository) ListByType(ctx context.Context, menuType *int16) ([]entity.Menu, error) {
|
||||
q := r.db.WithContext(ctx).Model(&entity.Menu{})
|
||||
if menuType != nil {
|
||||
q = q.Where("menu_type = ?", *menuType)
|
||||
}
|
||||
var rows []entity.Menu
|
||||
err := q.Order("sort_order ASC, created_at ASC").Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (r *menuRepository) ExistsPerms(ctx context.Context, perms string, excludeID string) (bool, error) {
|
||||
if perms == "" {
|
||||
return false, nil
|
||||
}
|
||||
q := r.db.WithContext(ctx).Model(&entity.Menu{}).Where("perms = ?", perms)
|
||||
if excludeID != "" {
|
||||
q = q.Where("id <> ?", excludeID)
|
||||
}
|
||||
var n int64
|
||||
err := q.Count(&n).Error
|
||||
return n > 0, err
|
||||
}
|
||||
|
||||
func (r *menuRepository) CountChildren(ctx context.Context, parentID string) (int64, error) {
|
||||
var n int64
|
||||
err := r.db.WithContext(ctx).Model(&entity.Menu{}).Where("parent_id = ?", parentID).Count(&n).Error
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *menuRepository) CountRoleRefs(ctx context.Context, menuID string) (int64, error) {
|
||||
var n int64
|
||||
err := r.db.WithContext(ctx).Model(&entity.RoleMenu{}).Where("menu_id = ?", menuID).Count(&n).Error
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *menuRepository) ListByPerms(ctx context.Context, perms string) ([]entity.Menu, error) {
|
||||
var rows []entity.Menu
|
||||
err := r.db.WithContext(ctx).Where("perms = ?", perms).Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (r *menuRepository) ListIDsByPermsIn(ctx context.Context, perms []string) ([]string, error) {
|
||||
if len(perms) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var ids []string
|
||||
err := r.db.WithContext(ctx).Model(&entity.Menu{}).Where("perms IN ?", perms).Pluck("id", &ids).Error
|
||||
return ids, err
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RoleRepository 角色与角色菜单
|
||||
type RoleRepository interface {
|
||||
Create(ctx context.Context, r *entity.Role) error
|
||||
Update(ctx context.Context, r *entity.Role) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
GetByID(ctx context.Context, id string) (*entity.Role, error)
|
||||
List(ctx context.Context, tenantID string, name, code string, page, pageSize int) ([]entity.Role, int64, error)
|
||||
ExistsCode(ctx context.Context, tenantID string, code string, excludeID string) (bool, error)
|
||||
CountUsers(ctx context.Context, roleID string) (int64, error)
|
||||
ReplaceRoleMenus(ctx context.Context, roleID string, menuIDs []string) error
|
||||
ListMenuIDsByRole(ctx context.Context, roleID string) ([]string, error)
|
||||
ListMenuIDsByRoles(ctx context.Context, roleIDs []string) ([]string, error)
|
||||
ListRolesByUser(ctx context.Context, userID string) ([]entity.Role, error)
|
||||
}
|
||||
|
||||
type roleRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRoleRepository(db *gorm.DB) RoleRepository {
|
||||
return &roleRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *roleRepository) Create(ctx context.Context, row *entity.Role) error {
|
||||
return r.db.WithContext(ctx).Create(row).Error
|
||||
}
|
||||
|
||||
func (r *roleRepository) Update(ctx context.Context, row *entity.Role) error {
|
||||
return r.db.WithContext(ctx).Save(row).Error
|
||||
}
|
||||
|
||||
func (r *roleRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entity.Role{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *roleRepository) GetByID(ctx context.Context, id string) (*entity.Role, error) {
|
||||
var out entity.Role
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&out).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *roleRepository) List(ctx context.Context, tenantID string, name, code string, page, pageSize int) ([]entity.Role, int64, error) {
|
||||
q := r.db.WithContext(ctx).Model(&entity.Role{}).Where("tenant_id = ?", tenantID)
|
||||
if name != "" {
|
||||
q = q.Where("role_name LIKE ?", "%"+name+"%")
|
||||
}
|
||||
if code != "" {
|
||||
q = q.Where("role_code LIKE ?", "%"+code+"%")
|
||||
}
|
||||
var total int64
|
||||
if err := q.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 10
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
var rows []entity.Role
|
||||
err := q.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&rows).Error
|
||||
return rows, total, err
|
||||
}
|
||||
|
||||
func (r *roleRepository) ExistsCode(ctx context.Context, tenantID string, code string, excludeID string) (bool, error) {
|
||||
q := r.db.WithContext(ctx).Model(&entity.Role{}).Where("tenant_id = ? AND role_code = ?", tenantID, code)
|
||||
if excludeID != "" {
|
||||
q = q.Where("id <> ?", excludeID)
|
||||
}
|
||||
var n int64
|
||||
err := q.Count(&n).Error
|
||||
return n > 0, err
|
||||
}
|
||||
|
||||
func (r *roleRepository) CountUsers(ctx context.Context, roleID string) (int64, error) {
|
||||
var n int64
|
||||
err := r.db.WithContext(ctx).Model(&entity.UserRole{}).Where("role_id = ?", roleID).Count(&n).Error
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *roleRepository) ReplaceRoleMenus(ctx context.Context, roleID string, menuIDs []string) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("role_id = ?", roleID).Delete(&entity.RoleMenu{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
for _, mid := range menuIDs {
|
||||
rm := entity.RoleMenu{ID: id.New(), RoleID: roleID, MenuID: mid}
|
||||
if err := tx.Create(&rm).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *roleRepository) ListMenuIDsByRole(ctx context.Context, roleID string) ([]string, error) {
|
||||
var ids []string
|
||||
err := r.db.WithContext(ctx).Model(&entity.RoleMenu{}).Where("role_id = ?", roleID).Pluck("menu_id", &ids).Error
|
||||
return ids, err
|
||||
}
|
||||
|
||||
func (r *roleRepository) ListMenuIDsByRoles(ctx context.Context, roleIDs []string) ([]string, error) {
|
||||
if len(roleIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var raw []string
|
||||
err := r.db.WithContext(ctx).Model(&entity.RoleMenu{}).Where("role_id IN ?", roleIDs).Pluck("menu_id", &raw).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seen := make(map[string]struct{}, len(raw))
|
||||
var ids []string
|
||||
for _, menuID := range raw {
|
||||
if _, ok := seen[menuID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[menuID] = struct{}{}
|
||||
ids = append(ids, menuID)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *roleRepository) ListRolesByUser(ctx context.Context, userID string) ([]entity.Role, error) {
|
||||
var roles []entity.Role
|
||||
err := r.db.WithContext(ctx).Table("iam_role").
|
||||
Joins("JOIN iam_user_role ur ON ur.role_id = iam_role.id").
|
||||
Where("ur.user_id = ?", userID).
|
||||
Find(&roles).Error
|
||||
return roles, err
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TenantRepository 租户数据访问
|
||||
type TenantRepository interface {
|
||||
Create(ctx context.Context, t *entity.Tenant) error
|
||||
Update(ctx context.Context, t *entity.Tenant) error
|
||||
GetByID(ctx context.Context, id string) (*entity.Tenant, error)
|
||||
GetByCode(ctx context.Context, code string) (*entity.Tenant, error)
|
||||
List(ctx context.Context, name, code string, status *int16, page, pageSize int) ([]entity.Tenant, int64, error)
|
||||
CountUsers(ctx context.Context, tenantID string) (int64, error)
|
||||
CountDepts(ctx context.Context, tenantID string) (int64, error)
|
||||
ExistsCode(ctx context.Context, code string, excludeID string) (bool, error)
|
||||
}
|
||||
|
||||
type tenantRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTenantRepository(db *gorm.DB) TenantRepository {
|
||||
return &tenantRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *tenantRepository) Create(ctx context.Context, t *entity.Tenant) error {
|
||||
return r.db.WithContext(ctx).Create(t).Error
|
||||
}
|
||||
|
||||
func (r *tenantRepository) Update(ctx context.Context, t *entity.Tenant) error {
|
||||
return r.db.WithContext(ctx).Save(t).Error
|
||||
}
|
||||
|
||||
func (r *tenantRepository) GetByID(ctx context.Context, id string) (*entity.Tenant, error) {
|
||||
var out entity.Tenant
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&out).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *tenantRepository) GetByCode(ctx context.Context, code string) (*entity.Tenant, error) {
|
||||
var out entity.Tenant
|
||||
err := r.db.WithContext(ctx).Where("tenant_code = ?", code).First(&out).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *tenantRepository) List(ctx context.Context, name, code string, status *int16, page, pageSize int) ([]entity.Tenant, int64, error) {
|
||||
q := r.db.WithContext(ctx).Model(&entity.Tenant{})
|
||||
if name != "" {
|
||||
q = q.Where("tenant_name LIKE ?", "%"+name+"%")
|
||||
}
|
||||
if code != "" {
|
||||
q = q.Where("tenant_code LIKE ?", "%"+code+"%")
|
||||
}
|
||||
if status != nil {
|
||||
q = q.Where("status = ?", *status)
|
||||
}
|
||||
var total int64
|
||||
if err := q.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 10
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
var rows []entity.Tenant
|
||||
err := q.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&rows).Error
|
||||
return rows, total, err
|
||||
}
|
||||
|
||||
func (r *tenantRepository) CountUsers(ctx context.Context, tenantID string) (int64, error) {
|
||||
var n int64
|
||||
err := r.db.WithContext(ctx).Model(&entity.User{}).Where("tenant_id = ?", tenantID).Count(&n).Error
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *tenantRepository) CountDepts(ctx context.Context, tenantID string) (int64, error) {
|
||||
var n int64
|
||||
err := r.db.WithContext(ctx).Model(&entity.Dept{}).Where("tenant_id = ?", tenantID).Count(&n).Error
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *tenantRepository) ExistsCode(ctx context.Context, code string, excludeID string) (bool, error) {
|
||||
q := r.db.WithContext(ctx).Model(&entity.Tenant{}).Where("tenant_code = ?", code)
|
||||
if excludeID != "" {
|
||||
q = q.Where("id <> ?", excludeID)
|
||||
}
|
||||
var n int64
|
||||
err := q.Count(&n).Error
|
||||
return n > 0, err
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserRepository 用户数据访问
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, u *entity.User) error
|
||||
Update(ctx context.Context, u *entity.User) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
GetByID(ctx context.Context, id string) (*entity.User, error)
|
||||
GetByUserName(ctx context.Context, tenantID string, userName string) (*entity.User, error)
|
||||
ExistsUserName(ctx context.Context, tenantID string, userName string, excludeID string) (bool, error)
|
||||
CountByDept(ctx context.Context, deptID string) (int64, error)
|
||||
List(ctx context.Context, tenantID string, deptID *string, roleID *string, keyword string, status *int16, page, pageSize int) ([]entity.User, int64, error)
|
||||
ReplaceUserDepts(ctx context.Context, userID string, primaryDept string, deptIDs []string) error
|
||||
ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string) error
|
||||
ListRoleIDs(ctx context.Context, userID string) ([]string, error)
|
||||
ListDeptIDs(ctx context.Context, userID string) ([]string, error)
|
||||
}
|
||||
|
||||
type userRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) UserRepository {
|
||||
return &userRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, u *entity.User) error {
|
||||
return r.db.WithContext(ctx).Create(u).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(ctx context.Context, u *entity.User) error {
|
||||
return r.db.WithContext(ctx).Save(u).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entity.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByID(ctx context.Context, id string) (*entity.User, error) {
|
||||
var out entity.User
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&out).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByUserName(ctx context.Context, tenantID string, userName string) (*entity.User, error) {
|
||||
var out entity.User
|
||||
err := r.db.WithContext(ctx).Where("tenant_id = ? AND user_name = ?", tenantID, userName).First(&out).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsUserName(ctx context.Context, tenantID string, userName string, excludeID string) (bool, error) {
|
||||
q := r.db.WithContext(ctx).Model(&entity.User{}).Where("tenant_id = ? AND user_name = ?", tenantID, userName)
|
||||
if excludeID != "" {
|
||||
q = q.Where("id <> ?", excludeID)
|
||||
}
|
||||
var n int64
|
||||
err := q.Count(&n).Error
|
||||
return n > 0, err
|
||||
}
|
||||
|
||||
func (r *userRepository) CountByDept(ctx context.Context, deptID string) (int64, error) {
|
||||
var n int64
|
||||
err := r.db.WithContext(ctx).Raw(`
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT id FROM iam_user WHERE dept_id = ? AND deleted_at IS NULL
|
||||
UNION
|
||||
SELECT user_id FROM iam_user_dept WHERE dept_id = ?
|
||||
) t`, deptID, deptID).Scan(&n).Error
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *userRepository) List(ctx context.Context, tenantID string, deptID *string, roleID *string, keyword string, status *int16, page, pageSize int) ([]entity.User, int64, error) {
|
||||
q := r.db.WithContext(ctx).Model(&entity.User{}).Where("tenant_id = ?", tenantID)
|
||||
if deptID != nil {
|
||||
d := *deptID
|
||||
q = q.Where("dept_id = ? OR id IN (SELECT user_id FROM iam_user_dept WHERE dept_id = ?)", d, d)
|
||||
}
|
||||
if roleID != nil {
|
||||
sub := r.db.WithContext(ctx).Model(&entity.UserRole{}).Select("user_id").Where("role_id = ?", *roleID)
|
||||
q = q.Where("id IN (?)", sub)
|
||||
}
|
||||
if keyword != "" {
|
||||
kw := "%" + keyword + "%"
|
||||
q = q.Where("user_name LIKE ? OR real_name LIKE ? OR phone LIKE ? OR email LIKE ?", kw, kw, kw, kw)
|
||||
}
|
||||
if status != nil {
|
||||
q = q.Where("status = ?", *status)
|
||||
}
|
||||
var total int64
|
||||
if err := q.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 10
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
var rows []entity.User
|
||||
err := q.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&rows).Error
|
||||
return rows, total, err
|
||||
}
|
||||
|
||||
func (r *userRepository) ReplaceUserDepts(ctx context.Context, userID string, primaryDept string, deptIDs []string) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("user_id = ?", userID).Delete(&entity.UserDept{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
seen := map[string]struct{}{}
|
||||
for _, did := range deptIDs {
|
||||
if _, ok := seen[did]; ok {
|
||||
continue
|
||||
}
|
||||
seen[did] = struct{}{}
|
||||
ud := entity.UserDept{
|
||||
ID: id.New(),
|
||||
UserID: userID,
|
||||
DeptID: did,
|
||||
IsPrimary: did == primaryDept,
|
||||
}
|
||||
if err := tx.Create(&ud).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(deptIDs) == 0 && primaryDept != "" {
|
||||
ud := entity.UserDept{ID: id.New(), UserID: userID, DeptID: primaryDept, IsPrimary: true}
|
||||
return tx.Create(&ud).Error
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *userRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("user_id = ?", userID).Delete(&entity.UserRole{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rid := range roleIDs {
|
||||
ur := entity.UserRole{ID: id.New(), UserID: userID, RoleID: rid}
|
||||
if err := tx.Create(&ur).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *userRepository) ListRoleIDs(ctx context.Context, userID string) ([]string, error) {
|
||||
var ids []string
|
||||
err := r.db.WithContext(ctx).Model(&entity.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &ids).Error
|
||||
return ids, err
|
||||
}
|
||||
|
||||
func (r *userRepository) ListDeptIDs(ctx context.Context, userID string) ([]string, error) {
|
||||
var ids []string
|
||||
err := r.db.WithContext(ctx).Model(&entity.UserDept{}).Where("user_id = ?", userID).Pluck("dept_id", &ids).Error
|
||||
return ids, err
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
package service
|
||||
|
||||
// DefaultTenantAdminRoleCode 新租户初始化时的单位管理员角色编码
|
||||
const DefaultTenantAdminRoleCode = "tenant_admin"
|
||||
@@ -0,0 +1,326 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"giter.top/smart/internal/iam/repository"
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
)
|
||||
|
||||
// DeptService 部门
|
||||
type DeptService interface {
|
||||
Tree(ctx context.Context, tenantID string, keyword string, leaderID *string) ([]DeptNode, error)
|
||||
Create(ctx context.Context, tenantID string, req *CreateDeptRequest) (*entity.Dept, error)
|
||||
Update(ctx context.Context, tenantID string, id string, req *UpdateDeptRequest) (*entity.Dept, error)
|
||||
Delete(ctx context.Context, tenantID string, ids []string) error
|
||||
Get(ctx context.Context, tenantID string, id string) (*entity.Dept, error)
|
||||
}
|
||||
|
||||
type CreateDeptRequest struct {
|
||||
ParentID string `json:"parent_id"`
|
||||
DeptName string `json:"dept_name" binding:"required,max=128"`
|
||||
LeaderID *string `json:"leader_id"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type UpdateDeptRequest struct {
|
||||
ParentID *string `json:"parent_id"`
|
||||
DeptName *string `json:"dept_name" binding:"omitempty,max=128"`
|
||||
LeaderID *string `json:"leader_id"`
|
||||
SortOrder *int `json:"sort_order"`
|
||||
}
|
||||
|
||||
// DeptNode 树节点
|
||||
type DeptNode struct {
|
||||
entity.Dept
|
||||
Children []DeptNode `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
type deptService struct {
|
||||
depts repository.DeptRepository
|
||||
users repository.UserRepository
|
||||
}
|
||||
|
||||
func NewDeptService(depts repository.DeptRepository, users repository.UserRepository) DeptService {
|
||||
return &deptService{depts: depts, users: users}
|
||||
}
|
||||
|
||||
func isDeptRoot(d *entity.Dept) bool {
|
||||
return d.ParentID == "" || d.ParentID == "0"
|
||||
}
|
||||
|
||||
func (s *deptService) Tree(ctx context.Context, tenantID string, keyword string, leaderID *string) ([]DeptNode, error) {
|
||||
rows, err := s.depts.ListByTenant(ctx, tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filtered := rows
|
||||
if keyword != "" || leaderID != nil {
|
||||
filtered = make([]entity.Dept, 0)
|
||||
for _, d := range rows {
|
||||
if keyword != "" && !strings.Contains(d.DeptName, keyword) {
|
||||
continue
|
||||
}
|
||||
if leaderID != nil && (d.LeaderID == nil || *d.LeaderID != *leaderID) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, d)
|
||||
}
|
||||
if keyword != "" || leaderID != nil {
|
||||
filtered = s.includeAncestors(rows, filtered)
|
||||
}
|
||||
}
|
||||
return buildDeptTree(filtered, ""), nil
|
||||
}
|
||||
|
||||
func (s *deptService) includeAncestors(all []entity.Dept, matched []entity.Dept) []entity.Dept {
|
||||
idSet := map[string]struct{}{}
|
||||
byID := map[string]entity.Dept{}
|
||||
for _, d := range all {
|
||||
byID[d.ID] = d
|
||||
}
|
||||
for _, d := range matched {
|
||||
cur := d
|
||||
for {
|
||||
idSet[cur.ID] = struct{}{}
|
||||
if isDeptRoot(&cur) {
|
||||
break
|
||||
}
|
||||
p, ok := byID[cur.ParentID]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
cur = p
|
||||
}
|
||||
}
|
||||
out := make([]entity.Dept, 0, len(idSet))
|
||||
for _, d := range all {
|
||||
if _, ok := idSet[d.ID]; ok {
|
||||
out = append(out, d)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildDeptTree(rows []entity.Dept, parentID string) []DeptNode {
|
||||
children := map[string][]entity.Dept{}
|
||||
for _, d := range rows {
|
||||
pid := d.ParentID
|
||||
if d.ParentID == "0" {
|
||||
pid = ""
|
||||
}
|
||||
children[pid] = append(children[pid], d)
|
||||
}
|
||||
var walk func(pid string) []DeptNode
|
||||
walk = func(pid string) []DeptNode {
|
||||
list := children[pid]
|
||||
out := make([]DeptNode, 0, len(list))
|
||||
for _, d := range list {
|
||||
out = append(out, DeptNode{Dept: d, Children: walk(d.ID)})
|
||||
}
|
||||
return out
|
||||
}
|
||||
return walk(parentID)
|
||||
}
|
||||
|
||||
func (s *deptService) Create(ctx context.Context, tenantID string, req *CreateDeptRequest) (*entity.Dept, error) {
|
||||
parentKey := req.ParentID
|
||||
if parentKey == "0" {
|
||||
parentKey = ""
|
||||
}
|
||||
ok, err := s.depts.ExistsSiblingName(ctx, tenantID, parentKey, req.DeptName, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("同级部门名称已存在")
|
||||
}
|
||||
d := &entity.Dept{
|
||||
ID: id.New(),
|
||||
TenantID: tenantID,
|
||||
ParentID: parentKey,
|
||||
DeptName: req.DeptName,
|
||||
LeaderID: req.LeaderID,
|
||||
SortOrder: req.SortOrder,
|
||||
Status: 1,
|
||||
}
|
||||
if err := s.depts.Create(ctx, d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := fmt.Sprintf("/%s/", d.ID)
|
||||
if parentKey != "" {
|
||||
p, err := s.depts.GetByID(ctx, parentKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.TenantID != tenantID {
|
||||
return nil, fmt.Errorf("父部门不属于当前租户")
|
||||
}
|
||||
base := p.DeptPath
|
||||
if base == "" {
|
||||
base = fmt.Sprintf("/%s/", p.ID)
|
||||
}
|
||||
path = base + fmt.Sprintf("%s/", d.ID)
|
||||
}
|
||||
if err := s.depts.UpdatePath(ctx, d.ID, path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.DeptPath = path
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (s *deptService) Update(ctx context.Context, tenantID string, id string, req *UpdateDeptRequest) (*entity.Dept, error) {
|
||||
d, err := s.depts.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, fmt.Errorf("部门不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if d.TenantID != tenantID {
|
||||
return nil, fmt.Errorf("部门不属于当前租户")
|
||||
}
|
||||
if isDeptRoot(d) {
|
||||
if req.ParentID != nil && *req.ParentID != "" && *req.ParentID != "0" {
|
||||
return nil, fmt.Errorf("根部门禁止移动")
|
||||
}
|
||||
if req.DeptName != nil && *req.DeptName != "" && *req.DeptName != d.DeptName {
|
||||
return nil, fmt.Errorf("根部门禁止重命名")
|
||||
}
|
||||
}
|
||||
curParent := d.ParentID
|
||||
if curParent == "0" {
|
||||
curParent = ""
|
||||
}
|
||||
var newParentForName string
|
||||
if req.ParentID != nil {
|
||||
np := *req.ParentID
|
||||
if np == "0" {
|
||||
np = ""
|
||||
}
|
||||
newParentForName = np
|
||||
} else {
|
||||
newParentForName = curParent
|
||||
}
|
||||
if req.DeptName != nil && *req.DeptName != "" {
|
||||
ok, err := s.depts.ExistsSiblingName(ctx, tenantID, newParentForName, *req.DeptName, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("同级部门名称已存在")
|
||||
}
|
||||
d.DeptName = *req.DeptName
|
||||
}
|
||||
if req.ParentID != nil {
|
||||
npID := *req.ParentID
|
||||
if npID == "0" {
|
||||
npID = ""
|
||||
}
|
||||
if npID != curParent {
|
||||
if npID == id {
|
||||
return nil, fmt.Errorf("不能将部门移动到自身之下")
|
||||
}
|
||||
if npID != "" {
|
||||
if s.isDescendant(ctx, id, npID) {
|
||||
return nil, fmt.Errorf("禁止移动至子部门(防环)")
|
||||
}
|
||||
np, err := s.depts.GetByID(ctx, npID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("父部门无效")
|
||||
}
|
||||
if np.TenantID != tenantID {
|
||||
return nil, fmt.Errorf("父部门不属于当前租户")
|
||||
}
|
||||
d.ParentID = npID
|
||||
base := np.DeptPath
|
||||
if base == "" {
|
||||
base = fmt.Sprintf("/%s/", np.ID)
|
||||
}
|
||||
d.DeptPath = base + fmt.Sprintf("%s/", d.ID)
|
||||
} else {
|
||||
d.ParentID = ""
|
||||
d.DeptPath = fmt.Sprintf("/%s/", d.ID)
|
||||
}
|
||||
_ = s.depts.UpdatePath(ctx, d.ID, d.DeptPath)
|
||||
}
|
||||
}
|
||||
if req.LeaderID != nil {
|
||||
d.LeaderID = req.LeaderID
|
||||
}
|
||||
if req.SortOrder != nil {
|
||||
d.SortOrder = *req.SortOrder
|
||||
}
|
||||
if err := s.depts.Update(ctx, d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (s *deptService) isDescendant(ctx context.Context, rootID, nodeID string) bool {
|
||||
if nodeID == rootID {
|
||||
return true
|
||||
}
|
||||
cur, err := s.depts.GetByID(ctx, nodeID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < 64 && cur.ParentID != "" && cur.ParentID != "0"; i++ {
|
||||
if cur.ParentID == rootID {
|
||||
return true
|
||||
}
|
||||
cur, err = s.depts.GetByID(ctx, cur.ParentID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *deptService) Delete(ctx context.Context, tenantID string, ids []string) error {
|
||||
for _, did := range ids {
|
||||
d, err := s.depts.GetByID(ctx, did)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.TenantID != tenantID {
|
||||
return fmt.Errorf("部门 %s 不属于当前租户", did)
|
||||
}
|
||||
if isDeptRoot(d) {
|
||||
return fmt.Errorf("根部门禁止删除")
|
||||
}
|
||||
n, err := s.depts.CountChildren(ctx, did)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n > 0 {
|
||||
return fmt.Errorf("部门 %s 存在子部门", did)
|
||||
}
|
||||
uc, err := s.users.CountByDept(ctx, did)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if uc > 0 {
|
||||
return fmt.Errorf("部门 %s 仍存在用户", did)
|
||||
}
|
||||
if err := s.depts.Delete(ctx, did); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deptService) Get(ctx context.Context, tenantID string, id string) (*entity.Dept, error) {
|
||||
d, err := s.depts.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.TenantID != tenantID {
|
||||
return nil, fmt.Errorf("部门不属于当前租户")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"giter.top/smart/internal/iam/repository"
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
)
|
||||
|
||||
// MenuService 菜单(全局资源)
|
||||
type MenuService interface {
|
||||
Create(ctx context.Context, req *CreateMenuRequest, isPlatform bool) (*entity.Menu, error)
|
||||
Update(ctx context.Context, mid string, req *UpdateMenuRequest, isPlatform bool) (*entity.Menu, error)
|
||||
Delete(ctx context.Context, ids []string, isPlatform bool) error
|
||||
Get(ctx context.Context, mid string) (*entity.Menu, error)
|
||||
Tree(ctx context.Context, menuType *int16) ([]MenuNode, error)
|
||||
NavForUser(ctx context.Context, userID string) ([]MenuNode, error)
|
||||
PermsForUser(ctx context.Context, userID string) ([]string, error)
|
||||
}
|
||||
|
||||
type CreateMenuRequest struct {
|
||||
ParentID string `json:"parent_id"`
|
||||
MenuName string `json:"menu_name" binding:"required,max=128"`
|
||||
MenuType int16 `json:"menu_type" binding:"required"`
|
||||
Perms string `json:"perms"`
|
||||
Path string `json:"path"`
|
||||
Component string `json:"component"`
|
||||
Icon string `json:"icon"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
IsVisible bool `json:"is_visible"`
|
||||
IsBuiltin bool `json:"is_builtin"`
|
||||
ExternalLink string `json:"external_link"`
|
||||
}
|
||||
|
||||
type UpdateMenuRequest struct {
|
||||
ParentID *string `json:"parent_id"`
|
||||
MenuName *string `json:"menu_name"`
|
||||
SortOrder *int `json:"sort_order"`
|
||||
IsVisible *bool `json:"is_visible"`
|
||||
Path *string `json:"path"`
|
||||
Component *string `json:"component"`
|
||||
Icon *string `json:"icon"`
|
||||
ExternalLink *string `json:"external_link"`
|
||||
Status *int16 `json:"status"`
|
||||
}
|
||||
|
||||
// MenuNode 菜单树节点
|
||||
type MenuNode struct {
|
||||
entity.Menu
|
||||
Children []MenuNode `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
type menuService struct {
|
||||
menus repository.MenuRepository
|
||||
roles repository.RoleRepository
|
||||
users repository.UserRepository
|
||||
}
|
||||
|
||||
func NewMenuService(menus repository.MenuRepository, roles repository.RoleRepository, users repository.UserRepository) MenuService {
|
||||
return &menuService{menus: menus, roles: roles, users: users}
|
||||
}
|
||||
|
||||
func normalizeMenuParent(pid string) string {
|
||||
if pid == "0" {
|
||||
return ""
|
||||
}
|
||||
return pid
|
||||
}
|
||||
|
||||
func (s *menuService) Create(ctx context.Context, req *CreateMenuRequest, isPlatform bool) (*entity.Menu, error) {
|
||||
if !isPlatform {
|
||||
return nil, repository.ErrForbidden
|
||||
}
|
||||
if req.Perms != "" {
|
||||
ok, err := s.menus.ExistsPerms(ctx, req.Perms, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("权限标识已存在")
|
||||
}
|
||||
}
|
||||
m := &entity.Menu{
|
||||
ID: id.New(),
|
||||
ParentID: normalizeMenuParent(req.ParentID),
|
||||
MenuName: req.MenuName,
|
||||
MenuType: req.MenuType,
|
||||
Perms: req.Perms,
|
||||
Path: req.Path,
|
||||
Component: req.Component,
|
||||
Icon: req.Icon,
|
||||
SortOrder: req.SortOrder,
|
||||
IsVisible: req.IsVisible,
|
||||
IsBuiltin: req.IsBuiltin,
|
||||
ExternalLink: req.ExternalLink,
|
||||
Status: 1,
|
||||
}
|
||||
if err := s.menus.Create(ctx, m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *menuService) Update(ctx context.Context, mid string, req *UpdateMenuRequest, isPlatform bool) (*entity.Menu, error) {
|
||||
if !isPlatform {
|
||||
return nil, repository.ErrForbidden
|
||||
}
|
||||
m, err := s.menus.GetByID(ctx, mid)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, fmt.Errorf("菜单不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if m.IsBuiltin {
|
||||
return nil, fmt.Errorf("系统内置菜单禁止修改")
|
||||
}
|
||||
if req.MenuName != nil {
|
||||
m.MenuName = *req.MenuName
|
||||
}
|
||||
if req.SortOrder != nil {
|
||||
m.SortOrder = *req.SortOrder
|
||||
}
|
||||
if req.IsVisible != nil {
|
||||
m.IsVisible = *req.IsVisible
|
||||
}
|
||||
if req.Path != nil {
|
||||
m.Path = *req.Path
|
||||
}
|
||||
if req.Component != nil {
|
||||
m.Component = *req.Component
|
||||
}
|
||||
if req.Icon != nil {
|
||||
m.Icon = *req.Icon
|
||||
}
|
||||
if req.ExternalLink != nil {
|
||||
m.ExternalLink = *req.ExternalLink
|
||||
}
|
||||
if req.Status != nil {
|
||||
m.Status = *req.Status
|
||||
}
|
||||
if err := s.menus.Update(ctx, m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *menuService) Delete(ctx context.Context, ids []string, isPlatform bool) error {
|
||||
if !isPlatform {
|
||||
return repository.ErrForbidden
|
||||
}
|
||||
for _, mid := range ids {
|
||||
m, err := s.menus.GetByID(ctx, mid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m.IsBuiltin {
|
||||
return fmt.Errorf("系统内置菜单禁止删除")
|
||||
}
|
||||
n, err := s.menus.CountChildren(ctx, mid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n > 0 {
|
||||
return fmt.Errorf("存在子菜单,无法删除")
|
||||
}
|
||||
rn, err := s.menus.CountRoleRefs(ctx, mid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rn > 0 {
|
||||
return fmt.Errorf("菜单仍被角色引用")
|
||||
}
|
||||
if err := s.menus.Delete(ctx, mid); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *menuService) Get(ctx context.Context, mid string) (*entity.Menu, error) {
|
||||
return s.menus.GetByID(ctx, mid)
|
||||
}
|
||||
|
||||
func (s *menuService) Tree(ctx context.Context, menuType *int16) ([]MenuNode, error) {
|
||||
rows, err := s.menus.ListByType(ctx, menuType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buildMenuTreeRows(rows), nil
|
||||
}
|
||||
|
||||
func buildMenuTreeRows(rows []entity.Menu) []MenuNode {
|
||||
byParent := map[string][]entity.Menu{}
|
||||
for _, m := range rows {
|
||||
pid := normalizeMenuParent(m.ParentID)
|
||||
byParent[pid] = append(byParent[pid], m)
|
||||
}
|
||||
for k := range byParent {
|
||||
sort.Slice(byParent[k], func(i, j int) bool {
|
||||
if byParent[k][i].SortOrder != byParent[k][j].SortOrder {
|
||||
return byParent[k][i].SortOrder < byParent[k][j].SortOrder
|
||||
}
|
||||
return byParent[k][i].ID < byParent[k][j].ID
|
||||
})
|
||||
}
|
||||
var walk func(pid string) []MenuNode
|
||||
walk = func(pid string) []MenuNode {
|
||||
list := byParent[pid]
|
||||
out := make([]MenuNode, 0, len(list))
|
||||
for _, m := range list {
|
||||
out = append(out, MenuNode{Menu: m, Children: walk(m.ID)})
|
||||
}
|
||||
return out
|
||||
}
|
||||
return walk("")
|
||||
}
|
||||
|
||||
func (s *menuService) NavForUser(ctx context.Context, userID string) ([]MenuNode, error) {
|
||||
rids, err := s.users.ListRoleIDs(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
menuIDs, err := s.roles.ListMenuIDsByRoles(ctx, rids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowed := map[string]struct{}{}
|
||||
for _, mid := range menuIDs {
|
||||
allowed[mid] = struct{}{}
|
||||
}
|
||||
pub, err := s.menus.ListByPerms(ctx, entity.PublicOverviewPerms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, m := range pub {
|
||||
allowed[m.ID] = struct{}{}
|
||||
}
|
||||
all, err := s.menus.ListAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
byID := map[string]entity.Menu{}
|
||||
for _, m := range all {
|
||||
byID[m.ID] = m
|
||||
}
|
||||
for _, m := range all {
|
||||
if _, ok := allowed[m.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
cur := m
|
||||
for {
|
||||
pid := normalizeMenuParent(cur.ParentID)
|
||||
if pid == "" {
|
||||
break
|
||||
}
|
||||
p, ok := byID[pid]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
allowed[p.ID] = struct{}{}
|
||||
cur = p
|
||||
}
|
||||
}
|
||||
filtered := make([]entity.Menu, 0)
|
||||
for _, m := range all {
|
||||
if _, ok := allowed[m.ID]; ok && m.Status == 1 && m.IsVisible {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
tree := buildMenuTreeRows(filtered)
|
||||
return pruneEmptyDirs(tree), nil
|
||||
}
|
||||
|
||||
func pruneEmptyDirs(nodes []MenuNode) []MenuNode {
|
||||
out := make([]MenuNode, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
ch := pruneEmptyDirs(n.Children)
|
||||
if n.MenuType == 1 && len(ch) == 0 {
|
||||
continue
|
||||
}
|
||||
n.Children = ch
|
||||
out = append(out, n)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *menuService) PermsForUser(ctx context.Context, userID string) ([]string, error) {
|
||||
rids, err := s.users.ListRoleIDs(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mids, err := s.roles.ListMenuIDsByRoles(ctx, rids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all, err := s.menus.ListAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idset := map[string]struct{}{}
|
||||
for _, mid := range mids {
|
||||
idset[mid] = struct{}{}
|
||||
}
|
||||
var perms []string
|
||||
for _, m := range all {
|
||||
if _, ok := idset[m.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
if m.Perms != "" {
|
||||
perms = append(perms, m.Perms)
|
||||
}
|
||||
}
|
||||
return perms, nil
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"giter.top/smart/internal/iam/repository"
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
)
|
||||
|
||||
// RoleService 角色
|
||||
type RoleService interface {
|
||||
Create(ctx context.Context, tenantID string, req *CreateRoleRequest, grantorUserID *string) (*entity.Role, error)
|
||||
Update(ctx context.Context, tenantID string, rid string, req *UpdateRoleRequest, grantorUserID *string) (*entity.Role, error)
|
||||
Delete(ctx context.Context, tenantID string, ids []string) error
|
||||
Get(ctx context.Context, tenantID string, rid string) (*entity.Role, error)
|
||||
List(ctx context.Context, tenantID string, name, code string, page, pageSize int) (*RoleListResponse, error)
|
||||
AssignMenus(ctx context.Context, tenantID string, roleID string, menuIDs []string, grantorUserID *string) error
|
||||
}
|
||||
|
||||
type CreateRoleRequest struct {
|
||||
RoleCode string `json:"role_code" binding:"required,max=64"`
|
||||
RoleName string `json:"role_name" binding:"required,max=128"`
|
||||
DataScope int16 `json:"data_scope" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
MenuIDs []string `json:"menu_ids"`
|
||||
}
|
||||
|
||||
type UpdateRoleRequest struct {
|
||||
RoleName *string `json:"role_name"`
|
||||
DataScope *int16 `json:"data_scope"`
|
||||
Description *string `json:"description"`
|
||||
MenuIDs []string `json:"menu_ids"`
|
||||
}
|
||||
|
||||
type RoleListResponse struct {
|
||||
Items []entity.Role `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
type roleService struct {
|
||||
roles repository.RoleRepository
|
||||
users repository.UserRepository
|
||||
menus repository.MenuRepository
|
||||
}
|
||||
|
||||
func NewRoleService(roles repository.RoleRepository, users repository.UserRepository, menus repository.MenuRepository) RoleService {
|
||||
return &roleService{roles: roles, users: users, menus: menus}
|
||||
}
|
||||
|
||||
func (s *roleService) grantorMenuSet(ctx context.Context, grantorUserID string) (map[string]struct{}, error) {
|
||||
rids, err := s.users.ListRoleIDs(ctx, grantorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids, err := s.roles.ListMenuIDsByRoles(ctx, rids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := make(map[string]struct{}, len(ids))
|
||||
for _, mid := range ids {
|
||||
m[mid] = struct{}{}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *roleService) assertMenuSubset(ctx context.Context, grantorUserID *string, menuIDs []string) error {
|
||||
if grantorUserID == nil || *grantorUserID == "" {
|
||||
return nil
|
||||
}
|
||||
allowed, err := s.grantorMenuSet(ctx, *grantorUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, mid := range menuIDs {
|
||||
if _, ok := allowed[mid]; !ok {
|
||||
return fmt.Errorf("防越权: 不能分配自身未拥有的菜单权限 (menu_id=%s)", mid)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *roleService) Create(ctx context.Context, tenantID string, req *CreateRoleRequest, grantorUserID *string) (*entity.Role, error) {
|
||||
ok, err := s.roles.ExistsCode(ctx, tenantID, req.RoleCode, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("角色编码已存在")
|
||||
}
|
||||
if err := s.assertMenuSubset(ctx, grantorUserID, req.MenuIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := &entity.Role{
|
||||
ID: id.New(),
|
||||
TenantID: tenantID,
|
||||
RoleCode: req.RoleCode,
|
||||
RoleName: req.RoleName,
|
||||
DataScope: req.DataScope,
|
||||
Description: req.Description,
|
||||
Status: 1,
|
||||
}
|
||||
if err := s.roles.Create(ctx, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(req.MenuIDs) > 0 {
|
||||
if err := s.roles.ReplaceRoleMenus(ctx, r.ID, req.MenuIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (s *roleService) Update(ctx context.Context, tenantID string, rid string, req *UpdateRoleRequest, grantorUserID *string) (*entity.Role, error) {
|
||||
r, err := s.roles.GetByID(ctx, rid)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, fmt.Errorf("角色不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if r.TenantID != tenantID {
|
||||
return nil, fmt.Errorf("角色不属于当前租户")
|
||||
}
|
||||
if r.IsBuiltin {
|
||||
// 内置角色仅允许改部分字段(MVP:允许改名称与数据范围与菜单需业务再定)
|
||||
}
|
||||
if req.RoleName != nil {
|
||||
r.RoleName = *req.RoleName
|
||||
}
|
||||
if req.DataScope != nil {
|
||||
r.DataScope = *req.DataScope
|
||||
}
|
||||
if req.Description != nil {
|
||||
r.Description = *req.Description
|
||||
}
|
||||
if err := s.roles.Update(ctx, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.MenuIDs != nil {
|
||||
if err := s.assertMenuSubset(ctx, grantorUserID, req.MenuIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.roles.ReplaceRoleMenus(ctx, r.ID, req.MenuIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (s *roleService) Delete(ctx context.Context, tenantID string, ids []string) error {
|
||||
for _, rid := range ids {
|
||||
r, err := s.roles.GetByID(ctx, rid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.TenantID != tenantID {
|
||||
return fmt.Errorf("角色 %s 不属于当前租户", rid)
|
||||
}
|
||||
if r.IsBuiltin {
|
||||
return fmt.Errorf("内置角色不可删除")
|
||||
}
|
||||
n, err := s.roles.CountUsers(ctx, rid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n > 0 {
|
||||
return fmt.Errorf("角色仍被用户使用")
|
||||
}
|
||||
if err := s.roles.Delete(ctx, rid); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *roleService) Get(ctx context.Context, tenantID string, rid string) (*entity.Role, error) {
|
||||
r, err := s.roles.GetByID(ctx, rid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.TenantID != tenantID {
|
||||
return nil, fmt.Errorf("角色不属于当前租户")
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (s *roleService) List(ctx context.Context, tenantID string, name, code string, page, pageSize int) (*RoleListResponse, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 10
|
||||
}
|
||||
rows, total, err := s.roles.List(ctx, tenantID, name, code, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tp := int(total) / pageSize
|
||||
if int(total)%pageSize != 0 {
|
||||
tp++
|
||||
}
|
||||
return &RoleListResponse{Items: rows, Total: total, Page: page, PageSize: pageSize, TotalPages: tp}, nil
|
||||
}
|
||||
|
||||
func (s *roleService) AssignMenus(ctx context.Context, tenantID string, roleID string, menuIDs []string, grantorUserID *string) error {
|
||||
r, err := s.roles.GetByID(ctx, roleID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.TenantID != tenantID {
|
||||
return fmt.Errorf("角色不属于当前租户")
|
||||
}
|
||||
if err := s.assertMenuSubset(ctx, grantorUserID, menuIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.roles.ReplaceRoleMenus(ctx, roleID, menuIDs)
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package service
|
||||
|
||||
import "giter.top/smart/internal/iam/entity"
|
||||
|
||||
// MergeDataScope 多角色数据范围并集:取最大(PRD:全部 > 本部门及子部门 > 本部门 > 仅本人)
|
||||
func MergeDataScope(scopes []int16) int16 {
|
||||
var m int16
|
||||
for _, s := range scopes {
|
||||
if s > m {
|
||||
m = s
|
||||
}
|
||||
}
|
||||
if m == 0 {
|
||||
return entity.DataScopeSelf
|
||||
}
|
||||
return m
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"giter.top/smart/internal/iam/repository"
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TenantService 租户
|
||||
type TenantService interface {
|
||||
Create(ctx context.Context, req *CreateTenantRequest) (*entity.Tenant, error)
|
||||
Update(ctx context.Context, id string, req *UpdateTenantRequest) (*entity.Tenant, error)
|
||||
Delete(ctx context.Context, ids []string) error
|
||||
Get(ctx context.Context, id string) (*entity.Tenant, error)
|
||||
List(ctx context.Context, name, code string, status *int16, page, pageSize int) (*TenantListResponse, error)
|
||||
}
|
||||
|
||||
type CreateTenantRequest struct {
|
||||
TenantCode string `json:"tenant_code" binding:"required,max=64"`
|
||||
TenantName string `json:"tenant_name" binding:"required,max=128"`
|
||||
AdminUserName string `json:"admin_user_name" binding:"required,max=64"`
|
||||
AdminPassword string `json:"admin_password" binding:"required,min=6,max=64"`
|
||||
AdminRealName string `json:"admin_real_name" binding:"max=64"`
|
||||
}
|
||||
|
||||
type UpdateTenantRequest struct {
|
||||
TenantName *string `json:"tenant_name"`
|
||||
TenantCode *string `json:"tenant_code" binding:"omitempty,max=64"`
|
||||
Status *int16 `json:"status"`
|
||||
ExpireTime *string `json:"expire_time"` // RFC3339
|
||||
}
|
||||
|
||||
type TenantListItem struct {
|
||||
entity.Tenant
|
||||
UserCount int64 `json:"user_count"`
|
||||
DeptCount int64 `json:"dept_count"`
|
||||
}
|
||||
|
||||
type TenantListResponse struct {
|
||||
Items []TenantListItem `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
type tenantService struct {
|
||||
db *gorm.DB
|
||||
tenants repository.TenantRepository
|
||||
depts repository.DeptRepository
|
||||
users repository.UserRepository
|
||||
roles repository.RoleRepository
|
||||
menus repository.MenuRepository
|
||||
}
|
||||
|
||||
func NewTenantService(
|
||||
db *gorm.DB,
|
||||
tenants repository.TenantRepository,
|
||||
depts repository.DeptRepository,
|
||||
users repository.UserRepository,
|
||||
roles repository.RoleRepository,
|
||||
menus repository.MenuRepository,
|
||||
) TenantService {
|
||||
return &tenantService{db: db, tenants: tenants, depts: depts, users: users, roles: roles, menus: menus}
|
||||
}
|
||||
|
||||
func (s *tenantService) Create(ctx context.Context, req *CreateTenantRequest) (*entity.Tenant, error) {
|
||||
ok, err := s.tenants.ExistsCode(ctx, req.TenantCode, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("租户编码已存在")
|
||||
}
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(req.AdminPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var out *entity.Tenant
|
||||
err = s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
t := &entity.Tenant{
|
||||
ID: id.New(),
|
||||
TenantCode: req.TenantCode,
|
||||
TenantName: req.TenantName,
|
||||
Status: 1,
|
||||
}
|
||||
if err := tx.Create(t).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
var ucount int64
|
||||
if err := tx.Model(&entity.User{}).Where("tenant_id = ? AND user_name = ?", t.ID, req.AdminUserName).Count(&ucount).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if ucount > 0 {
|
||||
return fmt.Errorf("管理员账号已存在")
|
||||
}
|
||||
root := &entity.Dept{
|
||||
ID: id.New(),
|
||||
TenantID: t.ID,
|
||||
ParentID: "",
|
||||
DeptName: req.TenantName,
|
||||
SortOrder: 0,
|
||||
Status: 1,
|
||||
}
|
||||
if err := tx.Create(root).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
path := fmt.Sprintf("/%s/", root.ID)
|
||||
if err := tx.Model(&entity.Dept{}).Where("id = ?", root.ID).Update("dept_path", path).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
admin := &entity.User{
|
||||
ID: id.New(),
|
||||
TenantID: t.ID,
|
||||
DeptID: &root.ID,
|
||||
UserName: req.AdminUserName,
|
||||
RealName: req.AdminRealName,
|
||||
PasswordHash: string(hash),
|
||||
Status: 1,
|
||||
}
|
||||
if err := tx.Create(admin).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Create(&entity.UserDept{ID: id.New(), UserID: admin.ID, DeptID: root.ID, IsPrimary: true}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
role := &entity.Role{
|
||||
ID: id.New(),
|
||||
TenantID: t.ID,
|
||||
RoleCode: DefaultTenantAdminRoleCode,
|
||||
RoleName: "超级管理员",
|
||||
DataScope: entity.DataScopeAll,
|
||||
Description: "租户初始化角色",
|
||||
IsBuiltin: true,
|
||||
Status: 1,
|
||||
}
|
||||
if err := tx.Create(role).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
var allMenus []entity.Menu
|
||||
if err := tx.Find(&allMenus).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
for _, m := range allMenus {
|
||||
if err := tx.Create(&entity.RoleMenu{ID: id.New(), RoleID: role.ID, MenuID: m.ID}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tx.Create(&entity.UserRole{ID: id.New(), UserID: admin.ID, RoleID: role.ID}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
aid := admin.ID
|
||||
if err := tx.Model(t).Update("admin_user_id", aid).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
t.AdminUserID = &aid
|
||||
out = t
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *tenantService) Update(ctx context.Context, id string, req *UpdateTenantRequest) (*entity.Tenant, error) {
|
||||
t, err := s.tenants.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, fmt.Errorf("租户不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if req.TenantName != nil && *req.TenantName != "" {
|
||||
t.TenantName = *req.TenantName
|
||||
if root, err := s.depts.FindRoot(ctx, t.ID); err == nil {
|
||||
root.DeptName = *req.TenantName
|
||||
_ = s.depts.Update(ctx, root)
|
||||
}
|
||||
}
|
||||
if req.TenantCode != nil && *req.TenantCode != "" {
|
||||
ok, err := s.tenants.ExistsCode(ctx, *req.TenantCode, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("租户编码已存在")
|
||||
}
|
||||
t.TenantCode = *req.TenantCode
|
||||
}
|
||||
if req.Status != nil {
|
||||
t.Status = *req.Status
|
||||
}
|
||||
if req.ExpireTime != nil && *req.ExpireTime != "" {
|
||||
et, err := time.Parse(time.RFC3339, *req.ExpireTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("到期时间格式无效: %w", err)
|
||||
}
|
||||
t.ExpireTime = &et
|
||||
if et.Before(time.Now()) {
|
||||
t.Status = 0
|
||||
}
|
||||
}
|
||||
if err := s.tenants.Update(ctx, t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (s *tenantService) Delete(ctx context.Context, ids []string) error {
|
||||
for _, tid := range ids {
|
||||
n, err := s.tenants.CountUsers(ctx, tid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n > 0 {
|
||||
return fmt.Errorf("租户 %s 仍存在用户,无法删除", tid)
|
||||
}
|
||||
}
|
||||
for _, tid := range ids {
|
||||
if err := s.db.WithContext(ctx).Delete(&entity.Tenant{}, "id = ?", tid).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tenantService) Get(ctx context.Context, id string) (*entity.Tenant, error) {
|
||||
return s.tenants.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *tenantService) List(ctx context.Context, name, code string, status *int16, page, pageSize int) (*TenantListResponse, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 10
|
||||
}
|
||||
rows, total, err := s.tenants.List(ctx, name, code, status, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items := make([]TenantListItem, 0, len(rows))
|
||||
for _, t := range rows {
|
||||
uc, _ := s.tenants.CountUsers(ctx, t.ID)
|
||||
dc, _ := s.tenants.CountDepts(ctx, t.ID)
|
||||
items = append(items, TenantListItem{Tenant: t, UserCount: uc, DeptCount: dc})
|
||||
}
|
||||
tp := int(total) / pageSize
|
||||
if int(total)%pageSize != 0 {
|
||||
tp++
|
||||
}
|
||||
return &TenantListResponse{Items: items, Total: total, Page: page, PageSize: pageSize, TotalPages: tp}, nil
|
||||
}
|
||||
@@ -0,0 +1,240 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"giter.top/smart/internal/iam/entity"
|
||||
"giter.top/smart/internal/iam/repository"
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// UserService 用户
|
||||
type UserService interface {
|
||||
Create(ctx context.Context, tenantID string, req *CreateUserRequest) (*entity.User, error)
|
||||
Update(ctx context.Context, tenantID string, uid string, req *UpdateUserRequest) (*entity.User, error)
|
||||
Delete(ctx context.Context, tenantID string, ids []string) error
|
||||
Get(ctx context.Context, tenantID string, uid string) (*entity.User, error)
|
||||
List(ctx context.Context, tenantID string, q *UserListQuery) (*UserListResponse, error)
|
||||
DataScopeForUser(ctx context.Context, userID string) (int16, error)
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
UserName string `json:"user_name" binding:"required,max=64"`
|
||||
Password string `json:"password" binding:"required,min=6,max=64"`
|
||||
RealName string `json:"real_name" binding:"max=64"`
|
||||
Phone string `json:"phone"`
|
||||
Email string `json:"email"`
|
||||
DeptID *string `json:"dept_id"`
|
||||
DeptIDs []string `json:"dept_ids"`
|
||||
RoleIDs []string `json:"role_ids"`
|
||||
}
|
||||
|
||||
type UpdateUserRequest struct {
|
||||
RealName *string `json:"real_name"`
|
||||
Phone *string `json:"phone"`
|
||||
Email *string `json:"email"`
|
||||
DeptID *string `json:"dept_id"`
|
||||
DeptIDs []string `json:"dept_ids"`
|
||||
RoleIDs []string `json:"role_ids"`
|
||||
Status *int16 `json:"status"`
|
||||
Password *string `json:"password"`
|
||||
}
|
||||
|
||||
type UserListQuery struct {
|
||||
DeptID *string
|
||||
RoleID *string
|
||||
Keyword string
|
||||
Status *int16
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
type UserListResponse struct {
|
||||
Items []entity.User `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
type userService struct {
|
||||
users repository.UserRepository
|
||||
roles repository.RoleRepository
|
||||
}
|
||||
|
||||
func NewUserService(users repository.UserRepository, roles repository.RoleRepository) UserService {
|
||||
return &userService{users: users, roles: roles}
|
||||
}
|
||||
|
||||
func (s *userService) Create(ctx context.Context, tenantID string, req *CreateUserRequest) (*entity.User, error) {
|
||||
ok, err := s.users.ExistsUserName(ctx, tenantID, req.UserName, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("账号已存在")
|
||||
}
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u := &entity.User{
|
||||
ID: id.New(),
|
||||
TenantID: tenantID,
|
||||
UserName: req.UserName,
|
||||
RealName: req.RealName,
|
||||
Phone: req.Phone,
|
||||
Email: req.Email,
|
||||
PasswordHash: string(hash),
|
||||
Status: 1,
|
||||
}
|
||||
if req.DeptID != nil {
|
||||
u.DeptID = req.DeptID
|
||||
}
|
||||
if err := s.users.Create(ctx, u); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
depts := req.DeptIDs
|
||||
primary := ""
|
||||
if req.DeptID != nil {
|
||||
primary = *req.DeptID
|
||||
}
|
||||
if len(depts) == 0 && primary != "" {
|
||||
depts = []string{primary}
|
||||
}
|
||||
if len(depts) > 0 {
|
||||
if primary == "" {
|
||||
primary = depts[0]
|
||||
}
|
||||
u.DeptID = &primary
|
||||
_ = s.users.Update(ctx, u)
|
||||
if err := s.users.ReplaceUserDepts(ctx, u.ID, primary, depts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if len(req.RoleIDs) > 0 {
|
||||
if err := s.users.ReplaceUserRoles(ctx, u.ID, req.RoleIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *userService) Update(ctx context.Context, tenantID string, uid string, req *UpdateUserRequest) (*entity.User, error) {
|
||||
u, err := s.users.GetByID(ctx, uid)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, fmt.Errorf("用户不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if u.TenantID != tenantID {
|
||||
return nil, fmt.Errorf("用户不属于当前租户")
|
||||
}
|
||||
if req.RealName != nil {
|
||||
u.RealName = *req.RealName
|
||||
}
|
||||
if req.Phone != nil {
|
||||
u.Phone = *req.Phone
|
||||
}
|
||||
if req.Email != nil {
|
||||
u.Email = *req.Email
|
||||
}
|
||||
if req.Status != nil {
|
||||
u.Status = *req.Status
|
||||
}
|
||||
if req.Password != nil && *req.Password != "" {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.PasswordHash = string(hash)
|
||||
}
|
||||
if err := s.users.Update(ctx, u); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.DeptIDs != nil || req.DeptID != nil {
|
||||
depts := req.DeptIDs
|
||||
primary := ""
|
||||
if req.DeptID != nil {
|
||||
primary = *req.DeptID
|
||||
u.DeptID = req.DeptID
|
||||
_ = s.users.Update(ctx, u)
|
||||
}
|
||||
if len(depts) == 0 && primary != "" {
|
||||
depts = []string{primary}
|
||||
}
|
||||
if primary == "" && len(depts) > 0 {
|
||||
primary = depts[0]
|
||||
}
|
||||
if err := s.users.ReplaceUserDepts(ctx, u.ID, primary, depts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if req.RoleIDs != nil {
|
||||
if err := s.users.ReplaceUserRoles(ctx, u.ID, req.RoleIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *userService) Delete(ctx context.Context, tenantID string, ids []string) error {
|
||||
for _, uid := range ids {
|
||||
u, err := s.users.GetByID(ctx, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.TenantID != tenantID {
|
||||
return fmt.Errorf("用户 %s 不属于当前租户", uid)
|
||||
}
|
||||
if err := s.users.Delete(ctx, uid); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userService) Get(ctx context.Context, tenantID string, uid string) (*entity.User, error) {
|
||||
u, err := s.users.GetByID(ctx, uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if u.TenantID != tenantID {
|
||||
return nil, fmt.Errorf("用户不属于当前租户")
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *userService) List(ctx context.Context, tenantID string, q *UserListQuery) (*UserListResponse, error) {
|
||||
if q.Page <= 0 {
|
||||
q.Page = 1
|
||||
}
|
||||
if q.PageSize <= 0 {
|
||||
q.PageSize = 10
|
||||
}
|
||||
rows, total, err := s.users.List(ctx, tenantID, q.DeptID, q.RoleID, q.Keyword, q.Status, q.Page, q.PageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tp := int(total) / q.PageSize
|
||||
if int(total)%q.PageSize != 0 {
|
||||
tp++
|
||||
}
|
||||
return &UserListResponse{Items: rows, Total: total, Page: q.Page, PageSize: q.PageSize, TotalPages: tp}, nil
|
||||
}
|
||||
|
||||
func (s *userService) DataScopeForUser(ctx context.Context, userID string) (int16, error) {
|
||||
roles, err := s.roles.ListRolesByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
scopes := make([]int16, 0, len(roles))
|
||||
for _, r := range roles {
|
||||
scopes = append(scopes, r.DataScope)
|
||||
}
|
||||
return MergeDataScope(scopes), nil
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package iam
|
||||
|
||||
import (
|
||||
"giter.top/smart/internal/iam/handler"
|
||||
"giter.top/smart/internal/iam/repository"
|
||||
"giter.top/smart/internal/iam/service"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// HandlerProviderSet 处理程序提供者集合
|
||||
var handlerProviderSet = wire.NewSet(
|
||||
handler.NewTenantHandler,
|
||||
handler.NewDeptHandler,
|
||||
handler.NewRoleHandler,
|
||||
handler.NewUserHandler,
|
||||
handler.NewMenuHandler,
|
||||
)
|
||||
|
||||
|
||||
// ServiceProviderSet 服务提供者集合
|
||||
var serviceProviderSet = wire.NewSet(
|
||||
service.NewTenantService,
|
||||
service.NewDeptService,
|
||||
service.NewRoleService,
|
||||
service.NewUserService,
|
||||
service.NewMenuService,
|
||||
)
|
||||
|
||||
|
||||
// RepositoryProviderSet 仓库提供者集合
|
||||
var repositoryProviderSet = wire.NewSet(
|
||||
repository.NewTenantRepository,
|
||||
repository.NewDeptRepository,
|
||||
repository.NewRoleRepository,
|
||||
repository.NewUserRepository,
|
||||
repository.NewMenuRepository,
|
||||
)
|
||||
|
||||
var ProviderSet = wire.NewSet(
|
||||
handlerProviderSet,
|
||||
serviceProviderSet,
|
||||
repositoryProviderSet,
|
||||
// 路由注册
|
||||
NewIamRoutes,
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// corsLocalDev 允许本机前端(localhost / 127.0.0.1 任意端口)跨域访问 API 与 OAuth;生产同域部署时可关闭或改为配置白名单。
|
||||
func corsLocalDev() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
o := c.GetHeader("Origin")
|
||||
if o != "" && isLocalDevOrigin(o) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", o)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Tenant-ID, X-User-ID, X-Grantor-User-ID")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
}
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func isLocalDevOrigin(o string) bool {
|
||||
return strings.HasPrefix(o, "http://localhost:") ||
|
||||
strings.HasPrefix(o, "http://127.0.0.1:") ||
|
||||
strings.HasPrefix(o, "http://[::1]:")
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"giter.top/smart/pkg/config"
|
||||
)
|
||||
|
||||
type GrpcServer struct {
|
||||
addr string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewGrpcServer(cfg *config.Config) *GrpcServer {
|
||||
return &GrpcServer{
|
||||
addr: cfg.Server.Grpc.Addr,
|
||||
timeout: cfg.Server.Grpc.Timeout,
|
||||
}
|
||||
}
|
||||
func (s *GrpcServer) Run() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GrpcServer) Stop() error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"giter.top/smart/internal/auth"
|
||||
"giter.top/smart/internal/iam"
|
||||
"giter.top/smart/internal/system"
|
||||
"giter.top/smart/pkg/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type HttpServer struct {
|
||||
addr string
|
||||
timeout time.Duration
|
||||
engine *gin.Engine
|
||||
}
|
||||
|
||||
func NewHttpServer(cfg *config.Config,
|
||||
engine *gin.Engine,
|
||||
) *HttpServer {
|
||||
return &HttpServer{
|
||||
addr: cfg.Server.Http.Addr,
|
||||
timeout: cfg.Server.Http.Timeout,
|
||||
engine: engine,
|
||||
}
|
||||
}
|
||||
func (s *HttpServer) Run() error {
|
||||
s.engine.Run(s.addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *HttpServer) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
type HttpRoutes interface {
|
||||
Register(engine *gin.Engine , apiGroup *gin.RouterGroup)
|
||||
}
|
||||
|
||||
func NewHttpEngine(cfg *config.Config,httpRoutes []HttpRoutes) *gin.Engine {
|
||||
engine := gin.Default()
|
||||
engine.Use(corsLocalDev())
|
||||
// 健康检查端点,供负载均衡或编排探活。
|
||||
engine.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
// 处理注册的路由
|
||||
apiGroup := engine.Group("/api/v1")
|
||||
for _, r := range httpRoutes {
|
||||
r.Register(engine, apiGroup)
|
||||
}
|
||||
return engine
|
||||
}
|
||||
|
||||
func NewHttpRouteRegistrars(
|
||||
authRoutes *auth.AuthRoutes,
|
||||
systemRoutes *system.SystemRoutes,
|
||||
iamRoutes *iam.IamRoutes,
|
||||
) []HttpRoutes {
|
||||
return []HttpRoutes{
|
||||
authRoutes,
|
||||
systemRoutes,
|
||||
iamRoutes,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"giter.top/smart/pkg/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewHttpEngine,
|
||||
ProvideServers,
|
||||
NewHttpRouteRegistrars,
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
Run() error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
func ProvideServers(cfg *config.Config, engine *gin.Engine) []Server {
|
||||
return []Server{
|
||||
NewHttpServer(cfg, engine),
|
||||
NewGrpcServer(cfg),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SystemParam 系统参数实体
|
||||
// 用于存储系统运行所需的各种配置参数,支持多种数据类型和分组管理
|
||||
type SystemParam struct {
|
||||
// ID 主键,使用 UUID v4 保证全局唯一性,避免自增 ID 带来的信息泄露风险
|
||||
ID string `json:"id" gorm:"column:id;type:varchar(36);primaryKey;not null;comment:主键"`
|
||||
|
||||
// ParamKey 参数键名,全局唯一,用于标识和访问参数值
|
||||
// 命名规范:小写字母 + 下划线,如:site_name, max_upload_size
|
||||
ParamKey string `json:"param_key" gorm:"column:param_key;type:varchar(100);uniqueIndex;not null;comment:参数键"`
|
||||
|
||||
// ParamValue 参数值,存储实际配置内容
|
||||
// 根据 ParamType 不同,可能是字符串、数字、布尔值或 JSON 数组
|
||||
ParamValue string `json:"param_value" gorm:"column:param_value;type:varchar(1000);not null;comment:参数值"`
|
||||
|
||||
// ParamType 参数类型,决定参数的校验规则和展示方式
|
||||
// 可选值:text(文本), number(数字), boolean(布尔), select(下拉选择)
|
||||
ParamType string `json:"param_type" gorm:"column:param_type;type:varchar(20);not null;default:'text';comment:类型:text,number,boolean,select"`
|
||||
|
||||
// ParamGroup 参数分组,用于对参数进行逻辑分组管理
|
||||
// 常见分组:basic(基础), security(安全), business(业务), system(系统)
|
||||
ParamGroup string `json:"param_group" gorm:"column:param_group;type:varchar(50);not null;default:'default';comment:分组"`
|
||||
|
||||
// ParamDesc 参数描述,说明该参数的用途、取值范围、默认值等信息
|
||||
// 建议包含:参数说明、可选值说明、修改影响等
|
||||
ParamDesc string `json:"param_desc" gorm:"column:param_desc;type:varchar(500);comment:描述"`
|
||||
|
||||
// CreatorID 创建人 ID,记录创建该参数的用户标识
|
||||
// 用于审计追踪,定位参数创建者
|
||||
CreatorID string `json:"creator_id" gorm:"column:creator_id;type:varchar(36);not null;default:'';comment:创建人 ID"`
|
||||
|
||||
// CreateTime 创建时间,记录参数创建的时间点
|
||||
// 使用指针类型,可以区分"未设置"和"已设置"状态
|
||||
// 数据库层面使用 CURRENT_TIMESTAMP 自动填充
|
||||
CreateTime *time.Time `json:"create_time" gorm:"column:create_time;type:datetime;default:current_timestamp;comment:创建时间"`
|
||||
|
||||
// LastUpdaterID 最后更新人 ID,记录最后一次修改该参数的用户标识
|
||||
// 用于审计追踪,定位参数修改者
|
||||
LastUpdaterID string `json:"last_updater_id" gorm:"column:last_updater_id;type:varchar(36);not null;default:'';comment:最后更新人 ID"`
|
||||
|
||||
// UpdateTime 最后更新时间,记录参数最后一次修改的时间点
|
||||
// 使用指针类型,可以区分"未设置"和"已设置"状态
|
||||
// 数据库层面使用 ON UPDATE CURRENT_TIMESTAMP 自动更新
|
||||
UpdateTime *time.Time `json:"update_time" gorm:"column:update_time;type:datetime;default:current_timestamp;on update current_timestamp;comment:最后更新时间"`
|
||||
}
|
||||
|
||||
// TableName 指定表名为 system_param
|
||||
// 遵循数据库命名规范:小写字母 + 下划线,复数形式
|
||||
func (SystemParam) TableName() string {
|
||||
return "system_param"
|
||||
}
|
||||
|
||||
// ParamType 参数类型常量
|
||||
// 定义系统支持的参数类型,用于前端展示和后端校验
|
||||
type ParamType string
|
||||
|
||||
const (
|
||||
// ParamTypeText 文本类型,适用于字符串值
|
||||
ParamTypeText ParamType = "text"
|
||||
// ParamTypeNumber 数字类型,适用于整数值
|
||||
ParamTypeNumber ParamType = "number"
|
||||
// ParamTypeBoolean 布尔类型,适用于 true/false 值
|
||||
ParamTypeBoolean ParamType = "boolean"
|
||||
// ParamTypeSelect 下拉选择类型,适用于预定义选项值
|
||||
ParamTypeSelect ParamType = "select"
|
||||
)
|
||||
|
||||
// ParamGroup 参数分组常量
|
||||
// 定义系统参数的逻辑分组,便于分类管理和权限控制
|
||||
type ParamGroup string
|
||||
|
||||
const (
|
||||
// GroupBasic 基础配置分组,包含系统基本信息
|
||||
// 如:站点名称、Logo、联系方式等
|
||||
GroupBasic ParamGroup = "basic"
|
||||
|
||||
// GroupSecurity 安全配置分组,包含安全相关参数
|
||||
// 如:密码策略、登录限制、Token 有效期等
|
||||
GroupSecurity ParamGroup = "security"
|
||||
|
||||
// GroupBusiness 业务配置分组,包含业务逻辑相关参数
|
||||
// 如:订单配置、支付参数、业务开关等
|
||||
GroupBusiness ParamGroup = "business"
|
||||
|
||||
// GroupSystem 系统配置分组,包含系统运行参数
|
||||
// 如:缓存配置、日志级别、性能参数等
|
||||
GroupSystem ParamGroup = "system"
|
||||
|
||||
// GroupDefault 默认分组,未明确分组的参数归入此类
|
||||
GroupDefault ParamGroup = "default"
|
||||
)
|
||||
@@ -0,0 +1,177 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"giter.top/smart/internal/system/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ParamHandler 系统参数 HTTP 处理器
|
||||
type ParamHandler struct {
|
||||
service service.ParamService
|
||||
}
|
||||
|
||||
// NewParamHandler 创建参数处理器实例
|
||||
func NewParamHandler(svc service.ParamService) *ParamHandler {
|
||||
return &ParamHandler{service: svc}
|
||||
}
|
||||
|
||||
// CreateParam 创建系统参数
|
||||
// @Summary 创建系统参数
|
||||
// @Tags 系统参数
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body service.CreateParamRequest true "创建参数请求"
|
||||
// @Success 201 {object} entity.SystemParam
|
||||
// @Router /api/v1/system/params [post]
|
||||
func (h *ParamHandler) CreateParam(c *gin.Context) {
|
||||
var req service.CreateParamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 从上下文获取用户 ID(实际项目中从 JWT token 解析)
|
||||
creatorID := "system"
|
||||
param, err := h.service.CreateParam(c.Request.Context(), &req, creatorID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, param)
|
||||
}
|
||||
|
||||
// UpdateParam 更新系统参数
|
||||
// @Summary 更新系统参数
|
||||
// @Tags 系统参数
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "参数 ID"
|
||||
// @Param request body service.UpdateParamRequest true "更新参数请求"
|
||||
// @Success 200 {object} entity.SystemParam
|
||||
// @Router /api/v1/system/params/{id} [put]
|
||||
func (h *ParamHandler) UpdateParam(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req service.UpdateParamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 从上下文获取用户 ID
|
||||
lastUpdaterID := "system"
|
||||
param, err := h.service.UpdateParam(c.Request.Context(), id, &req, lastUpdaterID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, param)
|
||||
}
|
||||
|
||||
// DeleteParams 批量删除系统参数
|
||||
// @Summary 批量删除系统参数
|
||||
// @Tags 系统参数
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body []string true "参数 ID 列表"
|
||||
// @Success 204
|
||||
// @Router /api/v1/system/params/batch [delete]
|
||||
func (h *ParamHandler) DeleteParams(c *gin.Context) {
|
||||
var ids []string
|
||||
if err := c.ShouldBindJSON(&ids); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.DeleteParams(c.Request.Context(), ids); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetParam 获取单个系统参数
|
||||
// @Summary 获取单个系统参数
|
||||
// @Tags 系统参数
|
||||
// @Produce json
|
||||
// @Param id path string true "参数 ID"
|
||||
// @Success 200 {object} entity.SystemParam
|
||||
// @Router /api/v1/system/params/{id} [get]
|
||||
func (h *ParamHandler) GetParam(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 ID"})
|
||||
return
|
||||
}
|
||||
|
||||
param, err := h.service.GetParam(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, param)
|
||||
}
|
||||
|
||||
// GetParamByKey 根据键获取系统参数
|
||||
// @Summary 根据键获取系统参数
|
||||
// @Tags 系统参数
|
||||
// @Produce json
|
||||
// @Param key path string true "参数键"
|
||||
// @Success 200 {object} entity.SystemParam
|
||||
// @Router /api/v1/system/params/key/{key} [get]
|
||||
func (h *ParamHandler) GetParamByKey(c *gin.Context) {
|
||||
key := c.Param("key")
|
||||
param, err := h.service.GetParamByKey(c.Request.Context(), key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, param)
|
||||
}
|
||||
|
||||
// ListParams 获取系统参数列表
|
||||
// @Summary 获取系统参数列表
|
||||
// @Tags 系统参数
|
||||
// @Produce json
|
||||
// @Param group query string false "分组"
|
||||
// @Param param_key query string false "参数键(模糊搜索)"
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Success 200 {object} service.ParamListResponse
|
||||
// @Router /api/v1/system/params [get]
|
||||
func (h *ParamHandler) ListParams(c *gin.Context) {
|
||||
group := c.Query("group")
|
||||
paramKey := c.Query("param_key")
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
|
||||
|
||||
response, err := h.service.ListParams(c.Request.Context(), group, paramKey, page, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetAllParams 获取所有系统参数
|
||||
// @Summary 获取所有系统参数
|
||||
// @Tags 系统参数
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]entity.SystemParam
|
||||
// @Router /api/v1/system/params/all [get]
|
||||
func (h *ParamHandler) GetAllParams(c *gin.Context) {
|
||||
params, err := h.service.GetAllParams(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, params)
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"giter.top/smart/internal/system/handler"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SystemRoutes 注册 system 模块的 HTTP 路由。
|
||||
type SystemRoutes struct {
|
||||
paramHandler *handler.ParamHandler
|
||||
}
|
||||
// NewSystemRoutes 构造 system 模块的路由注册器,由 Wire 注入。
|
||||
func NewSystemRoutes( paramHandler *handler.ParamHandler) *SystemRoutes {
|
||||
return &SystemRoutes{
|
||||
paramHandler: paramHandler,
|
||||
}
|
||||
}
|
||||
// TODO 添加注册信息
|
||||
func (s *SystemRoutes) Register(engine *gin.Engine, apiGroup *gin.RouterGroup) {
|
||||
group := apiGroup.Group("/system")
|
||||
s.registerParamRoutes(group)
|
||||
}
|
||||
// 系统参数路由
|
||||
func (s *SystemRoutes) registerParamRoutes(group *gin.RouterGroup) {
|
||||
paramGroup := group.Group("/param")
|
||||
{
|
||||
paramGroup.POST("/create", s.paramHandler.CreateParam)
|
||||
paramGroup.PUT("/update", s.paramHandler.UpdateParam)
|
||||
paramGroup.DELETE("/delete-batch", s.paramHandler.DeleteParams)
|
||||
paramGroup.GET("/get", s.paramHandler.GetParam)
|
||||
paramGroup.GET("/list", s.paramHandler.ListParams)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"giter.top/smart/internal/system/entity"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ErrNotFound 记录未找到
|
||||
var ErrNotFound = errors.New("param not found")
|
||||
|
||||
// ParamRepository 系统参数数据访问层
|
||||
type ParamRepository interface {
|
||||
// Create 创建系统参数
|
||||
Create(ctx context.Context, param *entity.SystemParam) error
|
||||
// Update 更新系统参数
|
||||
Update(ctx context.Context, param *entity.SystemParam) error
|
||||
// Delete 删除系统参数
|
||||
Delete(ctx context.Context, id string) error
|
||||
// DeleteBatch 批量删除
|
||||
DeleteBatch(ctx context.Context, ids []string) error
|
||||
// GetByID 根据 ID 获取
|
||||
GetByID(ctx context.Context, id string) (*entity.SystemParam, error)
|
||||
// GetByKey 根据键获取
|
||||
GetByKey(ctx context.Context, key string) (*entity.SystemParam, error)
|
||||
// List 获取列表(支持分页和筛选)
|
||||
List(ctx context.Context, group string, paramKey string, page, pageSize int) ([]entity.SystemParam, int64, error)
|
||||
// GetAll 获取所有参数(用于缓存)
|
||||
GetAll(ctx context.Context) (map[string]entity.SystemParam, error)
|
||||
// ExistsByKey 检查键是否存在(排除指定 ID)
|
||||
ExistsByKey(ctx context.Context, key string, excludeID string) (bool, error)
|
||||
}
|
||||
|
||||
type paramRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewParamRepository 创建参数仓库实例
|
||||
func NewParamRepository(db *gorm.DB) ParamRepository {
|
||||
return ¶mRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *paramRepository) Create(ctx context.Context, param *entity.SystemParam) error {
|
||||
return r.db.WithContext(ctx).Create(param).Error
|
||||
}
|
||||
|
||||
func (r *paramRepository) Update(ctx context.Context, param *entity.SystemParam) error {
|
||||
return r.db.WithContext(ctx).Save(param).Error
|
||||
}
|
||||
|
||||
func (r *paramRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&entity.SystemParam{}).Error
|
||||
}
|
||||
|
||||
func (r *paramRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Where("id IN ?", ids).Delete(&entity.SystemParam{}).Error
|
||||
}
|
||||
|
||||
func (r *paramRepository) GetByID(ctx context.Context, id string) (*entity.SystemParam, error) {
|
||||
var param entity.SystemParam
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(¶m).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ¶m, nil
|
||||
}
|
||||
|
||||
func (r *paramRepository) GetByKey(ctx context.Context, key string) (*entity.SystemParam, error) {
|
||||
var param entity.SystemParam
|
||||
err := r.db.WithContext(ctx).Where("param_key = ?", key).First(¶m).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ¶m, nil
|
||||
}
|
||||
|
||||
func (r *paramRepository) List(ctx context.Context, group string, paramKey string, page, pageSize int) ([]entity.SystemParam, int64, error) {
|
||||
var params []entity.SystemParam
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&entity.SystemParam{})
|
||||
|
||||
// 应用筛选条件
|
||||
if group != "" {
|
||||
query = query.Where("param_group = ?", group)
|
||||
}
|
||||
if paramKey != "" {
|
||||
query = query.Where("param_key LIKE ?", "%"+paramKey+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 10
|
||||
}
|
||||
err := query.Order("id DESC").Offset(offset).Limit(pageSize).Find(¶ms).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return params, total, nil
|
||||
}
|
||||
|
||||
func (r *paramRepository) GetAll(ctx context.Context) (map[string]entity.SystemParam, error) {
|
||||
var params []entity.SystemParam
|
||||
err := r.db.WithContext(ctx).Find(¶ms).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]entity.SystemParam, len(params))
|
||||
for _, param := range params {
|
||||
result[param.ParamKey] = param
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *paramRepository) ExistsByKey(ctx context.Context, key string, excludeID string) (bool, error) {
|
||||
query := r.db.WithContext(ctx).Where("param_key = ?", key)
|
||||
if excludeID != "" {
|
||||
query = query.Where("id != ?", excludeID)
|
||||
}
|
||||
var count int64
|
||||
err := query.Model(&entity.SystemParam{}).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"giter.top/smart/internal/system/entity"
|
||||
"giter.top/smart/internal/system/repository"
|
||||
"giter.top/smart/pkg/utils/id"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ErrInvalidParam 参数无效
|
||||
var ErrInvalidParam = errors.New("invalid param")
|
||||
|
||||
// ParamService 系统参数业务逻辑层
|
||||
type ParamService interface {
|
||||
// CreateParam 创建系统参数
|
||||
CreateParam(ctx context.Context, req *CreateParamRequest, creatorID string) (*entity.SystemParam, error)
|
||||
// UpdateParam 更新系统参数
|
||||
UpdateParam(ctx context.Context, id string, req *UpdateParamRequest, lastUpdaterID string) (*entity.SystemParam, error)
|
||||
// DeleteParam 删除系统参数
|
||||
DeleteParam(ctx context.Context, id string) error
|
||||
// DeleteParams 批量删除
|
||||
DeleteParams(ctx context.Context, ids []string) error
|
||||
// GetParam 获取单个参数
|
||||
GetParam(ctx context.Context, id string) (*entity.SystemParam, error)
|
||||
// GetParamByKey 根据键获取参数
|
||||
GetParamByKey(ctx context.Context, key string) (*entity.SystemParam, error)
|
||||
// ListParams 获取参数列表
|
||||
ListParams(ctx context.Context, group string, paramKey string, page, pageSize int) (*ParamListResponse, error)
|
||||
// GetAllParams 获取所有参数(用于缓存)
|
||||
GetAllParams(ctx context.Context) (map[string]entity.SystemParam, error)
|
||||
// GetParamValue 获取参数值(便捷方法)
|
||||
GetParamValue(ctx context.Context, key string) (string, error)
|
||||
// GetParamValueWithDefault 获取参数值,不存在则返回默认值
|
||||
GetParamValueWithDefault(ctx context.Context, key string, defaultValue string) string
|
||||
}
|
||||
|
||||
// CreateParamRequest 创建参数请求
|
||||
type CreateParamRequest struct {
|
||||
ParamKey string `json:"param_key" binding:"required,max=100"`
|
||||
ParamValue string `json:"param_value" binding:"required"`
|
||||
ParamType string `json:"param_type" binding:"required,oneof=text number boolean select"`
|
||||
ParamGroup string `json:"param_group" binding:"required,max:50"`
|
||||
ParamDesc string `json:"param_desc" max:"500"`
|
||||
}
|
||||
|
||||
// UpdateParamRequest 更新参数请求
|
||||
type UpdateParamRequest struct {
|
||||
ParamValue string `json:"param_value"`
|
||||
ParamType string `json:"param_type" binding:"omitempty,oneof=text number boolean select"`
|
||||
ParamDesc string `json:"param_desc" max:"500"`
|
||||
}
|
||||
|
||||
// ParamListResponse 参数列表响应
|
||||
type ParamListResponse struct {
|
||||
Items []entity.SystemParam `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
type paramService struct {
|
||||
repo repository.ParamRepository
|
||||
cache redis.UniversalClient
|
||||
cacheKey string
|
||||
}
|
||||
|
||||
// NewParamService 创建参数服务实例(与 cache.NewRedisClient 返回的 redis.UniversalClient 一致,便于 Wire 注入)
|
||||
func NewParamService(repo repository.ParamRepository, cacheClient redis.UniversalClient) ParamService {
|
||||
return ¶mService{
|
||||
repo: repo,
|
||||
cache: cacheClient,
|
||||
cacheKey: "system:params:*",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *paramService) CreateParam(ctx context.Context, req *CreateParamRequest, creatorID string) (*entity.SystemParam, error) {
|
||||
// 生成唯一 ID (UUID v7)
|
||||
id := id.New()
|
||||
|
||||
// 检查键是否已存在
|
||||
exists, err := s.repo.ExistsByKey(ctx, req.ParamKey, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("检查参数键失败:%w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, fmt.Errorf("参数键 %s 已存在", req.ParamKey)
|
||||
}
|
||||
|
||||
param := &entity.SystemParam{
|
||||
ID: id,
|
||||
ParamKey: req.ParamKey,
|
||||
ParamValue: req.ParamValue,
|
||||
ParamType: req.ParamType,
|
||||
ParamGroup: req.ParamGroup,
|
||||
ParamDesc: req.ParamDesc,
|
||||
CreatorID: creatorID,
|
||||
LastUpdaterID: creatorID,
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, param); err != nil {
|
||||
return nil, fmt.Errorf("创建参数失败:%w", err)
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.refreshCache(ctx)
|
||||
return param, nil
|
||||
}
|
||||
|
||||
func (s *paramService) UpdateParam(ctx context.Context, id string, req *UpdateParamRequest, lastUpdaterID string) (*entity.SystemParam, error) {
|
||||
// 获取现有参数
|
||||
param, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, fmt.Errorf("参数不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("获取参数失败:%w", err)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.ParamValue != "" {
|
||||
param.ParamValue = req.ParamValue
|
||||
}
|
||||
if req.ParamType != "" {
|
||||
param.ParamType = req.ParamType
|
||||
}
|
||||
if req.ParamDesc != "" {
|
||||
param.ParamDesc = req.ParamDesc
|
||||
}
|
||||
|
||||
param.LastUpdaterID = lastUpdaterID
|
||||
if err := s.repo.Update(ctx, param); err != nil {
|
||||
return nil, fmt.Errorf("更新参数失败:%w", err)
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.refreshCache(ctx)
|
||||
return param, nil
|
||||
}
|
||||
|
||||
func (s *paramService) DeleteParam(ctx context.Context, id string) error {
|
||||
if err := s.repo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("删除参数失败:%w", err)
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.refreshCache(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *paramService) DeleteParams(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := s.repo.DeleteBatch(ctx, ids); err != nil {
|
||||
return fmt.Errorf("批量删除参数失败:%w", err)
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.refreshCache(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *paramService) GetParam(ctx context.Context, id string) (*entity.SystemParam, error) {
|
||||
param, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, fmt.Errorf("参数不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return param, nil
|
||||
}
|
||||
|
||||
func (s *paramService) GetParamByKey(ctx context.Context, key string) (*entity.SystemParam, error) {
|
||||
param, err := s.repo.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, fmt.Errorf("参数 %s 不存在", key)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return param, nil
|
||||
}
|
||||
|
||||
func (s *paramService) ListParams(ctx context.Context, group string, paramKey string, page, pageSize int) (*ParamListResponse, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 10
|
||||
}
|
||||
|
||||
items, total, err := s.repo.List(ctx, group, paramKey, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取参数列表失败:%w", err)
|
||||
}
|
||||
|
||||
totalPages := int(total) / pageSize
|
||||
if int(total)%pageSize != 0 {
|
||||
totalPages++
|
||||
}
|
||||
|
||||
return &ParamListResponse{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *paramService) GetAllParams(ctx context.Context) (map[string]entity.SystemParam, error) {
|
||||
// 先从缓存获取
|
||||
if s.cache != nil {
|
||||
cached := s.cache.Get(ctx, "system:params:all").Val()
|
||||
if cached != "" {
|
||||
var params map[string]entity.SystemParam
|
||||
if err := json.Unmarshal([]byte(cached), ¶ms); err == nil {
|
||||
return params, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("解析缓存数据失败:%w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库获取
|
||||
params, err := s.repo.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 写入缓存
|
||||
if s.cache != nil {
|
||||
data, _ := json.Marshal(params)
|
||||
s.cache.Set(ctx, "system:params:all", string(data), 0) // 0 表示永不过期
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (s *paramService) GetParamValue(ctx context.Context, key string) (string, error) {
|
||||
param, err := s.GetParamByKey(ctx, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return param.ParamValue, nil
|
||||
}
|
||||
|
||||
func (s *paramService) GetParamValueWithDefault(ctx context.Context, key string, defaultValue string) string {
|
||||
value, err := s.GetParamValue(ctx, key)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// refreshCache 刷新缓存
|
||||
func (s *paramService) refreshCache(ctx context.Context) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 删除缓存,让下次请求重新构建
|
||||
s.cache.Del(ctx, "system:params:all")
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"giter.top/smart/internal/system/handler"
|
||||
"giter.top/smart/internal/system/repository"
|
||||
"giter.top/smart/internal/system/service"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// HandlerProviderSet 处理程序提供者集合
|
||||
var handlerProviderSet = wire.NewSet(
|
||||
handler.NewParamHandler,
|
||||
)
|
||||
|
||||
|
||||
// ServiceProviderSet 服务提供者集合
|
||||
var serviceProviderSet = wire.NewSet(
|
||||
service.NewParamService,
|
||||
)
|
||||
|
||||
|
||||
// RepositoryProviderSet 仓库提供者集合
|
||||
var repositoryProviderSet = wire.NewSet(
|
||||
repository.NewParamRepository,
|
||||
)
|
||||
|
||||
var ProviderSet = wire.NewSet(
|
||||
handlerProviderSet,
|
||||
serviceProviderSet,
|
||||
repositoryProviderSet,
|
||||
NewSystemRoutes,
|
||||
)
|
||||
Reference in New Issue
Block a user