#include <stdbool.h>
#include <stdio.h>
#include <limits.h>
#include <stdlib.h>
#include <stdint.h>
#include <libgen.h>
#include <string.h>
#include <locale.h>
#include <math.h>

#include "dca_typ.h"
#include "conjugrad.h"

#ifdef CUDA
#include <cuda.h>
#include <cuda_runtime_api.h>
// #include "evaluate_cuda.h"
#endif

#ifdef OPENMP
#include <omp.h>
// #include "evaluate_cpu_omp.h"
#endif

// #include "evaluate_cpu.h"

void	dca_typ::init_bias(conjugrad_float_t *x, usrdata *ud)
/** Init the pairwise emission potentials with single emission potentials as observed in the MSA
 * @param[out] x The matrix to initialize
 * @param[in] ncol The number of columns in the MSA (i.e. L)
 * @param[in] nrow The number of rows in the MSA (i.e. N)
 * @param[in] msa The MSA to read column frequencies from
 */
{
	int ncol = ud->ncol;
	int nrow = ud->nrow;
	int nsingle = ncol * (N_ALPHA - 1);
	int nsingle_padded = nsingle + N_ALPHA_PAD - (nsingle % N_ALPHA_PAD);
	int nvar_padded = nsingle_padded + ncol * ncol * N_ALPHA * N_ALPHA_PAD;

	conjugrad_float_t *x1 = x;

	//memset(x, 0, sizeof(conjugrad_float_t) * ud->nvar);
	memset(x, 0, sizeof(conjugrad_float_t) * nvar_padded);

	for(int j = 0; j < ncol; j++) {

		int aacounts[N_ALPHA];
		for(int a = 0; a < N_ALPHA; a++) {
			// init with pseudocounts
			aacounts[a] = 1;
		}

		// calculate weights for column
		for(int i = 0; i < nrow; i++) {
			aacounts[XX(ud,i,j,ncol)]++;
		}

		int aasum = nrow + N_ALPHA;

		conjugrad_float_t aafrac[N_ALPHA];
		for(int a = 0; a < N_ALPHA; a++) {
			aafrac[a] = ((conjugrad_float_t)aacounts[a]) / aasum;
		}

		// we set the weights of the other amino acids relative to the gap column at index 20 (to avoid degeneracy)
		conjugrad_float_t aanorm = flog(aafrac[20]);
		for(int a = 0; a < N_ALPHA - 1; a++) {
			VV(x1,j,a,ncol) = flog( aafrac[a] ) - aanorm;
		}

	}

}

void dca_typ::logo(bool color)
// modified by AFN: 7_27_2022.
{
static int calls=0;
   if(calls == 0){
	if(color) {
		printf(" _____ _____ _____               _ \n");
		printf("|\x1b[30;42m     |     |     |\x1b[0m___ ___ ___ _|\x1b[30;44m |\x1b[0m\n");
		printf("|\x1b[30;42m   --|   --| | | |\x1b[44m . |  _| -_| . |\x1b[0m\n");
		printf("|\x1b[30;42m_____|_____|_|_|_|\x1b[44m  _|_|\x1b[0m \x1b[30;44m|___|___|\x1b[0m version %s\n", __VERSION);
		printf("                  |\x1b[30;44m_|\x1b[0m\n\n");
	} else {
		printf(" _____ _____ _____               _ \n");
		printf("|     |     |     |___ ___ ___ _| |\n");
		printf("|   --|   --| | | | . |  _| -_| . |\n");
		printf("|_____|_____|_|_|_|  _|_| |___|___|\n");
		printf("                  |_|              \n\n");
	} calls++;
   }
}

char* dca_typ::concatenate(const char *s1, const char *s2)
{
	size_t l1 = strlen(s1);
	size_t l2 = strlen(s2);
	char *result = (char*) malloc(l1 + l2 + 1);
	if(result == NULL) {
		die("Cannot malloc new string!");
	}
	memcpy(result, s1, l1);
	memcpy(result+l1, s2, l2+1);
	return result;
}

