// #include <Rcpp.h>
#include <RcppArmadillo.h>
#include "beachmat/numeric_matrix.h"
#include "beachmat/integer_matrix.h"

using namespace Rcpp;
// [[Rcpp::depends(RcppArmadillo)]]



// This correction factor is necessary to avoid estimates of
// theta that are basically +Inf. The problem is that for
// some combination of the y, mu, and X the term
// lgamma(1/theta) and the log(det(t(X) %*% W %*% X))
// with W = diag(1/(1/mu + theta)) canceled each other
// exactly out for large theta.
const double cr_correction_factor = 0.99;


// [[Rcpp::export]]
List make_table_if_small(const NumericVector& x, int stop_if_larger){
  std::unordered_map<long, size_t> counts;
  counts.reserve(stop_if_larger);
  for (double v : x){
    ++counts[(long) v];
    if(counts.size() > stop_if_larger){
      return List::create(NumericVector::create(), NumericVector::create());
    }
  }
  NumericVector keys(counts.size());
  NumericVector values(counts.size());
  transform(counts.begin(), counts.end(), keys.begin(), [](std::pair<int, size_t> pair){return (double) pair.first;});
  transform(counts.begin(), counts.end(), values.begin(), [](std::pair<int, size_t> pair){return (double) pair.second;});
  return List::create(keys, values);
}


//--------------------------------------------------------------------------------------------------//
// The following code was originally copied from https://github.com/mikelove/DESeq2/blob/master/src/DESeq2.cpp
// I adapted it to the needs of this project by:
//  * renaming alpha -> theta for consitency
//  * removing the part for the prior on theta
//  * renaming x -> model_matrix
//  * additional small changes
//  * adding capability to calculate digamma/trigamma only
//    on unique counts



/*
 * DESeq2 C++ functions
 *
 * Author: Michael I. Love, Constantin Ahlmann-Eltze
 * Last modified: May 21, 2020
 * License: LGPL (>= 3)
 *
 * Note: The canonical, up-to-date DESeq2.cpp lives in
 * the DESeq2 library, the development branch of which
 * can be viewed here:
 *
 * https://github.com/mikelove/DESeq2/blob/master/src/DESeq2.cpp
 */



// this function returns the log posterior of dispersion parameter alpha, for negative binomial variables
// given the counts y, the expected means mu, the design matrix x (used for calculating the Cox-Reid adjustment),
// and the parameters for the normal prior on log alpha

// [[Rcpp::export]]
double conventional_loglikelihood_fast(NumericVector y, NumericVector mu, double log_theta, const arma::mat& model_matrix, bool do_cr_adj,
                                       NumericVector unique_counts = NumericVector::create(),
                                       NumericVector count_frequencies = NumericVector::create()) {
  double theta = exp(log_theta);
  double cr_term = 0.0;
  if(do_cr_adj){
    arma::vec w_diag = 1.0 / (1.0 / mu + theta);
    arma::mat b = model_matrix.t() * (model_matrix.each_col() % w_diag);
    // cr_term = -0.5 * log(det(b)) * cr_correction_factor;
    arma::mat L, U, P;
    arma::lu(L, U, P, b);
    double ld = sum(log(arma::diagvec(L)));
    arma::vec u_diag = arma::diagvec(U);
    for(double e : u_diag){
      ld += e < 1e-50 ? log(1e-50) : log(e);
    }
    cr_term = -0.5 * ld * cr_correction_factor;
  }
  double theta_neg1 = R_pow_di(theta, -1);
  double lgamma_term = 0;
  // If summarized counts are available use those to calculate sum(lgamma(y + theta_neg1))
  if(unique_counts.size() > 0 && unique_counts.size() == count_frequencies.size()){
    for(size_t iter = 0; iter < count_frequencies.size(); ++iter){
      lgamma_term += count_frequencies[iter] * lgamma(unique_counts[iter] + theta_neg1);
    }
  }else{
    lgamma_term = sum(lgamma(y + theta_neg1));
  }
  lgamma_term -=  y.size() * lgamma(theta_neg1);
  double ll_part = 0.0;
  for(size_t i = 0; i < y.size(); ++i){
    ll_part += (-y[i] - theta_neg1) * log(mu[i] + theta_neg1);
  }
  ll_part -= y.size() * theta_neg1 * log(theta);
  return lgamma_term + ll_part + cr_term;
}


// this function returns the derivative of the log posterior with respect to the log of the
// dispersion parameter alpha, given the same inputs as the previous function

