#include "pairing.h"
#include "fieldi.h"

void cc_miller_vertical(element_t v, point_t Q, point_t Z)
{
    element_t t0;

    if (point_is_inf(Z)) {
	return;
    }
    element_init(t0, v->field);
    element_sub(t0, Q->x, Z->x);
    element_mul(v, v, t0);
    element_clear(t0);

    return;
}

void cc_miller_tangent(element_t v, point_t Q, point_t Z)
{
    //a = -slope_tangent(Z.x, Z.y);
    //b = 1;
    //c = -(Z.y + a * Z.x);
    //but we multiply by 2*Z.y to avoid division

    //a = -Zx * (Zx + Zx + Zx + twicea_2) - a_4;
    //Common curves: a2 = 0 (and cc->a is a_4), so
    //a = -(Zx (Zx + Zx + Zx) + cc->a)
    //b = 2 * Zy
    //c = -(2 Zy^2 + a Zx);
    element_t a, b, c;
    common_curve_ptr cc = Z->curve->data;
    element_t t0;

    if (point_is_inf(Z)) {
	return;
    }

    if (element_is0(Z->y)) {
	//order 2
	cc_miller_vertical(v, Q, Z);
	return;
    }

    element_init(a, v->field);
    element_init(b, v->field);
    element_init(c, v->field);
    element_init(t0, v->field);

    element_add(a, Z->x, Z->x);
    element_add(a, a, Z->x);
    element_mul(a, a, Z->x);
    element_add(a, a, cc->a);
    element_neg(a, a);

    element_add(b, Z->y, Z->y);

    element_mul(t0, b, Z->y);
    element_mul(c, a, Z->x);
    element_add(c, c, t0);
    element_neg(c, c);

    element_mul(a, a, Q->x);
    element_mul(b, b, Q->y);
    element_add(c, c, a);
    element_add(c, c, b);
    element_mul(v, v, c);

    element_clear(a);
    element_clear(b);
    element_clear(c);
    element_clear(t0);
}

void cc_miller_line(element_t v, point_t Q, point_t A, point_t B)
{
    //a = -(B.y - A.y) / (B.x - A.x);
    //b = 1;
    //c = -(A.y + a * A.x);
    //but we'll multiply by B.x - A.x to avoid division
    element_t a, b, c;
    element_t t0;

    //we assume B is never O
    if (point_is_inf(A)) {
	cc_miller_vertical(v, Q, A);
	return;
    }

    if (!element_cmp(A->x, B->x)) {
	if (!element_cmp(A->y, B->y)) {
	    cc_miller_tangent(v, Q, A);
	    return;
	}
	cc_miller_vertical(v, Q, A);
	return;
    }

    element_init(a, v->field);
    element_init(b, v->field);
    element_init(c, v->field);
    element_init(t0, v->field);

    element_sub(b, B->x, A->x);
    element_sub(a, A->y, B->y);
    element_mul(t0, b, A->y);
    element_mul(c, a, A->x);
    element_add(c, c, t0);
    element_neg(c, c);

    element_mul(a, a, Q->x);
    element_mul(b, b, Q->y);
    element_add(c, c, a);
    element_add(c, c, b);
    element_mul(v, v, c);

    element_clear(a);
    element_clear(b);
    element_clear(c);
    element_clear(t0);
}

//assumes P is in the base field, Q in some field extension,
//and will be used in Tate pairing computation
void old_cc_miller(element_t res, mpz_t q, point_t Pbase, point_t Q,
	fieldmap mapbase)
