#!/bin/sh
# bee-nvidia-recover — drain NVIDIA clients, then reset a GPU or reload drivers.

set -u

log() {
    echo "[bee-nvidia-recover] $*"
}

log_blocker() {
    echo "[bee-nvidia-recover] blocker: $*"
}

usage() {
    cat <<'EOF'
usage:
  bee-nvidia-recover restart-drivers
  bee-nvidia-recover reset-gpu <index>
EOF
}

unit_exists() {
    systemctl cat "$1" >/dev/null 2>&1
}

unit_is_active() {
    systemctl is-active --quiet "$1" 2>/dev/null
}

stop_unit_if_active() {
    unit="$1"
    if unit_is_active "$unit"; then
        log "stopping $unit"
        systemctl stop "$unit"
        return 0
    fi
    return 1
}

start_unit_if_marked() {
    unit="$1"
    marker="$2"
    if [ "$marker" = "1" ] && unit_exists "$unit"; then
        log "starting $unit"
        systemctl start "$unit"
    fi
}

wait_for_process_exit() {
    name="$1"
    tries=0
    while pgrep -x "$name" >/dev/null 2>&1; do
        tries=$((tries + 1))
        if [ "$tries" -ge 15 ]; then
            log "WARN: $name is still running after stop request"
            return 1
        fi
        sleep 1
    done
    return 0
}

kill_pattern() {
    pattern="$1"
    if pgrep -f "$pattern" >/dev/null 2>&1; then
        pgrep -af "$pattern" 2>/dev/null | while IFS= read -r line; do
            [ -n "$line" ] || continue
            log_blocker "$line"
        done
        log "killing processes matching: $pattern"
        pkill -TERM -f "$pattern" >/dev/null 2>&1 || true
        sleep 1
        pkill -KILL -f "$pattern" >/dev/null 2>&1 || true
    fi
}

drain_gpu_clients() {
    display_was_active=0
    fabric_was_active=0

    for unit in display-manager.service lightdm.service; do
        if unit_exists "$unit" && stop_unit_if_active "$unit"; then
            log_blocker "service $unit"
            display_was_active=1
        fi
    done

    if unit_exists nvidia-fabricmanager.service && stop_unit_if_active nvidia-fabricmanager.service; then
        log_blocker "service nvidia-fabricmanager.service"
        fabric_was_active=1
    fi

    if pgrep -x nv-hostengine >/dev/null 2>&1; then
        pgrep -af "^nv-hostengine$" 2>/dev/null | while IFS= read -r line; do
            [ -n "$line" ] || continue
            log_blocker "$line"
        done
        log "stopping nv-hostengine"
        pkill -TERM -x nv-hostengine >/dev/null 2>&1 || true
        wait_for_process_exit nv-hostengine || pkill -KILL -x nv-hostengine >/dev/null 2>&1 || true
    fi

    for pattern in \
        "nvidia-smi" \
        "dcgmi" \
        "nvvs" \
        "dcgmproftester" \
        "all_reduce_perf" \
        "nvtop" \
        "bee-gpu-burn" \
        "bee-john-gpu-stress" \
        "bee-nccl-gpu-stress" \
        "Xorg" \
        "Xwayland"; do
        kill_pattern "$pattern"
    done
}

restore_gpu_clients() {
    if command -v nvidia-smi >/dev/null 2>&1; then
        if nvidia-smi -pm 1 >/dev/null 2>&1; then
            log "enabled NVIDIA persistence mode"
        else
            log "WARN: failed to enable NVIDIA persistence mode"
        fi
    fi

    if command -v nv-hostengine >/dev/null 2>&1 && ! pgrep -x nv-hostengine >/dev/null 2>&1; then
        log "starting nv-hostengine"
        nv-hostengine
    fi

    start_unit_if_marked nvidia-fabricmanager.service "${fabric_was_active:-0}"
    start_unit_if_marked display-manager.service "${display_was_active:-0}"
    if [ "${display_was_active:-0}" = "1" ] && unit_exists lightdm.service && ! unit_is_active lightdm.service; then
        start_unit_if_marked lightdm.service "1"
    fi
}

restart_drivers() {
    drain_gpu_clients
    for mod in nvidia_uvm nvidia_drm nvidia_modeset nvidia; do
        if lsmod | awk '{print $1}' | grep -qx "$mod"; then
            log "unloading module $mod"
            rmmod "$mod"
        fi
    done
    rm -f /dev/nvidiactl /dev/nvidia-uvm /dev/nvidia-uvm-tools /dev/nvidia[0-9]* 2>/dev/null || true
    log "reloading NVIDIA driver stack"
    /usr/local/bin/bee-nvidia-load
    restore_gpu_clients
}

reset_gpu() {
    index="$1"
    drain_gpu_clients
    log "resetting GPU $index"
    nvidia-smi -r -i "$index"
    restore_gpu_clients
}

cmd="${1:-}"
case "$cmd" in
    restart-drivers)
        restart_drivers
        ;;
    reset-gpu)
        if [ "$#" -ne 2 ]; then
            usage >&2
            exit 2
        fi
        reset_gpu "$2"
        ;;
    *)
        usage >&2
        exit 2
        ;;
esac
