/*
 * Copyright (c) 2019 Yubico AB. All rights reserved.
 * Use of this source code is governed by a BSD-style
 * license that can be found in the LICENSE file.
 */

#include <assert.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#include "mutator_aux.h"
#include "wiredata_fido2.h"
#include "wiredata_u2f.h"
#include "dummy.h"

#include "fido.h"
#include "fido/es256.h"
#include "fido/rs256.h"
#include "fido/eddsa.h"

#include "../openbsd-compat/openbsd-compat.h"

#define TAG_U2F		0x01
#define TAG_TYPE	0x02
#define TAG_CDH		0x03
#define TAG_RP_ID	0x04
#define TAG_EXT		0x05
#define TAG_SEED	0x06
#define TAG_UP		0x07
#define TAG_UV		0x08
#define TAG_WIRE_DATA	0x09
#define TAG_CRED_COUNT	0x0a
#define TAG_CRED	0x0b
#define TAG_ES256	0x0c
#define TAG_RS256	0x0d
#define TAG_PIN		0x0e
#define TAG_EDDSA	0x0f

/* Parameter set defining a FIDO2 get assertion operation. */
struct param {
	char		pin[MAXSTR];
	char		rp_id[MAXSTR];
	int		ext;
	int		seed;
	struct blob	cdh;
	struct blob	cred;
	struct blob	es256;
	struct blob	rs256;
	struct blob	eddsa;
	struct blob	wire_data;
	uint8_t		cred_count;
	uint8_t		type;
	uint8_t		u2f;
	uint8_t		up;
	uint8_t		uv;
};

/*
 * Collection of HID reports from an authenticator issued with a FIDO2
 * get assertion using the example parameters above.
 */
static const uint8_t dummy_wire_data_fido[] = {
	WIREDATA_CTAP_INIT,
	WIREDATA_CTAP_CBOR_INFO,
	WIREDATA_CTAP_CBOR_AUTHKEY,
	WIREDATA_CTAP_CBOR_PINTOKEN,
	WIREDATA_CTAP_CBOR_ASSERT,
};

/*
 * Collection of HID reports from an authenticator issued with a U2F
 * authentication using the example parameters above.
 */
static const uint8_t dummy_wire_data_u2f[] = {
	WIREDATA_CTAP_INIT,
	WIREDATA_CTAP_U2F_6985,
	WIREDATA_CTAP_U2F_6985,
	WIREDATA_CTAP_U2F_6985,
	WIREDATA_CTAP_U2F_6985,
	WIREDATA_CTAP_U2F_AUTH,
};

int    LLVMFuzzerTestOneInput(const uint8_t *, size_t);
size_t LLVMFuzzerCustomMutator(uint8_t *, size_t, size_t, unsigned int);

static int
unpack(const uint8_t *ptr, size_t len, struct param *p) NO_MSAN
{
	uint8_t **pp = (void *)&ptr;

	if (unpack_byte(TAG_UV, pp, &len, &p->uv) < 0 ||
	    unpack_byte(TAG_UP, pp, &len, &p->up) < 0 ||
	    unpack_byte(TAG_U2F, pp, &len, &p->u2f) < 0 ||
	    unpack_byte(TAG_TYPE, pp, &len, &p->type) < 0 ||
	    unpack_byte(TAG_CRED_COUNT, pp, &len, &p->cred_count) < 0 ||
	    unpack_int(TAG_EXT, pp, &len, &p->ext) < 0 ||
	    unpack_int(TAG_SEED, pp, &len, &p->seed) < 0 ||
	    unpack_string(TAG_RP_ID, pp, &len, p->rp_id) < 0 ||
	    unpack_string(TAG_PIN, pp, &len, p->pin) < 0 ||
	    unpack_blob(TAG_WIRE_DATA, pp, &len, &p->wire_data) < 0 ||
	    unpack_blob(TAG_RS256, pp, &len, &p->rs256) < 0 ||
	    unpack_blob(TAG_ES256, pp, &len, &p->es256) < 0 ||
	    unpack_blob(TAG_EDDSA, pp, &len, &p->eddsa) < 0 ||
	    unpack_blob(TAG_CRED, pp, &len, &p->cred) < 0 ||
	    unpack_blob(TAG_CDH, pp, &len, &p->cdh) < 0)
		return (-1);

	return (0);
}