void 	dca_typ::usage(char* exename, int long_usage)
/**
 * Print a pretty usage message
 * @param[in] exename The name of the executable
 * @param[in] long_usage The length of the usage message to display: 0: short usage, 1: usage and options
 */
{
	printf("Usage: %s [options] input.aln output.mat\n\n", exename);

	if(long_usage) {
		printf("Options:\n");
#ifdef CUDA
		printf("\t-d DEVICE \tCalculate on CUDA device number DEVICE (set to -1 to use CPU) [default: 0]\n");
#endif
#ifdef OPENMP
		printf("\t-t THREADS\tCalculate using THREADS threads on the CPU (automatically disables CUDA if available) [default: 1]\n");
#endif
		printf("\t-n NUMITER\tCompute a maximum of NUMITER operations [default: 50]\n");
		printf("\t-e EPSILON\tSet convergence criterion for minimum decrease in the last K iterations to EPSILON [default: 0.01]\n");
		printf("\t-k LASTK  \tSet K parameter for convergence criterion to LASTK [default: 5]\n");

		printf("\n");
		printf("\t-i INIFILE\tRead initial weights from INIFILE\n");
		printf("\t-r RAWFILE\tStore raw prediction matrix in RAWFILE\n");

		printf("\n");
		printf("\t-w IDTHRES\tSet sequence reweighting identity threshold to IDTHRES [default: 0.8]\n");
		printf("\t-l LFACTOR\tSet pairwise regularization coefficients to LFACTOR * (L-1) [default: 0.2]\n");
		printf("\t-A        \tDisable average product correction (APC)\n");
		printf("\t-R        \tRe-normalize output matrix to [0,1]\n");
		printf("\t-h        \tDisplay help\n");

		printf("\n");
		printf("\n\n");
	}
	exit(1);
}