//doesn't use the fact that P lies in the base field
//(we don't need Q until the last minute in miller_*()'s)
{
    //collate divisions
    int m;
    element_t v, vd;
    point_t P;
    point_t Z;

    element_init(v, res->field);
    element_init(vd, res->field);
    point_init(Z, Q->curve);
    point_init(P, Q->curve);

    point_map(P, mapbase, Pbase);
    point_set(Z, P);

    element_set1(v);
    element_set1(vd);
    m = mpz_sizeinbase(q, 2) - 2;

    while(m >= 0) {
	element_mul(v, v, v);
	element_mul(vd, vd, vd);
	cc_miller_tangent(v, Q, Z);
	point_double(Z, Z);
	cc_miller_vertical(vd, Q, Z);
	if (mpz_tstbit(q, m)) {
	    cc_miller_line(v, Q, Z, P);
	    point_add(Z, Z, P);
	    cc_miller_vertical(vd, Q, Z);
	}
	m--;
    }

    element_invert(vd, vd);
    element_mul(res, v, vd);

    element_clear(v);
    element_clear(vd);
    point_clear(Z);
    point_clear(P);
}

//assumes P is in the base field, Q in some field extension,
//and will be used in Tate pairing computation
void cc_miller(element_t res, mpz_t q, point_t P, point_t Q,
	fieldmap mapbase)
{
    //collate divisions
    int m;
    element_t v, vd;
    point_t Z;
    element_t a, b, c;
    common_curve_ptr cc = P->curve->data;
    element_t t0;
    element_t e0, e1;

    void do_vertical(element_t e)
    {
	if (point_is_inf(Z)) {
	    return;
	}
	mapbase(e0, Z->x);
	element_sub(e0, Q->x, e0);
	element_mul(e, e, e0);
    }

    void do_tangent(element_t e)
    {
	//a = -slope_tangent(Z.x, Z.y);
	//b = 1;
	//c = -(Z.y + a * Z.x);
	//but we multiply by 2*Z.y to avoid division

	//a = -Zx * (Zx + Zx + Zx + twicea_2) - a_4;
	//Common curves: a2 = 0 (and cc->a is a_4), so
	//a = -(Zx (Zx + Zx + Zx) + cc->a)
	//b = 2 * Zy
	//c = -(2 Zy^2 + a Zx);
	element_ptr Zx = Z->x;
	element_ptr Zy = Z->y;

	if (point_is_inf(Z)) {
	    return;
	}

	if (element_is0(Zy)) {
	    //order 2
	    do_vertical(e);
	    return;
	}

	element_add(a, Zx, Zx);
	element_add(a, a, Zx);
	element_mul(a, a, Zx);
	element_add(a, a, cc->a);
	element_neg(a, a);

	element_add(b, Zy, Zy);

	element_mul(t0, b, Zy);
	element_mul(c, a, Zx);
	element_add(c, c, t0);
	element_neg(c, c);

	//TODO: implement poly_mul_constant?
	mapbase(e0, a);
	element_mul(e0, e0, Q->x);
	mapbase(e1, b);
	element_mul(e1, e1, Q->y);
	element_add(e0, e0, e1);
	mapbase(e1, c);
	element_add(e0, e0, e1);
	element_mul(e, e, e0);
    }

    void do_line(element_ptr e)
    {
	//a = -(B.y - A.y) / (B.x - A.x);
	//b = 1;
	//c = -(A.y + a * A.x);
	//but we'll multiply by B.x - A.x to avoid division

	element_ptr Ax = Z->x;
	element_ptr Ay = Z->y;
	element_ptr Bx = P->x;
	element_ptr By = P->y;

	//we assume B is never O
	if (point_is_inf(Z)) {
	    do_vertical(e);
	    return;
	}

	if (!element_cmp(Ax, Bx)) {
	    if (!element_cmp(Ay, By)) {
		do_tangent(e);
		return;
	    }
	    do_vertical(e);
	    return;
	}

	element_sub(b, Bx, Ax);
	element_sub(a, Ay, By);
	element_mul(t0, b, Ay);
	element_mul(c, a, Ax);
	element_add(c, c, t0);
	element_neg(c, c);

	mapbase(e0, a);
	element_mul(e0, e0, Q->x);
	mapbase(e1, b);
	element_mul(e1, e1, Q->y);
	element_add(e0, e0, e1);
	mapbase(e1, c);
	element_add(e0, e0, e1);
	element_mul(e, e, e0);
    }

    element_init(a, P->curve->field);
    element_init(b, P->curve->field);
    element_init(c, P->curve->field);
    element_init(t0, P->curve->field);
    element_init(e0, res->field);
    element_init(e1, res->field);

    element_init(v, res->field);
    element_init(vd, res->field);
    point_init(Z, P->curve);

    point_set(Z, P);

    element_set1(v);
    element_set1(vd);
    m = mpz_sizeinbase(q, 2) - 2;

    while(m >= 0) {
	element_mul(v, v, v);
	element_mul(vd, vd, vd);
	do_tangent(v);
	point_double(Z, Z);
	do_vertical(vd);
	if (mpz_tstbit(q, m)) {
	    do_line(v);
	    point_add(Z, Z, P);
	    do_vertical(vd);
	}
	m--;
    }

    element_invert(vd, vd);
    element_mul(res, v, vd);

    element_clear(v);
    element_clear(vd);
    point_clear(Z);
    element_clear(a);
    element_clear(b);
    element_clear(c);
    element_clear(t0);
    element_clear(e0);
    element_clear(e1);
}