// [[Rcpp::export]]
double conventional_score_function_fast(NumericVector y, NumericVector mu, double log_theta, const arma::mat& model_matrix, bool do_cr_adj,
                                        NumericVector unique_counts = NumericVector::create(),
                                        NumericVector count_frequencies = NumericVector::create()) {
  double theta = exp(log_theta);
  double theta_neg1 = 1.0 / theta;

  double cr_term = 0.0;
  if(do_cr_adj){
    arma::vec w_diag = 1.0 / (1.0 / mu + theta);
    arma::vec dw_diag = -1 * w_diag % w_diag;
    arma::mat b = model_matrix.t() * (model_matrix.each_col() % w_diag);
    arma::mat db = model_matrix.t() * (model_matrix.each_col() % dw_diag);
    // The diag(1e-6) protects against singular matrices
    arma::mat b_inv = inv_sympd(b + arma::eye(b.n_rows, b.n_cols) * 1e-6);
    cr_term = -0.5 * trace(b_inv * db) * cr_correction_factor;
  }


  double digamma_term = 0;
  // If summarized counts are available use those to calculate sum(digamma(y + theta_neg1))
  if(unique_counts.size() > 0 && unique_counts.size() == count_frequencies.size()){
    double max_y = 0.0;
    double sum_y = 0.0;
    double sum_prod_y = 0.0;
    for(size_t iter = 0; iter < count_frequencies.size(); ++iter){
      digamma_term += count_frequencies[iter] * Rf_digamma(unique_counts[iter] + theta_neg1);
      sum_y += count_frequencies[iter] * unique_counts[iter];
      sum_prod_y += count_frequencies[iter] * (unique_counts[iter] - 1) * unique_counts[iter];
      max_y = std::max(max_y, unique_counts[iter]);
    }
    double corr = theta_neg1 > 1e5 ? sum_prod_y / (2 * theta_neg1) : 0.0;
    if(max_y * 1e6 < theta_neg1){
      // This approximation is based on the fact that for large x
      // (sum(digamma(y + x))  - length(y) * digamma(x)) * x \approx sum(y)
      // Due to numerical imprecision the digamma_term reaches sum(y) sometimes
      // quicker than the ll_term, thus I subtract the first term of the
      // Laurent series expansion at x -> inf
      digamma_term = sum_y - corr;
    }else{
      digamma_term -= y.size() * Rf_digamma(theta_neg1);
      digamma_term *= theta_neg1;
      digamma_term = std::min(digamma_term, sum_y - corr);
    }
  }else{
    double max_y = 0.0;
    double sum_y = 0.0;
    double sum_prod_y = 0.0;
    for(size_t iter = 0; iter < y.size(); ++iter){
      digamma_term += Rf_digamma(y[iter] + theta_neg1);
      sum_y += y[iter];
      sum_prod_y += (y[iter] - 1) * y[iter];
      max_y = std::max(max_y, y[iter]);
    }
    double corr = theta_neg1 > 1e5 ? sum_prod_y / (2 * theta_neg1) : 0.0;
    if(max_y * 1e6 < theta_neg1){
      digamma_term = sum_y - corr;
    }else{
      digamma_term -= y.size() * Rf_digamma(theta_neg1);
      digamma_term *= theta_neg1;

      digamma_term = std::min(digamma_term, sum_y - corr);
    }
  }

  double ll_part = 0.0;
  for(size_t i = 0; i < y.size(); ++i){
    double mu_theta = (mu[i] * theta);
    if(mu_theta < 1e-10){
      ll_part += mu_theta * mu_theta * (1 / (1 + mu_theta) - 0.5);
    }else if(mu_theta < 1e-4){
      // The bounds are based on the Taylor expansion of log(1 + x) for x = 0.
      double inv = 1 / (1 + mu_theta);
      double upper_bound = mu_theta * mu_theta * inv;
      double lower_bound = mu_theta * mu_theta * (inv - 0.5);
      double suggest = (log(1 + mu_theta) - mu[i] / (mu[i] + theta_neg1)) ;
      ll_part +=  std::max(std::min(suggest, upper_bound), lower_bound);
    }else{
      ll_part += log(1 + mu_theta)  - mu[i] / (mu[i] + theta_neg1);
    }
    ll_part += y[i] / (mu[i] + theta_neg1);
  }
  ll_part *= theta_neg1;
  return ll_part - digamma_term + cr_term * theta;
}



// this function returns the second derivative of the log posterior with respect to the log of the
// dispersion parameter alpha, given the same inputs as the previous function

