```// #include <Rcpp.h>
#include "beachmat/numeric_matrix.h"
#include "beachmat/integer_matrix.h"

#include <deviance.h>
#include <fisher_scoring_steps.h>

using namespace Rcpp;

template<class NumericType>
void clamp_inplace(/*INOUT parameter*/ arma::Mat<NumericType>& v, double min, double max){
for(int i = 0; i < v.n_elem; i++){
if(v.at(i) < min){
v.at(i) = min;
}else if(v.at(i) > max){
v.at(i) = max;
}
}
}

// Check how many unique rows are in a matrix and if this number is less than or equal to n
// This is important to determine if the model can be solved by group averages
// (ie. the numer of unique rows == number of columns)
// [[Rcpp::export]]
bool lte_n_equal_rows(const NumericMatrix& matrix, int n, double tolerance = 1e-10) {
NumericMatrix reference_matrix(n, matrix.ncol());
size_t n_matches = 0;
for(size_t row_idx = 0; row_idx < matrix.nrow(); row_idx++){
bool matched = false;
NumericMatrix::ConstRow vec = matrix(row_idx, _);
for(size_t ref_idx = 0; ref_idx < n_matches; ref_idx++){
NumericMatrix::Row ref_vec  = reference_matrix(ref_idx, _);
if(sum(abs(vec - ref_vec)) < tolerance){
matched = true;
break;
}
}
if(! matched){
++n_matches;
if(n_matches > n){
return false;
}
reference_matrix(n_matches - 1, _) = vec;
}
}
return true;
}

// [[Rcpp::export]]
IntegerVector get_row_groups(const NumericMatrix& matrix, int n_groups, double tolerance = 1e-10) {
NumericMatrix reference_matrix(n_groups, matrix.ncol());
IntegerVector groups(matrix.nrow());
size_t n_matches = 0;
for(size_t row_idx = 0; row_idx < matrix.nrow(); row_idx++){
bool matched = false;
NumericMatrix::ConstRow vec = matrix(row_idx, _);
for(size_t ref_idx = 0; ref_idx < n_matches; ref_idx++){
NumericMatrix::Row ref_vec  = reference_matrix(ref_idx, _);
if(sum(abs(vec - ref_vec)) < tolerance){
groups(row_idx) = ref_idx;
matched = true;
break;
}
}
if(! matched){
groups(row_idx) = n_matches;
reference_matrix(n_matches, _) = vec;
++n_matches;
}
}
return groups + 1;
}

arma::vec calculate_mu(const arma::mat& model_matrix, const arma::vec& beta_hat, const arma::vec& exp_off){
arma::vec mu_hat = exp(model_matrix * beta_hat) % exp_off;
clamp_inplace(mu_hat, 1e-50, 1e50);
return mu_hat;
}

/**
* This method takes in a proposal for a step and checks if it actually
* decreases the deviance of the model. If does not add it tries again
* with half the step size, then a quarter and so on.
*
* If even after 100 steps the deviance (0.5^100 = 7.9e-31) has not decreased
* it returns NaN.
*
* Note that the first two parameters are changed: beta_hat and mu_hat
*
* The function returns the new deviance.
*
*/
template<class NumericType>
double decrease_deviance(/*In-Out Parameter*/ arma::vec& beta_hat,
/*In-Out Parameter*/ arma::vec& mu_hat,
const arma::vec& step,
const arma::mat& model_matrix,
const arma::mat& exp_off,
const arma::Col<NumericType>& counts,
const double theta, const double dev_old, const double tolerance, const double max_rel_mu_change){
double speeding_factor = 1.0;
int line_iter = 0;
double dev = 0;
beta_hat = beta_hat + step;
const arma::vec mu_old = mu_hat;
while(true){
mu_hat = calculate_mu(model_matrix, beta_hat, exp_off);
dev = compute_gp_deviance_sum(counts, mu_hat, theta);
double conv_test = fabs(dev - dev_old)/(fabs(dev) + 0.1);
double mu_rel_change = max(mu_hat / mu_old);
if((dev < dev_old && mu_rel_change < max_rel_mu_change) || conv_test < tolerance){
break; // while loop
}else if(line_iter >= 100){
// speeding factor is very small, something is going wrong here
dev = std::numeric_limits<double>::quiet_NaN();
break; // while loop
}else{
// Halfing the speed
speeding_factor = speeding_factor / 2.0;
beta_hat = beta_hat - step * speeding_factor;
}
line_iter++;
}
return dev;
}

template<class NumericType>
double decrease_deviance_plus_ridge(/*In-Out Parameter*/ arma::vec& beta_hat,
/*In-Out Parameter*/ arma::vec& mu_hat,
const arma::vec& step,
const arma::mat& model_matrix,
const arma::mat& ridge_penalty_sq,
const arma::vec& ridge_target,
const arma::mat& exp_off,
const arma::Col<NumericType>& counts,
const double theta, const double dev_old,
const double tolerance, const double max_rel_mu_change){
double speeding_factor = 1.0;
int line_iter = 0;
double dev = 0;
int n_samples = model_matrix.n_rows;
beta_hat = beta_hat + step;
const arma::vec mu_old = mu_hat;
while(true){
mu_hat = calculate_mu(model_matrix, beta_hat, exp_off);
double pen_sum = n_samples * arma::as_scalar((beta_hat - ridge_target).t() * ridge_penalty_sq * (beta_hat - ridge_target));
dev = compute_gp_deviance_sum(counts, mu_hat, theta) + pen_sum;
double conv_test = fabs(dev - dev_old)/(fabs(dev) + 0.1);
double mu_rel_change = max(mu_hat / mu_old);
if((dev < dev_old && mu_rel_change < max_rel_mu_change) || conv_test < tolerance){
break; // while loop
}else if(line_iter >= 100){
// speeding factor is very small, something is going wrong here
dev = std::numeric_limits<double>::quiet_NaN();
break; // while loop
}else{
// Halfing the speed
speeding_factor = speeding_factor / 2.0;
beta_hat = beta_hat - step * speeding_factor;
}
line_iter++;
}
return dev;
}

//--------------------------------------------------------------------------------------------------//
// The following code was originally copied from https://github.com/mikelove/DESeq2/blob/master/src/DESeq2.cpp
// I adapted it to the needs of this project by:
//  * remove weights
//  * Calculate actual deviance (2 * (log(f_NB(y | mu, theta)) - log(f_NB(y | y, theta))))
//    instead of just 2 * log(f_NB(y | mu, theta)),
//  * Support DelayedArrays
//  * Remove unncessary outputs: beta_mat_var, hat_diagonals, deviance
//  * Remove beta divergence check if abs(beta) very large
//  * Add line search that ensures that deviance is decreasing at every step
//  * Add optional ridge penalty

// fit the Negative Binomial GLM with Fisher scoring
// note: the betas are on the natural log scale
//
template<class NumericType, class BMNumericType>
List fitBeta_fisher_scoring_impl(RObject Y, const arma::mat& model_matrix, RObject exp_offset_matrix,
NumericVector thetas, SEXP beta_matSEXP, Nullable<NumericMatrix> ridge_penalty_nl,
double tolerance, double max_rel_mu_change, int max_iter, bool use_diagonal_approx) {
auto Y_bm = beachmat::create_matrix<BMNumericType>(Y);
auto exp_offsets_bm = beachmat::create_numeric_matrix(exp_offset_matrix);
int n_samples = Y_bm->get_ncol();
int n_genes = Y_bm->get_nrow();

// the ridge penalty
bool apply_ridge_penalty = ridge_penalty_nl.isNotNull();
arma::mat ridge_penalty;
arma::mat ridge_penalty_sq;
arma::vec ridge_target;
if(apply_ridge_penalty){
NumericMatrix tmp = ridge_penalty_nl.get();
ridge_penalty = arma::mat(tmp.cbegin(), tmp.nrow(), tmp.ncol());
if(model_matrix.n_cols != ridge_penalty.n_cols){
stop("Number of columns in model_matrix does not match the columns of the ridge_penalty");
}
ridge_penalty_sq = ridge_penalty.t() * ridge_penalty;

if(tmp.hasAttribute("target")){
ridge_target = (NumericVector) tmp.attr("target");
}else{
ridge_target = arma::zeros(tmp.ncol());
}
}
// The result
arma::mat beta_mat = as<arma::mat>(beta_matSEXP);

// deviance, convergence and tolerance
NumericVector iterations(n_genes);
NumericVector deviance(n_genes);
for (int gene_idx = 0; gene_idx < n_genes; gene_idx++) {
if (gene_idx % 100 == 0) checkUserInterrupt();
// Fill count and offset vector from beachmat matrix
arma::Col<NumericType> counts(n_samples);
Y_bm->get_row(gene_idx, counts.begin());
arma::Col<double> exp_off(n_samples);
exp_offsets_bm->get_row(gene_idx, exp_off.begin());
// Init beta and mu
arma::vec beta_hat = beta_mat.row(gene_idx).t();
arma::vec mu_hat = calculate_mu(model_matrix, beta_hat, exp_off);
if(beta_hat.has_nan() || Rcpp::traits::is_na<REALSXP>(thetas(gene_idx))){
beta_hat.fill(NA_REAL);
iterations(gene_idx) = 0;
deviance(gene_idx) = NA_REAL;
continue;
}
// Init deviance
double dev_old = 0;
if(apply_ridge_penalty){
// For diagonal ridge_penalty: pen = Sum (lambda_i b_i)^2
double pen_sum = n_samples * arma::as_scalar((beta_hat - ridge_target).t() * ridge_penalty_sq * (beta_hat - ridge_target));
dev_old = compute_gp_deviance_sum(counts, mu_hat, thetas(gene_idx)) + pen_sum;
}else{
dev_old = compute_gp_deviance_sum(counts, mu_hat, thetas(gene_idx));
}
for (int t = 0; t < max_iter; t++) {
iterations(gene_idx)++;
// Find good direction to optimize beta
arma::vec step;
if(use_diagonal_approx){
step = fisher_scoring_diagonal_step(model_matrix, counts, mu_hat, thetas(gene_idx) * mu_hat);
}else{
if(apply_ridge_penalty){
step = fisher_scoring_qr_ridge_step(model_matrix, counts, mu_hat, thetas(gene_idx) * mu_hat, ridge_penalty, ridge_target, beta_hat);
}else{
step = fisher_scoring_qr_step(model_matrix, counts, mu_hat, thetas(gene_idx) * mu_hat);
}
}
// Find step size that actually decreases the deviance
double dev = 0;
if(apply_ridge_penalty){
dev = decrease_deviance_plus_ridge(beta_hat, mu_hat, step, model_matrix, ridge_penalty_sq, ridge_target,
exp_off, counts, thetas(gene_idx), dev_old, tolerance, max_rel_mu_change);
}else{
dev = decrease_deviance(beta_hat, mu_hat, step, model_matrix,
exp_off, counts, thetas(gene_idx), dev_old, tolerance, max_rel_mu_change);
}
double conv_test = fabs(dev - dev_old)/(fabs(dev) + 0.1);
dev_old = dev;
if (std::isnan(conv_test)) {
// This should not happen
beta_hat.fill(NA_REAL);
iterations(gene_idx) = max_iter;
break;
}
if (conv_test < tolerance) {
break;
}
}
beta_mat.row(gene_idx) = beta_hat.t();
deviance(gene_idx) = dev_old;
}

return List::create(
Named("beta_mat", beta_mat),
Named("iter", iterations),
Named("deviance", deviance));
}

// [[Rcpp::export]]
List fitBeta_fisher_scoring(RObject Y, const arma::mat& model_matrix, RObject exp_offset_matrix,
NumericVector thetas, SEXP beta_matSEXP, Nullable<NumericMatrix> ridge_penalty_nl,
double tolerance, double max_rel_mu_change, int max_iter) {
auto mattype=beachmat::find_sexp_type(Y);
if (mattype==INTSXP) {
return fitBeta_fisher_scoring_impl<int, beachmat::integer_matrix>(Y, model_matrix, exp_offset_matrix,
thetas,  beta_matSEXP,
/*ridge_penalty=*/ ridge_penalty_nl,
tolerance, max_rel_mu_change, max_iter,
/*use_diagonal_approx=*/ false);
} else if (mattype==REALSXP) {
return fitBeta_fisher_scoring_impl<double, beachmat::numeric_matrix>(Y, model_matrix, exp_offset_matrix,
thetas,  beta_matSEXP,
/*ridge_penalty=*/ ridge_penalty_nl,
tolerance, max_rel_mu_change, max_iter,
/*use_diagonal_approx=*/ false);
} else {
throw std::runtime_error("unacceptable matrix type");
}
}

// [[Rcpp::export]]
List fitBeta_diagonal_fisher_scoring(RObject Y, const arma::mat& model_matrix, RObject exp_offset_matrix,
NumericVector thetas, SEXP beta_matSEXP,
double tolerance, double max_rel_mu_change, int max_iter) {
auto mattype=beachmat::find_sexp_type(Y);
if (mattype==INTSXP) {
return fitBeta_fisher_scoring_impl<int, beachmat::integer_matrix>(Y, model_matrix, exp_offset_matrix,
thetas,  beta_matSEXP,
/*ridge_penalty=*/ R_NilValue,
tolerance, max_rel_mu_change, max_iter,
/*use_diagonal_approx=*/ true);
} else if (mattype==REALSXP) {
return fitBeta_fisher_scoring_impl<double, beachmat::numeric_matrix>(Y, model_matrix, exp_offset_matrix,
thetas,  beta_matSEXP,
/*ridge_penalty=*/ R_NilValue,
tolerance, max_rel_mu_change, max_iter,
/*use_diagonal_approx=*/ true);
} else {
throw std::runtime_error("unacceptable matrix type");
}
}

// If there is only one group, there is no need to do the full Fisher-scoring
// Instead a simple Newton-Raphson algorithm will do
template<class NumericType>
List fitBeta_one_group_internal(SEXP Y_SEXP, SEXP offsets_SEXP,
NumericVector thetas, NumericVector beta_start_values,
double tolerance, int maxIter) {
auto Y_bm = beachmat::create_matrix<NumericType>(Y_SEXP);

auto offsets_bm = beachmat::create_numeric_matrix(offsets_SEXP);
int n_samples = Y_bm->get_ncol();
int n_genes = Y_bm->get_nrow();
NumericVector result(n_genes);
IntegerVector iterations(n_genes);
NumericVector deviance(n_genes);

Environment glmGamPoiEnv = Environment::namespace_env("glmGamPoi");
Function estimate_betas_group_wise_optimize_helper = glmGamPoiEnv["estimate_betas_group_wise_optimize_helper"];

for(int gene_idx = 0; gene_idx < n_genes; gene_idx++){
if (gene_idx % 100 == 0) checkUserInterrupt();

double beta = beta_start_values(gene_idx);
const double& theta = thetas(gene_idx);
if(Rcpp::traits::is_na<REALSXP>(beta) || Rcpp::traits::is_na<REALSXP>(theta)){
// Missing values, just continue with next gene
result(gene_idx) = NA_REAL;
iterations(gene_idx) = 0;
deviance(gene_idx) = NA_REAL;
continue;
}

typename NumericType::vector counts(n_samples);
Y_bm->get_row(gene_idx, counts.begin());
NumericVector off(n_samples);
offsets_bm->get_row(gene_idx, off.begin());
// Newton-Raphson
int iter = 0;
for(; iter < maxIter; iter++){
double dl = 0.0;
double ddl = 0.0;
bool all_zero = true;
for(int sample_iter = 0; sample_iter < n_samples; sample_iter++){
const auto count = counts[sample_iter];
all_zero = all_zero && count == 0;
const double mu = std::exp(beta + off[sample_iter]);
const double denom = 1.0 + mu * theta;
dl += (count - mu) / denom;
ddl += mu * (1.0 + count * theta) / denom / denom;
// ddl += mu / denom;           // This is what edgeR is using
}
if(all_zero){
beta = R_NegInf;
break;
}
const double step = dl / ddl;
beta += step;
if(std::abs(step) < tolerance){
break;
}else if(Rcpp::traits::is_nan<REALSXP>(beta)){
break;
}
}
if(iter == maxIter || Rcpp::traits::is_nan<REALSXP>(beta)){
// Not converged -> try again with optimize()
beta =  Rcpp::as<double>(estimate_betas_group_wise_optimize_helper(counts, off, theta));
}
result(gene_idx) = beta;
iterations(gene_idx) = iter;
double dev = 0.0;
for(int sample_iter = 0; sample_iter < n_samples; sample_iter++){
dev += compute_gp_deviance(counts[sample_iter], exp(beta + off[sample_iter]), theta);
}
deviance(gene_idx) = dev;
}
return List::create(
Named("beta", result),
Named("iter", iterations),
Named("deviance", deviance)
);
}

// [[Rcpp::export(rng = false)]]
List fitBeta_one_group(RObject Y, RObject offset_matrix,
NumericVector thetas, NumericVector beta_start_values,
double tolerance, int maxIter) {
auto mattype=beachmat::find_sexp_type(Y);
if (mattype==INTSXP) {
return fitBeta_one_group_internal<beachmat::integer_matrix>(Y, offset_matrix, thetas, beta_start_values, tolerance, maxIter);
} else if (mattype==REALSXP) {
return fitBeta_one_group_internal<beachmat::numeric_matrix>(Y, offset_matrix, thetas, beta_start_values, tolerance, maxIter);
} else {
throw std::runtime_error("unacceptable matrix type");
}
}

```