Files
PriceForge/cmd/dbbackup/main.go
2026-03-07 23:11:42 +03:00

241 lines
5.3 KiB
Go

package main
import (
"database/sql"
"flag"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"git.mchus.pro/mchus/priceforge/internal/appstate"
"git.mchus.pro/mchus/priceforge/internal/config"
_ "github.com/go-sql-driver/mysql"
)
func main() {
configPathFlag := flag.String("config", "", "path to config file")
outDirFlag := flag.String("out-dir", "", "output directory for backup files")
flag.Parse()
configPath, err := appstate.ResolveConfigPath(strings.TrimSpace(*configPathFlag))
if err != nil {
fatalf("resolve config path: %v", err)
}
cfg, err := config.Load(configPath)
if err != nil {
fatalf("load config: %v", err)
}
outDir := strings.TrimSpace(*outDirFlag)
if outDir == "" {
outDir = filepath.Join(filepath.Dir(configPath), "backups")
}
if err := os.MkdirAll(outDir, 0700); err != nil {
fatalf("create backup dir: %v", err)
}
db, err := sql.Open("mysql", cfg.Database.DSN())
if err != nil {
fatalf("open database: %v", err)
}
defer db.Close()
if err := db.Ping(); err != nil {
fatalf("ping database: %v", err)
}
filename := fmt.Sprintf("%s_backup_%s.sql", cfg.Database.Name, time.Now().Format("20060102_150405"))
outPath := filepath.Join(outDir, filename)
f, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
fatalf("create backup file: %v", err)
}
defer f.Close()
if err := writeDump(db, cfg.Database.Name, f); err != nil {
fatalf("write dump: %v", err)
}
if err := f.Sync(); err != nil {
fatalf("sync dump: %v", err)
}
if err := appstate.EnsurePrivateFile(outPath); err != nil {
fatalf("chmod backup file: %v", err)
}
fmt.Println(outPath)
}
func writeDump(db *sql.DB, dbName string, f *os.File) error {
if _, err := fmt.Fprintf(f, "-- PriceForge backup\n-- Database: %s\n-- Generated at: %s\n\n", dbName, time.Now().Format(time.RFC3339)); err != nil {
return err
}
if _, err := f.WriteString("SET FOREIGN_KEY_CHECKS=0;\n\n"); err != nil {
return err
}
tables, err := listTables(db, dbName)
if err != nil {
return err
}
for _, table := range tables {
if err := dumpTable(db, table, f); err != nil {
return fmt.Errorf("dump table %s: %w", table, err)
}
}
_, err = f.WriteString("SET FOREIGN_KEY_CHECKS=1;\n")
return err
}
func listTables(db *sql.DB, dbName string) ([]string, error) {
rows, err := db.Query(`
SELECT table_name
FROM information_schema.tables
WHERE table_schema = ?
AND table_type = 'BASE TABLE'
ORDER BY table_name ASC
`, dbName)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, err
}
tables = append(tables, table)
}
if err := rows.Err(); err != nil {
return nil, err
}
sort.Strings(tables)
return tables, nil
}
func dumpTable(db *sql.DB, table string, f *os.File) error {
var tableName, createStmt string
if err := db.QueryRow("SHOW CREATE TABLE "+quoteIdent(table)).Scan(&tableName, &createStmt); err != nil {
return err
}
if _, err := fmt.Fprintf(f, "-- Table: %s\nDROP TABLE IF EXISTS %s;\n%s;\n\n", table, quoteIdent(table), createStmt); err != nil {
return err
}
rows, err := db.Query("SELECT * FROM " + quoteIdent(table))
if err != nil {
return err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return err
}
values := make([]any, len(columns))
dest := make([]any, len(columns))
rowCount := 0
for i := range values {
dest[i] = &values[i]
}
for rows.Next() {
if err := rows.Scan(dest...); err != nil {
return err
}
rowCount++
parts := make([]string, len(values))
for i, value := range values {
parts[i] = sqlLiteral(value)
}
if _, err := fmt.Fprintf(f, "INSERT INTO %s VALUES (%s);\n", quoteIdent(table), strings.Join(parts, ", ")); err != nil {
return err
}
}
if err := rows.Err(); err != nil {
return err
}
if rowCount > 0 {
if _, err := f.WriteString("\n"); err != nil {
return err
}
}
return nil
}
func quoteIdent(v string) string {
return "`" + strings.ReplaceAll(v, "`", "``") + "`"
}
func sqlLiteral(v any) string {
switch x := v.(type) {
case nil:
return "NULL"
case []byte:
return quoteString(string(x))
case string:
return quoteString(x)
case time.Time:
return quoteString(x.Format("2006-01-02 15:04:05.999999"))
case bool:
if x {
return "1"
}
return "0"
case int:
return strconv.Itoa(x)
case int8:
return strconv.FormatInt(int64(x), 10)
case int16:
return strconv.FormatInt(int64(x), 10)
case int32:
return strconv.FormatInt(int64(x), 10)
case int64:
return strconv.FormatInt(x, 10)
case uint:
return strconv.FormatUint(uint64(x), 10)
case uint8:
return strconv.FormatUint(uint64(x), 10)
case uint16:
return strconv.FormatUint(uint64(x), 10)
case uint32:
return strconv.FormatUint(uint64(x), 10)
case uint64:
return strconv.FormatUint(x, 10)
case float32:
return strconv.FormatFloat(float64(x), 'f', -1, 32)
case float64:
return strconv.FormatFloat(x, 'f', -1, 64)
default:
return quoteString(fmt.Sprint(v))
}
}
func quoteString(v string) string {
replacer := strings.NewReplacer(
"\\", "\\\\",
"'", "''",
"\x00", "\\0",
"\n", "\\n",
"\r", "\\r",
"\t", "\\t",
)
return "'" + replacer.Replace(v) + "'"
}
func fatalf(format string, args ...any) {
fmt.Fprintf(os.Stderr, format+"\n", args...)
os.Exit(1)
}