static void cc_pairing(element_ptr out, element_ptr in1, element_ptr in2,
	pairing_t pairing)
{
    mnt_pairing_data_ptr p = pairing->data;
    cc_miller(out, pairing->r, in1->data, in2->data, p->mapbase);
    element_pow(out, out, p->tateexp);
}

static int cc_is_almost_coddh(element_ptr a, element_ptr b,
	element_ptr c, element_ptr d,
	pairing_t pairing)
{
    int res = 0;
    element_t t0, t1, t2;

    element_init(t0, pairing->GT);
    element_init(t1, pairing->GT);
    element_init(t2, pairing->GT);
    mnt_pairing_data_ptr p = pairing->data;
    cc_miller(t0, pairing->r, a->data, d->data, p->mapbase);
    cc_miller(t1, pairing->r, b->data, c->data, p->mapbase);
    element_pow(t0, t0, p->tateexp);
    element_pow(t1, t1, p->tateexp);
    element_mul(t2, t0, t1);
    if (element_is1(t2)) {
	//g, g^x, h, h^-x case
	res = 1;
    } else {
	element_invert(t1, t1);
	element_mul(t2, t0, t1);
	if (element_is1(t2)) {
	    //g, g^x, h, h^x case
	    res = 1;
	}
    }
    element_clear(t0);
    element_clear(t1);
    element_clear(t2);
    return res;
}

static void trace(element_ptr out, element_ptr in, pairing_ptr pairing)
{
    int i;
    point_ptr p = in->data;
    point_ptr r = out->data;
    point_t q;
    mnt_pairing_data_ptr mpdp = pairing->data;

    point_init(q, p->curve);

    point_set(q, p);
    point_set(r, p);

    for (i=1; i<mpdp->k; i++) {
	cc_frobenius(q, q, mpdp->Fq->order);
	point_add(r, r, q);
    }
    point_clear(q);
}

