#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "conjugrad.h"

#include "dca_typ.h"

void dca_typ::die(const char* message)
/**
 * Exit the program with an error message.
 *
 * @param[in] message The error message to display
 */
{
	fprintf(stderr, "\nERROR: %s\n\n", message);
	exit(1);
}

void dca_typ::sum_submatrices(conjugrad_float_t *x, conjugrad_float_t *out, int ncol)
/** Sum up amino acid pairing submatrices so we have one score per pair of columns
 * @param[in] x The 21xLx21xL matrix of column and amino acid pairs
 * @param[out] out The LxL matrix of column pairs
 * @param[in] ncol The number of columns in the output matrix (i.e. L)
 */
{


	int nsingle = ncol * (N_ALPHA - 1);
	int nsingle_padded = nsingle + N_ALPHA_PAD - (nsingle % N_ALPHA_PAD);
	conjugrad_float_t *x2 = &x[nsingle_padded];

	memset(out, 0, sizeof(conjugrad_float_t) * ncol * ncol);

	conjugrad_float_t xnorm = 0;
	for(int k = 0; k < ncol; k++) {
		for(int j = k+1; j < ncol; j++) {
			for(int a = 0; a < N_ALPHA; a++) {
				for(int b = 0; b < N_ALPHA; b++) {
					conjugrad_float_t w = WW(x2,b,k,a,j,ncol);
					xnorm += w * w;
				}
			}
		}
	}
	printf("xnorm = %g\n", sqrt(xnorm));

	for(int k = 0; k < ncol; k++) {
		for(int j = 0; j < ncol; j++) {
			conjugrad_float_t mean = 0;
			for(int a = 0; a < N_ALPHA; a++) {
				for(int b = 0; b < N_ALPHA; b++) {
					mean += WW(x2,b,k,a,j,ncol);
				}
			}

			mean /= (N_ALPHA * N_ALPHA);

			for(int a = 0; a < N_ALPHA - 1; a++) {
				for(int b = 0; b < N_ALPHA - 1; b++) {
					conjugrad_float_t w = WW(x2,b,k,a,j,ncol) - mean;
					out[k * ncol + j] += w * w;
				}
			}
		}
	}

	for(int k = 0; k < ncol; k++) {
		for(int j = 0; j < ncol; j++) {
			out[k * ncol + j] = sqrt(out[k * ncol + j]);
		}

		out[k * ncol + k] = F0; // 0.0
	}
}

void dca_typ::apc(conjugrad_float_t *mat, int ncol)
/** Average product correction
 * @param[in,out] mat The matrix to process
 * @param[in] ncol The number of columns in the matrix
 */
{

	conjugrad_float_t means[ncol];
	memset(means, 0, sizeof(conjugrad_float_t) * ncol);

	conjugrad_float_t meansum = 0;
	for(int i = 0; i < ncol; i++) {
		for(int j = 0; j < ncol; j++) {
			conjugrad_float_t w = mat[i * ncol + j];
			means[j] += w / ncol;
			meansum += w;
		}
	}
	meansum /= ncol * ncol;


	for(int i = 0; i < ncol; i++) {
		for(int j = 0; j < ncol; j++) {
			mat[i * ncol + j] -= (means[i] * means[j]) / meansum;
		}
	}


	conjugrad_float_t min_wo_diag = 1./0.;
	for(int i = 0; i < ncol; i++) {
		for(int j = i+1; j < ncol; j++) {
			if(mat[i * ncol + j] < min_wo_diag) {
				min_wo_diag = mat[i * ncol + j];
			}
		}
	}

	
	for(int i = 0; i < ncol; i++) {
		for(int j = 0; j < ncol; j++) {
			mat[i * ncol + j] -= min_wo_diag;
		}

		mat[i * ncol + i] = 0;
	}


}

void	dca_typ::normalize(conjugrad_float_t *mat, int ncol)
/** Linearly re-scale the matrix
 * All matrix elements will be normalized so that the maximum element is 1.0 and the mini
mum element is 0.0.
 * The diagonal will be ignored when looking at the value range and later set to 0.0
 *
 * @param[in,out] mat The matrix to process
 * @param[in] ncol The number of columns in the matrix
 */
{
	conjugrad_float_t min = mat[1];
	conjugrad_float_t max = mat[1];
	for(int i = 0; i < ncol; i++) {
		for(int j = 0; j < ncol; j++) {
			if(i == j) { continue; }

			conjugrad_float_t x = mat[i * ncol + j];
			if(x < min) { min = x; }
			if(x > max) { max = x; }

		}
	}

	conjugrad_float_t range = max - min;
	assert(range != 0);

	for(int i = 0; i < ncol; i++) {
		for(int j = 0; j < ncol; j++) {

			conjugrad_float_t x = mat[i * ncol + j];
			x = (x-min)/range;
			mat[i * ncol + j] = x;

			if(i == j) {
				mat[i * ncol + j] = F0; // 0.0
			}
		}
	}

}
