#include <stdio.h>
#include <string.h>
#include <wolfssl/options.h>
#include <wolfssl/wolfcrypt/settings.h>
#include <wolfssl/openssl/evp.h>

/* CVE-2026-5479 Variant Test Program
 * Tests multiple distinct API paths to trigger ChaCha20-Poly1305 tag verification.
 */

#define KEY_SIZE 32
#define IV_SIZE 12
#define TAG_SIZE 16
#define AAD_SIZE 13
#define PLAINTEXT_SIZE 32

static void print_hex(const char *label, const unsigned char *data, int len)
{
    printf("%s", label);
    for (int i = 0; i < len; i++) printf("%02x", data[i]);
    printf("\n");
}

int main(void)
{
    unsigned char key[KEY_SIZE] = {
        0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,
        0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,
        0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,
        0x18,0x19,0x1a,0x1b,0x1c,0x1d,0x1e,0x1f
    };
    unsigned char iv[IV_SIZE] = {
        0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
        0x00,0x00,0x00,0x01
    };
    unsigned char aad[AAD_SIZE] = "hello world!!";
    unsigned char plaintext[PLAINTEXT_SIZE] = "The quick brown fox jumps over";
    unsigned char ciphertext[PLAINTEXT_SIZE];
    unsigned char decrypted[PLAINTEXT_SIZE];
    unsigned char tag[TAG_SIZE];
    unsigned char badTag[TAG_SIZE];
    int outLen, ret;
    EVP_CIPHER_CTX *ctx = NULL;
    int vuln_detected = 0;

    printf("=== CVE-2026-5479 Variant Test ===\n");
    printf("Testing multiple distinct ChaCha20-Poly1305 EVP paths\n\n");

    /* Encrypt plaintext */
    ctx = EVP_CIPHER_CTX_new();
    if (!ctx) { fprintf(stderr, "EVP_CIPHER_CTX_new failed\n"); return 1; }

    ret = EVP_EncryptInit_ex(ctx, EVP_chacha20_poly1305(), NULL, NULL, NULL);
    if (ret != 1) { fprintf(stderr, "EVP_EncryptInit_ex failed\n"); return 1; }
    EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN, IV_SIZE, NULL);
    ret = EVP_EncryptInit_ex(ctx, NULL, NULL, key, iv);
    if (ret != 1) { fprintf(stderr, "EVP_EncryptInit_ex(key,iv) failed\n"); return 1; }
    EVP_EncryptUpdate(ctx, NULL, &outLen, aad, AAD_SIZE);
    ret = EVP_EncryptUpdate(ctx, ciphertext, &outLen, plaintext, PLAINTEXT_SIZE);
    if (ret != 1) { fprintf(stderr, "EVP_EncryptUpdate failed\n"); return 1; }
    ret = EVP_EncryptFinal_ex(ctx, ciphertext + outLen, &outLen);
    if (ret != 1) { fprintf(stderr, "EVP_EncryptFinal_ex failed\n"); return 1; }
    ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, TAG_SIZE, tag);
    if (ret != 1) { fprintf(stderr, "EVP_CTRL_AEAD_GET_TAG failed\n"); return 1; }
    EVP_CIPHER_CTX_free(ctx);

    printf("Encryption succeeded.\n");
    print_hex("Ciphertext: ", ciphertext, PLAINTEXT_SIZE);
    print_hex("Tag:        ", tag, TAG_SIZE);
    printf("\n");

    /* Variant 1: EVP_CipherInit(enc=0) + EVP_CipherUpdate + EVP_CipherFinal with bad tag */
    memset(badTag, 0xAB, TAG_SIZE);
    memset(decrypted, 0, PLAINTEXT_SIZE);
    ctx = EVP_CIPHER_CTX_new();
    ret = EVP_CipherInit(ctx, EVP_chacha20_poly1305(), key, iv, 0);
    if (ret != 1) { fprintf(stderr, "EVP_CipherInit failed\n"); return 1; }
    EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN, IV_SIZE, NULL);
    EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, TAG_SIZE, badTag);
    ret = EVP_CipherUpdate(ctx, NULL, &outLen, aad, AAD_SIZE);
    if (ret != 1) { fprintf(stderr, "Variant 1: CipherUpdate AAD failed\n"); return 1; }
    ret = EVP_CipherUpdate(ctx, decrypted, &outLen, ciphertext, PLAINTEXT_SIZE);
    if (ret != 1) { fprintf(stderr, "Variant 1: CipherUpdate data failed\n"); return 1; }
    ret = EVP_CipherFinal(ctx, decrypted, &outLen);
    EVP_CIPHER_CTX_free(ctx);

    printf("--- Variant 1: EVP_CipherInit path with bad tag ---\n");
    printf("EVP_CipherFinal ret=%d\n", ret);
    printf("Plaintext match: %s\n", (memcmp(decrypted, plaintext, PLAINTEXT_SIZE) == 0) ? "YES" : "NO");
    printf("VULNERABLE: %s\n\n", (ret == 1) ? "YES" : "NO");
    if (ret == 1) vuln_detected = 1;

    /* Variant 2: EVP_DecryptInit_ex + modified AAD + correct tag */
    memcpy(badTag, tag, TAG_SIZE);
    memset(decrypted, 0, PLAINTEXT_SIZE);
    ctx = EVP_CIPHER_CTX_new();
    ret = EVP_DecryptInit_ex(ctx, EVP_chacha20_poly1305(), NULL, NULL, NULL);
    if (ret != 1) { fprintf(stderr, "Variant 2: DecryptInit failed\n"); return 1; }
    EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN, IV_SIZE, NULL);
    EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, TAG_SIZE, badTag);
    ret = EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv);
    if (ret != 1) { fprintf(stderr, "Variant 2: DecryptInit(key,iv) failed\n"); return 1; }
    unsigned char modifiedAad[AAD_SIZE];
    memcpy(modifiedAad, aad, AAD_SIZE);
    modifiedAad[0] ^= 0xFF;
    ret = EVP_DecryptUpdate(ctx, NULL, &outLen, modifiedAad, AAD_SIZE);
    if (ret != 1) { fprintf(stderr, "Variant 2: DecryptUpdate AAD failed\n"); return 1; }
    ret = EVP_DecryptUpdate(ctx, decrypted, &outLen, ciphertext, PLAINTEXT_SIZE);
    if (ret != 1) { fprintf(stderr, "Variant 2: DecryptUpdate data failed\n"); return 1; }
    ret = EVP_DecryptFinal_ex(ctx, decrypted, &outLen);
    EVP_CIPHER_CTX_free(ctx);

    printf("--- Variant 2: Modified AAD with correct tag ---\n");
    printf("EVP_DecryptFinal_ex ret=%d\n", ret);
    printf("Plaintext match: %s\n", (memcmp(decrypted, plaintext, PLAINTEXT_SIZE) == 0) ? "YES" : "NO");
    printf("VULNERABLE: %s\n\n", (ret == 1) ? "YES" : "NO");
    if (ret == 1) vuln_detected = 1;

    /* Variant 3: EVP_DecryptInit_ex + flipped ciphertext byte + correct tag */
    memcpy(badTag, tag, TAG_SIZE);
    memset(decrypted, 0, PLAINTEXT_SIZE);
    ctx = EVP_CIPHER_CTX_new();
    ret = EVP_DecryptInit_ex(ctx, EVP_chacha20_poly1305(), NULL, NULL, NULL);
    if (ret != 1) { fprintf(stderr, "Variant 3: DecryptInit failed\n"); return 1; }
    EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN, IV_SIZE, NULL);
    EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, TAG_SIZE, badTag);
    ret = EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv);
    if (ret != 1) { fprintf(stderr, "Variant 3: DecryptInit(key,iv) failed\n"); return 1; }
    ret = EVP_DecryptUpdate(ctx, NULL, &outLen, aad, AAD_SIZE);
    if (ret != 1) { fprintf(stderr, "Variant 3: DecryptUpdate AAD failed\n"); return 1; }
    unsigned char modifiedCipher[PLAINTEXT_SIZE];
    memcpy(modifiedCipher, ciphertext, PLAINTEXT_SIZE);
    modifiedCipher[5] ^= 0xFF;
    ret = EVP_DecryptUpdate(ctx, decrypted, &outLen, modifiedCipher, PLAINTEXT_SIZE);
    if (ret != 1) { fprintf(stderr, "Variant 3: DecryptUpdate data failed\n"); return 1; }
    ret = EVP_DecryptFinal_ex(ctx, decrypted, &outLen);
    EVP_CIPHER_CTX_free(ctx);

    printf("--- Variant 3: Flipped ciphertext byte + correct tag ---\n");
    printf("EVP_DecryptFinal_ex ret=%d\n", ret);
    printf("Plaintext match: %s\n", (memcmp(decrypted, plaintext, PLAINTEXT_SIZE) == 0) ? "YES" : "NO");
    printf("VULNERABLE: %s\n\n", (ret == 1) ? "YES" : "NO");
    if (ret == 1) vuln_detected = 1;

    return vuln_detected ? 0 : 1;
}