// [[Rcpp::export]]
double conventional_deriv_score_function_fast(NumericVector y, NumericVector mu, double log_theta, const arma::mat& model_matrix, bool do_cr_adj,
                                              NumericVector unique_counts = NumericVector::create(),
                                              NumericVector count_frequencies = NumericVector::create()) {
  double theta = exp(log_theta);
  double cr_term = 0.0;
  double cr_term2 = 0.0;
  if(do_cr_adj){
    arma::vec w_diag = 1/(1/mu + theta);
    arma::vec dw_diag = -1 * w_diag % w_diag;
    arma::vec d2w_diag = -2 * dw_diag % w_diag;

    arma::mat b = model_matrix.t() * (model_matrix.each_col() % w_diag);
    arma::mat db = model_matrix.t() * (model_matrix.each_col() % dw_diag);
    arma::mat d2b = model_matrix.t() * (model_matrix.each_col() % d2w_diag);
    // The diag(1e-6) protects against singular matrices
    arma::mat b_inv = inv_sympd(b + arma::eye(b.n_rows, b.n_cols) * 1e-6);
    arma::mat d_i_db = b_inv * db;
    double ddetb = trace(d_i_db);
    double d2detb = ((R_pow_di(ddetb, 2) - trace(d_i_db * d_i_db) + trace(b_inv * d2b)) );
    cr_term = (0.5 * R_pow_di(ddetb, 2) - 0.5 * d2detb)  * cr_correction_factor;
    cr_term2 = -0.5 * ddetb * cr_correction_factor;
  }

  double theta_neg1 = R_pow_di(theta, -1);
  double theta_neg2 = R_pow_di(theta, -2);
  double digamma_term = 0.0;
  double trigamma_term = 0.0;

  // If summarized counts are available use those to calculate sum(digamma()) and sum(trigamma())
  if(unique_counts.size() > 0 && unique_counts.size() == count_frequencies.size()){
    for(size_t iter = 0; iter < count_frequencies.size(); ++iter){
      digamma_term += count_frequencies[iter] * Rf_digamma(unique_counts[iter] + theta_neg1);
      trigamma_term += count_frequencies[iter] * Rf_trigamma(unique_counts[iter] + theta_neg1);
    }
    trigamma_term *= theta_neg2;

    digamma_term -= y.size() * Rf_digamma(theta_neg1);
    trigamma_term -=  theta_neg2 * y.size() * Rf_trigamma(theta_neg1);
  }else{
    digamma_term = sum(digamma(y + theta_neg1));
    digamma_term -= y.size() * Rf_digamma(theta_neg1);

    trigamma_term = theta_neg2 * sum(trigamma(y + theta_neg1));
    trigamma_term -=  theta_neg2 * y.size() * Rf_trigamma(theta_neg1);
  }

  double ll_part_1 = 0.0;
  double ll_part_2 = 0.0;
  for(size_t i = 0; i < y.size(); ++i){
    ll_part_1 += log(1 + mu[i] * theta) + (y[i] - mu[i]) / (mu[i] + theta_neg1);
    ll_part_2 += (mu[i] * mu[i] * theta + y[i]) / (1 + mu[i] * theta) / (1 + mu[i] * theta);
  }
  double ll_part = -2 * theta_neg1 * (ll_part_1 - digamma_term) + (ll_part_2 + trigamma_term);

  double res = ll_part + cr_term * R_pow_di(theta, 2) + (ll_part_1 - digamma_term) * theta_neg1 + cr_term2 * theta;
  return res;
}


// ------------------------------------------------------------------------------------------------

