#include "poly.h"

//implements R[x] for a given ring R
//also R[x]_{f(x)}

void poly_alloc(element_ptr e, int n)
{
    poly_field_data_ptr pdp = e->field->data;
    poly_element_ptr p = e->data;
    element_ptr e0;
    int k;
    k = p->coeff->count;
    while (k < n) {
	e0 = malloc(sizeof(element_t));
	element_init(e0, pdp->field);
	darray_append(p->coeff, e0);
	k++;
    }
    while (k > n) {
	k--;
	e0 = darray_at(p->coeff, k);
	element_clear(e0);
	free(e0);
	darray_remove_last(p->coeff);
    }
}

static void poly_init(element_ptr e)
{
    poly_element_ptr p;
    e->data = malloc(sizeof(poly_element_t));
    p = e->data;
    darray_init(p->coeff);
}

static void poly_clear(element_ptr e)
{
    poly_element_ptr p = e->data;

    poly_alloc(e, 0);
    darray_clear(p->coeff);
    free(e->data);
}

void poly_remove_leading_zeroes(element_ptr e)
{
    poly_element_ptr p = e->data;
    int n = p->coeff->count - 1;
    while (n >= 0) {
	element_ptr e0 = p->coeff->item[n];
	if (!element_is0(e0)) return;
	element_clear(e0);
	free(e0);
	darray_remove_last(p->coeff);
	n--;
    }
}

static void poly_set_si(element_ptr e, signed long int op)
{
    poly_element_ptr p = e->data;
    element_ptr e0;

    poly_alloc(e, 1);
    e0 = p->coeff->item[0];
    element_set_si(e0, op);
    poly_remove_leading_zeroes(e);
}

static void poly_set_mpz(element_ptr e, mpz_ptr op)
{
    poly_element_ptr p = e->data;
    element_ptr e0;

    poly_alloc(e, 1);
    e0 = p->coeff->item[0];
    element_set_mpz(e0, op);
    poly_remove_leading_zeroes(e);
}

void poly_set_coeff(element_ptr e, element_ptr a, int n)
{
    poly_element_ptr p = e->data;
    element_ptr e0;
    if (p->coeff->count < n + 1) {
	poly_alloc(e, n + 1);
    }
    e0 = p->coeff->item[n];
    element_set(e0, a);
}

void poly_setx(element_ptr f)
{
    poly_alloc(f, 2);
    element_set0(poly_coeff(f, 0));
    element_set1(poly_coeff(f, 1));
}

static void poly_set(element_ptr dst, element_ptr src)
{
    poly_element_ptr psrc = src->data;
    poly_element_ptr pdst = dst->data;
    int i;

    poly_alloc(dst, psrc->coeff->count);
    for (i=0; i<psrc->coeff->count; i++) {
	element_set(pdst->coeff->item[i], psrc->coeff->item[i]);
    }
}

static void poly_add(element_ptr sum, element_ptr f, element_ptr g)
{
    int i, n, n1;
    element_ptr big;

    n = poly_coeff_count(f);
    n1 = poly_coeff_count(g);
    if (n > n1) {
	big = f;
	n = n1;
	n1 = poly_coeff_count(f);
    } else {
	big = g;
    }

    poly_alloc(sum, n1);
    for (i=0; i<n; i++) {
	element_add(poly_coeff(sum, i), poly_coeff(f, i), poly_coeff(g, i));
    }
    for (; i<n1; i++) {
	element_set(poly_coeff(sum, i), poly_coeff(big, i));
    }
    poly_remove_leading_zeroes(sum);
}

static void poly_sub(element_ptr diff, element_ptr f, element_ptr g)
{
    int i, n, n1;
    element_ptr big;

    n = poly_coeff_count(f);
    n1 = poly_coeff_count(g);
    if (n > n1) {
	big = f;
	n = n1;
	n1 = poly_coeff_count(f);
    } else {
	big = g;
    }

    poly_alloc(diff, n1);
    for (i=0; i<n; i++) {
	element_sub(poly_coeff(diff, i), poly_coeff(f, i), poly_coeff(g, i));
    }
    for (; i<n1; i++) {
	if (big == f) {
	    element_set(poly_coeff(diff, i), poly_coeff(big, i));
	} else {
	    element_neg(poly_coeff(diff, i), poly_coeff(big, i));
	}
    }
    poly_remove_leading_zeroes(diff);
}

