#!/bin/sh
set -eu

SECONDS=5
STAGGER_SECONDS=0
SIZE_MB=0
DEVICES=""
EXCLUDE=""
PRECISION=""
PRECISION_PLAN=""
PRECISION_PLAN_SECONDS=""
WORKER="/usr/local/lib/bee/bee-gpu-burn-worker"

usage() {
    echo "usage: $0 [--seconds N] [--stagger-seconds N] [--size-mb N] [--devices 0,1] [--exclude 2,3] [--precision int8|fp8|fp16|fp32|fp64|fp4] [--precision-plan p1,p2,...,mixed] [--precision-plan-seconds s1,s2,...]" >&2
    exit 2
}

normalize_list() {
    echo "${1:-}" | tr ',' '\n' | sed 's/[[:space:]]//g' | awk 'NF' | sort -n | uniq | paste -sd, -
}

contains_csv() {
    needle="$1"
    haystack="${2:-}"
    echo ",${haystack}," | grep -q ",${needle},"
}

while [ "$#" -gt 0 ]; do
    case "$1" in
        --seconds|-t) [ "$#" -ge 2 ] || usage; SECONDS="$2"; shift 2 ;;
        --stagger-seconds) [ "$#" -ge 2 ] || usage; STAGGER_SECONDS="$2"; shift 2 ;;
        --size-mb|-m) [ "$#" -ge 2 ] || usage; SIZE_MB="$2"; shift 2 ;;
        --devices) [ "$#" -ge 2 ] || usage; DEVICES="$2"; shift 2 ;;
        --exclude) [ "$#" -ge 2 ] || usage; EXCLUDE="$2"; shift 2 ;;
        --precision) [ "$#" -ge 2 ] || usage; PRECISION="$2"; shift 2 ;;
        --precision-plan) [ "$#" -ge 2 ] || usage; PRECISION_PLAN="$2"; shift 2 ;;
        --precision-plan-seconds) [ "$#" -ge 2 ] || usage; PRECISION_PLAN_SECONDS="$2"; shift 2 ;;
        *) usage ;;
    esac
done

[ -x "${WORKER}" ] || { echo "bee-gpu-burn worker not found: ${WORKER}" >&2; exit 1; }

ALL_DEVICES=$(nvidia-smi --query-gpu=index --format=csv,noheader,nounits 2>/dev/null | sed 's/[[:space:]]//g' | awk 'NF' | paste -sd, -)
[ -n "${ALL_DEVICES}" ] || { echo "nvidia-smi found no NVIDIA GPUs" >&2; exit 1; }

DEVICES=$(normalize_list "${DEVICES}")
EXCLUDE=$(normalize_list "${EXCLUDE}")
SELECTED="${DEVICES}"
if [ -z "${SELECTED}" ]; then
    SELECTED="${ALL_DEVICES}"
fi

FINAL=""
for id in $(echo "${SELECTED}" | tr ',' ' '); do
    [ -n "${id}" ] || continue
    if contains_csv "${id}" "${EXCLUDE}"; then
        continue
    fi
    if [ -z "${FINAL}" ]; then
        FINAL="${id}"
    else
        FINAL="${FINAL},${id}"
    fi
done

[ -n "${FINAL}" ] || { echo "no NVIDIA GPUs selected after filters" >&2; exit 1; }

echo "loader=bee-gpu-burn"
echo "selected_gpus=${FINAL}"
echo "stagger_seconds=${STAGGER_SECONDS}"

export CUDA_DEVICE_ORDER="PCI_BUS_ID"

TMP_DIR=$(mktemp -d)
trap 'rm -rf "${TMP_DIR}"' EXIT INT TERM

GPU_COUNT=$(echo "${FINAL}" | tr ',' '\n' | awk 'NF' | wc -l | tr -d '[:space:]')
gpu_pos=0
WORKERS=""
for id in $(echo "${FINAL}" | tr ',' ' '); do
    gpu_pos=$((gpu_pos + 1))
    log="${TMP_DIR}/gpu-${id}.log"
    gpu_size_mb="${SIZE_MB}"
    if [ "${gpu_size_mb}" -le 0 ] 2>/dev/null; then
        total_mb=$(nvidia-smi --id="${id}" --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | tr -d '[:space:]')
        if [ -n "${total_mb}" ] && [ "${total_mb}" -gt 0 ] 2>/dev/null; then
            gpu_size_mb=$(( total_mb * 95 / 100 ))
        else
            gpu_size_mb=512
        fi
    fi
    extra_sec=$(( STAGGER_SECONDS * (GPU_COUNT - gpu_pos) ))
    gpu_seconds=$(( SECONDS + extra_sec ))
    echo "starting gpu ${id} size=${gpu_size_mb}MB seconds=${gpu_seconds}"
    precision_arg=""
    [ -n "${PRECISION}" ] && precision_arg="--precision ${PRECISION}"
    precision_plan_arg=""
    [ -n "${PRECISION_PLAN}" ] && precision_plan_arg="--precision-plan ${PRECISION_PLAN}"
    precision_plan_seconds_arg=""
    [ -n "${PRECISION_PLAN_SECONDS}" ] && precision_plan_seconds_arg="--precision-plan-seconds ${PRECISION_PLAN_SECONDS}"
    CUDA_VISIBLE_DEVICES="${id}" \
        "${WORKER}" --device 0 --seconds "${gpu_seconds}" --size-mb "${gpu_size_mb}" ${precision_arg} ${precision_plan_arg} ${precision_plan_seconds_arg} >"${log}" 2>&1 &
    pid=$!
    WORKERS="${WORKERS} ${pid}:${id}:${log}"
    if [ "${STAGGER_SECONDS}" -gt 0 ] && [ "${gpu_pos}" -lt "${GPU_COUNT}" ]; then
        sleep "${STAGGER_SECONDS}"
    fi
done

status=0
for spec in ${WORKERS}; do
    pid=${spec%%:*}
    rest=${spec#*:}
    id=${rest%%:*}
    log=${rest#*:}
    if wait "${pid}"; then
        echo "gpu ${id} finished: OK"
    else
        rc=$?
        echo "gpu ${id} finished: FAILED rc=${rc}"
        status=1
    fi
    sed "s/^/[gpu ${id}] /" "${log}" || true
done

exit "${status}"
