Files
exchange_go/services/binanceservice/lock_manager_optimized.go
2025-08-27 14:54:03 +08:00

575 lines
13 KiB
Go

package binanceservice
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-admin-team/go-admin-core/logger"
"github.com/go-redis/redis/v8"
)
// LockManagerConfig 锁管理器专用配置
type LockManagerConfig struct {
Type LockType `json:"type"`
Scope LockScope `json:"scope"`
Key string `json:"key"`
Expiration time.Duration `json:"expiration"`
Timeout time.Duration `json:"timeout"`
RetryDelay time.Duration `json:"retry_delay"`
}
// LockType 锁类型
type LockType string
const (
LockTypeLocal LockType = "local"
LockTypeDistributed LockType = "distributed"
)
// LockScope 锁范围
type LockScope string
const (
ScopeOrder LockScope = "order"
ScopePosition LockScope = "position"
ScopeUser LockScope = "user"
ScopeSymbol LockScope = "symbol"
ScopeGlobal LockScope = "global"
)
// Lock 锁接口
type Lock interface {
Lock(ctx context.Context) error
Unlock(ctx context.Context) error
TryLock(ctx context.Context) (bool, error)
IsLocked() bool
GetKey() string
GetExpiration() time.Duration
}
// LocalLock 本地锁
type LocalLock struct {
key string
mu *sync.RWMutex
locked bool
expiration time.Duration
acquiredAt time.Time
}
// DistributedLock 分布式锁
type DistributedLock struct {
key string
value string
expiration time.Duration
redisClient *redis.Client
locked bool
acquiredAt time.Time
refreshTicker *time.Ticker
stopRefresh chan struct{}
}
// LockManager 锁管理器
type LockManager struct {
config *OptimizedConfig
redisClient *redis.Client
localLocks map[string]*LocalLock
localMutex sync.RWMutex
metrics *MetricsCollector
}
// LockConfig 已在 config_optimized.go 中定义
// NewLockManager 创建锁管理器
func NewLockManager(config *OptimizedConfig, redisClient *redis.Client, metrics *MetricsCollector) *LockManager {
return &LockManager{
config: config,
redisClient: redisClient,
localLocks: make(map[string]*LocalLock),
metrics: metrics,
}
}
// CreateLock 创建锁
func (lm *LockManager) CreateLock(config LockManagerConfig) Lock {
switch config.Type {
case LockTypeDistributed:
return lm.createDistributedLock(config)
case LockTypeLocal:
return lm.createLocalLock(config)
default:
// 默认使用分布式锁
return lm.createDistributedLock(config)
}
}
// createLocalLock 创建本地锁
func (lm *LockManager) createLocalLock(config LockManagerConfig) *LocalLock {
lm.localMutex.Lock()
defer lm.localMutex.Unlock()
key := lm.buildLockKey(config.Scope, config.Key)
if lock, exists := lm.localLocks[key]; exists {
return lock
}
lock := &LocalLock{
key: key,
mu: &sync.RWMutex{},
expiration: config.Expiration,
}
lm.localLocks[key] = lock
return lock
}
// createDistributedLock 创建分布式锁
func (lm *LockManager) createDistributedLock(config LockManagerConfig) *DistributedLock {
key := lm.buildLockKey(config.Scope, config.Key)
value := fmt.Sprintf("%d_%s", time.Now().UnixNano(), generateRandomString(8))
return &DistributedLock{
key: key,
value: value,
expiration: config.Expiration,
redisClient: lm.redisClient,
stopRefresh: make(chan struct{}),
}
}
// buildLockKey 构建锁键名
func (lm *LockManager) buildLockKey(scope LockScope, key string) string {
return fmt.Sprintf("lock:%s:%s", scope, key)
}
// AcquireLock 获取锁(带超时和重试)
func (lm *LockManager) AcquireLock(ctx context.Context, config LockManagerConfig) (Lock, error) {
lock := lm.CreateLock(config)
startTime := time.Now()
// 设置超时上下文
lockCtx := ctx
if config.Timeout > 0 {
var cancel context.CancelFunc
lockCtx, cancel = context.WithTimeout(ctx, config.Timeout)
defer cancel()
}
// 重试获取锁
for {
select {
case <-lockCtx.Done():
waitTime := time.Since(startTime)
if lm.metrics != nil {
lm.metrics.RecordLockOperation(false, waitTime)
}
return nil, fmt.Errorf("获取锁超时: %s", config.Key)
default:
err := lock.Lock(lockCtx)
if err == nil {
waitTime := time.Since(startTime)
if lm.metrics != nil {
lm.metrics.RecordLockOperation(true, waitTime)
}
return lock, nil
}
// 等待重试
select {
case <-lockCtx.Done():
return nil, lockCtx.Err()
case <-time.After(config.RetryDelay):
continue
}
}
}
}
// ReleaseLock 释放锁
func (lm *LockManager) ReleaseLock(ctx context.Context, lock Lock) error {
if lock == nil {
return nil
}
return lock.Unlock(ctx)
}
// WithLock 使用锁执行操作
func (lm *LockManager) WithLock(ctx context.Context, config LockManagerConfig, operation func() error) error {
lock, err := lm.AcquireLock(ctx, config)
if err != nil {
return fmt.Errorf("获取锁失败: %w", err)
}
defer func() {
if unlockErr := lm.ReleaseLock(ctx, lock); unlockErr != nil {
logger.Errorf("释放锁失败 [%s]: %v", config.Key, unlockErr)
}
}()
return operation()
}
// CleanupExpiredLocks 清理过期的本地锁
func (lm *LockManager) CleanupExpiredLocks() {
lm.localMutex.Lock()
defer lm.localMutex.Unlock()
now := time.Now()
for key, lock := range lm.localLocks {
if lock.locked && lock.expiration > 0 && now.Sub(lock.acquiredAt) > lock.expiration {
lock.locked = false
logger.Warnf("本地锁已过期并被清理: %s", key)
}
}
}
// GetLockStatus 获取锁状态
func (lm *LockManager) GetLockStatus() map[string]interface{} {
lm.localMutex.RLock()
defer lm.localMutex.RUnlock()
status := make(map[string]interface{})
localLocks := make(map[string]interface{})
for key, lock := range lm.localLocks {
localLocks[key] = map[string]interface{}{
"locked": lock.locked,
"acquired_at": lock.acquiredAt,
"expiration": lock.expiration,
}
}
status["local_locks"] = localLocks
status["total_local_locks"] = len(lm.localLocks)
return status
}
// LocalLock 实现
// Lock 获取本地锁
func (ll *LocalLock) Lock(ctx context.Context) error {
ll.mu.Lock()
defer ll.mu.Unlock()
if ll.locked {
return fmt.Errorf("锁已被占用: %s", ll.key)
}
ll.locked = true
ll.acquiredAt = time.Now()
return nil
}
// Unlock 释放本地锁
func (ll *LocalLock) Unlock(ctx context.Context) error {
ll.mu.Lock()
defer ll.mu.Unlock()
if !ll.locked {
return fmt.Errorf("锁未被占用: %s", ll.key)
}
ll.locked = false
return nil
}
// TryLock 尝试获取本地锁
func (ll *LocalLock) TryLock(ctx context.Context) (bool, error) {
ll.mu.Lock()
defer ll.mu.Unlock()
if ll.locked {
return false, nil
}
ll.locked = true
ll.acquiredAt = time.Now()
return true, nil
}
// IsLocked 检查本地锁是否被占用
func (ll *LocalLock) IsLocked() bool {
ll.mu.RLock()
defer ll.mu.RUnlock()
return ll.locked
}
// GetKey 获取锁键名
func (ll *LocalLock) GetKey() string {
return ll.key
}
// GetExpiration 获取过期时间
func (ll *LocalLock) GetExpiration() time.Duration {
return ll.expiration
}
// DistributedLock 实现
// Lock 获取分布式锁
func (dl *DistributedLock) Lock(ctx context.Context) error {
// 使用 SET key value EX seconds NX 命令
result := dl.redisClient.SetNX(ctx, dl.key, dl.value, dl.expiration)
if result.Err() != nil {
return fmt.Errorf("获取分布式锁失败: %w", result.Err())
}
if !result.Val() {
return fmt.Errorf("分布式锁已被占用: %s", dl.key)
}
dl.locked = true
dl.acquiredAt = time.Now()
// 启动锁续期
dl.startRefresh()
return nil
}
// Unlock 释放分布式锁
func (dl *DistributedLock) Unlock(ctx context.Context) error {
if !dl.locked {
return fmt.Errorf("分布式锁未被占用: %s", dl.key)
}
// 停止续期
dl.stopRefreshProcess()
// 使用 Lua 脚本确保只删除自己的锁
luaScript := `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
`
result := dl.redisClient.Eval(ctx, luaScript, []string{dl.key}, dl.value)
if result.Err() != nil {
return fmt.Errorf("释放分布式锁失败: %w", result.Err())
}
dl.locked = false
return nil
}
// TryLock 尝试获取分布式锁
func (dl *DistributedLock) TryLock(ctx context.Context) (bool, error) {
result := dl.redisClient.SetNX(ctx, dl.key, dl.value, dl.expiration)
if result.Err() != nil {
return false, fmt.Errorf("尝试获取分布式锁失败: %w", result.Err())
}
if result.Val() {
dl.locked = true
dl.acquiredAt = time.Now()
dl.startRefresh()
return true, nil
}
return false, nil
}
// IsLocked 检查分布式锁是否被占用
func (dl *DistributedLock) IsLocked() bool {
return dl.locked
}
// GetKey 获取锁键名
func (dl *DistributedLock) GetKey() string {
return dl.key
}
// GetExpiration 获取过期时间
func (dl *DistributedLock) GetExpiration() time.Duration {
return dl.expiration
}
// startRefresh 启动锁续期
func (dl *DistributedLock) startRefresh() {
if dl.expiration <= 0 {
return
}
// 每隔过期时间的1/3进行续期
refreshInterval := dl.expiration / 3
dl.refreshTicker = time.NewTicker(refreshInterval)
go func() {
defer dl.refreshTicker.Stop()
for {
select {
case <-dl.stopRefresh:
return
case <-dl.refreshTicker.C:
if !dl.locked {
return
}
// 使用 Lua 脚本续期
luaScript := `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end
`
ctx := context.Background()
result := dl.redisClient.Eval(ctx, luaScript, []string{dl.key}, dl.value, int(dl.expiration.Seconds()))
if result.Err() != nil {
logger.Errorf("分布式锁续期失败 [%s]: %v", dl.key, result.Err())
return
}
if result.Val().(int64) == 0 {
logger.Warnf("分布式锁已被其他进程占用,停止续期 [%s]", dl.key)
dl.locked = false
return
}
}
}
}()
}
// stopRefreshProcess 停止续期进程
func (dl *DistributedLock) stopRefreshProcess() {
if dl.refreshTicker != nil {
close(dl.stopRefresh)
dl.refreshTicker.Stop()
dl.refreshTicker = nil
dl.stopRefresh = make(chan struct{})
}
}
// 辅助函数
// generateRandomString 生成随机字符串
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[time.Now().UnixNano()%int64(len(charset))]
}
return string(b)
}
// 预定义的锁配置
// GetOrderLockConfig 获取订单锁配置
func GetOrderLockConfig(orderID string) LockManagerConfig {
return LockManagerConfig{
Type: LockTypeDistributed,
Scope: ScopeOrder,
Key: orderID,
Expiration: 30 * time.Second,
Timeout: 5 * time.Second,
RetryDelay: 100 * time.Millisecond,
}
}
// GetPositionLockConfig 获取持仓锁配置
func GetPositionLockConfig(userID, symbol string) LockManagerConfig {
return LockManagerConfig{
Type: LockTypeDistributed,
Scope: ScopePosition,
Key: fmt.Sprintf("%s_%s", userID, symbol),
Expiration: 60 * time.Second,
Timeout: 10 * time.Second,
RetryDelay: 200 * time.Millisecond,
}
}
// GetUserLockConfig 获取用户锁配置
func GetUserLockConfig(userID string) LockManagerConfig {
return LockManagerConfig{
Type: LockTypeLocal,
Scope: ScopeUser,
Key: userID,
Expiration: 30 * time.Second,
Timeout: 3 * time.Second,
RetryDelay: 50 * time.Millisecond,
}
}
// GetSymbolLockConfig 获取交易对锁配置
func GetSymbolLockConfig(symbol string) LockManagerConfig {
return LockManagerConfig{
Type: LockTypeLocal,
Scope: ScopeSymbol,
Key: symbol,
Expiration: 15 * time.Second,
Timeout: 2 * time.Second,
RetryDelay: 25 * time.Millisecond,
}
}
// GetGlobalLockConfig 获取全局锁配置
func GetGlobalLockConfig(operation string) LockManagerConfig {
return LockManagerConfig{
Type: LockTypeDistributed,
Scope: ScopeGlobal,
Key: operation,
Expiration: 120 * time.Second,
Timeout: 30 * time.Second,
RetryDelay: 500 * time.Millisecond,
}
}
// 全局锁管理器实例
var GlobalLockManager *LockManager
// InitLockManager 初始化锁管理器
func InitLockManager(config *OptimizedConfig, redisClient *redis.Client, metrics *MetricsCollector) {
GlobalLockManager = NewLockManager(config, redisClient, metrics)
// 启动清理过期锁的定时任务
go func() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
GlobalLockManager.CleanupExpiredLocks()
}
}()
}
// GetLockManager 获取全局锁管理器
func GetLockManager() *LockManager {
return GlobalLockManager
}
// WithOrderLock 使用订单锁执行操作
func WithOrderLock(ctx context.Context, orderID string, operation func() error) error {
if GlobalLockManager == nil {
return operation()
}
config := GetOrderLockConfig(orderID)
return GlobalLockManager.WithLock(ctx, config, operation)
}
// WithPositionLock 使用持仓锁执行操作
func WithPositionLock(ctx context.Context, userID, symbol string, operation func() error) error {
if GlobalLockManager == nil {
return operation()
}
config := GetPositionLockConfig(userID, symbol)
return GlobalLockManager.WithLock(ctx, config, operation)
}
// WithUserLock 使用用户锁执行操作
func WithUserLock(ctx context.Context, userID string, operation func() error) error {
if GlobalLockManager == nil {
return operation()
}
config := GetUserLockConfig(userID)
return GlobalLockManager.WithLock(ctx, config, operation)
}