static void poly_neg(element_ptr f, element_ptr g)
{
    poly_element_ptr pf = f->data;
    poly_element_ptr pg = g->data;
    int i, n;

    n = pg->coeff->count;
    poly_alloc(f, n);
    for (i=0; i<n; i++) {
	element_neg(pf->coeff->item[i], pg->coeff->item[i]);
    }
}

static void poly_mul(element_ptr r, element_ptr f, element_ptr g)
{
    poly_element_ptr pprod;
    poly_element_ptr pf = f->data;
    poly_element_ptr pg = g->data;
    poly_field_data_ptr pdp = r->field->data;
    int fcount = pf->coeff->count;
    int gcount = pg->coeff->count;
    int i, j, n;
    element_t prod;
    element_t e0;

    if (!fcount || !gcount) {
	element_set0(r);
	return;
    }
    element_init(prod, r->field);
    pprod = prod->data;
    n = fcount + gcount - 1;
    poly_alloc(prod, n);
    element_init(e0, pdp->field);
    for (i=0; i<n; i++) {
	element_ptr x = pprod->coeff->item[i];
	element_set0(x);
	for (j=0; j<=i; j++) {
	    if (j < fcount && i - j < gcount) {
		element_mul(e0, pf->coeff->item[j], pg->coeff->item[i - j]);
		element_add(x, x, e0);
	    }
	}
    }
    poly_remove_leading_zeroes(prod);
    element_set(r, prod);
    element_clear(e0);
    element_clear(prod);
}

void poly_const_mul(element_ptr res, element_ptr a, element_ptr poly)
    //a lies in R, poly in R[x]
{
    int i;
    darray_ptr coeff = ((poly_element_ptr) poly->data)->coeff;
    int n = coeff->count;
    element_t r;

    element_init(r, res->field);
    poly_alloc(r, n);
    for (i=0; i<n; i++) {
	element_mul(poly_coeff(r, i), coeff->item[i], a);
    }
    poly_remove_leading_zeroes(r);
    element_set(res, r);
    element_clear(r);
}

void poly_random_monic(element_ptr f, int deg)
{
    int i;
    poly_alloc(f, deg + 1);
    for (i=0; i<deg; i++) {
	element_random(poly_coeff(f, i));
    }
    element_set1(poly_coeff(f, i));
}

static void polymod_random(element_ptr f)
{
    polymod_field_data_ptr p = f->field->data;
    int n = poly_degree(p->poly);
    int i;
    poly_alloc(f, n);
    for (i=0; i<n; i++) {
	element_random(poly_coeff(f, i));
    }
    poly_remove_leading_zeroes(f);
}

static void polymod_from_hash(element_ptr f, int len, void *data)
{
    polymod_field_data_ptr p = f->field->data;
    int n = poly_degree(p->poly);
    int i;
    poly_alloc(f, n);
    for (i=0; i<n; i++) {
	element_from_hash(poly_coeff(f, i), len, data);
    }
    poly_remove_leading_zeroes(f);
}

static void polymod_mul(element_ptr res, element_ptr f, element_ptr g)
{
    element_t prod;
    int i, n, n1;
    polymod_field_data_ptr p = res->field->data;
    element_t p0;

    element_init(prod, res->field);
    element_init(p0, res->field);
    n = poly_degree(p->poly);

    poly_mul(prod, f, g);

    n1 = poly_degree(prod);
    for (i=n; i<=n1; i++) {
	poly_const_mul(p0, poly_coeff(prod, i), p->xpwr->item[i-n]);
	element_add(prod, prod, p0);
    }
    poly_alloc(prod, n);
    poly_remove_leading_zeroes(prod);

    element_set(res, prod);
    element_clear(prod);
    element_clear(p0);
}

static void poly_set0(element_ptr e)
{
    poly_alloc(e, 0);
}

static void poly_set1(element_ptr e)
{
    poly_element_ptr p = e->data;
    element_ptr e0;

    poly_alloc(e, 1);
    e0 = p->coeff->item[0];
    element_set1(e0);
}

