#include "curve.h"

//I define "common" curves (abbr. "cc") to be those
//of the form y^2 = x^3 + ax + b defined over a field with
//characteristic > 3

struct common_curve_s {
    element_t a, b;
};
typedef struct common_curve_s common_curve_t[1];
typedef struct common_curve_s *common_curve_ptr;

static void cc_random(point_ptr p)
{
    element_t t, t1;
    common_curve_ptr cc = p->curve->data;

    element_init(t, p->curve->field);
    element_init(t1, p->curve->field);
    p->inf_flag = 0;
    do {
	element_random(p->x);
	element_mul(t, p->x, p->x);
	element_add(t, t, cc->a);
	element_mul(t, t, p->x);
	element_add(t, t, cc->b);
    } while (!element_is_sqr(t));
    element_sqrt(p->y, t);

    element_clear(t);
    element_clear(t1);
}

static void cc_add(point_ptr r, point_ptr p, point_ptr q)
{
    if (point_is_inf(p)) {
	point_set(r, q);
	return;
    }
    if (point_is_inf(q)) {
	point_set(r, p);
	return;
    }
    if (!element_cmp(p->x, q->x)) {
	if (!element_cmp(p->y, q->y)) {
	    if (element_is0(p->y)) {
		point_set_inf(r);
		return;
	    } else {
		element_t lambda, e0, e1;
		common_curve_ptr cc = p->curve->data;

		element_init(lambda, p->curve->field);
		element_init(e0, p->curve->field);
		element_init(e1, p->curve->field);
		//same point: double them

		//lambda = (3x^2 + a) / 2y
		element_set_si(e0, 3);
		element_mul(lambda, p->x, p->x);
		element_mul(lambda, lambda, e0);
		element_add(lambda, lambda, cc->a);
		element_add(e0, p->y, p->y);
		element_invert(e0, e0);
		element_mul(lambda, lambda, e0);
		//x1 = lambda^2 - 2x
		element_add(e1, p->x, p->x);
		element_mul(e0, lambda, lambda);
		element_sub(e0, e0, e1);
		//y1 = (x - x1)lambda - y
		element_sub(e1, p->x, e0);
		element_mul(e1, e1, lambda);
		element_sub(e1, e1, p->y);

		element_set(r->x, e0);
		element_set(r->y, e1);
		r->inf_flag = 0;

		element_clear(lambda);
		element_clear(e0);
		element_clear(e1);
		return;
	    }
	}
	//points are inverses of each other
	point_set_inf(r);
	return;
    } else {
	element_t lambda, e0, e1;

	element_init(lambda, p->curve->field);
	element_init(e0, p->curve->field);
	element_init(e1, p->curve->field);

	//lambda = (y2-y1)/(x2-x1)
	element_sub(e0, q->x, p->x);
	element_invert(e0, e0);
	element_sub(lambda, q->y, p->y);
	element_mul(lambda, lambda, e0);
	//x3 = lambda^2 - x1 - x2
	element_mul(e0, lambda, lambda);
	element_sub(e0, e0, p->x);
	element_sub(e0, e0, q->x);
	//y3 = (x1-x3)lambda - y1
	element_sub(e1, p->x, e0);
	element_mul(e1, e1, lambda);
	element_sub(e1, e1, p->y);

	element_set(r->x, e0);
	element_set(r->y, e1);

	element_clear(lambda);
	element_clear(e0);
	element_clear(e1);
    }
}

static void cc_mul(point_ptr r, mpz_ptr n, point_ptr p)
{
    int s;

    point_t result;
    point_init(result, r->curve);
    point_set_inf(result);

    if (mpz_sgn(n)) for (s = mpz_sizeinbase(n, 2) - 1; s>=0; s--) {
	point_add(result, result, result);
	if (mpz_tstbit(n, s)) {
	    point_add(result, result, p);
	}
    }
    point_set(r, result);
    point_clear(result);
}

