package db import ( "context" "fmt" "sync" "time" "git.mchus.pro/mchus/quoteforge/internal/localdb" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" ) const ( defaultConnectTimeout = 5 * time.Second defaultPingInterval = 30 * time.Second defaultReconnectCooldown = 10 * time.Second maxOpenConns = 10 maxIdleConns = 2 connMaxLifetime = 5 * time.Minute ) // ConnectionStatus represents the current status of the database connection type ConnectionStatus struct { IsConnected bool LastCheck time.Time LastError string // empty if no error DSNHost string // host:port for display (without password!) } // ConnectionManager manages database connections with thread-safety and connection pooling type ConnectionManager struct { localDB *localdb.LocalDB // for getting DSN from settings mu sync.RWMutex // protects db and state db *gorm.DB // current connection (nil if not connected) lastError error // last connection error lastCheck time.Time // time of last check/attempt connectTimeout time.Duration // timeout for connection (default: 5s) pingInterval time.Duration // minimum interval between pings (default: 30s) reconnectCooldown time.Duration // pause after failed attempt (default: 10s) } // NewConnectionManager creates a new ConnectionManager instance func NewConnectionManager(localDB *localdb.LocalDB) *ConnectionManager { return &ConnectionManager{ localDB: localDB, connectTimeout: defaultConnectTimeout, pingInterval: defaultPingInterval, reconnectCooldown: defaultReconnectCooldown, db: nil, lastError: nil, lastCheck: time.Time{}, } } // GetDB returns the current database connection, establishing it if needed // Thread-safe and respects connection cooldowns func (cm *ConnectionManager) GetDB() (*gorm.DB, error) { // Handle case where localDB is nil if cm.localDB == nil { return nil, fmt.Errorf("local database not initialized") } // First check if we already have a valid connection cm.mu.RLock() if cm.db != nil { // Check if connection is still valid and within ping interval if time.Since(cm.lastCheck) < cm.pingInterval { cm.mu.RUnlock() return cm.db, nil } } cm.mu.RUnlock() // Upgrade to write lock cm.mu.Lock() defer cm.mu.Unlock() // Double-check: someone else might have connected while we were waiting for the write lock if cm.db != nil { // Check if connection is still valid and within ping interval if time.Since(cm.lastCheck) < cm.pingInterval { return cm.db, nil } } // Check if we're in cooldown period after a failed attempt if cm.lastError != nil && time.Since(cm.lastCheck) < cm.reconnectCooldown { return nil, cm.lastError } // Attempt to connect err := cm.connect() if err != nil { // Drop stale handle so callers don't treat it as an active connection. cm.db = nil cm.lastError = err cm.lastCheck = time.Now() return nil, err } // Update last check time and return success cm.lastCheck = time.Now() cm.lastError = nil return cm.db, nil } // connect establishes a new database connection func (cm *ConnectionManager) connect() error { // Get DSN from local settings dsn, err := cm.localDB.GetDSN() if err != nil { return fmt.Errorf("getting DSN: %w", err) } // Create context with timeout ctx, cancel := context.WithTimeout(context.Background(), cm.connectTimeout) defer cancel() // Open database connection db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { return fmt.Errorf("opening database connection: %w", err) } // Test the connection sqlDB, err := db.DB() if err != nil { return fmt.Errorf("getting sql.DB: %w", err) } // Ping with timeout if err = sqlDB.PingContext(ctx); err != nil { return fmt.Errorf("pinging database: %w", err) } // Set connection pool settings sqlDB.SetMaxOpenConns(maxOpenConns) sqlDB.SetMaxIdleConns(maxIdleConns) sqlDB.SetConnMaxLifetime(connMaxLifetime) // Store the connection cm.db = db return nil } // IsOnline checks if the database is currently connected and responsive. // If disconnected, it tries to reconnect (respecting cooldowns in GetDB). func (cm *ConnectionManager) IsOnline() bool { cm.mu.RLock() isDisconnected := cm.db == nil lastErr := cm.lastError checkedRecently := time.Since(cm.lastCheck) < cm.pingInterval cm.mu.RUnlock() // Try reconnect in disconnected state. if isDisconnected { _, err := cm.GetDB() return err == nil } // If we've checked recently, return cached result. if checkedRecently { return lastErr == nil } // Need to perform actual ping. cm.mu.Lock() defer cm.mu.Unlock() // Double-check after acquiring write lock if cm.db == nil { return false } // Perform ping with timeout ctx, cancel := context.WithTimeout(context.Background(), cm.connectTimeout) defer cancel() sqlDB, err := cm.db.DB() if err != nil { cm.lastError = err cm.lastCheck = time.Now() cm.db = nil return false } if err = sqlDB.PingContext(ctx); err != nil { cm.lastError = err cm.lastCheck = time.Now() cm.db = nil return false } // Update last check time and return success cm.lastCheck = time.Now() cm.lastError = nil return true } // TryConnect forces a new connection attempt (for UI "Reconnect" button) // Ignores cooldown period func (cm *ConnectionManager) TryConnect() error { cm.mu.Lock() defer cm.mu.Unlock() // Attempt to connect err := cm.connect() if err != nil { cm.lastError = err cm.lastCheck = time.Now() return err } // Update last check time and clear error cm.lastCheck = time.Now() cm.lastError = nil return nil } // Disconnect closes the current database connection func (cm *ConnectionManager) Disconnect() { cm.mu.Lock() defer cm.mu.Unlock() if cm.db != nil { sqlDB, err := cm.db.DB() if err == nil { sqlDB.Close() } } cm.db = nil cm.lastError = nil } // GetLastError returns the last connection error (thread-safe) func (cm *ConnectionManager) GetLastError() error { cm.mu.RLock() defer cm.mu.RUnlock() return cm.lastError } // GetStatus returns the current connection status func (cm *ConnectionManager) GetStatus() ConnectionStatus { cm.mu.RLock() defer cm.mu.RUnlock() status := ConnectionStatus{ IsConnected: cm.db != nil, LastCheck: cm.lastCheck, LastError: "", DSNHost: "", } if cm.lastError != nil { status.LastError = cm.lastError.Error() } // Extract host from DSN for display if cm.localDB != nil { if dsn, err := cm.localDB.GetDSN(); err == nil { // Parse DSN to extract host:port // Format: user:password@tcp(host:port)/database?... status.DSNHost = extractHostFromDSN(dsn) } } return status } // extractHostFromDSN extracts the host:port part from a DSN string func extractHostFromDSN(dsn string) string { // Find the tcp( part tcpStart := 0 if tcpStart = len("tcp("); tcpStart < len(dsn) && dsn[tcpStart] == '(' { // Look for the closing parenthesis parenEnd := -1 for i := tcpStart + 1; i < len(dsn); i++ { if dsn[i] == ')' { parenEnd = i break } } if parenEnd != -1 { // Extract host:port part between tcp( and ) hostPort := dsn[tcpStart+1 : parenEnd] return hostPort } } // Fallback: try to find host:port by looking for @tcp( pattern atIndex := -1 for i := 0; i < len(dsn)-4; i++ { if dsn[i:i+4] == "@tcp" { atIndex = i break } } if atIndex != -1 { // Look for the opening parenthesis after @tcp parenStart := -1 for i := atIndex + 4; i < len(dsn); i++ { if dsn[i] == '(' { parenStart = i break } } if parenStart != -1 { // Look for the closing parenthesis parenEnd := -1 for i := parenStart + 1; i < len(dsn); i++ { if dsn[i] == ')' { parenEnd = i break } } if parenEnd != -1 { hostPort := dsn[parenStart+1 : parenEnd] return hostPort } } } // If we can't parse it, return empty string return "" }