void pairing_init_cc_param(pairing_t pairing, cc_param_t param)
{
    mnt_pairing_data_ptr p;
    element_t a, b;
    element_t irred;
    mpz_t z;

    mpz_init(pairing->r);
    mpz_set(pairing->r, param->r);
    field_init_fp(pairing->Zr, pairing->r);
    pairing->map = cc_pairing;
    pairing->is_almost_coddh = cc_is_almost_coddh;

    p =	pairing->data = malloc(sizeof(mnt_pairing_data_t));
    field_init_fp(p->Fq, param->q);
    element_init(a, p->Fq);
    element_init(b, p->Fq);
    element_set_mpz(a, param->a);
    element_set_mpz(b, param->b);
    curve_init_cc_ab(p->Eq, a, b);

    field_init_poly(p->Fqx, p->Fq);
    element_init(irred, p->Fqx);
    do {
	poly_random_monic(irred, param->k);
    } while (!poly_is_irred(irred));
    field_init_polymod(p->Fqk, irred);
    element_clear(irred);

    mpz_init(p->tateexp);
    mpz_sub_ui(p->tateexp, p->Fqk->order, 1);
    mpz_divexact(p->tateexp, p->tateexp, pairing->r);

    p->mapbase = ((polymod_field_data_ptr) p->Fqk->data)->mapbase;

    cc_switch_field(p->Eqk, p->Eq, p->Fqk, p->mapbase);

    pairing->G1 = malloc(sizeof(field_t));
    pairing->G2 = malloc(sizeof(field_t));

    field_init_curve_group(pairing->G1, p->Eq, param->h);
    mpz_init(z);
    mpz_set_si(z, 1);
    field_init_curve_group(pairing->G2, p->Eqk, z);
    mpz_clear(z);
    p->k = param->k;
    pairing->GT = p->Fqk;
    pairing->phi = trace;

    element_clear(a);
    element_clear(b);
}

static void phi_identity(element_ptr out, element_ptr in, pairing_ptr pairing)
{
    element_set(out, in);
}

//TODO: use Miller algorithm that's been specialized for Solinas primes
static void solinas_pairing(element_ptr out, element_ptr in1, element_ptr in2,
	pairing_t pairing)
//in1, in2 are from E(F_q), out from F_q^2
{
    solinas_pairing_data_ptr p = pairing->data;
    point_ptr Qbase = in2->data;
    point_t V, V1;
    element_t f, f0, f1;
    element_t a, b, c;
    element_t e0;
    int i, n;

    void do_tangent() {
	//a = -slope_tangent(V.x, V.y);
	//b = 1;
	//c = -(V.y + aV.x);
	//but we multiply by -2*V.y to avoid division so:
	//a = -(Vx (Vx + Vx + Vx) + cc->a)
	//b = 2 * Vy
	//c = -(2 Vy^2 + a Vx);
	element_ptr Vx = V->x;
	element_ptr Vy = V->y;
	element_add(a, Vx, Vx);
	element_add(a, a, Vx);
	element_mul(a, a, Vx);
	element_set1(b);
	element_add(a, a, b);
	element_neg(a, a);

	element_add(b, Vy, Vy);

	element_mul(e0, b, Vy);
	element_mul(c, a, Vx);
	element_add(c, c, e0);
	element_neg(c, c);

	//we'll map Qbase via (x,y) --> (-x, iy)
	//hence a Q.x + c = -a Qbase.x + c is real while
	//(b Q.y) = b Qbase.y i is purely imaginary.
	element_mul(a, a, Qbase->x);
	element_sub(fi_re(f0), c, a);
	element_mul(fi_im(f0), b, Qbase->y);
	element_mul(f, f, f0);
    }

    void do_line(point_ptr A, point_ptr B) {
	//a = -(B.y - A.y) / (B.x - A.x);
	//b = 1;
	//c = -(A.y + a * A.x);
	//but we'll multiply by B.x - A.x to avoid division, so
	//a = -(By - Ay)
	//b = Bx - Ax
	//c = -(Ay b + a Ax);
	element_sub(a, A->y, B->y);
	element_sub(b, B->x, A->x);
	element_mul(e0, a, A->x);
	element_mul(c, b, A->y);
	element_add(c, c, e0);
	element_neg(c, c);

	//we'll map Qbase via (x,y) --> (-x, iy)
	//hence a Q.x + c = -a Qbase.x + c is real while
	//(b Q.y) = b Qbase.y i is purely imaginary.
	element_mul(a, a, Qbase->x);
	element_sub(fi_re(f0), c, a);
	element_mul(fi_im(f0), b, Qbase->y);
	element_mul(f, f, f0);
    }

   if (0) {
	//check against standard Miller's alogrithm
	point_t Q;
	fi_data_ptr fdp;

	point_init(Q, p->Eq2);
	Q->inf_flag = 0;
	fdp = Q->x->data;
	element_neg(fdp->x, Qbase->x);
	element_set0(fdp->y);
	fdp = Q->y->data;
	element_set(fdp->y, Qbase->y);
	element_set0(fdp->x);

	cc_miller(out, pairing->r, in1->data, Q, element_field_to_fieldi);
	element_pow(out, out, p->tateexp);
	printf("check: ");
	element_out_str(stdout, 0, out);
	printf("\n");

	point_clear(Q);
    }

    point_init(V, p->Eq);
    point_init(V1, p->Eq);
    point_set(V, in1->data);
    element_init(f, p->Fq2);
    element_init(f0, p->Fq2);
    element_init(f1, p->Fq2);
    element_set1(f);
    element_init(a, p->Fq);
    element_init(b, p->Fq);
    element_init(c, p->Fq);
    element_init(e0, p->Fq);
    n = p->exp1;
    for (i=0; i<n; i++) {
	//f = f^2 g_V,V(Q)
	//where g_V,V = tangent at V
	//TODO: implement element_square?
	element_mul(f, f, f);
	do_tangent();
	point_double(V, V);
    }
    if (p->sign1 < 0) {
	//TODO: do_vert() or something too
	point_neg(V1, V);
	element_invert(f0, f);
printf("negative sign1 not implemented yet!\n");
    } else {
	point_set(V1, V);
	element_set(f1, f);
    }
    n = p->exp2;
    for (; i<n; i++) {
	element_mul(f, f, f);
	do_tangent();
	point_double(V, V);
    }

    element_mul(f, f, f1);
    do_line(V, V1);

    if (p->sign0 > 0) {
	//do something
printf("positive sign0 not implemented yet!\n");
    }
    //otherwise we do nothing since f_r = f_{r-1}

    //Tate exponentiation
    //simpler but slower:
    //element_pow(out, f, p->tateexp);
    //use this trick instead:
    element_invert(f0, f);
    element_neg(fi_im(f), fi_im(f));
    element_mul(f, f, f0);
    element_pow(out, f, p->h);

    element_clear(f);
    element_clear(f0);
    element_clear(f1);
    point_clear(V);
    point_clear(V1);
    element_clear(a);
    element_clear(b);
    element_clear(c);
    element_clear(e0);
}