void curve_init_cc_j(curve_ptr c, element_ptr j)
//assumes j != 0, 1728
{
    common_curve_ptr cc;

    c->field = j->field;
    c->random = cc_random;
    c->add = cc_add;
    c->mul = cc_mul;
    c->data = malloc(sizeof(common_curve_t));
    cc = c->data;
    element_init(cc->a, c->field);
    element_init(cc->b, c->field);
    element_set_si(cc->a, 1728);
    element_sub(cc->a, cc->a, j);
    element_invert(cc->a, cc->a);
    element_mul(cc->a, cc->a, j);

    //b = 2 j / (1728 - j)
    element_add(cc->b, cc->a, cc->a);
    //a = 3 j / (1728 - j)
    element_add(cc->a, cc->a, cc->b);
    printf("a = ");
    element_out_str(stdout, cc->a);
    printf(", b = ");
    element_out_str(stdout, cc->b);
    printf("\n");
}

//TODO: curve_clear_cc_j

void curve_init_cc_ab(curve_ptr c, element_ptr a, element_ptr b)
{
    common_curve_ptr cc;
    c->field = a->field;
    c->random = cc_random;
    c->add = cc_add;
    c->mul = cc_mul;
    c->data = malloc(sizeof(common_curve_t));
    cc = c->data;
    element_init(cc->a, c->field);
    element_init(cc->b, c->field);
    element_set(cc->a, a);
    element_set(cc->b, b);
}

void twist_curve(curve_ptr c)
{
    common_curve_ptr cc = c->data;
    element_ptr nqr = c->field->nqr;
    element_mul(cc->a, cc->a, nqr);
    element_mul(cc->a, cc->a, nqr);
    element_mul(cc->b, cc->b, nqr);
    element_mul(cc->b, cc->b, nqr);
    element_mul(cc->b, cc->b, nqr);
}

//TODO: curve_clear_cc_ab

void cc_switch_field(curve_ptr cnew, curve_ptr c, fieldmap_ptr map)
{
    common_curve_ptr ccnew, cc;
    cnew->field = map->dstfield;
    cnew->random = c->random;
    cnew->add = c->add;
    cnew->mul = c->mul;

    cnew->data = malloc(sizeof(common_curve_t));
    ccnew = cnew->data;
    cc = c->data;
    element_init(ccnew->a, cnew->field);
    element_init(ccnew->b, cnew->field);
    map->map(ccnew->a, cc->a);
    map->map(ccnew->b, cc->b);
}

void cc_init_extend(curve_ptr cnew, curve_ptr c, field_ptr fext)
{
    //assumes fext is a field extension of the field c is defined over
    cc_switch_field(cnew, c, ((polymod_field_data_ptr) fext->data)->mapbase);
}

void compute_trace_n(mpz_t res, mpz_t q, mpz_t trace, int n)
    //compute trace of Frobenius at q^n given trace at q
    //see p.105 of Blake, Seroussi and Smart
{
    int i;
    mpz_t c0, c1, c2;
    mpz_t t0;

    mpz_init(c0);
    mpz_init(c1);
    mpz_init(c2);
    mpz_init(t0);
    mpz_set_ui(c2, 2);
    mpz_set(c1, trace);
    for (i=2; i<=n; i++) {
	mpz_mul(c0, trace, c1);
	mpz_mul(t0, q, c2);
	mpz_sub(c0, c0, t0);
	mpz_set(c2, c1);
	mpz_set(c1, c0);
    }
    mpz_set(res, c1);
    mpz_clear(t0);
    mpz_clear(c2);
    mpz_clear(c1);
    mpz_clear(c0);
}

void point_map(point_t R, fieldmap_t map, point_t P)
{
    if (point_is_inf(P)) {
	point_set_inf(R);
	return;
    }
    map->map(R->x, P->x);
    map->map(R->y, P->y);
    R->inf_flag = 0;
}

void point_extend(point_t R, point_t P, field_t f)
{
    point_map(R, ((polymod_field_data_ptr) f->data)->mapbase, P);
}

//These slow_miller_*() functions are more general
//e.g. can use to compute Weil pairing not just Tate pairing.
//TODO: test these
void slow_miller_vertical(element_t v, element_t vd, point_t Q, point_t Z)
{
    element_t t0;

    if (point_is_inf(Z)) {
	return;
    }
    element_init(t0, v->field);
    element_sub(t0, Q->x, Z->x);
    element_mul(v, v, t0);
    element_clear(t0);

    return;
}