static size_t
pack(uint8_t *ptr, size_t len, const struct param *p)
{
	const size_t max = len;

	if (pack_byte(TAG_UV, &ptr, &len, p->uv) < 0 ||
	    pack_byte(TAG_UP, &ptr, &len, p->up) < 0 ||
	    pack_byte(TAG_U2F, &ptr, &len, p->u2f) < 0 ||
	    pack_byte(TAG_TYPE, &ptr, &len, p->type) < 0 ||
	    pack_byte(TAG_CRED_COUNT, &ptr, &len, p->cred_count) < 0 ||
	    pack_int(TAG_EXT, &ptr, &len, p->ext) < 0 ||
	    pack_int(TAG_SEED, &ptr, &len, p->seed) < 0 ||
	    pack_string(TAG_RP_ID, &ptr, &len, p->rp_id) < 0 ||
	    pack_string(TAG_PIN, &ptr, &len, p->pin) < 0 ||
	    pack_blob(TAG_WIRE_DATA, &ptr, &len, &p->wire_data) < 0 ||
	    pack_blob(TAG_RS256, &ptr, &len, &p->rs256) < 0 ||
	    pack_blob(TAG_ES256, &ptr, &len, &p->es256) < 0 ||
	    pack_blob(TAG_EDDSA, &ptr, &len, &p->eddsa) < 0 ||
	    pack_blob(TAG_CRED, &ptr, &len, &p->cred) < 0 ||
	    pack_blob(TAG_CDH, &ptr, &len, &p->cdh) < 0)
		return (0);

	return (max - len);
}

static size_t
input_len(int max)
{
	return (5 * len_byte() + 2 * len_int() + 2 * len_string(max) +
	    6 * len_blob(max));
}

static void
get_assert(fido_assert_t *assert, uint8_t u2f, const struct blob *cdh,
    const char *rp_id, int ext, uint8_t up, uint8_t uv, const char *pin,
    uint8_t cred_count, struct blob *cred)
{
	fido_dev_t	*dev;
	fido_dev_io_t	 io;

	memset(&io, 0, sizeof(io));

	io.open = dev_open;
	io.close = dev_close;
	io.read = dev_read;
	io.write = dev_write;

	if ((dev = fido_dev_new()) == NULL || fido_dev_set_io_functions(dev,
	    &io) != FIDO_OK || fido_dev_open(dev, "nodev") != FIDO_OK) {
		fido_dev_free(&dev);
		return;
	}

	if (u2f & 1)
		fido_dev_force_u2f(dev);

	for (uint8_t i = 0; i < cred_count; i++)
		fido_assert_allow_cred(assert, cred->body, cred->len);

	fido_assert_set_clientdata_hash(assert, cdh->body, cdh->len);
	fido_assert_set_rp(assert, rp_id);
	if (ext & 1)
		fido_assert_set_extensions(assert, FIDO_EXT_HMAC_SECRET);
	if (up & 1)
		fido_assert_set_up(assert, FIDO_OPT_TRUE);
	if (uv & 1)
		fido_assert_set_uv(assert, FIDO_OPT_TRUE);
	/* XXX reuse cred as hmac salt to keep struct param small */
	fido_assert_set_hmac_salt(assert, cred->body, cred->len);

	fido_dev_get_assert(dev, assert, u2f & 1 ? NULL : pin);

	fido_dev_cancel(dev);
	fido_dev_close(dev);
	fido_dev_free(&dev);
}

static void
verify_assert(int type, const unsigned char *cdh_ptr, size_t cdh_len,
    const char *rp_id, const unsigned char *authdata_ptr, size_t authdata_len,
    const unsigned char *sig_ptr, size_t sig_len, uint8_t up, uint8_t uv,
    int ext, void *pk)
{
	fido_assert_t	*assert = NULL;

	if ((assert = fido_assert_new()) == NULL)
		return;

	fido_assert_set_clientdata_hash(assert, cdh_ptr, cdh_len);
	fido_assert_set_rp(assert, rp_id);
	fido_assert_set_count(assert, 1);
	if (fido_assert_set_authdata(assert, 0, authdata_ptr,
	    authdata_len) != FIDO_OK) {
		fido_assert_set_authdata_raw(assert, 0, authdata_ptr,
		    authdata_len);
	}
	fido_assert_set_extensions(assert, ext);
	if (up & 1) fido_assert_set_up(assert, FIDO_OPT_TRUE);
	if (uv & 1) fido_assert_set_uv(assert, FIDO_OPT_TRUE);
	fido_assert_set_sig(assert, 0, sig_ptr, sig_len);
	fido_assert_verify(assert, 0, type, pk);

	fido_assert_free(&assert);
}

