#!/bin/sh
set -eu

SECONDS=5
SIZE_MB=0
DEVICES=""
EXCLUDE=""
WORKER="/usr/local/lib/bee/bee-gpu-burn-worker"

usage() {
    echo "usage: $0 [--seconds N] [--size-mb N] [--devices 0,1] [--exclude 2,3]" >&2
    exit 2
}

normalize_list() {
    echo "${1:-}" | tr ',' '\n' | sed 's/[[:space:]]//g' | awk 'NF' | sort -n | uniq | paste -sd, -
}

contains_csv() {
    needle="$1"
    haystack="${2:-}"
    echo ",${haystack}," | grep -q ",${needle},"
}

while [ "$#" -gt 0 ]; do
    case "$1" in
        --seconds|-t) [ "$#" -ge 2 ] || usage; SECONDS="$2"; shift 2 ;;
        --size-mb|-m) [ "$#" -ge 2 ] || usage; SIZE_MB="$2"; shift 2 ;;
        --devices) [ "$#" -ge 2 ] || usage; DEVICES="$2"; shift 2 ;;
        --exclude) [ "$#" -ge 2 ] || usage; EXCLUDE="$2"; shift 2 ;;
        *) usage ;;
    esac
done

[ -x "${WORKER}" ] || { echo "bee-gpu-burn worker not found: ${WORKER}" >&2; exit 1; }

ALL_DEVICES=$(nvidia-smi --query-gpu=index --format=csv,noheader,nounits 2>/dev/null | sed 's/[[:space:]]//g' | awk 'NF' | paste -sd, -)
[ -n "${ALL_DEVICES}" ] || { echo "nvidia-smi found no NVIDIA GPUs" >&2; exit 1; }

DEVICES=$(normalize_list "${DEVICES}")
EXCLUDE=$(normalize_list "${EXCLUDE}")
SELECTED="${DEVICES}"
if [ -z "${SELECTED}" ]; then
    SELECTED="${ALL_DEVICES}"
fi

FINAL=""
for id in $(echo "${SELECTED}" | tr ',' ' '); do
    [ -n "${id}" ] || continue
    if contains_csv "${id}" "${EXCLUDE}"; then
        continue
    fi
    if [ -z "${FINAL}" ]; then
        FINAL="${id}"
    else
        FINAL="${FINAL},${id}"
    fi
done

[ -n "${FINAL}" ] || { echo "no NVIDIA GPUs selected after filters" >&2; exit 1; }

echo "loader=bee-gpu-burn"
echo "selected_gpus=${FINAL}"

export CUDA_DEVICE_ORDER="PCI_BUS_ID"

TMP_DIR=$(mktemp -d)
trap 'rm -rf "${TMP_DIR}"' EXIT INT TERM

WORKERS=""
for id in $(echo "${FINAL}" | tr ',' ' '); do
    log="${TMP_DIR}/gpu-${id}.log"
    gpu_size_mb="${SIZE_MB}"
    if [ "${gpu_size_mb}" -le 0 ] 2>/dev/null; then
        total_mb=$(nvidia-smi --id="${id}" --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | tr -d '[:space:]')
        if [ -n "${total_mb}" ] && [ "${total_mb}" -gt 0 ] 2>/dev/null; then
            gpu_size_mb=$(( total_mb * 95 / 100 ))
        else
            gpu_size_mb=512
        fi
    fi
    echo "starting gpu ${id} size=${gpu_size_mb}MB"
    CUDA_VISIBLE_DEVICES="${id}" \
        "${WORKER}" --device 0 --seconds "${SECONDS}" --size-mb "${gpu_size_mb}" >"${log}" 2>&1 &
    pid=$!
    WORKERS="${WORKERS} ${pid}:${id}:${log}"
done

status=0
for spec in ${WORKERS}; do
    pid=${spec%%:*}
    rest=${spec#*:}
    id=${rest%%:*}
    log=${rest#*:}
    if wait "${pid}"; then
        echo "gpu ${id} finished: OK"
    else
        rc=$?
        echo "gpu ${id} finished: FAILED rc=${rc}"
        status=1
    fi
    sed "s/^/[gpu ${id}] /" "${log}" || true
done

exit "${status}"
