package api import ( "bytes" "database/sql" "encoding/json" "errors" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" "time" "github.com/go-sql-driver/mysql" "reanimator/internal/ingest" "reanimator/internal/repository" ) func TestIngestLogBundleIdempotent(t *testing.T) { dsn := os.Getenv("DATABASE_DSN") if dsn == "" { t.Skip("DATABASE_DSN not set") } db, err := repository.Open(dsn) if err != nil { t.Fatalf("open db: %v", err) } defer db.Close() if err := applyMigrations(db); err != nil { t.Fatalf("apply migrations: %v", err) } if err := cleanupRegistry(db); err != nil { t.Fatalf("cleanup: %v", err) } customerID := insertCustomer(t, db, "Acme") projectID := insertProject(t, db, customerID, "Core") assetID := insertAsset(t, db, projectID, "server-01", "ASSET-01") mux := http.NewServeMux() RegisterIngestRoutes(mux, IngestDependencies{Service: ingest.NewService(db)}) server := httptest.NewServer(mux) defer server.Close() collectedAt := time.Now().UTC().Format(time.RFC3339) payload := map[string]any{ "asset_id": assetID, "collected_at": collectedAt, "components": []map[string]any{ {"vendor_serial": "VSN-001"}, {"vendor_serial": "VSN-002"}, }, } body, err := json.Marshal(payload) if err != nil { t.Fatalf("marshal payload: %v", err) } resp, err := http.Post(server.URL+"/ingest/logbundle", "application/json", bytes.NewReader(body)) if err != nil { t.Fatalf("post: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusCreated { t.Fatalf("expected 201, got %d", resp.StatusCode) } resp, err = http.Post(server.URL+"/ingest/logbundle", "application/json", bytes.NewReader(body)) if err != nil { t.Fatalf("post duplicate: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200 on duplicate, got %d", resp.StatusCode) } assertCount(t, db, "log_bundles", 1) assertCount(t, db, "observations", 2) assertCountQuery(t, db, "SELECT COUNT(*) FROM installations WHERE removed_at IS NULL", 2) assertCountQuery(t, db, "SELECT COUNT(*) FROM components WHERE first_seen_at IS NOT NULL", 2) } func applyMigrations(db *sql.DB) error { paths := []string{ filepath.Join("migrations", "0001_init", "up.sql"), filepath.Join("migrations", "0002_registry", "up.sql"), filepath.Join("migrations", "0003_ingest", "up.sql"), filepath.Join("migrations", "0004_timeline", "up.sql"), filepath.Join("migrations", "0005_tickets", "up.sql"), } for _, path := range paths { data, err := os.ReadFile(path) if err != nil { return err } if err := execStatements(db, string(data)); err != nil { return err } } return nil } func execStatements(db *sql.DB, sqlText string) error { statements := strings.Split(sqlText, ";") for _, stmt := range statements { stmt = strings.TrimSpace(stmt) if stmt == "" { continue } if _, err := db.Exec(stmt); err != nil { var mysqlErr *mysql.MySQLError if errors.As(err, &mysqlErr) && (mysqlErr.Number == 1060 || mysqlErr.Number == 1061) { continue } return err } } return nil } func cleanupRegistry(db *sql.DB) error { statements := []string{ "DELETE FROM ticket_links", "DELETE FROM tickets", "DELETE FROM timeline_events", "DELETE FROM observations", "DELETE FROM log_bundles", "DELETE FROM installations", "DELETE FROM components", "DELETE FROM assets", "DELETE FROM locations", "DELETE FROM projects", "DELETE FROM customers", "DELETE FROM lots", } for _, stmt := range statements { if _, err := db.Exec(stmt); err != nil { return err } } return nil } func insertCustomer(t *testing.T, db *sql.DB, name string) int64 { t.Helper() result, err := db.Exec(`INSERT INTO customers (name) VALUES (?)`, name) if err != nil { t.Fatalf("insert customer: %v", err) } id, err := result.LastInsertId() if err != nil { t.Fatalf("customer id: %v", err) } return id } func insertProject(t *testing.T, db *sql.DB, customerID int64, name string) int64 { t.Helper() result, err := db.Exec(`INSERT INTO projects (customer_id, name) VALUES (?, ?)`, customerID, name) if err != nil { t.Fatalf("insert project: %v", err) } id, err := result.LastInsertId() if err != nil { t.Fatalf("project id: %v", err) } return id } func insertAsset(t *testing.T, db *sql.DB, projectID int64, name, serial string) int64 { t.Helper() result, err := db.Exec(`INSERT INTO assets (project_id, name, vendor_serial) VALUES (?, ?, ?)`, projectID, name, serial) if err != nil { t.Fatalf("insert asset: %v", err) } id, err := result.LastInsertId() if err != nil { t.Fatalf("asset id: %v", err) } return id } func assertCount(t *testing.T, db *sql.DB, table string, expected int) { t.Helper() query := "SELECT COUNT(*) FROM " + table assertCountQuery(t, db, query, expected) } func assertCountQuery(t *testing.T, db *sql.DB, query string, expected int) { t.Helper() var count int if err := db.QueryRow(query).Scan(&count); err != nil { t.Fatalf("count query: %v", err) } if count != expected { t.Fatalf("expected %d, got %d for query %q", expected, count, query) } }