package main import ( "database/sql" "flag" "fmt" "os" "path/filepath" "sort" "strconv" "strings" "time" "git.mchus.pro/mchus/priceforge/internal/appstate" "git.mchus.pro/mchus/priceforge/internal/config" _ "github.com/go-sql-driver/mysql" ) func main() { configPathFlag := flag.String("config", "", "path to config file") outDirFlag := flag.String("out-dir", "", "output directory for backup files") flag.Parse() configPath, err := appstate.ResolveConfigPath(strings.TrimSpace(*configPathFlag)) if err != nil { fatalf("resolve config path: %v", err) } cfg, err := config.Load(configPath) if err != nil { fatalf("load config: %v", err) } outDir := strings.TrimSpace(*outDirFlag) if outDir == "" { outDir = filepath.Join(filepath.Dir(configPath), "backups") } if err := os.MkdirAll(outDir, 0700); err != nil { fatalf("create backup dir: %v", err) } db, err := sql.Open("mysql", cfg.Database.DSN()) if err != nil { fatalf("open database: %v", err) } defer db.Close() if err := db.Ping(); err != nil { fatalf("ping database: %v", err) } filename := fmt.Sprintf("%s_backup_%s.sql", cfg.Database.Name, time.Now().Format("20060102_150405")) outPath := filepath.Join(outDir, filename) f, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) if err != nil { fatalf("create backup file: %v", err) } defer f.Close() if err := writeDump(db, cfg.Database.Name, f); err != nil { fatalf("write dump: %v", err) } if err := f.Sync(); err != nil { fatalf("sync dump: %v", err) } if err := appstate.EnsurePrivateFile(outPath); err != nil { fatalf("chmod backup file: %v", err) } fmt.Println(outPath) } func writeDump(db *sql.DB, dbName string, f *os.File) error { if _, err := fmt.Fprintf(f, "-- PriceForge backup\n-- Database: %s\n-- Generated at: %s\n\n", dbName, time.Now().Format(time.RFC3339)); err != nil { return err } if _, err := f.WriteString("SET FOREIGN_KEY_CHECKS=0;\n\n"); err != nil { return err } tables, err := listTables(db, dbName) if err != nil { return err } for _, table := range tables { if err := dumpTable(db, table, f); err != nil { return fmt.Errorf("dump table %s: %w", table, err) } } _, err = f.WriteString("SET FOREIGN_KEY_CHECKS=1;\n") return err } func listTables(db *sql.DB, dbName string) ([]string, error) { rows, err := db.Query(` SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = 'BASE TABLE' ORDER BY table_name ASC `, dbName) if err != nil { return nil, err } defer rows.Close() var tables []string for rows.Next() { var table string if err := rows.Scan(&table); err != nil { return nil, err } tables = append(tables, table) } if err := rows.Err(); err != nil { return nil, err } sort.Strings(tables) return tables, nil } func dumpTable(db *sql.DB, table string, f *os.File) error { var tableName, createStmt string if err := db.QueryRow("SHOW CREATE TABLE "+quoteIdent(table)).Scan(&tableName, &createStmt); err != nil { return err } if _, err := fmt.Fprintf(f, "-- Table: %s\nDROP TABLE IF EXISTS %s;\n%s;\n\n", table, quoteIdent(table), createStmt); err != nil { return err } rows, err := db.Query("SELECT * FROM " + quoteIdent(table)) if err != nil { return err } defer rows.Close() columns, err := rows.Columns() if err != nil { return err } values := make([]any, len(columns)) dest := make([]any, len(columns)) rowCount := 0 for i := range values { dest[i] = &values[i] } for rows.Next() { if err := rows.Scan(dest...); err != nil { return err } rowCount++ parts := make([]string, len(values)) for i, value := range values { parts[i] = sqlLiteral(value) } if _, err := fmt.Fprintf(f, "INSERT INTO %s VALUES (%s);\n", quoteIdent(table), strings.Join(parts, ", ")); err != nil { return err } } if err := rows.Err(); err != nil { return err } if rowCount > 0 { if _, err := f.WriteString("\n"); err != nil { return err } } return nil } func quoteIdent(v string) string { return "`" + strings.ReplaceAll(v, "`", "``") + "`" } func sqlLiteral(v any) string { switch x := v.(type) { case nil: return "NULL" case []byte: return quoteString(string(x)) case string: return quoteString(x) case time.Time: return quoteString(x.Format("2006-01-02 15:04:05.999999")) case bool: if x { return "1" } return "0" case int: return strconv.Itoa(x) case int8: return strconv.FormatInt(int64(x), 10) case int16: return strconv.FormatInt(int64(x), 10) case int32: return strconv.FormatInt(int64(x), 10) case int64: return strconv.FormatInt(x, 10) case uint: return strconv.FormatUint(uint64(x), 10) case uint8: return strconv.FormatUint(uint64(x), 10) case uint16: return strconv.FormatUint(uint64(x), 10) case uint32: return strconv.FormatUint(uint64(x), 10) case uint64: return strconv.FormatUint(x, 10) case float32: return strconv.FormatFloat(float64(x), 'f', -1, 32) case float64: return strconv.FormatFloat(x, 'f', -1, 64) default: return quoteString(fmt.Sprint(v)) } } func quoteString(v string) string { replacer := strings.NewReplacer( "\\", "\\\\", "'", "''", "\x00", "\\0", "\n", "\\n", "\r", "\\r", "\t", "\\t", ) return "'" + replacer.Replace(v) + "'" } func fatalf(format string, args ...any) { fmt.Fprintf(os.Stderr, format+"\n", args...) os.Exit(1) }