#!/bin/sh
# bee-nvidia-recover — drain NVIDIA clients, then reset a GPU or reload drivers.

set -u

log() {
    echo "[bee-nvidia-recover] $*"
}

log_blocker() {
    echo "[bee-nvidia-recover] blocker: $*"
}

usage() {
    cat <<'EOF'
usage:
  bee-nvidia-recover restart-drivers
  bee-nvidia-recover reset-gpu <index>
EOF
}

unit_exists() {
    systemctl cat "$1" >/dev/null 2>&1
}

unit_is_active() {
    systemctl is-active --quiet "$1" 2>/dev/null
}

stop_unit_if_active() {
    unit="$1"
    if unit_is_active "$unit"; then
        log "stopping $unit"
        systemctl stop "$unit"
        return 0
    fi
    return 1
}

start_unit_if_marked() {
    unit="$1"
    marker="$2"
    if [ "$marker" = "1" ] && unit_exists "$unit"; then
        log "starting $unit"
        systemctl start "$unit"
    fi
}

wait_for_process_exit() {
    name="$1"
    tries=0
    while pgrep -x "$name" >/dev/null 2>&1; do
        tries=$((tries + 1))
        if [ "$tries" -ge 15 ]; then
            log "WARN: $name is still running after stop request"
            return 1
        fi
        sleep 1
    done
    return 0
}

log_pid_details() {
    pid="$1"
    line=$(ps -p "$pid" -o pid=,comm=,args= 2>/dev/null | sed 's/^[[:space:]]*//')
    if [ -n "$line" ]; then
        log_blocker "$line"
    else
        log_blocker "pid $pid"
    fi
}

collect_gpu_compute_pids() {
    index="$1"
    if ! command -v nvidia-smi >/dev/null 2>&1; then
        return 0
    fi
    nvidia-smi --id="$index" \
        --query-compute-apps=pid \
        --format=csv,noheader,nounits 2>/dev/null \
        | sed 's/^[[:space:]]*//;s/[[:space:]]*$//' \
        | grep -E '^[0-9]+$' || true
}

collect_gpu_device_pids() {
    index="$1"
    dev="/dev/nvidia$index"
    [ -e "$dev" ] || return 0
    if command -v fuser >/dev/null 2>&1; then
        fuser "$dev" 2>/dev/null \
            | tr ' ' '\n' \
            | sed 's/[^0-9].*$//' \
            | grep -E '^[0-9]+$' || true
    fi
}

collect_gpu_holder_pids() {
    index="$1"
    {
        collect_gpu_compute_pids "$index"
        collect_gpu_device_pids "$index"
    } | awk 'NF' | sort -u
}

kill_pid_list() {
    pids="$1"
    [ -n "$pids" ] || return 0

    for pid in $pids; do
        log_pid_details "$pid"
    done
    log "terminating GPU holder PIDs: $(echo "$pids" | tr '\n' ' ' | sed 's/[[:space:]]*$//')"
    for pid in $pids; do
        kill -TERM "$pid" >/dev/null 2>&1 || true
    done
    sleep 1
    for pid in $pids; do
        if kill -0 "$pid" >/dev/null 2>&1; then
            log "forcing GPU holder PID $pid to exit"
            kill -KILL "$pid" >/dev/null 2>&1 || true
        fi
    done
}

gpu_has_display_holders() {
    index="$1"
    holders=$(collect_gpu_device_pids "$index")
    [ -n "$holders" ] || return 1
    for pid in $holders; do
        comm=$(ps -p "$pid" -o comm= 2>/dev/null | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
        case "$comm" in
            Xorg|Xwayland|X|gnome-shell)
                return 0
                ;;
        esac
    done
    return 1
}

stop_nv_hostengine_if_running() {
    if pgrep -x nv-hostengine >/dev/null 2>&1; then
        pgrep -af "^nv-hostengine$" 2>/dev/null | while IFS= read -r line; do
            [ -n "$line" ] || continue
            log_blocker "$line"
        done
        log "stopping nv-hostengine"
        pkill -TERM -x nv-hostengine >/dev/null 2>&1 || true
        wait_for_process_exit nv-hostengine || pkill -KILL -x nv-hostengine >/dev/null 2>&1 || true
        hostengine_was_active=1
        return 0
    fi
    return 1
}

stop_fabricmanager_if_active() {
    if unit_exists nvidia-fabricmanager.service && stop_unit_if_active nvidia-fabricmanager.service; then
        log_blocker "service nvidia-fabricmanager.service"
        fabric_was_active=1
        return 0
    fi
    return 1
}

stop_display_stack_if_active() {
    stopped=1
    for unit in display-manager.service lightdm.service; do
        if unit_exists "$unit" && stop_unit_if_active "$unit"; then
            log_blocker "service $unit"
            display_was_active=1
            stopped=0
        fi
    done
    return "$stopped"
}