void slow_miller_tangent(element_t v, element_t vd, point_t Q, point_t Z)
{
    //a = -slope_tangent(Z.x, Z.y);
    //b = 1;
    //c = -(Z.y + a * Z.x);
    //but we multiply by 2*Z.y to avoid division

    //a = -Zx * (Zx + Zx + Zx + *twicea2) - *a4;
    //Common curves: a2 = 0
    element_t a, b, c;
    element_t t0;
    common_curve_ptr cc = Z->curve->data;

    if (point_is_inf(Z)) {
	return;
    }

    if (element_is0(Z->y)) {
	//order 2
	slow_miller_vertical(v, vd, Q, Z);
	return;
    }

    element_init(a, v->field);
    element_init(b, v->field);
    element_init(c, v->field);
    element_init(t0, v->field);

    element_add(a, Z->x, Z->x);
    element_add(a, a, Z->x);
    element_mul(a, a, Z->x);
    element_add(a, a, cc->a);
    element_neg(a, a);

    /* alternative
    element_set_si(a, 3);
    element_mul(a, a, Z->x);
    element_mul(a, a, Z->x);
    element_add(a, a, cc->a);
    element_neg(a, a);
    */

    element_add(b, Z->y, Z->y);

    element_mul(t0, b, Z->y);
    element_mul(c, a, Z->x);
    element_add(c, c, t0);
    element_neg(c, c);

    element_mul(vd, vd, b);
    element_mul(a, a, Q->x);
    element_mul(b, b, Q->y);
    element_add(c, c, a);
    element_add(c, c, b);
    element_mul(v, v, c);

    element_clear(a);
    element_clear(b);
    element_clear(c);
    element_clear(t0);
}

void slow_miller_line(element_t v, element_t vd, point_t Q, point_t A, point_t B)
{
    //a = -(Q.y - P.y) / (Q.x - P.x);
    //b = 1;
    //c = -(P.y + a * P.x);
    //but we'll multiply by Q.x - P.x to avoid division
    element_t a, b, c;
    element_t t0;

    //we assume B is never O
    if (point_is_inf(A)) {
	slow_miller_vertical(v, vd, Q, A);
	return;
    }

    if (!element_cmp(A->x, B->x)) {
	if (!element_cmp(A->y, B->y)) {
	    slow_miller_tangent(v, vd, Q, A);
	    return;
	}
	slow_miller_vertical(v, vd, Q, A);
	return;
    }

    element_init(a, v->field);
    element_init(b, v->field);
    element_init(c, v->field);
    element_init(t0, v->field);

    element_sub(b, B->x, A->x);
    element_sub(a, A->y, B->y);
    element_mul(t0, b, A->y);
    element_mul(c, a, A->x);
    element_add(c, c, t0);
    element_neg(c, c);

    element_mul(vd, vd, b);
    element_mul(a, a, Q->x);
    element_mul(b, b, Q->y);
    element_add(c, c, a);
    element_add(c, c, b);
    element_mul(v, v, c);

    element_clear(a);
    element_clear(b);
    element_clear(c);
    element_clear(t0);
}

void slow_miller(element_t res, mpz_t q, point_t P, point_t Q)
{
    //collate divisions
    int m;
    element_t v, vd;
    point_t Z;
    element_init(v, res->field);
    element_init(vd, res->field);
    point_init(Z, P->curve);

    point_set(Z, P);

    element_set_si(v, 1);
    element_set_si(vd, 1);
    m = mpz_sizeinbase(q, 2) - 2;

    while(m >= 0) {
	element_mul(v, v, v);
	element_mul(vd, vd, vd);
	slow_miller_tangent(v, vd, Q, Z);
	point_add(Z, Z, Z);
	slow_miller_vertical(vd, v, Q, Z);
	if (mpz_tstbit(q, m)) {
	    slow_miller_line(v, vd, Q, Z, P);
	    point_add(Z, Z, P);
	    slow_miller_vertical(vd, v, Q, Z);
	}
	m--;
    }

    element_invert(vd, vd);
    element_mul(res, v, vd);

    element_clear(v);
    element_clear(vd);
    point_clear(Z);
}

void cc_miller_vertical(element_t v, point_t Q, point_t Z)
{
    element_t t0;

    if (point_is_inf(Z)) {
	return;
    }
    element_init(t0, v->field);
    element_sub(t0, Q->x, Z->x);
    element_mul(v, v, t0);
    element_clear(t0);

    return;
}