/*
 * Do a dummy conversion to exercise rs256_pk_from_RSA().
 */
static void
rs256_convert(const rs256_pk_t *k)
{
	EVP_PKEY *pkey = NULL;
	rs256_pk_t *pk = NULL;
	RSA *rsa = NULL;
	volatile int r;

	if ((pkey = rs256_pk_to_EVP_PKEY(k)) == NULL ||
	    (pk = rs256_pk_new()) == NULL ||
	    (rsa = EVP_PKEY_get0_RSA(pkey)) == NULL)
		goto out;

	r = rs256_pk_from_RSA(pk, rsa);
out:
	if (pk)
		rs256_pk_free(&pk);
	if (pkey)
		EVP_PKEY_free(pkey);
}

/*
 * Do a dummy conversion to exercise eddsa_pk_from_EVP_PKEY().
 */
static void
eddsa_convert(const eddsa_pk_t *k)
{
	EVP_PKEY *pkey = NULL;
	eddsa_pk_t *pk = NULL;
	volatile int r;

	if ((pkey = eddsa_pk_to_EVP_PKEY(k)) == NULL ||
	    (pk = eddsa_pk_new()) == NULL)
		goto out;

	r = eddsa_pk_from_EVP_PKEY(pk, pkey);
out:
	if (pk)
		eddsa_pk_free(&pk);
	if (pkey)
		EVP_PKEY_free(pkey);
}

int
LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)
{
	struct param	 p;
	fido_assert_t	*assert = NULL;
	es256_pk_t	*es256_pk = NULL;
	rs256_pk_t	*rs256_pk = NULL;
	eddsa_pk_t	*eddsa_pk = NULL;
	uint8_t		 flags;
	uint32_t	 sigcount;
	int		 cose_alg = 0;
	void		*pk;

	memset(&p, 0, sizeof(p));

	if (size < input_len(GETLEN_MIN) || size > input_len(GETLEN_MAX) ||
	    unpack(data, size, &p) < 0)
		return (0);

	prng_init((unsigned int)p.seed);

	fido_init(FIDO_DEBUG);
	fido_set_log_handler(consume_str);

	switch (p.type & 3) {
	case 0:
		cose_alg = COSE_ES256;

		if ((es256_pk = es256_pk_new()) == NULL)
			return (0);

		es256_pk_from_ptr(es256_pk, p.es256.body, p.es256.len);
		pk = es256_pk;

		break;
	case 1:
		cose_alg = COSE_RS256;

		if ((rs256_pk = rs256_pk_new()) == NULL)
			return (0);

		rs256_pk_from_ptr(rs256_pk, p.rs256.body, p.rs256.len);
		pk = rs256_pk;

		rs256_convert(pk);

		break;
	default:
		cose_alg = COSE_EDDSA;

		if ((eddsa_pk = eddsa_pk_new()) == NULL)
			return (0);

		eddsa_pk_from_ptr(eddsa_pk, p.eddsa.body, p.eddsa.len);
		pk = eddsa_pk;

		eddsa_convert(pk);

		break;
	}

	if ((assert = fido_assert_new()) == NULL)
		goto out;

	set_wire_data(p.wire_data.body, p.wire_data.len);

	get_assert(assert, p.u2f, &p.cdh, p.rp_id, p.ext, p.up, p.uv, p.pin,
	    p.cred_count, &p.cred);

	/* XXX +1 on purpose */
	for (size_t i = 0; i <= fido_assert_count(assert); i++) {
		verify_assert(cose_alg,
		    fido_assert_clientdata_hash_ptr(assert),
		    fido_assert_clientdata_hash_len(assert),
		    fido_assert_rp_id(assert),
		    fido_assert_authdata_ptr(assert, i),
		    fido_assert_authdata_len(assert, i),
		    fido_assert_sig_ptr(assert, i),
		    fido_assert_sig_len(assert, i), p.up, p.uv, p.ext, pk);
		consume(fido_assert_id_ptr(assert, i),
		    fido_assert_id_len(assert, i));
		consume(fido_assert_user_id_ptr(assert, i),
		    fido_assert_user_id_len(assert, i));
		consume(fido_assert_hmac_secret_ptr(assert, i),
		    fido_assert_hmac_secret_len(assert, i));
		consume(fido_assert_user_icon(assert, i),
		    xstrlen(fido_assert_user_icon(assert, i)));
		consume(fido_assert_user_name(assert, i),
		    xstrlen(fido_assert_user_name(assert, i)));
		consume(fido_assert_user_display_name(assert, i),
		    xstrlen(fido_assert_user_display_name(assert, i)));
		flags = fido_assert_flags(assert, i);
		consume(&flags, sizeof(flags));
		sigcount = fido_assert_sigcount(assert, i);
		consume(&sigcount, sizeof(sigcount));
	}

out:
	es256_pk_free(&es256_pk);
	rs256_pk_free(&rs256_pk);
	eddsa_pk_free(&eddsa_pk);

	fido_assert_free(&assert);

	return (0);
}

