325 lines
7.5 KiB
Go
325 lines
7.5 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
defaultConnectTimeout = 5 * time.Second
|
|
defaultPingInterval = 30 * time.Second
|
|
defaultReconnectCooldown = 10 * time.Second
|
|
|
|
maxOpenConns = 10
|
|
maxIdleConns = 2
|
|
connMaxLifetime = 5 * time.Minute
|
|
)
|
|
|
|
// ConnectionStatus represents the current status of the database connection
|
|
type ConnectionStatus struct {
|
|
IsConnected bool
|
|
LastCheck time.Time
|
|
LastError string // empty if no error
|
|
DSNHost string // host:port for display (without password!)
|
|
}
|
|
|
|
// ConnectionManager manages database connections with thread-safety and connection pooling
|
|
type ConnectionManager struct {
|
|
dsn string
|
|
dsnHost string
|
|
mu sync.RWMutex // protects db and state
|
|
db *gorm.DB // current connection (nil if not connected)
|
|
lastError error // last connection error
|
|
lastCheck time.Time // time of last check/attempt
|
|
connectTimeout time.Duration // timeout for connection (default: 5s)
|
|
pingInterval time.Duration // minimum interval between pings (default: 30s)
|
|
reconnectCooldown time.Duration // pause after failed attempt (default: 10s)
|
|
}
|
|
|
|
// NewConnectionManager creates a new ConnectionManager instance
|
|
func NewConnectionManager(dsn string, dsnHost string) *ConnectionManager {
|
|
return &ConnectionManager{
|
|
dsn: dsn,
|
|
dsnHost: dsnHost,
|
|
connectTimeout: defaultConnectTimeout,
|
|
pingInterval: defaultPingInterval,
|
|
reconnectCooldown: defaultReconnectCooldown,
|
|
db: nil,
|
|
lastError: nil,
|
|
lastCheck: time.Time{},
|
|
}
|
|
}
|
|
|
|
// GetDB returns the current database connection, establishing it if needed
|
|
// Thread-safe and respects connection cooldowns
|
|
func (cm *ConnectionManager) GetDB() (*gorm.DB, error) {
|
|
if cm.dsn == "" {
|
|
return nil, fmt.Errorf("database DSN is empty")
|
|
}
|
|
|
|
// First check if we already have a valid connection
|
|
cm.mu.RLock()
|
|
if cm.db != nil {
|
|
// Check if connection is still valid and within ping interval
|
|
if time.Since(cm.lastCheck) < cm.pingInterval {
|
|
cm.mu.RUnlock()
|
|
return cm.db, nil
|
|
}
|
|
}
|
|
cm.mu.RUnlock()
|
|
|
|
// Upgrade to write lock
|
|
cm.mu.Lock()
|
|
defer cm.mu.Unlock()
|
|
|
|
// Double-check: someone else might have connected while we were waiting for the write lock
|
|
if cm.db != nil {
|
|
// Check if connection is still valid and within ping interval
|
|
if time.Since(cm.lastCheck) < cm.pingInterval {
|
|
return cm.db, nil
|
|
}
|
|
}
|
|
|
|
// Check if we're in cooldown period after a failed attempt
|
|
if cm.lastError != nil && time.Since(cm.lastCheck) < cm.reconnectCooldown {
|
|
return nil, cm.lastError
|
|
}
|
|
|
|
// Attempt to connect
|
|
err := cm.connect()
|
|
if err != nil {
|
|
// Drop stale handle so callers don't treat it as an active connection.
|
|
cm.db = nil
|
|
cm.lastError = err
|
|
cm.lastCheck = time.Now()
|
|
return nil, err
|
|
}
|
|
|
|
// Update last check time and return success
|
|
cm.lastCheck = time.Now()
|
|
cm.lastError = nil
|
|
return cm.db, nil
|
|
}
|
|
|
|
// connect establishes a new database connection
|
|
func (cm *ConnectionManager) connect() error {
|
|
// Create context with timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), cm.connectTimeout)
|
|
defer cancel()
|
|
|
|
// Open database connection
|
|
db, err := gorm.Open(mysql.Open(cm.dsn), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("opening database connection: %w", err)
|
|
}
|
|
|
|
// Test the connection
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("getting sql.DB: %w", err)
|
|
}
|
|
|
|
// Ping with timeout
|
|
if err = sqlDB.PingContext(ctx); err != nil {
|
|
return fmt.Errorf("pinging database: %w", err)
|
|
}
|
|
|
|
// Set connection pool settings
|
|
sqlDB.SetMaxOpenConns(maxOpenConns)
|
|
sqlDB.SetMaxIdleConns(maxIdleConns)
|
|
sqlDB.SetConnMaxLifetime(connMaxLifetime)
|
|
|
|
// Store the connection
|
|
cm.db = db
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsOnline checks if the database is currently connected and responsive.
|
|
// If disconnected, it tries to reconnect (respecting cooldowns in GetDB).
|
|
func (cm *ConnectionManager) IsOnline() bool {
|
|
cm.mu.RLock()
|
|
isDisconnected := cm.db == nil
|
|
lastErr := cm.lastError
|
|
checkedRecently := time.Since(cm.lastCheck) < cm.pingInterval
|
|
cm.mu.RUnlock()
|
|
|
|
// Try reconnect in disconnected state.
|
|
if isDisconnected {
|
|
_, err := cm.GetDB()
|
|
return err == nil
|
|
}
|
|
|
|
// If we've checked recently, return cached result.
|
|
if checkedRecently {
|
|
return lastErr == nil
|
|
}
|
|
|
|
// Need to perform actual ping.
|
|
cm.mu.Lock()
|
|
defer cm.mu.Unlock()
|
|
|
|
// Double-check after acquiring write lock
|
|
if cm.db == nil {
|
|
return false
|
|
}
|
|
|
|
// Perform ping with timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), cm.connectTimeout)
|
|
defer cancel()
|
|
|
|
sqlDB, err := cm.db.DB()
|
|
if err != nil {
|
|
cm.lastError = err
|
|
cm.lastCheck = time.Now()
|
|
cm.db = nil
|
|
return false
|
|
}
|
|
|
|
if err = sqlDB.PingContext(ctx); err != nil {
|
|
cm.lastError = err
|
|
cm.lastCheck = time.Now()
|
|
cm.db = nil
|
|
return false
|
|
}
|
|
|
|
// Update last check time and return success
|
|
cm.lastCheck = time.Now()
|
|
cm.lastError = nil
|
|
return true
|
|
}
|
|
|
|
// TryConnect forces a new connection attempt (for UI "Reconnect" button)
|
|
// Ignores cooldown period
|
|
func (cm *ConnectionManager) TryConnect() error {
|
|
cm.mu.Lock()
|
|
defer cm.mu.Unlock()
|
|
|
|
// Attempt to connect
|
|
err := cm.connect()
|
|
if err != nil {
|
|
cm.lastError = err
|
|
cm.lastCheck = time.Now()
|
|
return err
|
|
}
|
|
|
|
// Update last check time and clear error
|
|
cm.lastCheck = time.Now()
|
|
cm.lastError = nil
|
|
return nil
|
|
}
|
|
|
|
// Disconnect closes the current database connection
|
|
func (cm *ConnectionManager) Disconnect() {
|
|
cm.mu.Lock()
|
|
defer cm.mu.Unlock()
|
|
|
|
if cm.db != nil {
|
|
sqlDB, err := cm.db.DB()
|
|
if err == nil {
|
|
sqlDB.Close()
|
|
}
|
|
}
|
|
cm.db = nil
|
|
cm.lastError = nil
|
|
}
|
|
|
|
// GetLastError returns the last connection error (thread-safe)
|
|
func (cm *ConnectionManager) GetLastError() error {
|
|
cm.mu.RLock()
|
|
defer cm.mu.RUnlock()
|
|
return cm.lastError
|
|
}
|
|
|
|
// GetStatus returns the current connection status
|
|
func (cm *ConnectionManager) GetStatus() ConnectionStatus {
|
|
cm.mu.RLock()
|
|
defer cm.mu.RUnlock()
|
|
|
|
status := ConnectionStatus{
|
|
IsConnected: cm.db != nil,
|
|
LastCheck: cm.lastCheck,
|
|
LastError: "",
|
|
DSNHost: "",
|
|
}
|
|
|
|
if cm.lastError != nil {
|
|
status.LastError = cm.lastError.Error()
|
|
}
|
|
|
|
if cm.dsnHost != "" {
|
|
status.DSNHost = cm.dsnHost
|
|
} else {
|
|
status.DSNHost = extractHostFromDSN(cm.dsn)
|
|
}
|
|
|
|
return status
|
|
}
|
|
|
|
// extractHostFromDSN extracts the host:port part from a DSN string
|
|
func extractHostFromDSN(dsn string) string {
|
|
// Find the tcp( part
|
|
tcpStart := 0
|
|
if tcpStart = len("tcp("); tcpStart < len(dsn) && dsn[tcpStart] == '(' {
|
|
// Look for the closing parenthesis
|
|
parenEnd := -1
|
|
for i := tcpStart + 1; i < len(dsn); i++ {
|
|
if dsn[i] == ')' {
|
|
parenEnd = i
|
|
break
|
|
}
|
|
}
|
|
if parenEnd != -1 {
|
|
// Extract host:port part between tcp( and )
|
|
hostPort := dsn[tcpStart+1 : parenEnd]
|
|
return hostPort
|
|
}
|
|
}
|
|
|
|
// Fallback: try to find host:port by looking for @tcp( pattern
|
|
atIndex := -1
|
|
for i := 0; i < len(dsn)-4; i++ {
|
|
if dsn[i:i+4] == "@tcp" {
|
|
atIndex = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if atIndex != -1 {
|
|
// Look for the opening parenthesis after @tcp
|
|
parenStart := -1
|
|
for i := atIndex + 4; i < len(dsn); i++ {
|
|
if dsn[i] == '(' {
|
|
parenStart = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if parenStart != -1 {
|
|
// Look for the closing parenthesis
|
|
parenEnd := -1
|
|
for i := parenStart + 1; i < len(dsn); i++ {
|
|
if dsn[i] == ')' {
|
|
parenEnd = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if parenEnd != -1 {
|
|
hostPort := dsn[parenStart+1 : parenEnd]
|
|
return hostPort
|
|
}
|
|
}
|
|
}
|
|
|
|
// If we can't parse it, return empty string
|
|
return ""
|
|
}
|