diff --git a/audit/internal/platform/install_to_ram.go b/audit/internal/platform/install_to_ram.go index 0960963..4f1e87f 100644 --- a/audit/internal/platform/install_to_ram.go +++ b/audit/internal/platform/install_to_ram.go @@ -22,6 +22,14 @@ var runRemountMedium = func() ([]byte, error) { return exec.Command("bee-remount-medium").CombinedOutput() } +var umountLiveMedium = func() error { + return exec.Command("umount", "/run/live/medium").Run() +} + +var ejectDevice = func(device string) error { + return exec.Command("eject", device).Run() +} + func (s *System) IsLiveMediaInRAM() bool { return s.LiveMediaRAMState().InRAM } @@ -261,7 +269,8 @@ bindMedium: if status.InRAM { log(fmt.Sprintf("Verification passed: live medium now served from %s.", describeLiveBootSource(status))) } - log("Done. Squashfs files are in RAM. Installation media can be safely disconnected.") + detachInstallMedium(status, log) + log("Done. Squashfs files are in RAM. Installation media has been detached when possible.") return nil } @@ -309,6 +318,34 @@ func ensureLiveMediumAvailable(log func(string)) ([]string, bool) { return squashfsFiles, sourceAvailable } +func detachInstallMedium(status LiveBootSource, log func(string)) { + if log == nil { + log = func(string) {} + } + + log("Detaching original installation medium...") + if err := umountLiveMedium(); err != nil { + log(fmt.Sprintf("Warning: could not unmount /run/live/medium: %v", err)) + } else { + log("Unmounted /run/live/medium.") + } + + device := strings.TrimSpace(status.Device) + if device == "" { + device = strings.TrimSpace(status.Source) + } + if device == "" || !strings.HasPrefix(device, "/dev/") { + log("No block device identified for eject; skipping media eject.") + return + } + + if err := ejectDevice(device); err != nil { + log(fmt.Sprintf("Warning: could not eject %s: %v", device, err)) + return + } + log(fmt.Sprintf("Ejected %s.", device)) +} + func verifyInstallToRAMStatus(status LiveBootSource, dstDir string, mediumRebound bool, log func(string)) error { if status.InRAM { return nil diff --git a/audit/internal/platform/install_to_ram_test.go b/audit/internal/platform/install_to_ram_test.go index 706adc0..73cf043 100644 --- a/audit/internal/platform/install_to_ram_test.go +++ b/audit/internal/platform/install_to_ram_test.go @@ -208,3 +208,75 @@ func TestEnsureLiveMediumAvailableRemountsSource(t *testing.T) { t.Fatalf("expected remount success log, logs=%v", logs) } } + +func TestDetachInstallMedium(t *testing.T) { + t.Parallel() + + origUmount := umountLiveMedium + origEject := ejectDevice + t.Cleanup(func() { + umountLiveMedium = origUmount + ejectDevice = origEject + }) + + t.Run("success", func(t *testing.T) { + var umountCalled bool + var ejected string + umountLiveMedium = func() error { + umountCalled = true + return nil + } + ejectDevice = func(device string) error { + ejected = device + return nil + } + var logs []string + detachInstallMedium(LiveBootSource{Kind: "cdrom", Device: "/dev/sr1"}, func(msg string) { logs = append(logs, msg) }) + if !umountCalled { + t.Fatal("expected umountLiveMedium to be called") + } + if ejected != "/dev/sr1" { + t.Fatalf("ejected=%q want /dev/sr1", ejected) + } + if len(logs) < 3 { + t.Fatalf("logs=%v", logs) + } + }) + + t.Run("no device", func(t *testing.T) { + umountLiveMedium = func() error { return nil } + ejectDevice = func(device string) error { + t.Fatalf("unexpected eject for %q", device) + return nil + } + var logs []string + detachInstallMedium(LiveBootSource{Kind: "ram", Source: "tmpfs"}, func(msg string) { logs = append(logs, msg) }) + found := false + for _, msg := range logs { + if msg == "No block device identified for eject; skipping media eject." { + found = true + break + } + } + if !found { + t.Fatalf("logs=%v", logs) + } + }) + + t.Run("eject failure is warning only", func(t *testing.T) { + umountLiveMedium = func() error { return nil } + ejectDevice = func(device string) error { return fmt.Errorf("exit status 1") } + var logs []string + detachInstallMedium(LiveBootSource{Kind: "usb", Device: "/dev/sdb1"}, func(msg string) { logs = append(logs, msg) }) + found := false + for _, msg := range logs { + if msg == "Warning: could not eject /dev/sdb1: exit status 1" { + found = true + break + } + } + if !found { + t.Fatalf("logs=%v", logs) + } + }) +}