#include "field.h"

static void zp_init(element_ptr e)
{
    e->data = malloc(sizeof(mpz_t));
    mpz_init(e->data);
}

static void zp_clear(element_ptr e)
{
    mpz_clear(e->data);
    free(e->data);
}

static void zp_set_si(element_ptr e, signed long int op)
{
    mpz_set_si(e->data, op);
    mpz_mod(e->data, e->data, e->field->order);
}

static void zp_set_mpz(element_ptr e, mpz_ptr z)
{
    mpz_set(e->data, z);
    mpz_mod(e->data, e->data, e->field->order);
}

static void zp_set0(element_ptr e)
{
    mpz_set_si(e->data, 0);
}

static void zp_set1(element_ptr e)
{
    mpz_set_si(e->data, 1);
}

static void zp_out_str(FILE *stream, int base, element_ptr e)
{
    mpz_out_str(stream, base, e->data);
}

static void zp_add(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_add(n->data, a->data, b->data);
    mpz_mod(n->data, n->data, n->field->order);
}

static void zp_sub(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_sub(n->data, a->data, b->data);
    mpz_mod(n->data, n->data, n->field->order);
}

static void zp_mul(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_mul(n->data, a->data, b->data);
    mpz_mod(n->data, n->data, n->field->order);
}

static void zp_set(element_ptr n, element_ptr a)
{
    mpz_set(n->data, a->data);
}

static void zp_neg(element_ptr n, element_ptr a)
{
    if (mpz_sgn((mpz_ptr) a->data)) mpz_sub(n->data, n->field->order, a->data);
}

static void zp_invert(element_ptr n, element_ptr a)
{
    mpz_invert(n->data, a->data, n->field->order);
}

static void zp_random(element_ptr n)
{
    pbc_mpz_random(n->data, n->field->order);
}

static void zp_from_hash(element_ptr n, int len, void *data)
    //TODO: something more sophisticated!
{
    mpz_t z;

    mpz_init(z);
    mpz_import(z, len, 1, 1, 0, 0, data);
    zp_set_mpz(n, z);
    mpz_clear(z);
}

static int zp_is1(element_ptr n)
{
    return !mpz_cmp_ui((mpz_ptr) n->data, 1);
}

static int zp_is0(element_ptr n)
{
    return !mpz_cmp_ui((mpz_ptr) n->data, 0);
}

static int zp_cmp(element_ptr a, element_ptr b)
{
    return mpz_cmp((mpz_ptr) a->data, (mpz_ptr) b->data);
}

static int zp_is_sqr(element_ptr a)
{
    return mpz_legendre(a->data, a->field->order) == 1;
}

static void fp_tonelli(element_ptr x, element_ptr a)
{
    int s;
    int i;
    mpz_t e;
    mpz_t t, t0;
    element_t ginv, e0;
    element_ptr nqr;

    mpz_init(t);
    mpz_init(e);
    mpz_init(t0);
    element_init(ginv, a->field);
    element_init(e0, a->field);
    nqr = field_get_nqr(a->field);

    element_invert(ginv, nqr); 

    //let q be the order of the field
    //q - 1 = 2^s t, t odd
    mpz_sub_ui(t, a->field->order, 1);
    s = mpz_scan1(t, 0);
    mpz_tdiv_q_2exp(t, t, s);
    mpz_set_ui(e, 0);
    for (i=2; i<=s; i++) {
	mpz_sub_ui(t0, a->field->order, 1);
	mpz_tdiv_q_2exp(t0, t0, i);
	element_pow(e0, ginv, e);
	element_mul(e0, e0, a);
	element_pow(e0, e0, t0);
	if (!element_is1(e0)) mpz_setbit(e, i-1);
    }
    element_pow(e0, ginv, e);
    element_mul(e0, e0, a);
    mpz_add_ui(t, t, 1);
    mpz_tdiv_q_2exp(t, t, 1);
    element_pow(e0, e0, t);
    mpz_tdiv_q_2exp(e, e, 1);
    element_pow(x, nqr, e);
    /* TODO: this would be a good place to use element_pow2 ... -hs */
    element_mul(x, x, e0);
    mpz_clear(t);
    mpz_clear(e);
    mpz_clear(t0);
    element_clear(ginv);
    element_clear(e0);
}

static void zp_field_clear(field_t f)
{
}

static int zp_to_bytes(unsigned char *data, element_t e)
{
    mpz_t z;
    int i, n;
    unsigned char *ptr;

    mpz_init(z);
    mpz_set(z, e->data);
    n = e->field->fixed_length_in_bytes;
    ptr = data;
    for (i = 0; i < n; i++) {
	*ptr = (unsigned char) mpz_get_ui(z);
	ptr++;
	mpz_tdiv_q_2exp(z, z, 8);
    }
    mpz_clear(z);
    return n;
}

