Files
smart-go/internal/iam/repository/role_repository.go
T
2026-04-23 18:58:13 +08:00

148 lines
4.6 KiB
Go

package repository
import (
"context"
"giter.top/smart/internal/iam/entity"
"giter.top/smart/pkg/utils/id"
"gorm.io/gorm"
)
// RoleRepository 角色与角色菜单
type RoleRepository interface {
Create(ctx context.Context, r *entity.Role) error
Update(ctx context.Context, r *entity.Role) error
Delete(ctx context.Context, id string) error
GetByID(ctx context.Context, id string) (*entity.Role, error)
List(ctx context.Context, tenantID string, name, code string, page, pageSize int) ([]entity.Role, int64, error)
ExistsCode(ctx context.Context, tenantID string, code string, excludeID string) (bool, error)
CountUsers(ctx context.Context, roleID string) (int64, error)
ReplaceRoleMenus(ctx context.Context, roleID string, menuIDs []string) error
ListMenuIDsByRole(ctx context.Context, roleID string) ([]string, error)
ListMenuIDsByRoles(ctx context.Context, roleIDs []string) ([]string, error)
ListRolesByUser(ctx context.Context, userID string) ([]entity.Role, error)
}
type roleRepository struct {
db *gorm.DB
}
func NewRoleRepository(db *gorm.DB) RoleRepository {
return &roleRepository{db: db}
}
func (r *roleRepository) Create(ctx context.Context, row *entity.Role) error {
return r.db.WithContext(ctx).Create(row).Error
}
func (r *roleRepository) Update(ctx context.Context, row *entity.Role) error {
return r.db.WithContext(ctx).Save(row).Error
}
func (r *roleRepository) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entity.Role{}, "id = ?", id).Error
}
func (r *roleRepository) GetByID(ctx context.Context, id string) (*entity.Role, error) {
var out entity.Role
err := r.db.WithContext(ctx).Where("id = ?", id).First(&out).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrNotFound
}
return nil, err
}
return &out, nil
}
func (r *roleRepository) List(ctx context.Context, tenantID string, name, code string, page, pageSize int) ([]entity.Role, int64, error) {
q := r.db.WithContext(ctx).Model(&entity.Role{}).Where("tenant_id = ?", tenantID)
if name != "" {
q = q.Where("role_name LIKE ?", "%"+name+"%")
}
if code != "" {
q = q.Where("role_code LIKE ?", "%"+code+"%")
}
var total int64
if err := q.Count(&total).Error; err != nil {
return nil, 0, err
}
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 10
}
offset := (page - 1) * pageSize
var rows []entity.Role
err := q.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&rows).Error
return rows, total, err
}
func (r *roleRepository) ExistsCode(ctx context.Context, tenantID string, code string, excludeID string) (bool, error) {
q := r.db.WithContext(ctx).Model(&entity.Role{}).Where("tenant_id = ? AND role_code = ?", tenantID, code)
if excludeID != "" {
q = q.Where("id <> ?", excludeID)
}
var n int64
err := q.Count(&n).Error
return n > 0, err
}
func (r *roleRepository) CountUsers(ctx context.Context, roleID string) (int64, error) {
var n int64
err := r.db.WithContext(ctx).Model(&entity.UserRole{}).Where("role_id = ?", roleID).Count(&n).Error
return n, err
}
func (r *roleRepository) ReplaceRoleMenus(ctx context.Context, roleID string, menuIDs []string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("role_id = ?", roleID).Delete(&entity.RoleMenu{}).Error; err != nil {
return err
}
for _, mid := range menuIDs {
rm := entity.RoleMenu{ID: id.New(), RoleID: roleID, MenuID: mid}
if err := tx.Create(&rm).Error; err != nil {
return err
}
}
return nil
})
}
func (r *roleRepository) ListMenuIDsByRole(ctx context.Context, roleID string) ([]string, error) {
var ids []string
err := r.db.WithContext(ctx).Model(&entity.RoleMenu{}).Where("role_id = ?", roleID).Pluck("menu_id", &ids).Error
return ids, err
}
func (r *roleRepository) ListMenuIDsByRoles(ctx context.Context, roleIDs []string) ([]string, error) {
if len(roleIDs) == 0 {
return nil, nil
}
var raw []string
err := r.db.WithContext(ctx).Model(&entity.RoleMenu{}).Where("role_id IN ?", roleIDs).Pluck("menu_id", &raw).Error
if err != nil {
return nil, err
}
seen := make(map[string]struct{}, len(raw))
var ids []string
for _, menuID := range raw {
if _, ok := seen[menuID]; ok {
continue
}
seen[menuID] = struct{}{}
ids = append(ids, menuID)
}
return ids, nil
}
func (r *roleRepository) ListRolesByUser(ctx context.Context, userID string) ([]entity.Role, error) {
var roles []entity.Role
err := r.db.WithContext(ctx).Table("iam_role").
Joins("JOIN iam_user_role ur ON ur.role_id = iam_role.id").
Where("ur.user_id = ?", userID).
Find(&roles).Error
return roles, err
}