#include "e_param.h"
#include "param.h"

struct e_pairing_data_s {
    field_t Fq;
    curve_t Eq;
    mpz_t tateexp;
    fieldmap mapbase;
};
typedef struct e_pairing_data_s e_pairing_data_t[1];
typedef struct e_pairing_data_s *e_pairing_data_ptr;

void e_param_init(e_param_t ep)
{
    mpz_init(ep->q);
    mpz_init(ep->r);
    mpz_init(ep->h);
    mpz_init(ep->a);
    mpz_init(ep->b);
}

void e_param_clear(e_param_t ep)
{
    mpz_clear(ep->q);
    mpz_clear(ep->r);
    mpz_clear(ep->h);
    mpz_clear(ep->a);
    mpz_clear(ep->b);
}

void e_param_gen(e_param_t p, int rbits, int qbits)
{
    //28 takes 5-bits to represent
    int hbits = (qbits - 5) / 2 - rbits;
    mpz_ptr q = p->q;
    mpz_ptr r = p->r;
    mpz_ptr h = p->h;
    mpz_t n;
    field_t Fq;
    curve_t cc;
    element_t j;
    int found = 0;

    p->exp2 = rbits;
    mpz_init(n);

    do {
	int i;
	mpz_set_ui(r, 0);
	mpz_setbit(r, p->exp2);

	//use q as a temp variable
	mpz_set_ui(q, 0);
	do {
	    p->exp1 = rand() % rbits;
	} while (!p->exp1);
	mpz_setbit(q, p->exp1);
	if (rand() % 2) {
	    p->sign1 = 1;
	    mpz_add(r, r, q);
	} else {
	    p->sign1 = -1;
	    mpz_sub(r, r, q);
	}
	if (rand() % 2) {
	    p->sign0 = 1;
	    mpz_add_ui(r, r, 1);
	} else {
	    p->sign0 = -1;
	    mpz_sub_ui(r, r, 1);
	}
	if (!mpz_probab_prime_p(r, 10)) continue;
	for (i=0; i<10; i++) {
	    //use q as a temp variable
	    mpz_set_ui(q, 0);
	    mpz_setbit(q, hbits);
	    pbc_mpz_random(h, q);
	    mpz_mul(h, h, h);
	    mpz_mul_ui(h, h, 28);
	    //finally q takes the value it should
	    mpz_mul(n, r, r);
	    mpz_mul(n, n, h);
	    mpz_add_ui(q, n, 1);
	    if (mpz_probab_prime_p(q, 10)) {
		found = 1;
		break;
	    }
	}
    } while (!found);
    /*
    do {
	mpz_set_ui(r, 0);
	mpz_setbit(r, rbits);
	pbc_mpz_random(r, r);
	mpz_nextprime(r, r);
	mpz_mul(n, r, r);
	mpz_mul_ui(n, n, 28);
	mpz_add_ui(q, n, 1);
    } while (!mpz_probab_prime_p(q, 10));
    */

    field_init_fp(Fq, q);
    element_init(j, Fq);
    element_set_si(j, -3375);
    curve_init_cc_j(cc, j);
    element_clear(j);

    //we may need to twist it however
    {
	point_t P;

	//pick a random point P and see if it has the right order
	point_init(P, cc);
	point_random(P);
	point_mul(P, n, P);
	//if not, we twist the curve
	if (!point_is_inf(P)) {
	    twist_curve(cc);
	}
	point_clear(P);
    }
    mpz_set(p->a, ((common_curve_ptr) cc->data)->a->data);
    mpz_set(p->b, ((common_curve_ptr) cc->data)->b->data);

    mpz_clear(n);
}

void e_param_out_str(FILE *stream, e_param_ptr p)
{
    param_out_type(stream, "e");
    param_out_mpz(stream, "q", p->q);
    param_out_mpz(stream, "r", p->r);
    param_out_mpz(stream, "h", p->h);
    param_out_mpz(stream, "a", p->a);
    param_out_mpz(stream, "b", p->b);
    param_out_int(stream, "exp2", p->exp2);
    param_out_int(stream, "exp1", p->exp1);
    param_out_int(stream, "sign1", p->sign1);
    param_out_int(stream, "sign0", p->sign0);
}

void e_param_inp_str(e_param_ptr p, FILE *stream)
{
    symtab_t tab;

    symtab_init(tab);
    param_read(tab, stream);

    lookup_mpz(p->q, tab, "q");
    lookup_mpz(p->r, tab, "r");
    lookup_mpz(p->h, tab, "h");
    lookup_mpz(p->a, tab, "a");
    lookup_mpz(p->b, tab, "b");
    p->exp2 = lookup_int(tab, "exp2");
    p->exp1 = lookup_int(tab, "exp1");
    p->sign1 = lookup_int(tab, "sign1");
    p->sign0 = lookup_int(tab, "sign0");

    param_clear_tab(tab);
    symtab_clear(tab);
}