template<class NumericType>
List estimate_overdispersions_fast_internal(RObject Y, RObject mean_matrix, NumericMatrix model_matrix, bool do_cox_reid_adjustment,
                                            double n_subsamples, int max_iter){
  auto Y_bm = beachmat::create_matrix<NumericType>(Y);
  auto mean_mat_bm = beachmat::create_numeric_matrix(mean_matrix);
  int n_samples = Y_bm->get_ncol();
  int n_genes = Y_bm->get_nrow();
  NumericVector estimates(n_genes);
  NumericVector iterations(n_genes);
  CharacterVector messages(n_genes);
  if(n_genes != mean_mat_bm->get_nrow() || n_samples != mean_mat_bm->get_ncol()){
    throw std::runtime_error("Dimensions of Y and mean_matrix do not match");
  }

  // This is calling back to R, which simplifies my code a lot
  Environment glmGamPoiEnv = Environment::namespace_env("glmGamPoi");
  Function overdispersion_mle_impl = glmGamPoiEnv["overdispersion_mle_impl"];
  for(int gene_idx = 0; gene_idx < n_genes; gene_idx++){
    if (gene_idx % 100 == 0) checkUserInterrupt();
    typename NumericType::vector counts(n_samples);
    Y_bm->get_row(gene_idx, counts.begin());
    NumericVector mu(n_samples);
    mean_mat_bm->get_row(gene_idx, mu.begin());

    // Check if the first value is NA, if yes all of them will be
    if(n_samples > 0 && Rcpp::traits::is_na<REALSXP>(mu[0])){
      estimates(gene_idx) = NA_REAL;
      iterations(gene_idx) = max_iter;
      messages(gene_idx) = "Mean estimate was NA. Cannot estimate overdispersion";
    }else{
      List dispRes =  Rcpp::as<List>(overdispersion_mle_impl(counts, mu, model_matrix, do_cox_reid_adjustment, n_subsamples, max_iter));
      estimates(gene_idx) = Rcpp::as<double>(dispRes["estimate"]);
      iterations(gene_idx) = Rcpp::as<double>(dispRes["iterations"]);
      messages(gene_idx) = Rcpp::as<String>(dispRes["message"]);
    }

  }
  return List::create(
    Named("estimate", estimates),
    Named("iterations", iterations),
    Named("message", messages));;
}

// [[Rcpp::export]]
List estimate_overdispersions_fast(RObject Y, RObject mean_matrix, NumericMatrix model_matrix, bool do_cox_reid_adjustment,
                              double n_subsamples, int max_iter){
  auto mattype=beachmat::find_sexp_type(Y);
  if (mattype==INTSXP) {
    return estimate_overdispersions_fast_internal<beachmat::integer_matrix>(Y, mean_matrix, model_matrix, do_cox_reid_adjustment, n_subsamples, max_iter);
  } else if (mattype==REALSXP) {
    return estimate_overdispersions_fast_internal<beachmat::numeric_matrix>(Y, mean_matrix, model_matrix, do_cox_reid_adjustment, n_subsamples, max_iter);
  } else {
    throw std::runtime_error("unacceptable matrix type");
  }
}


template<class NumericType>
NumericVector estimate_global_overdispersions_fast_internal(RObject Y, RObject mean_matrix, const arma::mat model_matrix, const bool do_cox_reid_adjustment,
                                                            const NumericVector log_thetas){
  const auto Y_bm = beachmat::create_matrix<NumericType>(Y);
  const auto mean_mat_bm = beachmat::create_numeric_matrix(mean_matrix);
  int n_samples = Y_bm->get_ncol();
  int n_genes = Y_bm->get_nrow();
  int n_spline_points = log_thetas.size();

  NumericVector log_likelihoods(n_spline_points);

  for(int gene_idx = 0; gene_idx < n_genes; gene_idx++){
    if (gene_idx % 100 == 0) checkUserInterrupt();
    NumericVector counts(n_samples);
    Y_bm->get_row(gene_idx, counts.begin());
    NumericVector mu(n_samples);
    mean_mat_bm->get_row(gene_idx, mu.begin());

    ListOf<NumericVector> tab = List::create(NumericVector::create(), NumericVector::create());

    tab = make_table_if_small(counts, /*stop_if_larger = */ n_samples / 2);

    for(int point_idx = 0; point_idx < n_spline_points; point_idx++){
      log_likelihoods[point_idx] += conventional_loglikelihood_fast(counts, mu, log_thetas[point_idx], model_matrix,
                                                                    do_cox_reid_adjustment, tab[0], tab[1]);
    }
  }
  return log_likelihoods;
}


// [[Rcpp::export]]
NumericVector estimate_global_overdispersions_fast(RObject Y, RObject mean_matrix, const arma::mat model_matrix, const bool do_cox_reid_adjustment,
                                                   const NumericVector log_thetas){
  auto mattype=beachmat::find_sexp_type(Y);
  if (mattype==INTSXP) {
    return estimate_global_overdispersions_fast_internal<beachmat::integer_matrix>(Y, mean_matrix, model_matrix, do_cox_reid_adjustment, log_thetas);
  } else if (mattype==REALSXP) {
    return estimate_global_overdispersions_fast_internal<beachmat::numeric_matrix>(Y, mean_matrix, model_matrix, do_cox_reid_adjustment, log_thetas);
  } else {
    throw std::runtime_error("unacceptable matrix type");
  }
}