package migrate import ( "database/sql" "errors" "fmt" "strings" "time" "github.com/go-sql-driver/mysql" ) // EnsureSchema applies all "up" migrations when the database has no tables. func EnsureSchema(db *sql.DB, dir string) error { columnName, err := ensureSchemaMigrations(db) if err != nil { return err } empty, err := isDatabaseEmpty(db) if err != nil { return err } runner := Runner{Dir: dir} migrations, err := runner.Load(Up) if err != nil { return err } if empty { for _, migration := range migrations { if err := execStatements(db, migration.SQL); err != nil { return fmt.Errorf("apply %s: %w", migration.Name, err) } if err := recordMigration(db, columnName, migration.Name); err != nil { return err } } return nil } applied, err := loadAppliedMigrations(db, columnName) if err != nil { return err } for _, migration := range migrations { if applied[migration.Name] { continue } if err := execStatements(db, migration.SQL); err != nil { return fmt.Errorf("apply %s: %w", migration.Name, err) } if err := recordMigration(db, columnName, migration.Name); err != nil { return err } } return nil } func isDatabaseEmpty(db *sql.DB) (bool, error) { var count int if err := db.QueryRow(`SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE()`).Scan(&count); err != nil { return false, err } return count == 0, nil } func execStatements(db *sql.DB, sqlText string) error { statements := strings.Split(sqlText, ";") for _, stmt := range statements { stmt = strings.TrimSpace(stmt) if stmt == "" { continue } if _, err := db.Exec(stmt); err != nil { var mysqlErr *mysql.MySQLError if errors.As(err, &mysqlErr) && (mysqlErr.Number == 1060 || mysqlErr.Number == 1061) { continue } return err } } return nil } func ensureSchemaMigrations(db *sql.DB) (string, error) { column, err := schemaMigrationsColumn(db) if err != nil { return "", err } if column != "" { return column, nil } if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations ( id BIGINT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(255) NOT NULL UNIQUE, applied_at TIMESTAMP NOT NULL )`); err != nil { return "", err } return "name", nil } func schemaMigrationsColumn(db *sql.DB) (string, error) { var nameCount int if err := db.QueryRow(` SELECT COUNT(*) FROM information_schema.columns WHERE table_schema = DATABASE() AND table_name = 'schema_migrations' AND column_name = 'name' `).Scan(&nameCount); err != nil { return "", err } if nameCount > 0 { return "name", nil } var versionCount int if err := db.QueryRow(` SELECT COUNT(*) FROM information_schema.columns WHERE table_schema = DATABASE() AND table_name = 'schema_migrations' AND column_name = 'version' `).Scan(&versionCount); err != nil { return "", err } if versionCount > 0 { return "version", nil } return "", nil } func loadAppliedMigrations(db *sql.DB, column string) (map[string]bool, error) { query := fmt.Sprintf("SELECT %s FROM schema_migrations", column) rows, err := db.Query(query) if err != nil { return nil, err } defer rows.Close() applied := map[string]bool{} for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return nil, err } applied[name] = true } if err := rows.Err(); err != nil { return nil, err } return applied, nil } func recordMigration(db *sql.DB, column, name string) error { query := fmt.Sprintf("INSERT INTO schema_migrations (%s, applied_at) VALUES (?, ?)", column) _, err := db.Exec(query, name, time.Now().UTC()) return err }