241 lines
5.3 KiB
Go
241 lines
5.3 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"flag"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.mchus.pro/mchus/priceforge/internal/appstate"
|
|
"git.mchus.pro/mchus/priceforge/internal/config"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
func main() {
|
|
configPathFlag := flag.String("config", "", "path to config file")
|
|
outDirFlag := flag.String("out-dir", "", "output directory for backup files")
|
|
flag.Parse()
|
|
|
|
configPath, err := appstate.ResolveConfigPath(strings.TrimSpace(*configPathFlag))
|
|
if err != nil {
|
|
fatalf("resolve config path: %v", err)
|
|
}
|
|
|
|
cfg, err := config.Load(configPath)
|
|
if err != nil {
|
|
fatalf("load config: %v", err)
|
|
}
|
|
|
|
outDir := strings.TrimSpace(*outDirFlag)
|
|
if outDir == "" {
|
|
outDir = filepath.Join(filepath.Dir(configPath), "backups")
|
|
}
|
|
if err := os.MkdirAll(outDir, 0700); err != nil {
|
|
fatalf("create backup dir: %v", err)
|
|
}
|
|
|
|
db, err := sql.Open("mysql", cfg.Database.DSN())
|
|
if err != nil {
|
|
fatalf("open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
if err := db.Ping(); err != nil {
|
|
fatalf("ping database: %v", err)
|
|
}
|
|
|
|
filename := fmt.Sprintf("%s_backup_%s.sql", cfg.Database.Name, time.Now().Format("20060102_150405"))
|
|
outPath := filepath.Join(outDir, filename)
|
|
|
|
f, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
fatalf("create backup file: %v", err)
|
|
}
|
|
defer f.Close()
|
|
|
|
if err := writeDump(db, cfg.Database.Name, f); err != nil {
|
|
fatalf("write dump: %v", err)
|
|
}
|
|
if err := f.Sync(); err != nil {
|
|
fatalf("sync dump: %v", err)
|
|
}
|
|
if err := appstate.EnsurePrivateFile(outPath); err != nil {
|
|
fatalf("chmod backup file: %v", err)
|
|
}
|
|
|
|
fmt.Println(outPath)
|
|
}
|
|
|
|
func writeDump(db *sql.DB, dbName string, f *os.File) error {
|
|
if _, err := fmt.Fprintf(f, "-- PriceForge backup\n-- Database: %s\n-- Generated at: %s\n\n", dbName, time.Now().Format(time.RFC3339)); err != nil {
|
|
return err
|
|
}
|
|
if _, err := f.WriteString("SET FOREIGN_KEY_CHECKS=0;\n\n"); err != nil {
|
|
return err
|
|
}
|
|
|
|
tables, err := listTables(db, dbName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, table := range tables {
|
|
if err := dumpTable(db, table, f); err != nil {
|
|
return fmt.Errorf("dump table %s: %w", table, err)
|
|
}
|
|
}
|
|
|
|
_, err = f.WriteString("SET FOREIGN_KEY_CHECKS=1;\n")
|
|
return err
|
|
}
|
|
|
|
func listTables(db *sql.DB, dbName string) ([]string, error) {
|
|
rows, err := db.Query(`
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = ?
|
|
AND table_type = 'BASE TABLE'
|
|
ORDER BY table_name ASC
|
|
`, dbName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tables []string
|
|
for rows.Next() {
|
|
var table string
|
|
if err := rows.Scan(&table); err != nil {
|
|
return nil, err
|
|
}
|
|
tables = append(tables, table)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
sort.Strings(tables)
|
|
return tables, nil
|
|
}
|
|
|
|
func dumpTable(db *sql.DB, table string, f *os.File) error {
|
|
var tableName, createStmt string
|
|
if err := db.QueryRow("SHOW CREATE TABLE "+quoteIdent(table)).Scan(&tableName, &createStmt); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := fmt.Fprintf(f, "-- Table: %s\nDROP TABLE IF EXISTS %s;\n%s;\n\n", table, quoteIdent(table), createStmt); err != nil {
|
|
return err
|
|
}
|
|
|
|
rows, err := db.Query("SELECT * FROM " + quoteIdent(table))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
values := make([]any, len(columns))
|
|
dest := make([]any, len(columns))
|
|
rowCount := 0
|
|
for i := range values {
|
|
dest[i] = &values[i]
|
|
}
|
|
|
|
for rows.Next() {
|
|
if err := rows.Scan(dest...); err != nil {
|
|
return err
|
|
}
|
|
rowCount++
|
|
|
|
parts := make([]string, len(values))
|
|
for i, value := range values {
|
|
parts[i] = sqlLiteral(value)
|
|
}
|
|
if _, err := fmt.Fprintf(f, "INSERT INTO %s VALUES (%s);\n", quoteIdent(table), strings.Join(parts, ", ")); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
if rowCount > 0 {
|
|
if _, err := f.WriteString("\n"); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func quoteIdent(v string) string {
|
|
return "`" + strings.ReplaceAll(v, "`", "``") + "`"
|
|
}
|
|
|
|
func sqlLiteral(v any) string {
|
|
switch x := v.(type) {
|
|
case nil:
|
|
return "NULL"
|
|
case []byte:
|
|
return quoteString(string(x))
|
|
case string:
|
|
return quoteString(x)
|
|
case time.Time:
|
|
return quoteString(x.Format("2006-01-02 15:04:05.999999"))
|
|
case bool:
|
|
if x {
|
|
return "1"
|
|
}
|
|
return "0"
|
|
case int:
|
|
return strconv.Itoa(x)
|
|
case int8:
|
|
return strconv.FormatInt(int64(x), 10)
|
|
case int16:
|
|
return strconv.FormatInt(int64(x), 10)
|
|
case int32:
|
|
return strconv.FormatInt(int64(x), 10)
|
|
case int64:
|
|
return strconv.FormatInt(x, 10)
|
|
case uint:
|
|
return strconv.FormatUint(uint64(x), 10)
|
|
case uint8:
|
|
return strconv.FormatUint(uint64(x), 10)
|
|
case uint16:
|
|
return strconv.FormatUint(uint64(x), 10)
|
|
case uint32:
|
|
return strconv.FormatUint(uint64(x), 10)
|
|
case uint64:
|
|
return strconv.FormatUint(x, 10)
|
|
case float32:
|
|
return strconv.FormatFloat(float64(x), 'f', -1, 32)
|
|
case float64:
|
|
return strconv.FormatFloat(x, 'f', -1, 64)
|
|
default:
|
|
return quoteString(fmt.Sprint(v))
|
|
}
|
|
}
|
|
|
|
func quoteString(v string) string {
|
|
replacer := strings.NewReplacer(
|
|
"\\", "\\\\",
|
|
"'", "''",
|
|
"\x00", "\\0",
|
|
"\n", "\\n",
|
|
"\r", "\\r",
|
|
"\t", "\\t",
|
|
)
|
|
return "'" + replacer.Replace(v) + "'"
|
|
}
|
|
|
|
func fatalf(format string, args ...any) {
|
|
fmt.Fprintf(os.Stderr, format+"\n", args...)
|
|
os.Exit(1)
|
|
}
|