static int poly_is0(element_ptr e)
{
    poly_element_ptr p = e->data;
    return !p->coeff->count;
}

static int poly_is1(element_ptr e)
{
    poly_element_ptr p = e->data;
    if (p->coeff->count == 1) {
	return element_is1(p->coeff->item[0]);
    }
    return 0;
}

static void poly_out_str(FILE *stream, int base, element_ptr e)
{
    int i;
    int n = poly_coeff_count(e);

    if (!n) {
	fputs("[0]", stream);
	return;
    }
    fputc('[', stream);
    for (i=0; i<n; i++) {
	if (i) fputc(' ', stream);
	element_out_str(stream, 0, poly_coeff(e, i));
    }
    fputc(']', stream);
}

void poly_div(element_ptr quot, element_ptr rem,
	element_ptr a, element_ptr b)
{
    poly_element_ptr pq, pr;
    poly_field_data_ptr pdp = a->field->data;
    element_t q, r;
    element_t binv, e0;
    element_ptr qe;
    int m, n;
    int i, k;

    if (element_is0(b)) {
	fprintf(stderr, "BUG! division by zero!\n");
	exit(1);
    }
    n = poly_degree(b);
    m = poly_degree(a);
    if (n > m) {
	element_set(rem, a);
	element_set0(quot);
	return;
    }
    element_init(r, a->field);
    element_init(q, a->field);
    element_init(binv, pdp->field);
    element_init(e0, pdp->field);
    pq = q->data;
    pr = r->data;
    element_set(r, a);
    k = m - n;
    poly_alloc(q, k + 1);
    element_invert(binv, poly_coeff(b, n));
    while (k >= 0) {
	qe = pq->coeff->item[k];
	element_mul(qe, binv, pr->coeff->item[m]);
	for (i=0; i<=n; i++) {
	    element_mul(e0, qe, poly_coeff(b, i));
	    element_sub(pr->coeff->item[i + k], pr->coeff->item[i + k], e0);
	}
	k--;
	m--;
    }
    poly_remove_leading_zeroes(r);
    element_set(quot, q);
    element_set(rem, r);

    element_clear(q);
    element_clear(r);
    element_clear(e0);
    element_clear(binv);
}

void poly_invert(element_ptr res, element_ptr f, element_ptr m)
{
    element_t q, r0, r1, r2;
    element_t b0, b1, b2;
    element_t inv;

    element_init(b0, res->field);
    element_init(b1, res->field);
    element_init(b2, res->field);
    element_init(q, res->field);
    element_init(r0, res->field);
    element_init(r1, res->field);
    element_init(r2, res->field);
    element_init(inv, poly_base_field(res));
    element_set0(b0);
    element_set1(b1);
    element_set(r0, m);
    element_set(r1, f);

    for (;;) {
	poly_div(q, r2, r0, r1);
	if (element_is0(r2)) break;
	element_mul(b2, b1, q);
	element_sub(b2, b0, b2);
	element_set(b0, b1);
	element_set(b1, b2);
	element_set(r0, r1);
	element_set(r1, r2);
    }
    element_invert(inv, poly_coeff(r1, 0));
    poly_const_mul(res, inv, b1);
    element_clear(inv);
    element_clear(q);
    element_clear(r0);
    element_clear(r1);
    element_clear(r2);
    element_clear(b0);
    element_clear(b1);
    element_clear(b2);
}

static void polymod_invert(element_ptr r, element_ptr f)
{
    polymod_field_data_ptr p = r->field->data;
    poly_invert(r, f, p->poly);
}

void element_field_to_poly(element_ptr f, element_ptr g)
{
    poly_alloc(f, 1);
    element_set(poly_coeff(f, 0), g);
    poly_remove_leading_zeroes(f);
}

static int poly_cmp(element_ptr f, element_ptr g)
{
    int i;
    int n = poly_coeff_count(f);
    int n1 = poly_coeff_count(g);
    if (n != n1) return 1;
    for (i=0; i<n; i++) {
	if (element_cmp(poly_coeff(f, i), poly_coeff(g, i))) return 1;
    }
    return 0;
}

static void field_clear_poly(field_ptr f)
{
    poly_field_data_ptr p = f->data;
    free(p);
}

