209 lines
5.3 KiB
Go
209 lines
5.3 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
|
|
"reanimator/internal/ingest"
|
|
"reanimator/internal/repository"
|
|
)
|
|
|
|
func TestIngestLogBundleIdempotent(t *testing.T) {
|
|
dsn := os.Getenv("DATABASE_DSN")
|
|
if dsn == "" {
|
|
t.Skip("DATABASE_DSN not set")
|
|
}
|
|
|
|
db, err := repository.Open(dsn)
|
|
if err != nil {
|
|
t.Fatalf("open db: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
if err := applyMigrations(db); err != nil {
|
|
t.Fatalf("apply migrations: %v", err)
|
|
}
|
|
if err := cleanupRegistry(db); err != nil {
|
|
t.Fatalf("cleanup: %v", err)
|
|
}
|
|
|
|
customerID := insertCustomer(t, db, "Acme")
|
|
projectID := insertProject(t, db, customerID, "Core")
|
|
assetID := insertAsset(t, db, projectID, "server-01", "ASSET-01")
|
|
|
|
mux := http.NewServeMux()
|
|
RegisterIngestRoutes(mux, IngestDependencies{Service: ingest.NewService(db)})
|
|
server := httptest.NewServer(mux)
|
|
defer server.Close()
|
|
|
|
collectedAt := time.Now().UTC().Format(time.RFC3339)
|
|
payload := map[string]any{
|
|
"asset_id": assetID,
|
|
"collected_at": collectedAt,
|
|
"components": []map[string]any{
|
|
{"vendor_serial": "VSN-001"},
|
|
{"vendor_serial": "VSN-002"},
|
|
},
|
|
}
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
t.Fatalf("marshal payload: %v", err)
|
|
}
|
|
|
|
resp, err := http.Post(server.URL+"/ingest/logbundle", "application/json", bytes.NewReader(body))
|
|
if err != nil {
|
|
t.Fatalf("post: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusCreated {
|
|
t.Fatalf("expected 201, got %d", resp.StatusCode)
|
|
}
|
|
|
|
resp, err = http.Post(server.URL+"/ingest/logbundle", "application/json", bytes.NewReader(body))
|
|
if err != nil {
|
|
t.Fatalf("post duplicate: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200 on duplicate, got %d", resp.StatusCode)
|
|
}
|
|
|
|
assertCount(t, db, "log_bundles", 1)
|
|
assertCount(t, db, "observations", 2)
|
|
assertCountQuery(t, db, "SELECT COUNT(*) FROM installations WHERE removed_at IS NULL", 2)
|
|
assertCountQuery(t, db, "SELECT COUNT(*) FROM components WHERE first_seen_at IS NOT NULL", 2)
|
|
}
|
|
|
|
func applyMigrations(db *sql.DB) error {
|
|
paths := []string{
|
|
filepath.Join("migrations", "0001_init", "up.sql"),
|
|
filepath.Join("migrations", "0002_registry", "up.sql"),
|
|
filepath.Join("migrations", "0003_ingest", "up.sql"),
|
|
filepath.Join("migrations", "0004_timeline", "up.sql"),
|
|
filepath.Join("migrations", "0005_tickets", "up.sql"),
|
|
filepath.Join("migrations", "0006_analytics", "up.sql"),
|
|
filepath.Join("migrations", "0007_hardware_ingest", "up.sql"),
|
|
filepath.Join("migrations", "0008_service_uniques", "up.sql"),
|
|
}
|
|
for _, path := range paths {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := execStatements(db, string(data)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func execStatements(db *sql.DB, sqlText string) error {
|
|
statements := strings.Split(sqlText, ";")
|
|
for _, stmt := range statements {
|
|
stmt = strings.TrimSpace(stmt)
|
|
if stmt == "" {
|
|
continue
|
|
}
|
|
if _, err := db.Exec(stmt); err != nil {
|
|
var mysqlErr *mysql.MySQLError
|
|
if errors.As(err, &mysqlErr) && (mysqlErr.Number == 1060 || mysqlErr.Number == 1061) {
|
|
continue
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func cleanupRegistry(db *sql.DB) error {
|
|
statements := []string{
|
|
"DELETE FROM failure_events",
|
|
"DELETE FROM asset_firmware_states",
|
|
"DELETE FROM ticket_links",
|
|
"DELETE FROM tickets",
|
|
"DELETE FROM timeline_events",
|
|
"DELETE FROM observations",
|
|
"DELETE FROM log_bundles",
|
|
"DELETE FROM installations",
|
|
"DELETE FROM components",
|
|
"DELETE FROM assets",
|
|
"DELETE FROM locations",
|
|
"DELETE FROM projects",
|
|
"DELETE FROM customers",
|
|
"DELETE FROM lots",
|
|
}
|
|
for _, stmt := range statements {
|
|
if _, err := db.Exec(stmt); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func insertCustomer(t *testing.T, db *sql.DB, name string) int64 {
|
|
t.Helper()
|
|
result, err := db.Exec(`INSERT INTO customers (name) VALUES (?)`, name)
|
|
if err != nil {
|
|
t.Fatalf("insert customer: %v", err)
|
|
}
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
t.Fatalf("customer id: %v", err)
|
|
}
|
|
return id
|
|
}
|
|
|
|
func insertProject(t *testing.T, db *sql.DB, customerID int64, name string) int64 {
|
|
t.Helper()
|
|
result, err := db.Exec(`INSERT INTO projects (customer_id, name) VALUES (?, ?)`, customerID, name)
|
|
if err != nil {
|
|
t.Fatalf("insert project: %v", err)
|
|
}
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
t.Fatalf("project id: %v", err)
|
|
}
|
|
return id
|
|
}
|
|
|
|
func insertAsset(t *testing.T, db *sql.DB, projectID int64, name, serial string) int64 {
|
|
t.Helper()
|
|
result, err := db.Exec(`INSERT INTO assets (project_id, name, vendor_serial) VALUES (?, ?, ?)`, projectID, name, serial)
|
|
if err != nil {
|
|
t.Fatalf("insert asset: %v", err)
|
|
}
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
t.Fatalf("asset id: %v", err)
|
|
}
|
|
return id
|
|
}
|
|
|
|
func assertCount(t *testing.T, db *sql.DB, table string, expected int) {
|
|
t.Helper()
|
|
query := "SELECT COUNT(*) FROM " + table
|
|
assertCountQuery(t, db, query, expected)
|
|
}
|
|
|
|
func assertCountQuery(t *testing.T, db *sql.DB, query string, expected int) {
|
|
t.Helper()
|
|
var count int
|
|
if err := db.QueryRow(query).Scan(&count); err != nil {
|
|
t.Fatalf("count query: %v", err)
|
|
}
|
|
if count != expected {
|
|
t.Fatalf("expected %d, got %d for query %q", expected, count, query)
|
|
}
|
|
}
|