#include "field.h"

//TODO: provide an interface so that the programmer
//can provide their own source of random bytes?

static void deterministic_mpz_random(mpz_t z, mpz_t limit, void *data)
{

    static gmp_randstate_t rs;
    static int rs_is_ready;

    if (!rs_is_ready) {
	gmp_randinit_default(rs);
	rs_is_ready = 1;
    }
    mpz_urandomm(z, rs, limit);
}

static void file_mpz_random(mpz_t r, mpz_t limit, void *data)
//TODO: generate some kind of warning on error
//also allow programmer to pick between /dev/random and /dev/urandom
{
    char *filename = (char *) data;
    FILE *fp;
    int n, bytecount, leftover;
    unsigned char *bytes;
    mpz_t z;
    mpz_init(z);
    //fp = fopen("/dev/urandom", "rb");
    fp = fopen(filename, "rb");
    if (!fp) return;
    n = mpz_sizeinbase(limit, 2);
    bytecount = (n + 7) / 8;
    leftover = n % 8;
    bytes = (unsigned char *) malloc(bytecount);
    for (;;) {
	fread(bytes, 1, bytecount, fp);
	if (leftover) {
	    *bytes = *bytes % (1 << leftover);
	}
	mpz_import(z, bytecount, 1, 1, 0, 0, bytes);
	if (mpz_cmp(z, limit) < 0) break;
    }
    fclose(fp);
    mpz_set(r, z);
    mpz_clear(z);
}

static void (*current_mpz_random)(mpz_t, mpz_t, void *) = deterministic_mpz_random;
static void *current_random_data;

void pbc_mpz_random(mpz_t z, mpz_t limit)
{
    current_mpz_random(z, limit, current_random_data);
}

void random_set_deterministic()
{
    current_mpz_random = deterministic_mpz_random;
    current_random_data = NULL;
}

void random_set_file(char *filename)
{
    current_mpz_random = file_mpz_random;
    current_random_data = filename;
}

static void zp_init(element_ptr e)
{
    e->data = malloc(sizeof(mpz_t));
    mpz_init(e->data);
}

static void zp_clear(element_ptr e)
{
    mpz_clear(e->data);
    free(e->data);
}

static void zp_set_si(element_ptr e, signed long int op)
{
    mpz_set_si(e->data, op);
    mpz_mod(e->data, e->data, e->field->order);
}

static void zp_set_mpz(element_ptr e, mpz_ptr z)
{
    mpz_set(e->data, z);
    mpz_mod(e->data, e->data, e->field->order);
}

static void zp_set0(element_ptr e)
{
    mpz_set_si(e->data, 0);
}

static void zp_set1(element_ptr e)
{
    mpz_set_si(e->data, 1);
}

static void zp_out_str(FILE *stream, int base, element_ptr e)
{
    mpz_out_str(stream, base, e->data);
}

static void zp_add(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_add(n->data, a->data, b->data);
    mpz_mod(n->data, n->data, n->field->order);
}

static void zp_sub(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_sub(n->data, a->data, b->data);
    mpz_mod(n->data, n->data, n->field->order);
}

static void zp_mul(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_mul(n->data, a->data, b->data);
    mpz_mod(n->data, n->data, n->field->order);
}

static void zp_set(element_ptr n, element_ptr a)
{
    mpz_set(n->data, a->data);
}

static void zp_neg(element_ptr n, element_ptr a)
{
    if (mpz_sgn((mpz_ptr) a->data)) mpz_sub(n->data, n->field->order, a->data);
}

static void zp_invert(element_ptr n, element_ptr a)
{
    mpz_invert(n->data, a->data, n->field->order);
}

static void zp_random(element_ptr n)
{
    pbc_mpz_random(n->data, n->field->order);
}

static void zp_from_hash(element_ptr n, int len, void *data)
    //TODO: something more sophisticated!
{
    mpz_t z;

    mpz_init(z);
    mpz_import(z, len, 1, 1, 0, 0, data);
    zp_set_mpz(n, z);
    mpz_clear(z);
}

static int zp_is1(element_ptr n)
{
    return !mpz_cmp_ui((mpz_ptr) n->data, 1);
}

static int zp_is0(element_ptr n)
{
    return !mpz_cmp_ui((mpz_ptr) n->data, 0);
}

static int zp_cmp(element_ptr a, element_ptr b)
{
    return mpz_cmp((mpz_ptr) a->data, (mpz_ptr) b->data);
}

static int zp_is_sqr(element_ptr a)
{
    return mpz_legendre(a->data, a->field->order) == 1;
}

