166 lines
3.6 KiB
Go
166 lines
3.6 KiB
Go
package migrate
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
// EnsureSchema applies all "up" migrations when the database has no tables.
|
|
func EnsureSchema(db *sql.DB, dir string) error {
|
|
columnName, err := ensureSchemaMigrations(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
empty, err := isDatabaseEmpty(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
runner := Runner{Dir: dir}
|
|
migrations, err := runner.Load(Up)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if empty {
|
|
for _, migration := range migrations {
|
|
if err := execStatements(db, migration.SQL); err != nil {
|
|
return fmt.Errorf("apply %s: %w", migration.Name, err)
|
|
}
|
|
if err := recordMigration(db, columnName, migration.Name); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
applied, err := loadAppliedMigrations(db, columnName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, migration := range migrations {
|
|
if applied[migration.Name] {
|
|
continue
|
|
}
|
|
if err := execStatements(db, migration.SQL); err != nil {
|
|
return fmt.Errorf("apply %s: %w", migration.Name, err)
|
|
}
|
|
if err := recordMigration(db, columnName, migration.Name); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func isDatabaseEmpty(db *sql.DB) (bool, error) {
|
|
var count int
|
|
if err := db.QueryRow(`SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE()`).Scan(&count); err != nil {
|
|
return false, err
|
|
}
|
|
return count == 0, nil
|
|
}
|
|
|
|
func execStatements(db *sql.DB, sqlText string) error {
|
|
statements := strings.Split(sqlText, ";")
|
|
for _, stmt := range statements {
|
|
stmt = strings.TrimSpace(stmt)
|
|
if stmt == "" {
|
|
continue
|
|
}
|
|
if _, err := db.Exec(stmt); err != nil {
|
|
var mysqlErr *mysql.MySQLError
|
|
if errors.As(err, &mysqlErr) && (mysqlErr.Number == 1060 || mysqlErr.Number == 1061) {
|
|
continue
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ensureSchemaMigrations(db *sql.DB) (string, error) {
|
|
column, err := schemaMigrationsColumn(db)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if column != "" {
|
|
return column, nil
|
|
}
|
|
|
|
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
|
name VARCHAR(255) NOT NULL UNIQUE,
|
|
applied_at TIMESTAMP NOT NULL
|
|
)`); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return "name", nil
|
|
}
|
|
|
|
func schemaMigrationsColumn(db *sql.DB) (string, error) {
|
|
var nameCount int
|
|
if err := db.QueryRow(`
|
|
SELECT COUNT(*)
|
|
FROM information_schema.columns
|
|
WHERE table_schema = DATABASE()
|
|
AND table_name = 'schema_migrations'
|
|
AND column_name = 'name'
|
|
`).Scan(&nameCount); err != nil {
|
|
return "", err
|
|
}
|
|
if nameCount > 0 {
|
|
return "name", nil
|
|
}
|
|
|
|
var versionCount int
|
|
if err := db.QueryRow(`
|
|
SELECT COUNT(*)
|
|
FROM information_schema.columns
|
|
WHERE table_schema = DATABASE()
|
|
AND table_name = 'schema_migrations'
|
|
AND column_name = 'version'
|
|
`).Scan(&versionCount); err != nil {
|
|
return "", err
|
|
}
|
|
if versionCount > 0 {
|
|
return "version", nil
|
|
}
|
|
|
|
return "", nil
|
|
}
|
|
|
|
func loadAppliedMigrations(db *sql.DB, column string) (map[string]bool, error) {
|
|
query := fmt.Sprintf("SELECT %s FROM schema_migrations", column)
|
|
rows, err := db.Query(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
applied := map[string]bool{}
|
|
for rows.Next() {
|
|
var name string
|
|
if err := rows.Scan(&name); err != nil {
|
|
return nil, err
|
|
}
|
|
applied[name] = true
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return applied, nil
|
|
}
|
|
|
|
func recordMigration(db *sql.DB, column, name string) error {
|
|
query := fmt.Sprintf("INSERT INTO schema_migrations (%s, applied_at) VALUES (?, ?)", column)
|
|
_, err := db.Exec(query, name, time.Now().UTC())
|
|
return err
|
|
}
|