package repository import ( "context" "giter.top/smart/internal/iam/entity" "gorm.io/gorm" ) // TenantRepository 租户数据访问 type TenantRepository interface { Create(ctx context.Context, t *entity.Tenant) error Update(ctx context.Context, t *entity.Tenant) error GetByID(ctx context.Context, id string) (*entity.Tenant, error) GetByCode(ctx context.Context, code string) (*entity.Tenant, error) List(ctx context.Context, name, code string, status *int16, page, pageSize int) ([]entity.Tenant, int64, error) CountUsers(ctx context.Context, tenantID string) (int64, error) CountDepts(ctx context.Context, tenantID string) (int64, error) ExistsCode(ctx context.Context, code string, excludeID string) (bool, error) } type tenantRepository struct { db *gorm.DB } func NewTenantRepository(db *gorm.DB) TenantRepository { return &tenantRepository{db: db} } func (r *tenantRepository) Create(ctx context.Context, t *entity.Tenant) error { return r.db.WithContext(ctx).Create(t).Error } func (r *tenantRepository) Update(ctx context.Context, t *entity.Tenant) error { return r.db.WithContext(ctx).Save(t).Error } func (r *tenantRepository) GetByID(ctx context.Context, id string) (*entity.Tenant, error) { var out entity.Tenant err := r.db.WithContext(ctx).Where("id = ?", id).First(&out).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrNotFound } return nil, err } return &out, nil } func (r *tenantRepository) GetByCode(ctx context.Context, code string) (*entity.Tenant, error) { var out entity.Tenant err := r.db.WithContext(ctx).Where("tenant_code = ?", code).First(&out).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrNotFound } return nil, err } return &out, nil } func (r *tenantRepository) List(ctx context.Context, name, code string, status *int16, page, pageSize int) ([]entity.Tenant, int64, error) { q := r.db.WithContext(ctx).Model(&entity.Tenant{}) if name != "" { q = q.Where("tenant_name LIKE ?", "%"+name+"%") } if code != "" { q = q.Where("tenant_code LIKE ?", "%"+code+"%") } if status != nil { q = q.Where("status = ?", *status) } var total int64 if err := q.Count(&total).Error; err != nil { return nil, 0, err } if page <= 0 { page = 1 } if pageSize <= 0 { pageSize = 10 } offset := (page - 1) * pageSize var rows []entity.Tenant err := q.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&rows).Error return rows, total, err } func (r *tenantRepository) CountUsers(ctx context.Context, tenantID string) (int64, error) { var n int64 err := r.db.WithContext(ctx).Model(&entity.User{}).Where("tenant_id = ?", tenantID).Count(&n).Error return n, err } func (r *tenantRepository) CountDepts(ctx context.Context, tenantID string) (int64, error) { var n int64 err := r.db.WithContext(ctx).Model(&entity.Dept{}).Where("tenant_id = ?", tenantID).Count(&n).Error return n, err } func (r *tenantRepository) ExistsCode(ctx context.Context, code string, excludeID string) (bool, error) { q := r.db.WithContext(ctx).Model(&entity.Tenant{}).Where("tenant_code = ?", code) if excludeID != "" { q = q.Where("id <> ?", excludeID) } var n int64 err := q.Count(&n).Error return n > 0, err }