static int zp_from_bytes(element_t e, unsigned char *data)
{
    unsigned char *ptr;
    int i, n;
    mpz_ptr z = e->data;
    mpz_t z1;

    mpz_init(z1);
    mpz_set_ui(z, 0);

    ptr = data;
    n = e->field->fixed_length_in_bytes;
    for (i=0; i<n; i++) {
	mpz_set_ui(z1, *ptr);
	mpz_mul_2exp(z1, z1, i * 8);
	ptr++;
	mpz_add(z, z, z1);
    }
    mpz_clear(z1);
    return n;
}

void field_init_fp(field_ptr f, mpz_t prime)
{
    field_init(f);
    f->init = zp_init;
    f->clear = zp_clear;
    f->set_si = zp_set_si;
    f->set_mpz = zp_set_mpz;
    f->out_str = zp_out_str;
    f->add = zp_add;
    f->sub = zp_sub;
    f->set = zp_set;
    f->mul = zp_mul;
    f->neg = zp_neg;
    f->cmp = zp_cmp;
    f->invert = zp_invert;
    f->random = zp_random;
    f->from_hash = zp_from_hash;
    f->is1 = zp_is1;
    f->is0 = zp_is0;
    f->set0 = zp_set0;
    f->set1 = zp_set1;
    f->is_sqr = zp_is_sqr;
    f->sqrt = fp_tonelli;
    f->field_clear = zp_field_clear;
    f->to_bytes = zp_to_bytes;
    f->from_bytes = zp_from_bytes;

    mpz_set(f->order, prime);
    f->data = NULL;
    f->fixed_length_in_bytes = (mpz_sizeinbase(prime, 2) + 7) / 8;
}

void element_pow(element_ptr x, element_ptr a, mpz_ptr n)
{
    int s;

    element_t result;

    if (!mpz_sgn(n)) {
        element_set1(x);
        return;
    }

    element_init(result, x->field);
    element_set1(result);

    for (s = mpz_sizeinbase(n, 2) - 1; s >=0; s--) {
	element_mul(result, result, result);
	if (mpz_tstbit(n, s)) {
	    element_mul(result, result, a);
	}
    }
    element_set(x, result);
    element_clear(result);
}

void element_pow2(element_ptr x, element_ptr a1, mpz_ptr n1,
                                 element_ptr a2, mpz_ptr n2)
{
    int s, s1, s2;
    int b1, b2;

    element_t result, a1a2;

    if (!mpz_sgn(n1) && !mpz_sgn(n2)) {
        element_set1(x);
        return;
    }

    element_init(result, x->field);
    element_set1(result);

    element_init(a1a2, x->field);
    element_mul(a1a2, a1, a2);

    s1 = mpz_sizeinbase(n1, 2) - 1;
    s2 = mpz_sizeinbase(n2, 2) - 1;
    for (s = (s1 > s2) ? s1 : s2; s >=0; s--) {
        element_mul(result, result, result);
        b1 = mpz_tstbit(n1, s); b2 = mpz_tstbit(n2, s);
        if (b1 && b2) {
            element_mul(result, result, a1a2);
        } else if (b1) {
            element_mul(result, result, a1);
        } else if (b2) {
            element_mul(result, result, a2);
        }
    }

    element_set(x, result);
    element_clear(result);
    element_clear(a1a2);
}

void element_pow3(element_ptr x, element_ptr a1, mpz_ptr n1,
                                 element_ptr a2, mpz_ptr n2,
                                 element_ptr a3, mpz_ptr n3)
{
    int s, s1, s2, s3;
    int b;
    int i;

    element_t result;
    element_t lookup[8];

    if (!mpz_sgn(n1) && !mpz_sgn(n2) && !mpz_sgn(n3)) {
        element_set1(x);
        return;
    }

    element_init(result, x->field);
    element_set1(result);

    for (i=0; i<8; i++)
        element_init(lookup[i], x->field);

    /* build lookup table. */
    element_set1(lookup[0]);
    element_set(lookup[1], a1);
    element_set(lookup[2], a2);
    element_set(lookup[4], a3);
    element_mul(lookup[3], a1, a2);
    element_mul(lookup[5], a1, a3);
    element_mul(lookup[6], a2, a3);
    element_mul(lookup[7], lookup[6], a1);

    /* calculate largest exponent bitsize */
    s1 = mpz_sizeinbase(n1, 2) - 1;
    s2 = mpz_sizeinbase(n2, 2) - 1;
    s3 = mpz_sizeinbase(n3, 2) - 1;
    s = (s1 > s2) ? ((s1 > s3) ? s1 : s3)
                    : ((s2 > s3) ? s2 : s3);

    for (; s >=0; s--) {
        element_mul(result, result, result);
        b = (mpz_tstbit(n1, s))
          + (mpz_tstbit(n2, s) << 1)
          + (mpz_tstbit(n3, s) << 2);
        element_mul(result, result, lookup[b]);
    }
    
    element_set(x, result);
    element_clear(result);
    for (i=0; i<8; i++)
        element_clear(lookup[i]);
}

element_ptr field_get_nqr(field_ptr f)
{
    if (!f->nqr) {
	f->nqr = malloc(sizeof(element_t));
	element_init(f->nqr, f);
	do {
	    element_random(f->nqr);
	} while (element_is_sqr(f->nqr));
    }
    return f->nqr;
}