static void fp_tonelli(element_ptr x, element_ptr a)
{
    int s;
    int i;
    mpz_t e;
    mpz_t t, t0;
    element_t ginv, e0;

    mpz_init(t);
    mpz_init(e);
    mpz_init(t0);
    element_init(ginv, a->field);
    element_init(e0, a->field);

    element_invert(ginv, a->field->nqr); 

    //let q be the order of the field
    //q - 1 = 2^s t, t odd
    mpz_sub_ui(t, a->field->order, 1);
    s = mpz_scan1(t, 0);
    mpz_tdiv_q_2exp(t, t, s);
    mpz_set_ui(e, 0);
    for (i=2; i<=s; i++) {
	mpz_sub_ui(t0, a->field->order, 1);
	mpz_tdiv_q_2exp(t0, t0, i);
	element_pow(e0, ginv, e);
	element_mul(e0, e0, a);
	element_pow(e0, e0, t0);
	if (!element_is1(e0)) mpz_setbit(e, i-1);
    }
    element_pow(e0, ginv, e);
    element_mul(e0, e0, a);
    mpz_add_ui(t, t, 1);
    mpz_tdiv_q_2exp(t, t, 1);
    element_pow(e0, e0, t);
    mpz_tdiv_q_2exp(e, e, 1);
    element_pow(x, x->field->nqr, e);
    element_mul(x, x, e0);
    mpz_clear(t);
    mpz_clear(e);
    mpz_clear(t0);
    element_clear(ginv);
    element_clear(e0);
}

static void zp_field_clear(field_t f)
{
    mpz_clear(f->order);
    mpz_clear(f->nqr->data);
    free(f->nqr->data);
}

static int zp_to_bytes(unsigned char *data, element_t e)
{
    mpz_t z;
    int i, n;
    unsigned char *ptr;

    mpz_init(z);
    mpz_set(z, e->data);
    n = e->field->fixed_length_in_bytes;
    ptr = data;
    for (i = 0; i < n; i++) {
	*ptr = (unsigned char) mpz_get_ui(z);
	ptr++;
	mpz_tdiv_q_2exp(z, z, 8);
    }
    mpz_clear(z);
    return n;
}

static int zp_from_bytes(element_t e, unsigned char *data)
{
    unsigned char *ptr;
    int i, n;
    mpz_ptr z = e->data;
    mpz_t z1;

    mpz_init(z1);
    mpz_set_ui(z, 0);

    ptr = data;
    n = e->field->fixed_length_in_bytes;
    for (i=0; i<n; i++) {
	mpz_set_ui(z1, *ptr);
	mpz_mul_2exp(z1, z1, i * 8);
	ptr++;
	mpz_add(z, z, z1);
    }
    mpz_clear(z1);
    return n;
}

void field_init_fp(field_ptr f, mpz_t prime)
{
    f->init = zp_init;
    f->clear = zp_clear;
    f->set_si = zp_set_si;
    f->set_mpz = zp_set_mpz;
    f->out_str = zp_out_str;
    f->add = zp_add;
    f->sub = zp_sub;
    f->set = zp_set;
    f->mul = zp_mul;
    f->neg = zp_neg;
    f->cmp = zp_cmp;
    f->invert = zp_invert;
    f->random = zp_random;
    f->from_hash = zp_from_hash;
    f->is1 = zp_is1;
    f->is0 = zp_is0;
    f->set0 = zp_set0;
    f->set1 = zp_set1;
    f->is_sqr = zp_is_sqr;
    f->sqrt = fp_tonelli;
    f->field_clear = zp_field_clear;
    f->to_bytes = zp_to_bytes;
    f->from_bytes = zp_from_bytes;

    mpz_init(f->order);
    mpz_set(f->order, prime);
    f->data = NULL;
    f->nqr->data = malloc(sizeof(mpz_t));
    mpz_init(f->nqr->data);
    for(;;) {
	pbc_mpz_random(f->nqr->data, prime);
	if (mpz_legendre(f->nqr->data, prime) == -1) break;
    }
    f->fixed_length_in_bytes = (mpz_sizeinbase(prime, 2) + 7) / 8;
}

void element_pow(element_ptr x, element_ptr a, mpz_ptr n)
{
    int s;

    element_t result;
    element_init(result, x->field);
    element_set1(result);

    if (mpz_sgn(n)) for (s = mpz_sizeinbase(n, 2) - 1; s >=0; s--) {
	element_mul(result, result, result);
	if (mpz_tstbit(n, s)) {
	    element_mul(result, result, a);
	}
    }
    element_set(x, result);
    element_clear(result);
}
