package batcher

import (
	"context"
	"fmt"
	"testing"

	batchpb "go.temporal.io/api/batch/v1"
	commonpb "go.temporal.io/api/common/v1"
	enumspb "go.temporal.io/api/enums/v1"
	historypb "go.temporal.io/api/history/v1"
	"go.temporal.io/api/workflowservice/v1"
	"go.temporal.io/sdk/testsuite"
	batchspb "go.temporal.io/server/api/batch/v1"
	"go.temporal.io/server/common/dynamicconfig"
	"go.temporal.io/server/common/log"
	"go.temporal.io/server/common/metrics"
	"go.temporal.io/server/common/namespace"
	serverSdk "go.temporal.io/server/common/sdk"
	"go.temporal.io/server/common/testing/mockapi/workflowservicemock/v1"
	"go.temporal.io/server/common/testing/mocksdk"
	"go.uber.org/mock/gomock"
	"golang.org/x/time/rate"
)

const (
	variantBoundNSName = "bound-ns"
	variantBoundNSID   = "bound-ns-id"
	variantOtherNSName = "other-ns"
)

func makeVariantActivities(t *testing.T, frontend workflowservice.WorkflowServiceClient) *activities {
	ctrl := gomock.NewController(t)
	mockClientFactory := serverSdk.NewMockClientFactory(ctrl)
	mockClient := mocksdk.NewMockClient(ctrl)
	mockClientFactory.EXPECT().NewClient(gomock.Any()).Return(mockClient).AnyTimes()
	mockClient.EXPECT().CountWorkflow(gomock.Any(), gomock.Any()).Return(&workflowservice.CountWorkflowExecutionsResponse{Count: 1}, nil).AnyTimes()
	mockClient.EXPECT().ListWorkflow(gomock.Any(), gomock.Any()).Return(&workflowservice.ListWorkflowExecutionsResponse{}, nil).AnyTimes()
	mockClient.EXPECT().CancelWorkflow(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
	mockClient.EXPECT().TerminateWorkflow(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()

	fe := frontend
	if fe == nil {
		mockFE := workflowservicemock.NewMockWorkflowServiceClient(ctrl)
		mockFE.EXPECT().SignalWorkflowExecution(gomock.Any(), gomock.Any()).Return(&workflowservice.SignalWorkflowExecutionResponse{}, nil).AnyTimes()
		mockFE.EXPECT().DeleteWorkflowExecution(gomock.Any(), gomock.Any()).Return(&workflowservice.DeleteWorkflowExecutionResponse{}, nil).AnyTimes()
		mockFE.EXPECT().ResetWorkflowExecution(gomock.Any(), gomock.Any()).Return(&workflowservice.ResetWorkflowExecutionResponse{}, nil).AnyTimes()
		mockFE.EXPECT().GetWorkflowExecutionHistory(gomock.Any(), gomock.Any()).Return(&workflowservice.GetWorkflowExecutionHistoryResponse{History: &historypb.History{Events: []*historypb.HistoryEvent{{EventId: 1, EventType: enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED}}}}, nil).AnyTimes()
		fe = mockFE
	}

	return &activities{
		activityDeps: activityDeps{
			MetricsHandler: metrics.NoopMetricsHandler,
			Logger:         log.NewTestLogger(),
			FrontendClient: fe,
			ClientFactory:  mockClientFactory,
		},
		namespace:   namespace.Name(variantBoundNSName),
		namespaceID: namespace.ID(variantBoundNSID),
		rps:         dynamicconfig.GetIntPropertyFnFilteredByNamespace(50),
		concurrency: dynamicconfig.GetIntPropertyFnFilteredByNamespace(1),
	}
}

// TestVariant_NilRequest tests whether checkNamespaceProtobuf can be bypassed
// by providing a nil Request. If Request is nil, the namespace name check is
// skipped, but downstream code will panic when accessing Request fields.
// This is variant attempt #1.
func TestVariant_NilRequest(t *testing.T) {
	ts := testsuite.WorkflowTestSuite{}
	env := ts.NewTestActivityEnvironment()
	a := makeVariantActivities(t, nil)
	env.RegisterActivity(a.BatchActivityWithProtobuf)

	input := &batchspb.BatchOperationInput{
		NamespaceId: variantBoundNSID,
		Request:     nil,
	}

	_, err := env.ExecuteActivity(a.BatchActivityWithProtobuf, input)
	if err == nil {
		fmt.Println("VARIANT1_NIL_REQUEST_BYPASSED")
		fmt.Println("TEST_CONTINUING")
	}
	fmt.Printf("VARIANT1_NIL_REQUEST_BLOCKED: %v\n", err)
}

// TestVariant_NonProtobufBatchActivity tests whether the legacy BatchActivity
// path can be exploited with a mismatched namespace.
// This is variant attempt #2.
func TestVariant_NonProtobufBatchActivity(t *testing.T) {
	ts := testsuite.WorkflowTestSuite{}
	env := ts.NewTestActivityEnvironment()
	a := makeVariantActivities(t, nil)
	env.RegisterActivity(a.BatchActivity)

	params := BatchParams{
		Namespace: variantOtherNSName,
		BatchType: string(BatchTypeSignal),
		SignalParams: SignalParams{
			SignalName: "test-signal",
		},
		Executions: []*commonpb.WorkflowExecution{{WorkflowId: "w"}},
	}

	_, err := env.ExecuteActivity(a.BatchActivity, params)
	if err == nil {
		fmt.Println("VARIANT2_NON_PROTOBUF_BYPASSED")
		fmt.Println("TEST_CONTINUING")
	}
	fmt.Printf("VARIANT2_NON_PROTOBUF_BLOCKED: %v\n", err)
}

// TestVariant_CancelOperationType tests whether the fix is operation-type-specific
// by using CANCEL instead of SIGNAL with a mismatched namespace.
// This is variant attempt #3.
func TestVariant_CancelOperationType(t *testing.T) {
	ts := testsuite.WorkflowTestSuite{}
	env := ts.NewTestActivityEnvironment()
	a := makeVariantActivities(t, nil)
	env.RegisterActivity(a.BatchActivityWithProtobuf)

	input := &batchspb.BatchOperationInput{
		NamespaceId: variantBoundNSID,
		BatchType:   enumspb.BATCH_OPERATION_TYPE_CANCEL,
		Request: &workflowservice.StartBatchOperationRequest{
			Namespace: variantOtherNSName,
			Operation: &workflowservice.StartBatchOperationRequest_CancellationOperation{
				CancellationOperation: &batchpb.BatchOperationCancellation{},
			},
			Executions: []*commonpb.WorkflowExecution{{WorkflowId: "w"}},
		},
	}

	_, err := env.ExecuteActivity(a.BatchActivityWithProtobuf, input)
	if err == nil {
		fmt.Println("VARIANT3_CANCEL_BYPASSED")
		fmt.Println("TEST_CONTINUING")
	}
	fmt.Printf("VARIANT3_CANCEL_BLOCKED: %v\n", err)
}

// TestVariant_CaseInsensitiveNamespace tests whether a case-different namespace
// name can bypass the string comparison in checkNamespaceProtobuf.
// This is variant attempt #4.
func TestVariant_CaseInsensitiveNamespace(t *testing.T) {
	ts := testsuite.WorkflowTestSuite{}
	env := ts.NewTestActivityEnvironment()
	a := makeVariantActivities(t, nil)
	env.RegisterActivity(a.BatchActivityWithProtobuf)

	input := &batchspb.BatchOperationInput{
		NamespaceId: variantBoundNSID,
		BatchType:   enumspb.BATCH_OPERATION_TYPE_SIGNAL,
		Request: &workflowservice.StartBatchOperationRequest{
			Namespace: "BOUND-NS",
			Operation: &workflowservice.StartBatchOperationRequest_SignalOperation{
				SignalOperation: &batchpb.BatchOperationSignal{Signal: "s"},
			},
			Executions: []*commonpb.WorkflowExecution{{WorkflowId: "w"}},
		},
	}

	_, err := env.ExecuteActivity(a.BatchActivityWithProtobuf, input)
	if err == nil {
		fmt.Println("VARIANT4_CASE_BYPASSED")
		fmt.Println("TEST_CONTINUING")
	}
	fmt.Printf("VARIANT4_CASE_BLOCKED: %v\n", err)
}

// TestVariant_ResetOperationType tests whether the Reset path specifically
// can be bypassed since the fix mentions it as a belt-and-suspenders change.
// This is variant attempt #5.
func TestVariant_ResetOperationType(t *testing.T) {
	ctrl := gomock.NewController(t)
	mockFE := workflowservicemock.NewMockWorkflowServiceClient(ctrl)

	var capturedNs string
	mockFE.EXPECT().
		GetWorkflowExecutionHistory(gomock.Any(), gomock.Any()).
		DoAndReturn(func(_ context.Context, req *workflowservice.GetWorkflowExecutionHistoryRequest, _ ...any) (*workflowservice.GetWorkflowExecutionHistoryResponse, error) {
			capturedNs = req.Namespace
			return &workflowservice.GetWorkflowExecutionHistoryResponse{
				History: &historypb.History{
					Events: []*historypb.HistoryEvent{
						{EventId: 1, EventType: enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED},
					},
				},
			}, nil
		}).AnyTimes()

	a := makeVariantActivities(t, mockFE)

	input := &batchspb.BatchOperationInput{
		NamespaceId: variantBoundNSID,
		BatchType:   enumspb.BATCH_OPERATION_TYPE_RESET,
		Request: &workflowservice.StartBatchOperationRequest{
			Namespace: variantOtherNSName,
			Operation: &workflowservice.StartBatchOperationRequest_ResetOperation{
				ResetOperation: &batchpb.BatchOperationReset{
					ResetType: enumspb.RESET_TYPE_FIRST_WORKFLOW_TASK,
				},
			},
			Executions: []*commonpb.WorkflowExecution{{WorkflowId: "w", RunId: "r"}},
		},
	}

	ts := testsuite.WorkflowTestSuite{}
	env := ts.NewTestActivityEnvironment()
	env.RegisterActivity(a.BatchActivityWithProtobuf)

	_, err := env.ExecuteActivity(a.BatchActivityWithProtobuf, input)
	if err == nil {
		fmt.Printf("VARIANT5_RESET_BYPASSED with capturedNs=%s\n", capturedNs)
		fmt.Println("TEST_CONTINUING")
	}
	fmt.Printf("VARIANT5_RESET_BLOCKED: %v capturedNs=%s\n", err, capturedNs)
}

// TestVariant_TaskProcessorDirectNamespace tests the direct startTaskProcessor
// path with a forged namespace parameter. Note: this is NOT a real bypass
// because startTaskProcessorProtobuf is only reachable through BatchActivityWithProtobuf,
// which performs checkNamespaceProtobuf. This test merely confirms that the
// helper function uses the parameter it receives (defense-in-depth observation).
// This is variant attempt #6.
func TestVariant_TaskProcessorDirectNamespace(t *testing.T) {
	ctrl := gomock.NewController(t)
	mockFE := workflowservicemock.NewMockWorkflowServiceClient(ctrl)

	var capturedNs string
	mockFE.EXPECT().
		SignalWorkflowExecution(gomock.Any(), gomock.Any()).
		DoAndReturn(func(_ context.Context, req *workflowservice.SignalWorkflowExecutionRequest, _ ...any) (*workflowservice.SignalWorkflowExecutionResponse, error) {
			capturedNs = req.Namespace
			return &workflowservice.SignalWorkflowExecutionResponse{}, nil
		})

	batchOp := &batchspb.BatchOperationInput{
		NamespaceId: variantBoundNSID,
		Request: &workflowservice.StartBatchOperationRequest{
			Namespace: variantBoundNSName,
			Operation: &workflowservice.StartBatchOperationRequest_SignalOperation{
				SignalOperation: &batchpb.BatchOperationSignal{Signal: "s"},
			},
		},
	}

	taskCh := make(chan task, 1)
	respCh := make(chan taskResponse, 1)
	taskCh <- task{
		execution: &commonpb.WorkflowExecution{WorkflowId: "w"},
		page:      &page{},
	}

	ctx, cancel := context.WithCancel(context.Background())
	done := make(chan struct{})
	go func() {
		defer close(done)
		startTaskProcessorProtobuf(ctx, batchOp, variantOtherNSName, taskCh, respCh,
			rate.NewLimiter(rate.Inf, 1), nil, mockFE,
			metrics.NoopMetricsHandler, log.NewTestLogger())
	}()

	<-respCh
	cancel()
	<-done

	// This logs what happens when validation is bypassed, but does NOT claim a bypass
	// because an attacker cannot call startTaskProcessorProtobuf directly.
	fmt.Printf("VARIANT6_TASKPROCESSOR_DIRECT_NS=%s (requires caller bypass)\n", capturedNs)
}

// TestVariant_EmptyNamespace tests whether an empty namespace string can bypass
// checkNamespaceProtobuf.
// This is variant attempt #7.
func TestVariant_EmptyNamespace(t *testing.T) {
	ts := testsuite.WorkflowTestSuite{}
	env := ts.NewTestActivityEnvironment()
	a := makeVariantActivities(t, nil)
	env.RegisterActivity(a.BatchActivityWithProtobuf)

	input := &batchspb.BatchOperationInput{
		NamespaceId: variantBoundNSID,
		BatchType:   enumspb.BATCH_OPERATION_TYPE_SIGNAL,
		Request: &workflowservice.StartBatchOperationRequest{
			Namespace: "",
			Operation: &workflowservice.StartBatchOperationRequest_SignalOperation{
				SignalOperation: &batchpb.BatchOperationSignal{Signal: "s"},
			},
			Executions: []*commonpb.WorkflowExecution{{WorkflowId: "w"}},
		},
	}

	_, err := env.ExecuteActivity(a.BatchActivityWithProtobuf, input)
	if err == nil {
		fmt.Println("VARIANT7_EMPTY_NS_BYPASSED")
		fmt.Println("TEST_CONTINUING")
	}
	fmt.Printf("VARIANT7_EMPTY_NS_BLOCKED: %v\n", err)
}

// TestVariant_OriginalCVEPath verifies the original CVE reproduction on the
// current checked-out version (should be blocked on fixed, allowed on vulnerable).
func TestVariant_OriginalCVEPath(t *testing.T) {
	ctrl := gomock.NewController(t)
	mockFE := workflowservicemock.NewMockWorkflowServiceClient(ctrl)

	var capturedNs string
	mockFE.EXPECT().
		SignalWorkflowExecution(gomock.Any(), gomock.Any()).
		DoAndReturn(func(_ context.Context, req *workflowservice.SignalWorkflowExecutionRequest, _ ...any) (*workflowservice.SignalWorkflowExecutionResponse, error) {
			capturedNs = req.Namespace
			return &workflowservice.SignalWorkflowExecutionResponse{}, nil
		}).AnyTimes()

	a := makeVariantActivities(t, mockFE)

	input := &batchspb.BatchOperationInput{
		NamespaceId: variantBoundNSID,
		BatchType:   enumspb.BATCH_OPERATION_TYPE_SIGNAL,
		Request: &workflowservice.StartBatchOperationRequest{
			Namespace: variantOtherNSName, // mismatched
			Operation: &workflowservice.StartBatchOperationRequest_SignalOperation{
				SignalOperation: &batchpb.BatchOperationSignal{Signal: "s"},
			},
			Executions: []*commonpb.WorkflowExecution{{WorkflowId: "w"}},
		},
	}

	ts := testsuite.WorkflowTestSuite{}
	env := ts.NewTestActivityEnvironment()
	env.RegisterActivity(a.BatchActivityWithProtobuf)

	_, err := env.ExecuteActivity(a.BatchActivityWithProtobuf, input)
	if err == nil {
		fmt.Printf("ORIGINAL_CVE_BYPASSED capturedNs=%s\n", capturedNs)
		fmt.Println("TEST_CONTINUING")
	}
	fmt.Printf("ORIGINAL_CVE_BLOCKED: %v capturedNs=%s\n", err, capturedNs)
}
