diff --git a/audit/internal/webui/server.go b/audit/internal/webui/server.go index e4165af..4948bf5 100644 --- a/audit/internal/webui/server.go +++ b/audit/internal/webui/server.go @@ -1,13 +1,16 @@ package webui import ( + "bufio" "encoding/json" "errors" "fmt" "html" + "io" "log/slog" "math" "mime" + "net" "net/http" "os" "path/filepath" @@ -373,6 +376,38 @@ func (w *trackingResponseWriter) Write(p []byte) (int, error) { return w.ResponseWriter.Write(p) } +func (w *trackingResponseWriter) Flush() { + w.wroteHeader = true + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (w *trackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := w.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("hijacking not supported") + } + return h.Hijack() +} + +func (w *trackingResponseWriter) Push(target string, opts *http.PushOptions) error { + p, ok := w.ResponseWriter.(http.Pusher) + if !ok { + return http.ErrNotSupported + } + return p.Push(target, opts) +} + +func (w *trackingResponseWriter) ReadFrom(r io.Reader) (int64, error) { + rf, ok := w.ResponseWriter.(io.ReaderFrom) + if !ok { + return io.Copy(w.ResponseWriter, r) + } + w.wroteHeader = true + return rf.ReadFrom(r) +} + func recoverMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tw := &trackingResponseWriter{ResponseWriter: w} diff --git a/audit/internal/webui/server_test.go b/audit/internal/webui/server_test.go index abe1ea5..dfe7b8e 100644 --- a/audit/internal/webui/server_test.go +++ b/audit/internal/webui/server_test.go @@ -51,6 +51,32 @@ func TestRecoverMiddlewareReturns500OnPanic(t *testing.T) { } } +func TestRecoverMiddlewarePreservesStreamingInterfaces(t *testing.T) { + handler := recoverMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !sseStart(w) { + return + } + if !sseWrite(w, "tick", "ok") { + t.Fatal("expected sse write to succeed") + } + })) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/stream", nil) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + if got := rec.Header().Get("Content-Type"); got != "text/event-stream" { + t.Fatalf("content-type=%q", got) + } + body := rec.Body.String() + if !strings.Contains(body, "event: tick\n") || !strings.Contains(body, "data: ok\n\n") { + t.Fatalf("body=%q", body) + } +} + func TestChartDataFromSamplesUsesFullHistory(t *testing.T) { samples := []platform.LiveMetricSample{ { diff --git a/audit/internal/webui/tasks.go b/audit/internal/webui/tasks.go index 743a20b..b376db3 100644 --- a/audit/internal/webui/tasks.go +++ b/audit/internal/webui/tasks.go @@ -429,43 +429,8 @@ func (q *taskQueue) worker() { wg.Add(1) goRecoverOnce("task "+t.Target, func() { defer wg.Done() - defer func() { - if rec := recover(); rec != nil { - msg := fmt.Sprintf("task panic: %v", rec) - slog.Error("task panic", - "task_id", t.ID, - "target", t.Target, - "panic", fmt.Sprint(rec), - "stack", string(debug.Stack()), - ) - j.append("ERROR: " + msg) - j.finish(msg) - } - }() - - if q.kmsgWatcher != nil && isSATTarget(t.Target) { - q.kmsgWatcher.NotifyTaskStarted(t.ID, t.Target) - } - - q.runTask(t, j, taskCtx) - - if q.kmsgWatcher != nil { - q.kmsgWatcher.NotifyTaskFinished(t.ID) - } - - q.mu.Lock() - now2 := time.Now() - t.DoneAt = &now2 - if t.Status == TaskRunning { - if j.err != "" { - t.Status = TaskFailed - t.ErrMsg = j.err - } else { - t.Status = TaskDone - } - } - q.persistLocked() - q.mu.Unlock() + defer taskCancel() + q.executeTask(t, j, taskCtx) }) } wg.Wait() @@ -481,6 +446,54 @@ func (q *taskQueue) worker() { } } +func (q *taskQueue) executeTask(t *Task, j *jobState, ctx context.Context) { + startedKmsgWatch := false + defer q.finalizeTaskRun(t, j) + defer func() { + if startedKmsgWatch && q.kmsgWatcher != nil { + q.kmsgWatcher.NotifyTaskFinished(t.ID) + } + }() + defer func() { + if rec := recover(); rec != nil { + msg := fmt.Sprintf("task panic: %v", rec) + slog.Error("task panic", + "task_id", t.ID, + "target", t.Target, + "panic", fmt.Sprint(rec), + "stack", string(debug.Stack()), + ) + j.append("ERROR: " + msg) + j.finish(msg) + } + }() + + if q.kmsgWatcher != nil && isSATTarget(t.Target) { + q.kmsgWatcher.NotifyTaskStarted(t.ID, t.Target) + startedKmsgWatch = true + } + + q.runTask(t, j, ctx) +} + +func (q *taskQueue) finalizeTaskRun(t *Task, j *jobState) { + q.mu.Lock() + defer q.mu.Unlock() + + now := time.Now() + t.DoneAt = &now + if t.Status == TaskRunning { + if j.err != "" { + t.Status = TaskFailed + t.ErrMsg = j.err + } else { + t.Status = TaskDone + t.ErrMsg = "" + } + } + q.persistLocked() +} + // setCPUGovernor writes the given governor to all CPU scaling_governor sysfs files. // Silently ignores errors (e.g. when cpufreq is not available). func setCPUGovernor(governor string) { diff --git a/audit/internal/webui/tasks_test.go b/audit/internal/webui/tasks_test.go index 27a32f6..01c1659 100644 --- a/audit/internal/webui/tasks_test.go +++ b/audit/internal/webui/tasks_test.go @@ -467,3 +467,52 @@ func TestRunTaskInstallUsesSharedCommandStreaming(t *testing.T) { t.Fatalf("unexpected error: %q", j.err) } } + +func TestExecuteTaskMarksPanicsAsFailedAndClosesKmsgWindow(t *testing.T) { + dir := t.TempDir() + q := &taskQueue{ + opts: &HandlerOptions{App: &app.App{}}, + statePath: filepath.Join(dir, "tasks-state.json"), + logsDir: filepath.Join(dir, "tasks"), + kmsgWatcher: newKmsgWatcher(nil), + } + tk := &Task{ + ID: "cpu-panic-1", + Name: "CPU SAT", + Target: "cpu", + Status: TaskRunning, + CreatedAt: time.Now(), + } + j := &jobState{} + + orig := runCPUAcceptancePackCtx + runCPUAcceptancePackCtx = func(_ *app.App, _ context.Context, _ string, _ int, _ func(string)) (string, error) { + panic("boom") + } + defer func() { runCPUAcceptancePackCtx = orig }() + + q.executeTask(tk, j, context.Background()) + + if tk.Status != TaskFailed { + t.Fatalf("status=%q want %q", tk.Status, TaskFailed) + } + if tk.DoneAt == nil { + t.Fatal("expected done_at to be set") + } + if !strings.Contains(tk.ErrMsg, "task panic: boom") { + t.Fatalf("task error=%q", tk.ErrMsg) + } + if !strings.Contains(j.err, "task panic: boom") { + t.Fatalf("job error=%q", j.err) + } + q.kmsgWatcher.mu.Lock() + activeCount := q.kmsgWatcher.activeCount + window := q.kmsgWatcher.window + q.kmsgWatcher.mu.Unlock() + if activeCount != 0 { + t.Fatalf("activeCount=%d want 0", activeCount) + } + if window != nil { + t.Fatalf("expected kmsg window to be cleared, got %+v", window) + } +}