Unify NVIDIA GPU recovery paths

This commit is contained in:
2026-04-23 20:31:41 +03:00
parent 6112094d45
commit 749fc8a94d
6 changed files with 278 additions and 82 deletions

View File

@@ -2,7 +2,7 @@ package platform
import (
"context"
"os"
"fmt"
"os/exec"
"path/filepath"
"strings"
@@ -188,18 +188,16 @@ func TestBenchmarkCalibrationThrottleReasonIgnoresPowerReasons(t *testing.T) {
}
func TestResetBenchmarkGPUsSkipsWithoutRoot(t *testing.T) {
t.Parallel()
oldGeteuid := benchmarkGeteuid
oldExec := satExecCommand
oldReset := benchmarkResetNvidiaGPU
benchmarkGeteuid = func() int { return 1000 }
satExecCommand = func(name string, args ...string) *exec.Cmd {
t.Fatalf("unexpected command: %s %v", name, args)
return nil
benchmarkResetNvidiaGPU = func(int) (string, error) {
t.Fatal("unexpected reset call")
return "", nil
}
t.Cleanup(func() {
benchmarkGeteuid = oldGeteuid
satExecCommand = oldExec
benchmarkResetNvidiaGPU = oldReset
})
var logs []string
@@ -215,44 +213,52 @@ func TestResetBenchmarkGPUsSkipsWithoutRoot(t *testing.T) {
}
func TestResetBenchmarkGPUsResetsEachGPU(t *testing.T) {
t.Parallel()
dir := t.TempDir()
script := filepath.Join(dir, "nvidia-smi")
argsLog := filepath.Join(dir, "args.log")
if err := os.WriteFile(script, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> "+argsLog+"\nprintf 'ok\\n'\n"), 0755); err != nil {
t.Fatalf("write script: %v", err)
}
oldGeteuid := benchmarkGeteuid
oldSleep := benchmarkSleep
oldLookPath := satLookPath
oldReset := benchmarkResetNvidiaGPU
benchmarkGeteuid = func() int { return 0 }
benchmarkSleep = func(time.Duration) {}
satLookPath = func(file string) (string, error) {
if file == "nvidia-smi" {
return script, nil
}
return exec.LookPath(file)
var calls []int
benchmarkResetNvidiaGPU = func(index int) (string, error) {
calls = append(calls, index)
return "ok\n", nil
}
t.Cleanup(func() {
benchmarkGeteuid = oldGeteuid
benchmarkSleep = oldSleep
satLookPath = oldLookPath
benchmarkResetNvidiaGPU = oldReset
})
failed := resetBenchmarkGPUs(context.Background(), filepath.Join(dir, "verbose.log"), []int{2, 5}, nil)
failed := resetBenchmarkGPUs(context.Background(), filepath.Join(t.TempDir(), "verbose.log"), []int{2, 5}, nil)
if len(failed) != 0 {
t.Fatalf("failed=%v want no failures", failed)
}
raw, err := os.ReadFile(argsLog)
if err != nil {
t.Fatalf("read args log: %v", err)
if got, want := fmt.Sprint(calls), "[2 5]"; got != want {
t.Fatalf("calls=%v want %s", calls, want)
}
got := strings.Fields(string(raw))
want := []string{"-i", "2", "-r", "-i", "5", "-r"}
if strings.Join(got, " ") != strings.Join(want, " ") {
t.Fatalf("args=%v want %v", got, want)
}
func TestResetBenchmarkGPUsTracksFailuresFromSharedReset(t *testing.T) {
oldGeteuid := benchmarkGeteuid
oldSleep := benchmarkSleep
oldReset := benchmarkResetNvidiaGPU
benchmarkGeteuid = func() int { return 0 }
benchmarkSleep = func(time.Duration) {}
benchmarkResetNvidiaGPU = func(index int) (string, error) {
if index == 5 {
return "busy\n", exec.ErrNotFound
}
return "ok\n", nil
}
t.Cleanup(func() {
benchmarkGeteuid = oldGeteuid
benchmarkSleep = oldSleep
benchmarkResetNvidiaGPU = oldReset
})
failed := resetBenchmarkGPUs(context.Background(), filepath.Join(t.TempDir(), "verbose.log"), []int{2, 5}, nil)
if got, want := fmt.Sprint(failed), "[5]"; got != want {
t.Fatalf("failed=%v want %s", failed, want)
}
}