#include <math.h>
#include <stdbool.h>
#include <stdlib.h>
#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <unistd.h>
#include "conjugrad.h"

#include "dca_typ.h"

conjugrad_float_t evaluate_cpu(void *instance, const conjugrad_float_t *x,
	conjugrad_float_t *g, const int nvar_padded)
/** Callback for LBFGS optimization to calculate function value and gradient
 * @param[in] instance The user data passed to the LBFGS optimizer
 * @param[in] x The current variable assignments
 * @param[out] g The current gradient
 * @param[in] nvar The number of variables
 * @param[in] step The step size for the current iteration
 */
{
	usrdata *ud = (usrdata *)instance;
	extra_usrdata *udx = (extra_usrdata *)ud->extra;

	int ncol = ud->ncol;
	int nrow = ud->nrow;
	int nsingle = ud->nsingle;
	int nsingle_padded = nsingle + N_ALPHA_PAD - (nsingle % N_ALPHA_PAD);

	conjugrad_float_t lambda_single = ud->lambda_single;
	conjugrad_float_t lambda_pair = ud->lambda_pair;

	const conjugrad_float_t *x1 = x;
	const conjugrad_float_t *x2 = &x[nsingle_padded];

	conjugrad_float_t *g1 = g;
	conjugrad_float_t *g2l = &g[nsingle_padded];

	conjugrad_float_t *g2 = udx->g2;

	// set fx and gradient to 0 initially
	conjugrad_float_t fx = F0; // 0.0

	memset(g, 0, sizeof(conjugrad_float_t) * nvar_padded);
	memset(g2, 0, sizeof(conjugrad_float_t) * (nvar_padded - nsingle_padded));

	for(int i = 0; i < nrow; i++) {
		conjugrad_float_t weight = ud->weights[i];

		conjugrad_float_t precomp[N_ALPHA * ncol] __attribute__ ((aligned (32)));	// aka PrC(ncol,a,s)
		conjugrad_float_t precomp_sum[ncol] __attribute__ ((aligned (32)));
		conjugrad_float_t precomp_norm[N_ALPHA * ncol] __attribute__ ((aligned (32)));	// aka PrCN(ncol,a,s)

		// compute PrC(a,s) = V_s(a) + sum(k \in V_s) w_{sk}(a, X^i_k)
		for(int a = 0; a < N_ALPHA-1; a++) {
			for(int s = 0; s < ncol; s++) {
				PrC(ncol,a,s) = VV(x1,s,a,ncol);
			}
		}
		for(int s = 0; s < ncol; s++) { PrC(ncol,N_ALPHA - 1, s)=0; }
		for(int k = 0; k < ncol; k++) {
			unsigned char xik = XX(ud,i,k,ncol);
			const conjugrad_float_t *w = &WW(x2,xik,k,0,0,ncol);
			conjugrad_float_t *p = &PrC(ncol,0, 0);
			for(int j = 0; j < N_ALPHA * ncol; j++) {
				*p++ += *w++;
			}
		}
		// compute precomp_sum(s) = log( sum(a=1..21) exp(PrC(ncol,a,s)) )
		memset(precomp_sum, 0, sizeof(conjugrad_float_t) * ncol);
		for(int a = 0; a < N_ALPHA; a++) {
			for(int s = 0; s < ncol; s++) {
				precomp_sum[s] += fexp(PrC(ncol,a,s));
			}
		}
		for(int s = 0; s < ncol; s++) {
			precomp_sum[s] = flog(precomp_sum[s]);
		}
		for(int a = 0; a < N_ALPHA; a++) {
			for(int s = 0; s < ncol; s++) {
				PrCN(ncol,a,s) = fexp(PrC(ncol,a, s) - precomp_sum[s]);
			}
		}
		// actually compute fx and gradient
		for(int k = 0; k < ncol; k++) {
			unsigned char xik = XX(ud,i,k,ncol);
			fx += weight * (-PrC(ncol,xik, k ) + precomp_sum[k]);
			if(xik < N_ALPHA - 1) {
				GG1(g1,k,xik,ncol) -= weight;
			}
			for(int a = 0; a < N_ALPHA - 1; a++) {
				GG1(g1,k,a,ncol) += weight * PrCN(ncol,a,k);
			}
		}
		for(int k = 0; k < ncol; k++) {
			unsigned char xik = XX(ud,i,k,ncol);
			for(int j = 0; j < ncol; j++) {
				int xij = XX(ud,i,j,ncol);
				GG2(g2,xik,k,xij,j,ncol) -= weight;
			}
			conjugrad_float_t *pg = &GG2(g2,xik,k,0,0,ncol);
			conjugrad_float_t *pp = &PrCN(ncol,0, 0);
			for(int j = 0; j < N_ALPHA * ncol; j++) {
				*pg++ += weight * *pp++;
			}
		}
	} // i
	// add transposed onto un-transposed
	for(int b = 0; b < N_ALPHA; b++) {
	   for(int k = 0; k < ncol; k++) {
		for(int a = 0; a < N_ALPHA; a++) {
		   for(int j = 0; j < ncol; j++) {
			GG2L(g2l,b,k,a,j,ncol) = 
			    GG2(g2,b,k,a,j,ncol) + GG2(g2,a,j,b,k,ncol);
		   }
		}
	   }
	}
	// set gradients to zero for self-edges
	for(int b = 0; b < N_ALPHA; b++) {
		for(int k = 0; k < ncol; k++) {
			for(int a = 0; a < N_ALPHA; a++) {
				GG2L(g2l,b,k,a,k,ncol)=0;
			}
		}
	}
	// regularization
	conjugrad_float_t reg = F0; // 0.0
	for(int v = 0; v < nsingle; v++) {
		reg += lambda_single * x[v] * x[v];
		g[v] += F2 * lambda_single * x[v]; // F2 is 2.0
	}
	for(int v = nsingle_padded; v < nvar_padded; v++) {
		reg += F05 * lambda_pair * x[v] * x[v]; // F05 is 0.5
		g[v] += F2 * lambda_pair * x[v]; // F2 is 2.0
	} fx += reg;
	return fx;
}

int 	dca_typ::init_cpu(void* instance)
{

	usrdata *ud = (usrdata *)instance;
	extra_usrdata *udx = (extra_usrdata *)malloc(sizeof(extra_usrdata));
	ud->extra = udx;

	int ncol = ud->ncol;
	udx->g2 = conjugrad_malloc(ncol * ncol * N_ALPHA * N_ALPHA_PAD);
	if(udx->g2 == NULL) {
		die("ERROR: Not enough memory to allocate temp g2 matrix!");
	}
	if(ud->reweighting_threshold != F1) {
		calculate_weights(ud->weights, ud->msa, ncol, ud->nrow,
			ud->reweighting_threshold);
	} else { uniform_weights(ud->weights, ud->nrow); }
	return true;
}

int	dca_typ::destroy_cpu(void* instance)
{

	usrdata *ud = (usrdata *)instance;
	extra_usrdata *udx = (extra_usrdata *)ud->extra;

	conjugrad_free(udx->g2);
	free(ud->extra);
	return EXIT_SUCCESS;
}