void pairing_init_solinas_param(pairing_t pairing, solinas_param_t param)
{
    element_t a, b;
    solinas_pairing_data_ptr p;

    p =	pairing->data = malloc(sizeof(solinas_pairing_data_t));
    p->exp2 = param->exp2;
    p->exp1 = param->exp1;
    p->sign1 = param->sign1;
    p->sign0 = param->sign0;
    mpz_init(pairing->r);
    mpz_set(pairing->r, param->r);
    field_init_fp(pairing->Zr, pairing->r);
    pairing->map = solinas_pairing;
    pairing->is_almost_coddh = cc_is_almost_coddh;

    field_init_fp(p->Fq, param->q);
    element_init(a, p->Fq);
    element_init(b, p->Fq);
    element_set1(a);
    element_set0(b);
    curve_init_cc_ab(p->Eq, a, b);
    element_clear(a);
    element_clear(b);

    field_init_fi(p->Fq2, p->Fq);
    cc_switch_field(p->Eq2, p->Eq, p->Fq2, element_field_to_fieldi);

    mpz_init(p->tateexp);
    mpz_sub_ui(p->tateexp, p->Fq2->order, 1);
    mpz_divexact(p->tateexp, p->tateexp, pairing->r);

    mpz_init(p->h);
    mpz_set(p->h, param->h);

    pairing->G1 = malloc(sizeof(field_t));
    field_init_curve_group(pairing->G1, p->Eq, param->h);
    pairing->G2 = pairing->G1;
    pairing->phi = phi_identity;
    pairing->GT = p->Fq2;
}
