#include "field.h"

gmp_randstate_t rs;

static void zp_init(element_ptr e)
{
    e->data = malloc(sizeof(mpz_t));
    mpz_init(e->data);
}

static void zp_clear(element_ptr e)
{
    mpz_clear(e->data);
    free(e->data);
}

static void zp_set_si(element_ptr e, signed long int op)
{
    mpz_set_si(e->data, op);
    mpz_mod(e->data, e->data, e->field->order);
}

static void zp_set_mpz(element_ptr e, mpz_ptr z)
{
    mpz_set(e->data, z);
    mpz_mod(e->data, e->data, e->field->order);
}

static void zp_set0(element_ptr e)
{
    mpz_set_si(e->data, 0);
}

static void zp_out_str(FILE *stream, element_ptr e)
{
    mpz_out_str(stream, 10, e->data);
}

static void zp_add(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_add(n->data, a->data, b->data);
    mpz_mod(n->data, n->data, n->field->order);
}

static void zp_sub(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_sub(n->data, a->data, b->data);
    mpz_mod(n->data, n->data, n->field->order);
}

static void zp_mul(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_mul(n->data, a->data, b->data);
    mpz_mod(n->data, n->data, n->field->order);
}

static void zp_set(element_ptr n, element_ptr a)
{
    mpz_set(n->data, a->data);
}

static void zp_neg(element_ptr n, element_ptr a)
{
    if (mpz_sgn((mpz_ptr) a->data)) mpz_sub(n->data, n->field->order, a->data);
}

static void zp_invert(element_ptr n, element_ptr a)
{
    mpz_invert(n->data, a->data, n->field->order);
}

static void zp_random(element_ptr n)
{
    mpz_urandomm(n->data, rs, n->field->order);
}

static int zp_is1(element_ptr n)
{
    return !mpz_cmp_ui((mpz_ptr) n->data, 1);
}

static int zp_is0(element_ptr n)
{
    return !mpz_cmp_ui((mpz_ptr) n->data, 0);
}

static int zp_cmp(element_ptr a, element_ptr b)
{
    return mpz_cmp((mpz_ptr) a->data, (mpz_ptr) b->data);
}

static int zp_is_sqr(element_ptr a)
{
    return mpz_legendre(a->data, a->field->order) == 1;
}

static void fp_tonelli(element_ptr x, element_ptr a)
{
    int s;
    int i;
    mpz_t e;
    mpz_t t, t0;
    element_t ginv, e0;

    mpz_init(t);
    mpz_init(e);
    mpz_init(t0);
    element_init(ginv, a->field);
    element_init(e0, a->field);

    element_invert(ginv, a->field->nqr); 

    //let q be the order of the field
    //q - 1 = 2^s t, t odd
    mpz_sub_ui(t, a->field->order, 1);
    s = mpz_scan1(t, 0);
    mpz_tdiv_q_2exp(t, t, s);
    mpz_set_ui(e, 0);
    for (i=2; i<=s; i++) {
	mpz_sub_ui(t0, a->field->order, 1);
	mpz_tdiv_q_2exp(t0, t0, i);
	element_pow(e0, ginv, e);
	element_mul(e0, e0, a);
	element_pow(e0, e0, t0);
	if (!element_is1(e0)) mpz_setbit(e, i-1);
    }
    element_pow(e0, ginv, e);
    element_mul(e0, e0, a);
    mpz_add_ui(t, t, 1);
    mpz_tdiv_q_2exp(t, t, 1);
    element_pow(e0, e0, t);
    mpz_tdiv_q_2exp(e, e, 1);
    element_pow(x, x->field->nqr, e);
    element_mul(x, x, e0);
    mpz_clear(t);
    mpz_clear(e);
    mpz_clear(t0);
    element_clear(ginv);
    element_clear(e0);
}

void field_init_fp(field_ptr f, mpz_t prime)
{
    f->init = zp_init;
    f->clear = zp_clear;
    f->set_si = zp_set_si;
    f->set_mpz = zp_set_mpz;
    f->out_str = zp_out_str;
    f->add = zp_add;
    f->sub = zp_sub;
    f->set = zp_set;
    f->mul = zp_mul;
    f->neg = zp_neg;
    f->cmp = zp_cmp;
    f->invert = zp_invert;
    f->random = zp_random;
    f->is1 = zp_is1;
    f->is0 = zp_is0;
    f->set0 = zp_set0;
    f->is_sqr = zp_is_sqr;
    f->sqrt = fp_tonelli;

    mpz_init(f->order);
    mpz_set(f->order, prime);
    f->data = NULL;
    f->nqr->data = malloc(sizeof(mpz_t));
    mpz_init(f->nqr->data);
    for(;;) {
	mpz_urandomm(f->nqr->data, rs, prime);
	if (mpz_legendre(f->nqr->data, prime) == -1) break;
    }
}

//TODO:field_clear_zp

void element_pow(element_ptr x, element_ptr a, mpz_ptr n)
{
    int s;

    element_t result;
    element_init(result, x->field);
    element_set_si(result, 1);

    if (mpz_sgn(n)) for (s = mpz_sizeinbase(n, 2) - 1; s >=0; s--) {
	element_mul(result, result, result);
	if (mpz_tstbit(n, s)) {
	    element_mul(result, result, a);
	}
    }
    element_set(x, result);
    element_clear(result);
}