void cc_miller_tangent(element_t v, point_t Q, point_t Z)
{
    //a = -slope_tangent(Z.x, Z.y);
    //b = 1;
    //c = -(Z.y + a * Z.x);
    //but we multiply by 2*Z.y to avoid division

    //a = -Zx * (Zx + Zx + Zx + *twicea2) - *a4;
    //Common curves: a2 = 0
    element_t a, b, c;
    element_t t0;
    common_curve_ptr cc = Z->curve->data;

    if (point_is_inf(Z)) {
	return;
    }

    if (element_is0(Z->y)) {
	//order 2
	cc_miller_vertical(v, Q, Z);
	return;
    }

    element_init(a, v->field);
    element_init(b, v->field);
    element_init(c, v->field);
    element_init(t0, v->field);

    element_add(a, Z->x, Z->x);
    element_add(a, a, Z->x);
    element_mul(a, a, Z->x);
    element_add(a, a, cc->a);
    element_neg(a, a);

    element_add(b, Z->y, Z->y);

    element_mul(t0, b, Z->y);
    element_mul(c, a, Z->x);
    element_add(c, c, t0);
    element_neg(c, c);

    element_mul(a, a, Q->x);
    element_mul(b, b, Q->y);
    element_add(c, c, a);
    element_add(c, c, b);
    element_mul(v, v, c);

    element_clear(a);
    element_clear(b);
    element_clear(c);
    element_clear(t0);
}

void cc_miller_line(element_t v, point_t Q, point_t A, point_t B)
{
    //a = -(B.y - A.y) / (B.x - A.x);
    //b = 1;
    //c = -(A.y + a * A.x);
    //but we'll multiply by B.x - A.x to avoid division
    element_t a, b, c;
    element_t t0;

    //we assume B is never O
    if (point_is_inf(A)) {
	cc_miller_vertical(v, Q, A);
	return;
    }

    if (!element_cmp(A->x, B->x)) {
	if (!element_cmp(A->y, B->y)) {
	    cc_miller_tangent(v, Q, A);
	    return;
	}
	cc_miller_vertical(v, Q, A);
	return;
    }

    element_init(a, v->field);
    element_init(b, v->field);
    element_init(c, v->field);
    element_init(t0, v->field);

    element_sub(b, B->x, A->x);
    element_sub(a, A->y, B->y);
    element_mul(t0, b, A->y);
    element_mul(c, a, A->x);
    element_add(c, c, t0);
    element_neg(c, c);

    element_mul(a, a, Q->x);
    element_mul(b, b, Q->y);
    element_add(c, c, a);
    element_add(c, c, b);
    element_mul(v, v, c);

    element_clear(a);
    element_clear(b);
    element_clear(c);
    element_clear(t0);
}

//assumes P is in the base field, Q in some field extension,
//and will be used in Tate pairing computation
void cc_miller(element_t res, mpz_t q, point_t Pbase, point_t Q)
//TODO: optimize using the fact that P lies in the base field
//(we don't need Q until the last minute in miller_*()'s)
{
    //collate divisions
    int m;
    element_t v, vd;
    point_t P;
    point_t Z;

    element_init(v, res->field);
    element_init(vd, res->field);
    point_init(Z, Q->curve);
    point_init(P, Q->curve);

    point_extend(P, Pbase, (Q->curve->field));
    point_set(Z, P);

    element_set_si(v, 1);
    element_set_si(vd, 1);
    m = mpz_sizeinbase(q, 2) - 2;

    while(m >= 0) {
	element_mul(v, v, v);
	element_mul(vd, vd, vd);
	cc_miller_tangent(v, Q, Z);
	point_add(Z, Z, Z);
	cc_miller_vertical(vd, Q, Z);
	if (mpz_tstbit(q, m)) {
	    cc_miller_line(v, Q, Z, P);
	    point_add(Z, Z, P);
	    cc_miller_vertical(vd, Q, Z);
	}
	m--;
    }

    element_invert(vd, vd);
    element_mul(res, v, vd);

    element_clear(v);
    element_clear(vd);
    point_clear(Z);
    point_clear(P);
}