void field_init_poly(field_ptr f, field_ptr base_field)
{
    poly_field_data_ptr p;

    field_init(f);
    f->data = malloc(sizeof(poly_field_data_t));
    p = f->data;
    p->field = base_field;
    p->mapbase = element_field_to_poly;
    f->field_clear = field_clear_poly;
    f->init = poly_init;
    f->clear = poly_clear;
    f->set_si = poly_set_si;
    f->set_mpz = poly_set_mpz;
    f->out_str = poly_out_str;
    f->set = poly_set;
    f->add = poly_add;
    f->is0 = poly_is0;
    f->is1 = poly_is1;
    f->set0 = poly_set0;
    f->set1 = poly_set1;
    f->sub = poly_sub;
    f->neg = poly_neg;
    f->mul = poly_mul;
    f->cmp = poly_cmp;
}

static void field_clear_polymod(field_ptr f)
{
    polymod_field_data_ptr p = f->data;
    void clear(void *p) {
	element_clear(p);
	free(p);
    }

    darray_forall(p->xpwr, clear);
    darray_clear(p->xpwr);
    element_clear(p->poly);
    free(f->data);
}

static int polymod_is_sqr(element_ptr e)
{
    int res;
    mpz_t z;
    element_t e0;

    element_init(e0, e->field);
    mpz_init(z);
    mpz_sub_ui(z, e->field->order, 1);
    mpz_divexact_ui(z, z, 2);

    element_pow(e0, e, z);
    res = element_is1(e0);
    element_clear(e0);
    return res;
}

static void polymod_sqrt(element_ptr res, element_ptr a)
    //use Cantor-Zassenhaus
{
    field_t kx;
    element_t f;
    element_t r, s;
    element_t e0;
    mpz_t z;

    field_init_poly(kx, a->field);
    mpz_init(z);
    element_init(f, kx);
    element_init(r, kx);
    element_init(s, kx);
    element_init(e0, a->field);

    poly_alloc(f, 3);
    element_set1(poly_coeff(f, 2));
    element_neg(poly_coeff(f, 0), a);

    mpz_sub_ui(z, a->field->order, 1);
    mpz_divexact_ui(z, z, 2);
    for (;;) {
	int i;
	element_ptr x;
	element_ptr e1, e2;

	poly_alloc(r, 2);
	element_set1(poly_coeff(r, 1));
	x = poly_coeff(r, 0);
	element_random(x);
	element_mul(e0, x, x);
	if (!element_cmp(e0, a)) {
	    element_set(res, x);
	    break;
	}
	element_set1(s);
	//TODO: this can be optimized greatly
	//since we know r has the form ax + b
	for (i = mpz_sizeinbase(z, 2) - 1; i >=0; i--) {
	    element_mul(s, s, s);
	    if (poly_degree(s) == 2) {
		e1 = poly_coeff(s, 0);
		e2 = poly_coeff(s, 2);
		element_mul(e0, e2, a);
		element_add(e1, e1, e0);
		poly_alloc(s, 2);
		poly_remove_leading_zeroes(s);
	    }
	    if (mpz_tstbit(z, i)) {
		element_mul(s, s, r);
		if (poly_degree(s) == 2) {
		    e1 = poly_coeff(s, 0);
		    e2 = poly_coeff(s, 2);
		    element_mul(e0, e2, a);
		    element_add(e1, e1, e0);
		    poly_alloc(s, 2);
		    poly_remove_leading_zeroes(s);
		}
	    }
	}
	if (poly_degree(s) < 1) continue;
	element_set1(e0);
	e1 = poly_coeff(s, 0);
	e2 = poly_coeff(s, 1);
	element_add(e1, e1, e0);
	element_invert(e0, e2);
	element_mul(e0, e0, e1);
	element_mul(e2, e0, e0);
	if (!element_cmp(e2, a)) {
	    element_set(res, e0);
	    break;
	}
    }

    mpz_clear(z);
    element_clear(f);
    element_clear(r);
    element_clear(s);
    element_clear(e0);
    field_clear(kx);
}