static size_t
pack_dummy(uint8_t *ptr, size_t len)
{
	struct param	dummy;
	uint8_t		blob[16384];
	size_t		blob_len;

	memset(&dummy, 0, sizeof(dummy));

	dummy.type = 1; /* rsa */
	dummy.ext = FIDO_EXT_HMAC_SECRET;

	strlcpy(dummy.pin, dummy_pin, sizeof(dummy.pin));
	strlcpy(dummy.rp_id, dummy_rp_id, sizeof(dummy.rp_id));

	dummy.cred.len = sizeof(dummy_cdh); /* XXX */
	dummy.cdh.len = sizeof(dummy_cdh);
	dummy.es256.len = sizeof(dummy_es256);
	dummy.rs256.len = sizeof(dummy_rs256);
	dummy.eddsa.len = sizeof(dummy_eddsa);
	dummy.wire_data.len = sizeof(dummy_wire_data_fido);

	memcpy(&dummy.cred.body, &dummy_cdh, dummy.cred.len); /* XXX */
	memcpy(&dummy.cdh.body, &dummy_cdh, dummy.cdh.len);
	memcpy(&dummy.wire_data.body, &dummy_wire_data_fido,
	    dummy.wire_data.len);
	memcpy(&dummy.es256.body, &dummy_es256, dummy.es256.len);
	memcpy(&dummy.rs256.body, &dummy_rs256, dummy.rs256.len);
	memcpy(&dummy.eddsa.body, &dummy_eddsa, dummy.eddsa.len);

	blob_len = pack(blob, sizeof(blob), &dummy);
	assert(blob_len != 0);

	if (blob_len > len) {
		memcpy(ptr, blob, len);
		return (len);
	}

	memcpy(ptr, blob, blob_len);

	return (blob_len);
}

size_t
LLVMFuzzerCustomMutator(uint8_t *data, size_t size, size_t maxsize,
    unsigned int seed) NO_MSAN
{
	struct param	p;
	uint8_t		blob[16384];
	size_t		blob_len;

	(void)seed;

	memset(&p, 0, sizeof(p));

	if (unpack(data, size, &p) < 0)
		return (pack_dummy(data, maxsize));

	mutate_byte(&p.uv);
	mutate_byte(&p.up);
	mutate_byte(&p.u2f);
	mutate_byte(&p.type);
	mutate_byte(&p.cred_count);

	mutate_int(&p.ext);
	p.seed = (int)seed;

	if (p.u2f & 1) {
		p.wire_data.len = sizeof(dummy_wire_data_u2f);
		memcpy(&p.wire_data.body, &dummy_wire_data_u2f,
		    p.wire_data.len);
	} else {
		p.wire_data.len = sizeof(dummy_wire_data_fido);
		memcpy(&p.wire_data.body, &dummy_wire_data_fido,
		    p.wire_data.len);
	}

	mutate_blob(&p.wire_data);
	mutate_blob(&p.rs256);
	mutate_blob(&p.es256);
	mutate_blob(&p.eddsa);
	mutate_blob(&p.cred);
	mutate_blob(&p.cdh);

	mutate_string(p.rp_id);
	mutate_string(p.pin);

	blob_len = pack(blob, sizeof(blob), &p);

	if (blob_len == 0 || blob_len > maxsize)
		return (0);

	memcpy(data, blob, blob_len);

	return (blob_len);
}