int	dca_typ::Run( )
// int argc, char **argv) // now globally defined for object dca_typ
{
	char *rawfilename = NULL;
	int numiter = 250;
	int use_apc = 1;
	int use_normalization = 0;
	conjugrad_float_t lambda_single = F001; // 0.01
	conjugrad_float_t lambda_pair = FInf;
	conjugrad_float_t lambda_pair_factor = F02; // 0.2
	int conjugrad_k = 5;
	conjugrad_float_t conjugrad_eps = 0.01;

	parse_options *optList, *thisOpt;

	char *optstr;
	char *old_optstr = (char*) malloc(1);
	old_optstr[0] = 0;
	optstr = this->concatenate("r:i:n:w:k:e:l:ARh?", old_optstr);
	free(old_optstr);

#ifdef OPENMP
	int numthreads = 1;
	old_optstr = optstr;
	optstr = this->concatenate("t:", optstr);
	free(old_optstr);
#endif

#ifdef CUDA
	int use_def_gpu = 0;
	old_optstr = optstr;
	optstr = this->concatenate("d:", optstr);
	free(old_optstr);
#endif

	optList = parseopt(argc, argv, optstr);
	free(optstr);

	char* msafilename = NULL;
	char* matfilename = NULL;
	char* initfilename = NULL;

	conjugrad_float_t reweighting_threshold = F08; // 0.8

	while(optList != NULL) {
		thisOpt = optList;
		optList = optList->next;

		switch(thisOpt->option) {
#ifdef OPENMP
			case 't':
				numthreads = atoi(thisOpt->argument);

#ifdef CUDA
				use_def_gpu = -1; // automatically disable GPU if number of threads specified
#endif
				break;
#endif
#ifdef CUDA
			case 'd':
				use_def_gpu = atoi(thisOpt->argument);
				break;
#endif
			case 'r':
				rawfilename = thisOpt->argument;
				break;
			case 'i':
				initfilename = thisOpt->argument;
				break;
			case 'n':
				numiter = atoi(thisOpt->argument);
				break;
			case 'w':
				reweighting_threshold = (conjugrad_float_t)atof(thisOpt->argument);
				break;
			case 'l':
				lambda_pair_factor = (conjugrad_float_t)atof(thisOpt->argument);
				break;
			case 'k':
				conjugrad_k = (int)atoi(thisOpt->argument);
				break;
			case 'e':
				conjugrad_eps = (conjugrad_float_t)atof(thisOpt->argument);
				break;
			case 'A':
				use_apc = 0;
				break;
			case 'R': use_normalization = 1; break;
			case 'h':
			case '?': usage(argv[0], 1); break;
			case 0:
				if(msafilename == NULL) {
					msafilename = thisOpt->argument;
				} else if(matfilename == NULL) {
					matfilename = thisOpt->argument;
				} else { usage(argv[0], 0); }
				break;
			default:
				die("Unknown argument"); 
		} free(thisOpt);
	}
	if(msafilename == NULL || matfilename == NULL) { usage(argv[0], 0); }
	FILE *msafile = fopen(msafilename, "r");
	if( msafile == NULL) {
		printf("Cannot open %s!\n\n", msafilename);
		return 2;
	}

	int ncol, nrow;
	unsigned char* msa = read_msa(msafile, &ncol, &nrow);
	fclose(msafile);

	int nsingle = ncol * (N_ALPHA - 1);
	int nvar = nsingle + ncol * ncol * N_ALPHA * N_ALPHA;
	int nsingle_padded = nsingle + N_ALPHA_PAD - (nsingle % N_ALPHA_PAD);
	int nvar_padded = nsingle_padded + ncol * ncol * N_ALPHA * N_ALPHA_PAD;

	bool color = false;
	logo(color);

#ifdef CUDA
	int num_devices, dev_ret;
	struct cudaDeviceProp prop;
	dev_ret = cudaGetDeviceCount(&num_devices);
	if(dev_ret != CUDA_SUCCESS) { num_devices = 0; }

#if 1	// AFN: 7_27_2022.
static int calls=0;
#endif

	if(num_devices == 0) {
		printf("No CUDA devices available, ");
		use_def_gpu = -1;
	} else if (use_def_gpu < -1 || use_def_gpu >= num_devices) {
		printf("Error: %d is not a valid device number. ",use_def_gpu);
		printf("Please choose a number between 0 and %d\n",num_devices - 1);
		exit(1);
	} else {
#if 1	// AFN: 7_27_2022.
	if(calls==0) printf("Found %d CUDA devices, ", num_devices);
#else
		printf("Found %d CUDA devices, ", num_devices);
#endif
	}

	if (use_def_gpu != -1) {
	   cudaError_t err = cudaSetDevice(use_def_gpu);
	   if(cudaSuccess != err) {
		printf("Error setting device: %d\n", err);
		exit(1);
	   }
	   cudaGetDeviceProperties(&prop, use_def_gpu);
	   printf("using device #%d: %s\n", use_def_gpu, prop.name);

	   size_t mem_free, mem_total;
	   err = cudaMemGetInfo(&mem_free, &mem_total);
	   if(cudaSuccess != err) {
		printf("Error getting memory info: %d\n", err);
		exit(1);
	   }

	   size_t mem_needed = nrow * ncol * 2 + // MSAs
	   sizeof(conjugrad_float_t) * nrow * ncol * 2 + // PC, PCS
	   sizeof(conjugrad_float_t) * nrow * ncol * N_ALPHA_PAD + // PCN
	   sizeof(conjugrad_float_t) * nrow + // Weights
	   (sizeof(conjugrad_float_t) * ((N_ALPHA - 1) * ncol + ncol * ncol * N_ALPHA * N_ALPHA_PAD)) * 4;

	   setlocale(LC_NUMERIC, "");
#if 1 // AFN: 7_27_2022.
	   if(calls == 0){
	     printf("Total GPU RAM:  %'17lu\n", mem_total);
	     printf("Free GPU RAM:   %'17lu\n", mem_free);
	     printf("Needed GPU RAM: %'17lu ", mem_needed);
	     calls++;
   	   }
#endif

	   if(mem_needed <= mem_free) { printf("✓\n"); }
	   else { printf("⚠\n"); }

	} else {
		printf("using CPU");
#ifdef OPENMP
		printf(" (%d thread(s))", numthreads);
#endif
		printf("\n");

	}
#else // end if(CUDA)
	printf("using CPU");
#ifdef OPENMP
	printf(" (%d thread(s))\n", numthreads);
#endif // OPENMP
	printf("\n");
#endif // end else CUDA

	conjugrad_float_t *x = conjugrad_malloc(nvar_padded);
	if( x == NULL) {
		die("ERROR: Not enough memory to allocate variables!");
	}
	memset(x, 0, sizeof(conjugrad_float_t) * nvar_padded);

	// Auto-set lambda_pair
	if(isnan(lambda_pair)) {
		lambda_pair = lambda_pair_factor * (ncol - 1);
	}

	// fill up user data struct for passing to evaluate
	usrdata *ud = (usrdata *)malloc( sizeof(usrdata) );
	if(ud == 0) { die("Cannot allocate memory for user data!"); }
	ud->msa = msa;
	ud->ncol = ncol;
	ud->nrow = nrow;
	ud->nsingle = nsingle;
	ud->nvar = nvar;
	ud->lambda_single = lambda_single;
	ud->lambda_pair = lambda_pair;
	ud->weights = conjugrad_malloc(nrow);
	ud->reweighting_threshold = reweighting_threshold;

	if(initfilename == NULL) {
		// Initialize emissions to pwm
		init_bias(x, ud);
	} else {
		// Load potentials from file
		read_raw(initfilename, ud, x);
	}

	// optimize with default parameters
	conjugrad_parameter_t *param = conjugrad_init();

	param->max_iterations = numiter;
	param->epsilon = conjugrad_eps;
	param->k = conjugrad_k;
	param->max_linesearch = 5;
	param->alpha_mul = F05;
	param->ftol = 1e-4;
	param->wolfe = F02;
	char	mode='C';	// cpu
#ifdef OPENMP
	omp_set_dynamic(0);
	omp_set_num_threads(numthreads);
	if(numthreads > 1) {
		mode='O';	// omp
	}
#endif
#ifdef CUDA
	if(use_def_gpu != -1) {
		mode='G';	// gpu
	}
#endif
	switch(mode){
	   case 'C': init_cpu(ud); break;
	   case 'G': init_cuda(ud); break;
	   case 'O': init_cpu_omp(ud); break;
	   default: fprintf(stderr,"dca_typ mode error\n"); exit(1); break;
	}
	printf("\nWill optimize %d %ld-bit variables\n\n", nvar, 
			sizeof(conjugrad_float_t) * 8);

	if(color) { printf("\x1b[1m"); }
	printf("iter\teval\tf(x)    \t║x║     \t║g║     \tstep\n");
	if(color) { printf("\x1b[0m"); }

	conjugrad_float_t fx;
	int ret;
#ifdef CUDA
	if(use_def_gpu != -1) {
		conjugrad_float_t *d_x;
		cudaError_t err = cudaMalloc((void **) &d_x, 
				sizeof(conjugrad_float_t) * nvar_padded);
		if (cudaSuccess != err) {
			printf("CUDA error No. %d while allocation memory for d_x\n",
				err);
			exit(1);
		}
		err = cudaMemcpy(d_x, x,
			sizeof(conjugrad_float_t) * nvar_padded, cudaMemcpyHostToDevice);
		if (cudaSuccess != err) {
		    printf("CUDA error No. %d while copying parameters to GPU\n",
				err); exit(1);
		}
		ret = conjugrad_gpu(nvar_padded, d_x, &fx, evaluate_cuda, ud, param);
		err = cudaMemcpy(x, d_x, sizeof(conjugrad_float_t) * nvar_padded,
					cudaMemcpyDeviceToHost);
		if (cudaSuccess != err) {
		   printf("CUDA error No. %d while copying parameters back to CPU\n",
			err); exit(1);
		}
		err = cudaFree(d_x);
		if (cudaSuccess != err) {
		   printf("CUDA error No. %d while freeing memory for d_x\n", err);
		   exit(1);
		}
	} else if(mode=='O'){
long time1 = time(NULL);
	   ret=conjugrad(nvar_padded, x, &fx, evaluate_cpu_omp, ud, param);
fprintf(stderr, "\ttime conjugrad(): %d seconds (%0.2f minutes)\n",
                        time(NULL)-time1,(float)(time(NULL)-time1)/60.0);
	} else {
long time1 = time(NULL);
	   ret=conjugrad(nvar_padded, x, &fx, evaluate_cpu, ud, param);
fprintf(stderr, "\ttime one cpu: %d seconds (%0.2f minutes)\n",
                        time(NULL)-time1,(float)(time(NULL)-time1)/60.0);
	}
#else	// core DCA routine (in conjugrad.c) absent CUDA: afn5_15_23.
	if(mode=='O'){
	   ret=conjugrad(nvar_padded, x, &fx, evaluate_cpu_omp, ud, param);
	} else {
	   ret=conjugrad(nvar_padded, x, &fx, evaluate_cpu, ud, param);
	}
#endif
	printf("\n");
	printf("%s with status code %d - ", (ret < 0 ? "Exit" : "Done"), ret);
	if(ret == CONJUGRAD_SUCCESS) {
		printf("Success!\n");
	} else if(ret == CONJUGRAD_ALREADY_MINIMIZED) {
		printf("Already minimized!\n");
	} else if(ret == CONJUGRADERR_MAXIMUMITERATION) {
		printf("Maximum number of iterations reached.\n");
	} else {
		printf("Unknown status code!\n");
	}
	printf("\nFinal fx = %f\n\n", fx);
	FILE* out = fopen(matfilename, "w");
	if(out == NULL) {
		printf("Cannot open %s for writing!\n\n", matfilename);
		return 3;
	}
	conjugrad_float_t *outmat = conjugrad_malloc(ncol * ncol);
	FILE *rawfile = NULL;
	if(rawfilename != NULL) {
		printf("Writing raw output to %s\n", rawfilename);
		rawfile = fopen(rawfilename, "w");
		if(rawfile == NULL) {
			printf("Cannot open %s for writing!\n\n", rawfilename);
			return 4;
		}
		write_raw(rawfile, x, ncol);
	}
	sum_submatrices(x, outmat, ncol);
	if(use_apc) { apc(outmat, ncol); }
	if(use_normalization) { normalize(outmat, ncol); }
	write_matrix(out, outmat, ncol, ncol);
	if(rawfile != NULL) { fclose(rawfile); }
	fflush(out); fclose(out);
	switch(mode){
	   case 'C': destroy_cpu(ud); break;
	   case 'G': destroy_cuda(ud); break;
	   case 'O': destroy_cpu_omp(ud); break;
	   default: fprintf(stderr,"dca_typ mode error\n"); exit(1); break;
	}
	conjugrad_free(outmat); conjugrad_free(x); conjugrad_free(ud->weights);
	free(ud); free(msa); free(param);
	printf("Output can be found in %s\n", matfilename);
	return 0;
}