void poly_make_monic(element_t f, element_t g)
{
    int n = poly_coeff_count(g);
    int i;
    element_ptr e0;
    poly_alloc(f, n);
    if (!n) return;

    e0 = poly_coeff(f, n - 1);
    element_invert(e0, poly_coeff(g, n - 1));
    for (i=0; i<n-1; i++) {
	element_mul(poly_coeff(f, i), poly_coeff(g, i), e0);
    }
    element_set1(e0);
}

static int polymod_to_bytes(unsigned char *data, element_t f)
{
    polymod_field_data_ptr p = f->field->data;
    int n = poly_degree(p->poly);
    int i;
    int len = 0;

    for (i=0; i<n; i++) {
	element_ptr e = poly_coeff(f, i);
	len += element_to_bytes(data + len, e);
    }
    return len;
}

static int polymod_length_in_bytes(element_t f)
{
    polymod_field_data_ptr p = f->field->data;
    int n = poly_degree(p->poly);
    int res = 0;
    int i;

    for (i=0; i<n; i++) {
	res += element_length_in_bytes(poly_coeff(f, i));
    }

    return res;
}

static int polymod_from_bytes(element_t f, unsigned char *data)
{
    polymod_field_data_ptr p = f->field->data;
    int n = poly_degree(p->poly);
    int len = 0;
    int i;

    for (i=0; i<n; i++) {
	len += element_from_bytes(poly_coeff(f, i), data + len);
    }
    return len;
}

void compute_x_powers(darray_ptr xpwr, element_ptr poly)
//compute x^n,...,x^{2n-2} mod poly
{
    element_t p0;
    element_ptr pwr, pwr1, pwrn;
    int i, j, k;
    int n = poly_coeff_count(poly);

    element_init(p0, poly->field);
    darray_init(xpwr);
    for (i=0; i<n-1; i++) {
	pwr = malloc(sizeof(element_t));
	element_init(pwr, poly->field);
	darray_append(xpwr, pwr);
    }
    pwrn = xpwr->item[0];
    element_neg(pwrn, poly);
    poly_alloc(pwrn, n - 1);
    poly_remove_leading_zeroes(pwrn);
    for (i=1; i<n-1; i++) {
	pwr = xpwr->item[i-1];
	pwr1 = xpwr->item[i];
	k = poly_coeff_count(pwr);
	for (j=0; j<k; j++) {
	    poly_set_coeff(pwr1, poly_coeff(pwr, j), j+1);
	}
	if (k == n - 1) {
	    element_ptr e1 = poly_coeff(pwr1, k);
	    poly_const_mul(p0, e1, pwrn);
	    poly_alloc(pwr1, k);
	    element_add(pwr1, pwr1, p0);
	}
    }
    element_clear(p0);
}

void field_init_polymod(field_ptr f, element_ptr poly)
    //assumes poly is monic
{
    poly_field_data_ptr pdp = poly->field->data;
    poly_element_ptr pd = poly->data;
    polymod_field_data_ptr p;
    int n;

    field_init(f);
    f->data = malloc(sizeof(polymod_field_data_t));
    p = f->data;
    p->field = pdp->field;
    p->mapbase = element_field_to_poly;
    element_init(p->poly, poly->field);
    element_set(p->poly, poly);
    f->field_clear = field_clear_polymod;
    f->init = poly_init;
    f->clear = poly_clear;
    f->set_si = poly_set_si;
    f->set_mpz = poly_set_mpz;
    f->out_str = poly_out_str;
    f->set = poly_set;
    f->add = poly_add;
    f->is0 = poly_is0;
    f->is1 = poly_is1;
    f->set0 = poly_set0;
    f->set1 = poly_set1;
    f->sub = poly_sub;
    f->neg = poly_neg;
    f->cmp = poly_cmp;
    f->mul = polymod_mul;
    f->random = polymod_random;
    f->from_hash = polymod_from_hash;
    f->invert = polymod_invert;
    f->is_sqr = polymod_is_sqr;
    f->sqrt = polymod_sqrt;
    f->to_bytes = polymod_to_bytes;
    f->from_bytes = polymod_from_bytes;
    if (pdp->field->fixed_length_in_bytes < 0) {
	f->fixed_length_in_bytes = -1;
	f->length_in_bytes = polymod_length_in_bytes;
    } else {
	f->fixed_length_in_bytes = pdp->field->fixed_length_in_bytes * poly_degree(poly);
    }
    n = pd->coeff->count;
    mpz_pow_ui(f->order, p->field->order, n - 1);

    darray_init(p->xpwr);
    compute_x_powers(p->xpwr, poly);
}

