feat: 优化web
This commit is contained in:
Vendored
+184
@@ -0,0 +1,184 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"giter.top/smart/pkg/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// NewRedis 根据配置创建 Redis 客户端,支持单机、哨兵、集群三种模式。
|
||||
func NewRedis(cfg *config.Config) (redis.UniversalClient) {
|
||||
if cfg == nil {
|
||||
panic("cache: config is nil")
|
||||
}
|
||||
r := cfg.Data.Redis
|
||||
mode := strings.ToLower(strings.TrimSpace(r.Mode))
|
||||
if mode == "" {
|
||||
mode = "standalone"
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case "standalone":
|
||||
addr, err := standaloneAddr(cfg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
opt := &redis.Options{
|
||||
Addr: addr,
|
||||
DB: r.DB,
|
||||
}
|
||||
applyCommonToClient(opt, cfg)
|
||||
return redis.NewClient(opt)
|
||||
|
||||
case "sentinel":
|
||||
if strings.TrimSpace(r.MasterName) == "" {
|
||||
panic("cache: redis sentinel requires master_name")
|
||||
}
|
||||
if len(r.Addrs) == 0 {
|
||||
panic("cache: redis sentinel requires addrs (sentinel 节点列表)")
|
||||
}
|
||||
opt := &redis.FailoverOptions{
|
||||
MasterName: r.MasterName,
|
||||
SentinelAddrs: r.Addrs,
|
||||
DB: r.DB,
|
||||
}
|
||||
applyCommonToFailover(opt, cfg)
|
||||
return redis.NewFailoverClient(opt)
|
||||
|
||||
case "cluster":
|
||||
if len(r.Addrs) == 0 {
|
||||
panic("cache: redis cluster requires addrs")
|
||||
}
|
||||
opt := &redis.ClusterOptions{
|
||||
Addrs: r.Addrs,
|
||||
}
|
||||
applyCommonToCluster(opt, cfg)
|
||||
return redis.NewClusterClient(opt)
|
||||
|
||||
default:
|
||||
panic(fmt.Sprintf("cache: unsupported redis mode %q", r.Mode))
|
||||
}
|
||||
}
|
||||
|
||||
func standaloneAddr(cfg *config.Config) (string, error) {
|
||||
r := cfg.Data.Redis
|
||||
if strings.TrimSpace(r.Addr) != "" {
|
||||
return r.Addr, nil
|
||||
}
|
||||
if len(r.Addrs) > 0 && strings.TrimSpace(r.Addrs[0]) != "" {
|
||||
return r.Addrs[0], nil
|
||||
}
|
||||
return "", errors.New("cache: redis standalone requires addr or addrs[0]")
|
||||
}
|
||||
|
||||
// Ping 用于启动时探测连接是否可用。
|
||||
func Ping(ctx context.Context, c redis.UniversalClient) error {
|
||||
if c == nil {
|
||||
return errors.New("cache: redis client is nil")
|
||||
}
|
||||
return c.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
func applyCommonToClient(opt *redis.Options, cfg *config.Config) {
|
||||
r := cfg.Data.Redis
|
||||
opt.Username = r.Username
|
||||
opt.Password = r.Password
|
||||
if r.PoolSize > 0 {
|
||||
opt.PoolSize = r.PoolSize
|
||||
}
|
||||
if r.MinIdleConns > 0 {
|
||||
opt.MinIdleConns = r.MinIdleConns
|
||||
}
|
||||
if r.MaxRetries != 0 {
|
||||
opt.MaxRetries = r.MaxRetries
|
||||
}
|
||||
if r.RetryDelay > 0 {
|
||||
opt.MinRetryBackoff = r.RetryDelay
|
||||
}
|
||||
if r.RetryMaxDelay > 0 {
|
||||
opt.MaxRetryBackoff = r.RetryMaxDelay
|
||||
}
|
||||
if r.DialTimeout > 0 {
|
||||
opt.DialTimeout = r.DialTimeout
|
||||
}
|
||||
if r.ReadTimeout > 0 {
|
||||
opt.ReadTimeout = r.ReadTimeout
|
||||
}
|
||||
if r.WriteTimeout > 0 {
|
||||
opt.WriteTimeout = r.WriteTimeout
|
||||
}
|
||||
if r.IdleTimeout > 0 {
|
||||
opt.ConnMaxIdleTime = r.IdleTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func applyCommonToFailover(opt *redis.FailoverOptions, cfg *config.Config) {
|
||||
r := cfg.Data.Redis
|
||||
opt.Username = r.Username
|
||||
opt.Password = r.Password
|
||||
if r.PoolSize > 0 {
|
||||
opt.PoolSize = r.PoolSize
|
||||
}
|
||||
if r.MinIdleConns > 0 {
|
||||
opt.MinIdleConns = r.MinIdleConns
|
||||
}
|
||||
if r.MaxRetries != 0 {
|
||||
opt.MaxRetries = r.MaxRetries
|
||||
}
|
||||
if r.RetryDelay > 0 {
|
||||
opt.MinRetryBackoff = r.RetryDelay
|
||||
}
|
||||
if r.RetryMaxDelay > 0 {
|
||||
opt.MaxRetryBackoff = r.RetryMaxDelay
|
||||
}
|
||||
if r.DialTimeout > 0 {
|
||||
opt.DialTimeout = r.DialTimeout
|
||||
}
|
||||
if r.ReadTimeout > 0 {
|
||||
opt.ReadTimeout = r.ReadTimeout
|
||||
}
|
||||
if r.WriteTimeout > 0 {
|
||||
opt.WriteTimeout = r.WriteTimeout
|
||||
}
|
||||
if r.IdleTimeout > 0 {
|
||||
opt.ConnMaxIdleTime = r.IdleTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func applyCommonToCluster(opt *redis.ClusterOptions, cfg *config.Config) {
|
||||
r := cfg.Data.Redis
|
||||
opt.Username = r.Username
|
||||
opt.Password = r.Password
|
||||
if r.PoolSize > 0 {
|
||||
opt.PoolSize = r.PoolSize
|
||||
}
|
||||
if r.MinIdleConns > 0 {
|
||||
opt.MinIdleConns = r.MinIdleConns
|
||||
}
|
||||
if r.MaxRetries != 0 {
|
||||
opt.MaxRetries = r.MaxRetries
|
||||
}
|
||||
if r.RetryDelay > 0 {
|
||||
opt.MinRetryBackoff = r.RetryDelay
|
||||
}
|
||||
if r.RetryMaxDelay > 0 {
|
||||
opt.MaxRetryBackoff = r.RetryMaxDelay
|
||||
}
|
||||
if r.DialTimeout > 0 {
|
||||
opt.DialTimeout = r.DialTimeout
|
||||
}
|
||||
if r.ReadTimeout > 0 {
|
||||
opt.ReadTimeout = r.ReadTimeout
|
||||
}
|
||||
if r.WriteTimeout > 0 {
|
||||
opt.WriteTimeout = r.WriteTimeout
|
||||
}
|
||||
if r.IdleTimeout > 0 {
|
||||
opt.ConnMaxIdleTime = r.IdleTimeout
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server struct {
|
||||
Http struct {
|
||||
Addr string `mapstructure:"addr"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
} `mapstructure:"http"`
|
||||
Grpc struct {
|
||||
Addr string `mapstructure:"addr"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
} `mapstructure:"grpc"`
|
||||
} `mapstructure:"server"`
|
||||
Data struct {
|
||||
Database struct {
|
||||
Driver string `mapstructure:"driver"`
|
||||
DSN string `mapstructure:"dsn"`
|
||||
} `mapstructure:"database"`
|
||||
Redis struct {
|
||||
// Mode: standalone(单机)、sentinel(哨兵)、cluster(集群)
|
||||
Mode string `mapstructure:"mode"`
|
||||
Addr string `mapstructure:"addr"`
|
||||
Addrs []string `mapstructure:"addrs"`
|
||||
Password string `mapstructure:"password"`
|
||||
Username string `mapstructure:"username"`
|
||||
DB int `mapstructure:"db"`
|
||||
MasterName string `mapstructure:"master_name"`
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
// MinIdleConns 最小空闲连接数
|
||||
MinIdleConns int `mapstructure:"min_idle_conns"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
RetryDelay time.Duration `mapstructure:"retry_delay"`
|
||||
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
|
||||
DialTimeout time.Duration `mapstructure:"dial_timeout"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout"`
|
||||
WriteTimeout time.Duration `mapstructure:"write_timeout"`
|
||||
// IdleTimeout 映射为 go-redis ConnMaxIdleTime
|
||||
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
|
||||
} `mapstructure:"redis"`
|
||||
} `mapstructure:"data"`
|
||||
// Auth 认证域(OAuth2、会话等);PublicBaseURL 为浏览器可访问的后端根 URL(用于登录回跳拼接 /oauth/authorize)
|
||||
Auth struct {
|
||||
PublicBaseURL string `mapstructure:"public_base_url"`
|
||||
OAuth2 struct {
|
||||
FrontendLoginURL string `mapstructure:"frontend_login_url"`
|
||||
AuthCodeTTL time.Duration `mapstructure:"auth_code_ttl"`
|
||||
AccessTokenTTL time.Duration `mapstructure:"access_token_ttl"`
|
||||
RefreshTokenTTL time.Duration `mapstructure:"refresh_token_ttl"`
|
||||
} `mapstructure:"oauth2"`
|
||||
Session struct {
|
||||
CookieName string `mapstructure:"cookie_name"`
|
||||
CookieDomain string `mapstructure:"cookie_domain"`
|
||||
CookieSecure bool `mapstructure:"cookie_secure"`
|
||||
SameSite string `mapstructure:"same_site"` // lax, strict, none
|
||||
TTL time.Duration `mapstructure:"ttl"`
|
||||
} `mapstructure:"session"`
|
||||
// RateLimit 登录与令牌端点限流(进程内按 IP;多实例需网关或 Redis 限流)
|
||||
RateLimit struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
LoginPerMinute int `mapstructure:"login_per_minute"`
|
||||
TokenPerMinute int `mapstructure:"token_per_minute"`
|
||||
} `mapstructure:"rate_limit"`
|
||||
} `mapstructure:"auth"`
|
||||
}
|
||||
|
||||
// 加载配置文件
|
||||
func Load(path string) (*Config, error) {
|
||||
v := viper.New()
|
||||
v.SetConfigFile(path)
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var config Config
|
||||
if err := v.Unmarshal(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"giter.top/smart/pkg/config"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func NewDB(cfg *config.Config) *gorm.DB {
|
||||
driver := cfg.Data.Database.Driver
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
switch driver {
|
||||
case "mysql":
|
||||
// db, err = NewMySQLDB(cfg)
|
||||
case "postgres":
|
||||
db, err = NewPgSQLDB(cfg)
|
||||
case "sqlite":
|
||||
// return NewSQLiteDB(cfg)
|
||||
default:
|
||||
panic("unsupported driver")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
package db
|
||||
@@ -0,0 +1,18 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"giter.top/smart/pkg/config"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// pg sql 数据库连接
|
||||
func NewPgSQLDB(cfg *config.Config) (*gorm.DB , error) {
|
||||
db, err := gorm.Open(postgres.Open(cfg.Data.Database.DSN), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to connect to postgres database")
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// RandomURLSafe 生成 URL-safe 随机串(用于 opaque token、authorization code 等)。
|
||||
func RandomURLSafe(nBytes int) (string, error) {
|
||||
b := make([]byte, nBytes)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package codec
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
// HashPassword 将明文密码生成为 bcrypt 哈希字符串(与业务中 bcrypt.DefaultCost 一致)。
|
||||
func HashPassword(password string) (string, error) {
|
||||
b, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// VerifyPassword 校验明文是否与 bcrypt 哈希匹配。
|
||||
// password 明文密码
|
||||
// hashedPassword 哈希密码
|
||||
func VerifyPassword(password,hashedPassword string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
generator IDGenerator
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
type IDGenerator interface {
|
||||
generate() string
|
||||
}
|
||||
|
||||
func New() string {
|
||||
once.Do(func() {
|
||||
generator = NewUUIDGenerator()
|
||||
})
|
||||
return generator.generate()
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package id
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type UUIDGenerator struct {
|
||||
}
|
||||
|
||||
func NewUUIDGenerator() IDGenerator {
|
||||
return &UUIDGenerator{}
|
||||
}
|
||||
|
||||
func (g *UUIDGenerator) generate() string {
|
||||
id, _ := uuid.NewV7()
|
||||
return id.String()
|
||||
}
|
||||
Reference in New Issue
Block a user