//assumes P is in the base field, Q in some field extension
static void cc_miller(element_t res, mpz_t q, point_t P,
	element_ptr Qx, element_ptr Qy, fieldmap mapbase)
{
    //collate divisions
    int m;
    element_t v, vd;
    point_t Z;
    element_t a, b, c;
    common_curve_ptr cc = P->curve->data;
    element_t t0;
    element_t e0, e1;

    void do_vertical(element_t e)
    {
	if (point_is_inf(Z)) {
	    return;
	}
	mapbase(e0, Z->x);
	element_sub(e0, Qx, e0);
	element_mul(e, e, e0);
    }

    void do_tangent(element_t e)
    {
	//a = -slope_tangent(Z.x, Z.y);
	//b = 1;
	//c = -(Z.y + a * Z.x);
	//but we multiply by 2*Z.y to avoid division

	//a = -Zx * (Zx + Zx + Zx + twicea_2) - a_4;
	//Common curves: a2 = 0 (and cc->a is a_4), so
	//a = -(Zx (Zx + Zx + Zx) + cc->a)
	//b = 2 * Zy
	//c = -(2 Zy^2 + a Zx);
	element_ptr Zx = Z->x;
	element_ptr Zy = Z->y;

	if (point_is_inf(Z)) {
	    return;
	}

	if (element_is0(Zy)) {
	    //order 2
	    do_vertical(e);
	    return;
	}

	element_add(a, Zx, Zx);
	element_add(a, a, Zx);
	element_mul(a, a, Zx);
	element_add(a, a, cc->a);
	element_neg(a, a);

	element_add(b, Zy, Zy);

	element_mul(t0, b, Zy);
	element_mul(c, a, Zx);
	element_add(c, c, t0);
	element_neg(c, c);

	//TODO: use poly_mul_constant?
	mapbase(e0, a);
	element_mul(e0, e0, Qx);
	mapbase(e1, b);
	element_mul(e1, e1, Qy);
	element_add(e0, e0, e1);
	mapbase(e1, c);
	element_add(e0, e0, e1);
	element_mul(e, e, e0);
    }

    void do_line(element_ptr e)
    {
	//a = -(B.y - A.y) / (B.x - A.x);
	//b = 1;
	//c = -(A.y + a * A.x);
	//but we'll multiply by B.x - A.x to avoid division

	element_ptr Ax = Z->x;
	element_ptr Ay = Z->y;
	element_ptr Bx = P->x;
	element_ptr By = P->y;

	//we assume B is never O
	if (point_is_inf(Z)) {
	    do_vertical(e);
	    return;
	}

	if (!element_cmp(Ax, Bx)) {
	    if (!element_cmp(Ay, By)) {
		do_tangent(e);
		return;
	    }
	    do_vertical(e);
	    return;
	}

	element_sub(b, Bx, Ax);
	element_sub(a, Ay, By);
	element_mul(t0, b, Ay);
	element_mul(c, a, Ax);
	element_add(c, c, t0);
	element_neg(c, c);

	mapbase(e0, a);
	element_mul(e0, e0, Qx);
	mapbase(e1, b);
	element_mul(e1, e1, Qy);
	element_add(e0, e0, e1);
	mapbase(e1, c);
	element_add(e0, e0, e1);
	element_mul(e, e, e0);
    }

    element_init(a, P->curve->field);
    element_init(b, P->curve->field);
    element_init(c, P->curve->field);
    element_init(t0, P->curve->field);
    element_init(e0, res->field);
    element_init(e1, res->field);

    element_init(v, res->field);
    element_init(vd, res->field);
    point_init(Z, P->curve);

    point_set(Z, P);

    element_set1(v);
    element_set1(vd);
    m = mpz_sizeinbase(q, 2) - 2;

    while(m >= 0) {
	element_mul(v, v, v);
	element_mul(vd, vd, vd);
	do_tangent(v);
	point_double(Z, Z);
	do_vertical(vd);
	if (mpz_tstbit(q, m)) {
	    do_line(v);
	    point_add(Z, Z, P);
	    do_vertical(vd);
	}
	m--;
    }

    element_invert(vd, vd);
    element_mul(res, v, vd);

    element_clear(v);
    element_clear(vd);
    point_clear(Z);
    element_clear(a);
    element_clear(b);
    element_clear(c);
    element_clear(t0);
    element_clear(e0);
    element_clear(e1);
}

static void e_pairing(element_ptr out, element_ptr in1, element_ptr in2,
	pairing_t pairing)
{
    e_pairing_data_ptr p = pairing->data;
    point_ptr Q = in2->data;
    element_t e0;
    point_t R, QR;
    point_init(R, p->Eq);
    point_init(QR, p->Eq);
    point_random(R);
    element_init(e0, out->field);
    point_add(QR, Q, R);
    cc_miller(out, pairing->r, in1->data, QR->x, QR->y, p->mapbase);
    cc_miller(e0, pairing->r, in1->data, R->x, R->y, p->mapbase);
    element_invert(e0, e0);
    element_mul(out, out, e0);
    element_pow(out, out, p->tateexp);
    point_clear(R);
    point_clear(QR);
}

void pairing_init_e_param(pairing_t pairing, e_param_t param)
{
    e_pairing_data_ptr p;
    element_t a, b;

    mpz_init(pairing->r);
    mpz_set(pairing->r, param->r);
    field_init_fp(pairing->Zr, pairing->r);
    pairing->map = e_pairing;
    pairing->is_almost_coddh = generic_is_almost_coddh;

    p =	pairing->data = malloc(sizeof(e_pairing_data_t));
    field_init_fp(p->Fq, param->q);
    element_init(a, p->Fq);
    element_init(b, p->Fq);
    element_set_mpz(a, param->a);
    element_set_mpz(b, param->b);
    curve_init_cc_ab(p->Eq, a, b);

    mpz_init(p->tateexp);
    mpz_sub_ui(p->tateexp, p->Fq->order, 1);
    mpz_divexact(p->tateexp, p->tateexp, pairing->r);

    p->mapbase = p->Fq->set;

    pairing->G1 = malloc(sizeof(field_t));
    pairing->G2 = pairing->G1;

    field_init_curve_group(pairing->G1, p->Eq, param->h);
    pairing->GT = p->Fq;
    //pairing->phi = trace;

    element_clear(a);
    element_clear(b);
}