void trial_divide(darray_ptr factor, darray_ptr mult, mpz_t n)
{
    mpz_t p, m;
    mpz_ptr fac, mul;

    mpz_init(p);
    mpz_init(m);
    mpz_set(m ,n);
    mpz_set_ui(p, 2);

    while (mpz_cmp_ui(m, 1)) {
	if (mpz_probab_prime_p(m, 10)) {
	    mpz_set(p, m);
	}
	if (mpz_divisible_p(m, p)) {
	    fac = malloc(sizeof(mpz_t));
	    mul = malloc(sizeof(mpz_t));
	    mpz_init(fac);
	    mpz_init(mul);
	    mpz_set(fac, p);
	    darray_append(factor, fac);
	    darray_append(mult, mul);
	    do {
		mpz_divexact(m, m, p);
		mpz_add_ui(mul, mul, 1);
	    } while (mpz_divisible_p(m, p));
	}
	mpz_nextprime(p, p);
    }
   
    mpz_clear(m);
    mpz_clear(p);
}

void poly_gcd(element_ptr d, element_ptr f, element_ptr g)
{
    element_t a, b, q, r;
    element_init(a, d->field);
    element_init(b, d->field);
    element_init(q, d->field);
    element_init(r, d->field);

    element_set(a, f);
    element_set(b, g);
    for(;;) {
	//TODO: don't care about q
	poly_div(q, r, a, b);
	if (element_is0(r)) break;
	element_set(a, b);
	element_set(b, r);
    }
    element_set(d, b);
    element_clear(a);
    element_clear(b);
    element_clear(q);
    element_clear(r);
}

void poly_unmod(element_ptr f)
    //input: f in K[x]_{p(x)}
    //output: f is now a polynomial over K[x]
    //(i.e. the field is not modulo anything)
{
    polymod_field_data_ptr p = f->field->data;
    f->field = p->poly->field;
}

int poly_is_irred_degfac(element_ptr f, darray_t factor)
    //called by poly_is_irred
    //needs to be passed a list of the factors of deg f
{
    int res;
    element_t xpow, x;
    mpz_t deg, z;
    int i, n;
    field_ptr basef = poly_base_field(f);
    field_t rxmod;

    mpz_init(deg);
    mpz_init(z);

    field_init_polymod(rxmod, f);

    mpz_set_ui(deg, poly_degree(f));
    element_init(xpow, rxmod);
    element_init(x, rxmod);
    n = factor->count;
    poly_setx(x);

    res = 0;
    for (i=0; i<n; i++) {
	mpz_divexact(z, deg, factor->item[i]);
	mpz_pow_ui(z, basef->order, mpz_get_ui(z));
	element_pow(xpow, x, z);
	element_sub(xpow, xpow, x);
	if (element_is0(xpow)) {
	    goto done;
	}
	poly_gcd(xpow, f, xpow);
	if (poly_degree(xpow) != 0) goto done;
    }

    mpz_pow_ui(z, basef->order, poly_degree(f));
    element_pow(xpow, x, z);
    element_sub(xpow, xpow, x);
    if (element_is0(xpow)) res = 1;

done:
    element_clear(xpow);
    element_clear(x);
    mpz_clear(deg);
    mpz_clear(z);
    field_clear(rxmod);
    return res;
}

int poly_is_irred(element_ptr f)
{
    darray_t fac, mul;
    mpz_t deg;
    int res;
    void clear(void *p) {
	mpz_clear(p);
	free(p);
    }

    if (poly_degree(f) <= 1) return 1;

    darray_init(fac);
    darray_init(mul);
    mpz_init(deg);

    mpz_set_ui(deg, poly_degree(f));

    trial_divide(fac, mul, deg);
    res = poly_is_irred_degfac(f, fac);

    darray_forall(fac, clear);
    darray_forall(mul, clear);

    darray_clear(fac);
    darray_clear(mul);
    mpz_clear(deg);
    return res;
}
