Files
smart-go/internal/iam/repository/user_repository.go
T
2026-04-23 18:58:13 +08:00

180 lines
5.8 KiB
Go

package repository
import (
"context"
"giter.top/smart/internal/iam/entity"
"giter.top/smart/pkg/utils/id"
"gorm.io/gorm"
)
// UserRepository 用户数据访问
type UserRepository interface {
Create(ctx context.Context, u *entity.User) error
Update(ctx context.Context, u *entity.User) error
Delete(ctx context.Context, id string) error
GetByID(ctx context.Context, id string) (*entity.User, error)
GetByUserName(ctx context.Context, tenantID string, userName string) (*entity.User, error)
ExistsUserName(ctx context.Context, tenantID string, userName string, excludeID string) (bool, error)
CountByDept(ctx context.Context, deptID string) (int64, error)
List(ctx context.Context, tenantID string, deptID *string, roleID *string, keyword string, status *int16, page, pageSize int) ([]entity.User, int64, error)
ReplaceUserDepts(ctx context.Context, userID string, primaryDept string, deptIDs []string) error
ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string) error
ListRoleIDs(ctx context.Context, userID string) ([]string, error)
ListDeptIDs(ctx context.Context, userID string) ([]string, error)
}
type userRepository struct {
db *gorm.DB
}
func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepository{db: db}
}
func (r *userRepository) Create(ctx context.Context, u *entity.User) error {
return r.db.WithContext(ctx).Create(u).Error
}
func (r *userRepository) Update(ctx context.Context, u *entity.User) error {
return r.db.WithContext(ctx).Save(u).Error
}
func (r *userRepository) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entity.User{}, "id = ?", id).Error
}
func (r *userRepository) GetByID(ctx context.Context, id string) (*entity.User, error) {
var out entity.User
err := r.db.WithContext(ctx).Where("id = ?", id).First(&out).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrNotFound
}
return nil, err
}
return &out, nil
}
func (r *userRepository) GetByUserName(ctx context.Context, tenantID string, userName string) (*entity.User, error) {
var out entity.User
err := r.db.WithContext(ctx).Where("tenant_id = ? AND user_name = ?", tenantID, userName).First(&out).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrNotFound
}
return nil, err
}
return &out, nil
}
func (r *userRepository) ExistsUserName(ctx context.Context, tenantID string, userName string, excludeID string) (bool, error) {
q := r.db.WithContext(ctx).Model(&entity.User{}).Where("tenant_id = ? AND user_name = ?", tenantID, userName)
if excludeID != "" {
q = q.Where("id <> ?", excludeID)
}
var n int64
err := q.Count(&n).Error
return n > 0, err
}
func (r *userRepository) CountByDept(ctx context.Context, deptID string) (int64, error) {
var n int64
err := r.db.WithContext(ctx).Raw(`
SELECT COUNT(*) FROM (
SELECT id FROM iam_user WHERE dept_id = ? AND deleted_at IS NULL
UNION
SELECT user_id FROM iam_user_dept WHERE dept_id = ?
) t`, deptID, deptID).Scan(&n).Error
return n, err
}
func (r *userRepository) List(ctx context.Context, tenantID string, deptID *string, roleID *string, keyword string, status *int16, page, pageSize int) ([]entity.User, int64, error) {
q := r.db.WithContext(ctx).Model(&entity.User{}).Where("tenant_id = ?", tenantID)
if deptID != nil {
d := *deptID
q = q.Where("dept_id = ? OR id IN (SELECT user_id FROM iam_user_dept WHERE dept_id = ?)", d, d)
}
if roleID != nil {
sub := r.db.WithContext(ctx).Model(&entity.UserRole{}).Select("user_id").Where("role_id = ?", *roleID)
q = q.Where("id IN (?)", sub)
}
if keyword != "" {
kw := "%" + keyword + "%"
q = q.Where("user_name LIKE ? OR real_name LIKE ? OR phone LIKE ? OR email LIKE ?", kw, kw, kw, kw)
}
if status != nil {
q = q.Where("status = ?", *status)
}
var total int64
if err := q.Count(&total).Error; err != nil {
return nil, 0, err
}
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 10
}
offset := (page - 1) * pageSize
var rows []entity.User
err := q.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&rows).Error
return rows, total, err
}
func (r *userRepository) ReplaceUserDepts(ctx context.Context, userID string, primaryDept string, deptIDs []string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("user_id = ?", userID).Delete(&entity.UserDept{}).Error; err != nil {
return err
}
seen := map[string]struct{}{}
for _, did := range deptIDs {
if _, ok := seen[did]; ok {
continue
}
seen[did] = struct{}{}
ud := entity.UserDept{
ID: id.New(),
UserID: userID,
DeptID: did,
IsPrimary: did == primaryDept,
}
if err := tx.Create(&ud).Error; err != nil {
return err
}
}
if len(deptIDs) == 0 && primaryDept != "" {
ud := entity.UserDept{ID: id.New(), UserID: userID, DeptID: primaryDept, IsPrimary: true}
return tx.Create(&ud).Error
}
return nil
})
}
func (r *userRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("user_id = ?", userID).Delete(&entity.UserRole{}).Error; err != nil {
return err
}
for _, rid := range roleIDs {
ur := entity.UserRole{ID: id.New(), UserID: userID, RoleID: rid}
if err := tx.Create(&ur).Error; err != nil {
return err
}
}
return nil
})
}
func (r *userRepository) ListRoleIDs(ctx context.Context, userID string) ([]string, error) {
var ids []string
err := r.db.WithContext(ctx).Model(&entity.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &ids).Error
return ids, err
}
func (r *userRepository) ListDeptIDs(ctx context.Context, userID string) ([]string, error) {
var ids []string
err := r.db.WithContext(ctx).Model(&entity.UserDept{}).Where("user_id = ?", userID).Pluck("dept_id", &ids).Error
return ids, err
}