package db import ( "context" "fmt" "sync" "time" "git.mchus.pro/mchus/quoteforge/internal/localdb" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" ) const ( defaultConnectTimeout = 5 * time.Second defaultPingInterval = 30 * time.Second defaultReconnectCooldown = 10 * time.Second // defaultStatusCheckInterval controls how often the background prober // re-checks connectivity to keep IsOnline() cheap. Request-handling // goroutines must never pay the MySQL dial/read timeout themselves. defaultStatusCheckInterval = 15 * time.Second maxOpenConns = 10 maxIdleConns = 2 connMaxLifetime = 5 * time.Minute ) // ConnectionStatus represents the current status of the database connection type ConnectionStatus struct { IsConnected bool LastCheck time.Time LastError string // empty if no error DSNHost string // host:port for display (without password!) } // ConnectionManager manages database connections with thread-safety and connection pooling type ConnectionManager struct { localDB *localdb.LocalDB // for getting DSN from settings mu sync.RWMutex // protects db/lastError/lastCheck only — never held during network I/O connMu sync.Mutex // serializes actual dial/ping attempts; held *instead of* mu during network I/O db *gorm.DB // current connection (nil if not connected) lastError error // last connection error lastCheck time.Time // time of last check/attempt connectTimeout time.Duration // timeout for connection (default: 5s) pingInterval time.Duration // minimum interval between pings (default: 30s) reconnectCooldown time.Duration // pause after failed attempt (default: 10s) statusCheckInterval time.Duration // background prober cadence (default: 15s) stopStatusLoop chan struct{} // closed by Stop() to end the background loop } // NewConnectionManager creates a new ConnectionManager instance func NewConnectionManager(localDB *localdb.LocalDB) *ConnectionManager { return &ConnectionManager{ localDB: localDB, connectTimeout: defaultConnectTimeout, pingInterval: defaultPingInterval, reconnectCooldown: defaultReconnectCooldown, statusCheckInterval: defaultStatusCheckInterval, db: nil, lastError: nil, lastCheck: time.Time{}, } } // Start launches a background goroutine that keeps the online-status cache // fresh, so that IsOnline() (called from request-handling middleware) never // blocks on network I/O itself. Returns immediately — the app must be able // to serve the local-first UI right away, before connectivity is even known. // Until the first background check completes, IsOnline() reports offline // (the safe default). Stop via ctx cancellation or Stop(). func (cm *ConnectionManager) Start(ctx context.Context) { cm.mu.Lock() if cm.stopStatusLoop == nil { cm.stopStatusLoop = make(chan struct{}) } stopCh := cm.stopStatusLoop cm.mu.Unlock() go func() { // Prime the cache in the background; the dial/read timeout must not // delay server startup. cm.checkOnlineNow() ticker := time.NewTicker(cm.statusCheckInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-stopCh: return case <-ticker.C: cm.checkOnlineNow() } } }() } // Stop ends the background status-refresh loop started by Start. func (cm *ConnectionManager) Stop() { cm.mu.Lock() defer cm.mu.Unlock() if cm.stopStatusLoop != nil { close(cm.stopStatusLoop) cm.stopStatusLoop = nil } } // GetDB returns the current database connection, establishing it if needed. // Thread-safe and respects connection cooldowns. The actual network I/O // (dial/ping) never runs while holding cm.mu, so concurrent readers of the // cached status (IsOnline, GetStatus) are never blocked by an in-flight // connection attempt — only concurrent GetDB/checkOnlineNow callers are // serialized against each other, via connMu. func (cm *ConnectionManager) GetDB() (*gorm.DB, error) { // Handle case where localDB is nil if cm.localDB == nil { return nil, fmt.Errorf("local database not initialized") } if db, err, ok := cm.cachedResult(); ok { return db, err } // Serialize actual connection attempts so concurrent callers don't dial // in parallel. This may block the calling goroutine on network I/O, but // never blocks other goroutines that only read the cached status. cm.connMu.Lock() defer cm.connMu.Unlock() // Re-check: another goroutine may have just finished connecting while we // were waiting for connMu. if db, err, ok := cm.cachedResult(); ok { return db, err } newDB, err := cm.dial() cm.mu.Lock() if err != nil { // Drop stale handle so callers don't treat it as an active connection. cm.db = nil cm.lastError = err cm.lastCheck = time.Now() cm.mu.Unlock() return nil, err } cm.db = newDB cm.lastError = nil cm.lastCheck = time.Now() cm.mu.Unlock() return newDB, nil } // cachedResult returns (db, err, true) if the cached state is still fresh // enough to answer without a new network round-trip: either a live // connection within pingInterval, or a recent failure still within // reconnectCooldown. Returns ok=false if a fresh connection attempt is needed. func (cm *ConnectionManager) cachedResult() (*gorm.DB, error, bool) { cm.mu.RLock() defer cm.mu.RUnlock() if cm.db != nil && time.Since(cm.lastCheck) < cm.pingInterval { return cm.db, nil, true } if cm.db == nil && cm.lastError != nil && time.Since(cm.lastCheck) < cm.reconnectCooldown { return nil, cm.lastError, true } return nil, nil, false } // dial establishes a new database connection. Pure network I/O — must not be // called while holding cm.mu. func (cm *ConnectionManager) dial() (*gorm.DB, error) { // Get DSN from local settings dsn, err := cm.localDB.GetDSN() if err != nil { return nil, fmt.Errorf("getting DSN: %w", err) } // Create context with timeout ctx, cancel := context.WithTimeout(context.Background(), cm.connectTimeout) defer cancel() // Open database connection db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { return nil, fmt.Errorf("opening database connection: %w", err) } // Test the connection sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("getting sql.DB: %w", err) } // Ping with timeout if err = sqlDB.PingContext(ctx); err != nil { return nil, fmt.Errorf("pinging database: %w", err) } // Set connection pool settings sqlDB.SetMaxOpenConns(maxOpenConns) sqlDB.SetMaxIdleConns(maxIdleConns) sqlDB.SetConnMaxLifetime(connMaxLifetime) return db, nil } // IsOnline returns the cached connectivity status. It never performs network // I/O itself so it is safe to call from request-handling middleware; the // background loop started by Start() (or an explicit TryConnect) is // responsible for keeping the cache fresh. func (cm *ConnectionManager) IsOnline() bool { cm.mu.RLock() defer cm.mu.RUnlock() return cm.db != nil && cm.lastError == nil } // checkOnlineNow checks if the database is currently connected and // responsive, performing real network I/O (dial/ping) as needed. If // disconnected, it tries to reconnect (respecting cooldowns in GetDB). This // must only be called from the background status loop or explicit // user-triggered reconnects, never from request-handling goroutines. func (cm *ConnectionManager) checkOnlineNow() bool { cm.mu.RLock() isDisconnected := cm.db == nil lastErr := cm.lastError checkedRecently := time.Since(cm.lastCheck) < cm.pingInterval cm.mu.RUnlock() // Try reconnect in disconnected state. if isDisconnected { _, err := cm.GetDB() return err == nil } // If we've checked recently, return cached result. if checkedRecently { return lastErr == nil } // Serialize actual ping attempts (network I/O) against other // connect/ping attempts, without ever holding cm.mu during the I/O. cm.connMu.Lock() defer cm.connMu.Unlock() cm.mu.RLock() db := cm.db checkedRecently = time.Since(cm.lastCheck) < cm.pingInterval cm.mu.RUnlock() if db == nil { return false } if checkedRecently { return true } // Perform ping with timeout — no locks held here. ctx, cancel := context.WithTimeout(context.Background(), cm.connectTimeout) defer cancel() sqlDB, err := db.DB() if err == nil { err = sqlDB.PingContext(ctx) } cm.mu.Lock() defer cm.mu.Unlock() if err != nil { cm.lastError = err cm.lastCheck = time.Now() cm.db = nil return false } cm.lastCheck = time.Now() cm.lastError = nil return true } // TryConnect forces a new connection attempt (for UI "Reconnect" button). // Ignores the reconnect cooldown, but still serializes against other // dial attempts via connMu and never holds cm.mu during network I/O. func (cm *ConnectionManager) TryConnect() error { cm.connMu.Lock() defer cm.connMu.Unlock() newDB, err := cm.dial() cm.mu.Lock() defer cm.mu.Unlock() if err != nil { cm.db = nil cm.lastError = err cm.lastCheck = time.Now() return err } cm.db = newDB cm.lastError = nil cm.lastCheck = time.Now() return nil } // Disconnect closes the current database connection func (cm *ConnectionManager) Disconnect() { cm.mu.Lock() defer cm.mu.Unlock() if cm.db != nil { sqlDB, err := cm.db.DB() if err == nil { sqlDB.Close() } } cm.db = nil cm.lastError = nil } // MarkOffline closes the current connection and preserves the last observed error. func (cm *ConnectionManager) MarkOffline(err error) { cm.mu.Lock() defer cm.mu.Unlock() if cm.db != nil { sqlDB, dbErr := cm.db.DB() if dbErr == nil { sqlDB.Close() } } cm.db = nil cm.lastError = err cm.lastCheck = time.Now() } // GetLastError returns the last connection error (thread-safe) func (cm *ConnectionManager) GetLastError() error { cm.mu.RLock() defer cm.mu.RUnlock() return cm.lastError } // GetStatus returns the current connection status func (cm *ConnectionManager) GetStatus() ConnectionStatus { cm.mu.RLock() defer cm.mu.RUnlock() status := ConnectionStatus{ IsConnected: cm.db != nil, LastCheck: cm.lastCheck, LastError: "", DSNHost: "", } if cm.lastError != nil { status.LastError = cm.lastError.Error() } // Extract host from DSN for display if cm.localDB != nil { if dsn, err := cm.localDB.GetDSN(); err == nil { // Parse DSN to extract host:port // Format: user:password@tcp(host:port)/database?... status.DSNHost = extractHostFromDSN(dsn) } } return status } // extractHostFromDSN extracts the host:port part from a DSN string func extractHostFromDSN(dsn string) string { // Find the tcp( part tcpStart := 0 if tcpStart = len("tcp("); tcpStart < len(dsn) && dsn[tcpStart] == '(' { // Look for the closing parenthesis parenEnd := -1 for i := tcpStart + 1; i < len(dsn); i++ { if dsn[i] == ')' { parenEnd = i break } } if parenEnd != -1 { // Extract host:port part between tcp( and ) hostPort := dsn[tcpStart+1 : parenEnd] return hostPort } } // Fallback: try to find host:port by looking for @tcp( pattern atIndex := -1 for i := 0; i < len(dsn)-4; i++ { if dsn[i:i+4] == "@tcp" { atIndex = i break } } if atIndex != -1 { // Look for the opening parenthesis after @tcp parenStart := -1 for i := atIndex + 4; i < len(dsn); i++ { if dsn[i] == '(' { parenStart = i break } } if parenStart != -1 { // Look for the closing parenthesis parenEnd := -1 for i := parenStart + 1; i < len(dsn); i++ { if dsn[i] == ')' { parenEnd = i break } } if parenEnd != -1 { hostPort := dsn[parenStart+1 : parenEnd] return hostPort } } } // If we can't parse it, return empty string return "" }