try_gpu_reset() {
    index="$1"
    log "resetting GPU $index"
    nvidia-smi -r -i "$index"
}

drain_gpu_clients() {
    display_was_active=0
    fabric_was_active=0
    hostengine_was_active=0

    if pgrep -x nv-hostengine >/dev/null 2>&1; then
        pgrep -af "^nv-hostengine$" 2>/dev/null | while IFS= read -r line; do
            [ -n "$line" ] || continue
            log_blocker "$line"
        done
        log "stopping nv-hostengine"
        pkill -TERM -x nv-hostengine >/dev/null 2>&1 || true
        wait_for_process_exit nv-hostengine || pkill -KILL -x nv-hostengine >/dev/null 2>&1 || true
        hostengine_was_active=1
    fi

    if unit_exists nvidia-fabricmanager.service && stop_unit_if_active nvidia-fabricmanager.service; then
        log_blocker "service nvidia-fabricmanager.service"
        fabric_was_active=1
    fi

    for unit in display-manager.service lightdm.service; do
        if unit_exists "$unit" && stop_unit_if_active "$unit"; then
            log_blocker "service $unit"
            display_was_active=1
        fi
    done

    for dev in /dev/nvidia[0-9]*; do
        [ -e "$dev" ] || continue
        holders=$(collect_gpu_device_pids "${dev#/dev/nvidia}")
        kill_pid_list "$holders"
    done
}

restore_gpu_clients() {
    if command -v nvidia-smi >/dev/null 2>&1; then
        if nvidia-smi -pm 1 >/dev/null 2>&1; then
            log "enabled NVIDIA persistence mode"
        else
            log "WARN: failed to enable NVIDIA persistence mode"
        fi
    fi

    if [ "${hostengine_was_active:-0}" = "1" ] && command -v nv-hostengine >/dev/null 2>&1 && ! pgrep -x nv-hostengine >/dev/null 2>&1; then
        log "starting nv-hostengine"
        nv-hostengine
    fi

    start_unit_if_marked nvidia-fabricmanager.service "${fabric_was_active:-0}"
    start_unit_if_marked display-manager.service "${display_was_active:-0}"
    if [ "${display_was_active:-0}" = "1" ] && unit_exists lightdm.service && ! unit_is_active lightdm.service; then
        start_unit_if_marked lightdm.service "1"
    fi
}

restart_drivers() {
    drain_gpu_clients
    for mod in nvidia_uvm nvidia_drm nvidia_modeset nvidia; do
        if lsmod | awk '{print $1}' | grep -qx "$mod"; then
            log "unloading module $mod"
            rmmod "$mod"
        fi
    done
    rm -f /dev/nvidiactl /dev/nvidia-uvm /dev/nvidia-uvm-tools /dev/nvidia[0-9]* 2>/dev/null || true
    log "reloading NVIDIA driver stack"
    /usr/local/bin/bee-nvidia-load
    restore_gpu_clients
}

reset_gpu() {
    index="$1"
    display_was_active=0
    fabric_was_active=0
    hostengine_was_active=0

    holders=$(collect_gpu_holder_pids "$index")
    if [ -n "$holders" ]; then
        kill_pid_list "$holders"
    fi
    if try_gpu_reset "$index"; then
        restore_gpu_clients
        return 0
    fi

    stop_nv_hostengine_if_running || true
    holders=$(collect_gpu_holder_pids "$index")
    if [ -n "$holders" ]; then
        kill_pid_list "$holders"
    fi
    if try_gpu_reset "$index"; then
        restore_gpu_clients
        return 0
    fi

    stop_fabricmanager_if_active || true
    holders=$(collect_gpu_holder_pids "$index")
    if [ -n "$holders" ]; then
        kill_pid_list "$holders"
    fi
    if try_gpu_reset "$index"; then
        restore_gpu_clients
        return 0
    fi

    if gpu_has_display_holders "$index"; then
        stop_display_stack_if_active || true
        holders=$(collect_gpu_holder_pids "$index")
        if [ -n "$holders" ]; then
            kill_pid_list "$holders"
        fi
        if try_gpu_reset "$index"; then
            restore_gpu_clients
            return 0
        fi
    fi

    holders=$(collect_gpu_holder_pids "$index")
    if [ -n "$holders" ]; then
        log "GPU $index still has holders after targeted drain"
        kill_pid_list "$holders"
    fi
    try_gpu_reset "$index"
    rc=$?
    restore_gpu_clients
    return "$rc"
}

cmd="${1:-}"
case "$cmd" in
    restart-drivers)
        restart_drivers
        ;;
    reset-gpu)
        if [ "$#" -ne 2 ]; then
            usage >&2
            exit 2
        fi
        reset_gpu "$2"
        ;;
    *)
        usage >&2
        exit 2
        ;;
esac
