#!/bin/sh
set -eu

SECONDS=300
DEVICES=""
EXCLUDE=""
MIN_BYTES="512M"
MAX_BYTES="4G"
FACTOR="2"
ITERS="20"
ALL_REDUCE_BIN="/usr/local/bin/all_reduce_perf"

usage() {
    echo "usage: $0 [--seconds N] [--devices 0,1] [--exclude 2,3]" >&2
    exit 2
}

normalize_list() {
    echo "${1:-}" | tr ',' '\n' | sed 's/[[:space:]]//g' | awk 'NF' | sort -n | uniq | paste -sd, -
}

contains_csv() {
    needle="$1"
    haystack="${2:-}"
    echo ",${haystack}," | grep -q ",${needle},"
}

while [ "$#" -gt 0 ]; do
    case "$1" in
        --seconds|-t) [ "$#" -ge 2 ] || usage; SECONDS="$2"; shift 2 ;;
        --devices) [ "$#" -ge 2 ] || usage; DEVICES="$2"; shift 2 ;;
        --exclude) [ "$#" -ge 2 ] || usage; EXCLUDE="$2"; shift 2 ;;
        *) usage ;;
    esac
done

[ -x "${ALL_REDUCE_BIN}" ] || { echo "all_reduce_perf not found: ${ALL_REDUCE_BIN}" >&2; exit 1; }

ALL_DEVICES=$(nvidia-smi --query-gpu=index --format=csv,noheader,nounits 2>/dev/null | sed 's/[[:space:]]//g' | awk 'NF' | paste -sd, -)
[ -n "${ALL_DEVICES}" ] || { echo "nvidia-smi found no NVIDIA GPUs" >&2; exit 1; }

DEVICES=$(normalize_list "${DEVICES}")
EXCLUDE=$(normalize_list "${EXCLUDE}")
SELECTED="${DEVICES}"
if [ -z "${SELECTED}" ]; then
    SELECTED="${ALL_DEVICES}"
fi

FINAL=""
for id in $(echo "${SELECTED}" | tr ',' ' '); do
    [ -n "${id}" ] || continue
    if contains_csv "${id}" "${EXCLUDE}"; then
        continue
    fi
    if [ -z "${FINAL}" ]; then
        FINAL="${id}"
    else
        FINAL="${FINAL},${id}"
    fi
done

[ -n "${FINAL}" ] || { echo "no NVIDIA GPUs selected after filters" >&2; exit 1; }

GPU_COUNT=$(echo "${FINAL}" | tr ',' '\n' | awk 'NF' | wc -l | awk '{print $1}')
[ "${GPU_COUNT}" -gt 0 ] || { echo "selected GPU count is zero" >&2; exit 1; }

echo "loader=nccl"
echo "selected_gpus=${FINAL}"
echo "gpu_count=${GPU_COUNT}"
echo "range=${MIN_BYTES}..${MAX_BYTES}"
echo "iters=${ITERS}"

export CUDA_DEVICE_ORDER="PCI_BUS_ID"

deadline=$(( $(date +%s) + SECONDS ))
round=0

while :; do
    now=$(date +%s)
    if [ "${now}" -ge "${deadline}" ]; then
        break
    fi
    round=$((round + 1))
    remaining=$((deadline - now))
    echo "round=${round} remaining_sec=${remaining}"
    CUDA_VISIBLE_DEVICES="${FINAL}" \
        "${ALL_REDUCE_BIN}" \
        -b "${MIN_BYTES}" \
        -e "${MAX_BYTES}" \
        -f "${FACTOR}" \
        -g "${GPU_COUNT}" \
        --iters "${ITERS}"
done
