#include <stdio.h>
#include <stdlib.h>
#include <gmp.h>
#include "field.h"
#include "random.h"
#include "utils.h"
#include "fp.h"
// wrappers around GMP mpz functions to implement Z

static void z_init(element_ptr e)
{
    e->data = malloc(sizeof(mpz_t));
    mpz_init(e->data);
}

static void z_clear(element_ptr e)
{
    mpz_clear(e->data);
    free(e->data);
}

static void z_set_si(element_ptr e, signed long int op)
{
    mpz_set_si(e->data, op);
}

static void z_set_mpz(element_ptr e, mpz_ptr z)
{
    mpz_set(e->data, z);
}

static void z_set0(element_ptr e)
{
    mpz_set_ui(e->data, 0);
}

static void z_set1(element_ptr e)
{
    mpz_set_ui(e->data, 1);
}

static size_t z_out_str(FILE *stream, int base, element_ptr e)
{
    return mpz_out_str(stream, base, e->data);
}

static int z_sgn(element_ptr a)
{
    mpz_ptr z = a->data;
    return mpz_sgn(z);
}

static void z_add(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_add(n->data, a->data, b->data);
}

static void z_sub(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_sub(n->data, a->data, b->data);
}

static void z_square(element_ptr c, element_ptr a)
{
    mpz_mul(c->data, a->data, a->data);
}

static void z_double(element_ptr n, element_ptr a)
{
    mpz_mul_2exp(n->data, a->data, 1);
}

static void z_halve(element_ptr n, element_ptr a)
{
    mpz_tdiv_q_2exp(n->data, a->data, -1);
}

static void z_mul(element_ptr n, element_ptr a, element_ptr b)
{
    mpz_mul(n->data, a->data, b->data);
}

static void z_mul_mpz(element_ptr n, element_ptr a, mpz_ptr z)
{
    mpz_mul(n->data, a->data, z);
}

static void z_mul_si(element_ptr n, element_ptr a, signed long int z)
{
    mpz_mul_si(n->data, a->data, z);
}

static void z_pow_mpz(element_ptr n, element_ptr a, mpz_ptr z)
{
    mpz_pow_ui(n->data, a->data, mpz_get_ui(z));
}

static void z_set(element_ptr n, element_ptr a)
{
    mpz_set(n->data, a->data);
}

static void z_neg(element_ptr n, element_ptr a)
{
    mpz_neg(n->data, a->data);
}

//only correct for 1, -1
static void z_invert(element_ptr n, element_ptr a)
{
    mpz_set(n->data, a->data);
}

//(doesn't make sense if order is infinite)
static void z_random(element_ptr n)
{
    mpz_set_ui(n->data, 0);
}

static void z_from_hash(element_ptr n, int len, void *data)
    //TODO: something more sophisticated!
{
    mpz_import(n->data, len, -1, 1, -1, 0, data);
}

static int z_is1(element_ptr n)
{
    return !mpz_cmp_ui((mpz_ptr) n->data, 1);
}

static int z_is0(element_ptr n)
{
    return mpz_is0(n->data);
}

static int z_cmp(element_ptr a, element_ptr b)
{
    return mpz_cmp((mpz_ptr) a->data, (mpz_ptr) b->data);
}

static int z_is_sqr(element_ptr a)
{
    return mpz_perfect_power_p(a->data);
}

static void z_sqrt(element_ptr c, element_ptr a)
{
    mpz_sqrt(c->data, a->data);
}

static void z_field_clear(field_t f)
{
    UNUSED_VAR (f);
}

static int z_to_bytes(unsigned char *data, element_t e)
//could use mpz_export?
{
    mpz_ptr z = e->data;
    size_t n = (mpz_sizeinbase(z, 2) + 7) / 8;
    size_t i;
    for (i=0; i<4; i++) {
	data[i] = (n >> 8 * i);
    }

    if (mpz_sgn(z) < 0) {
	data[3] |= 128;
    }

    mpz_export(data + 4, NULL, -1, 1, -1, 0, z);

    return n+4;
}

static int z_from_bytes(element_t e, unsigned char *data)
{
    unsigned char *ptr;
    size_t i, n;
    mpz_ptr z = e->data;
    mpz_t z1;
    int neg = 0;

    mpz_init(z1);
    mpz_set_ui(z, 0);

    ptr = data;
    n = 0;
    for (i=0; i<4; i++) {
	n += ((unsigned int) *ptr) << 8 * i;
	ptr++;
    }
    if (data[3] & 128) {
	neg = 1;
	n &= ~(1 << (sizeof(unsigned int) * 8 - 1));
    }
    for (i=0; i<n; i++) {
	mpz_set_ui(z1, *ptr);
	mpz_mul_2exp(z1, z1, i * 8);
	ptr++;
	mpz_add(z, z, z1);
    }
    mpz_clear(z1);
    if (neg) mpz_neg(z, z);
    return n;
}

static void z_to_mpz(mpz_ptr z, element_ptr a)
{
    mpz_set(z, a->data);
}

static int z_length_in_bytes(element_ptr a)
{
    return (mpz_sizeinbase(a->data, 2) + 7) / 8 + 4;
}

static void z_print_info(FILE *out, field_ptr f)
{
    UNUSED_VAR(f);
    fprintf(out, "Z: wrapped GMP\n");
}

void field_init_z(field_ptr f)
{
    field_init(f);
    f->init = z_init;
    f->clear = z_clear;
    f->set_si = z_set_si;
    f->set_mpz = z_set_mpz;
    f->out_str = z_out_str;
    f->sign = z_sgn;
    f->add = z_add;
    f->sub = z_sub;
    f->set = z_set;
    f->square = z_square;
    f->doub = z_double;
    f->halve = z_halve;
    f->mul = z_mul;
    f->mul_mpz = z_mul_mpz;
    f->mul_si = z_mul_si;
    f->pow_mpz = z_pow_mpz;
    f->neg = z_neg;
    f->cmp = z_cmp;
    f->invert = z_invert;
    f->random = z_random;
    f->from_hash = z_from_hash;
    f->is1 = z_is1;
    f->is0 = z_is0;
    f->set0 = z_set0;
    f->set1 = z_set1;
    f->is_sqr = z_is_sqr;
    f->sqrt = z_sqrt;
    f->field_clear = z_field_clear;
    f->to_bytes = z_to_bytes;
    f->from_bytes = z_from_bytes;
    f->to_mpz = z_to_mpz;
    f->length_in_bytes = z_length_in_bytes;

    f->print_info = z_print_info;

    mpz_set_ui(f->order, 0);
    f->data = NULL;
    f->fixed_length_in_bytes = -1;
}
