#' @include pca.R
#' @include HermesData-methods.R
NULL

#' Calculation of R2 between Sample Variable and Principal Components
#'
#' @description `r lifecycle::badge("stable")`
#'
#' This helper function calculates R2 values between one sample variable from [`AnyHermesData`]
#' and all Principal Components (PCs) separately (one linear model is fit for each PC).
#'
#' @details Note that in case there are estimation problems for any of the PCs, then `NA` will
#' be returned for those.
#'
#' @param pca (`matrix`)\cr principal components matrix generated by [calc_pca()].
#' @param x (`vector`)\cr values of one sample variable from a [`AnyHermesData`] object.
#'
#' @return A vector with R2 values for each principal component.
#'
#' @export
#'
#' @examples
#' object <- hermes_data %>%
#'   add_quality_flags() %>%
#'   filter() %>%
#'   normalize()
#'
#' # Obtain the principal components.
#' pca <- calc_pca(object)$x
#'
#' # Obtain the sample variable.
#' x <- colData(object)$AGE18
#'
#' # Correlate them.
#' r2 <- h_pca_var_rsquared(pca, x)
h_pca_var_rsquared <- function(pca, x) {
  assert_that(
    is.matrix(pca),
    is.numeric(x) || is.factor(x) || is.character(x) || is.logical(x),
    identical(length(x), nrow(pca)),
    all(abs(colMeans(pca)) < 1e-10)
  )
  use_sample <- !is.na(x)
  x <- x[use_sample]
  if (is_constant(x)) {
    warning("sample variable is constant and R2 values cannot be calculated")
  }
  pca <- pca[use_sample, ]
  design <- stats::model.matrix(~x)
  # Transpose such that PCs are in rows, and samples in columns.
  y0 <- t(pca)
  utils::capture.output(fit <- limma::lmFit(y0, design = design))
  had_problems <- apply(fit$coefficients, 1L, function(row) any(is.na(row)))
  sst <- rowSums(y0^2)
  ssr <- sst - fit$df.residual * fit$sigma^2
  result <- ssr / sst
  result[had_problems] <- NA
  result
}

#' Calculation of R2 Matrix between Sample Variables and Principal Components
#'
#' @description `r lifecycle::badge("stable")`
#'
#' This function processes sample variables from [`AnyHermesData`] and the
#' corresponding principal components matrix, and then generates the matrix of R2 values.
#'
#' @details
#'   - Note that only the `df` columns which are `numeric`, `character`, `factor` or
#'     `logical` are included in the resulting matrix, because other variable types are not
#'     supported.
#'   - In addition, `df` columns which are constant, all `NA`, or `character` or `factor`
#'     columns with too many levels are also dropped before the analysis.
#'
#' @param pca (`matrix`)\cr comprises principal components generated by [calc_pca()].
#' @param df (`data.frame`)\cr from the [SummarizedExperiment::colData()] of a
#'   [`AnyHermesData`] object.
#'
#' @return A matrix with R2 values for all combinations of sample variables and principal
#'   components.
#'
#' @seealso [h_pca_var_rsquared()] which is used internally to calculate the R2 for one
#'   sample variable.
#'
#' @export
#'
#' @examples
#' object <- hermes_data %>%
#'   add_quality_flags() %>%
#'   filter() %>%
#'   normalize()
#'
#' # Obtain the principal components.
#' pca <- calc_pca(object)$x
#'
#' # Obtain the `colData` as a `data.frame`.
#' df <- as.data.frame(colData(object))
#'
#' # Correlate them.
#' r2_all <- h_pca_df_r2_matrix(pca, df)
#' str(r2_all)
#'
#' # We can see that only about half of the columns from `df` were
#' # used for the correlations.
#' ncol(r2_all)
#' ncol(df)
h_pca_df_r2_matrix <- function(pca, df) {
  assert_that(
    is.matrix(pca),
    is.data.frame(df),
    identical(nrow(pca), nrow(df))
  )
  # Sequentially filter down the columns in `df`.
  # Sample variable must be numeric, character, factor or logical.
  is_accepted_type <- vapply(df, function(x) {
    is.numeric(x) || is.character(x) || is.factor(x) || is.logical(x)
  }, TRUE)
  df <- df[, is_accepted_type]
  # Sample variable cannot be completely `NA`.
  is_all_na <- vapply(df, all_na, TRUE)
  df <- df[, !is_all_na]
  # Sample variable cannot have a constant value.
  is_all_constant <- vapply(df, is_constant, TRUE)
  df <- df[, !is_all_constant]
  # Filter character or factor sample variable that has too many (more than half the
  # number of samples) unique values.
  too_many_levels <- vapply(df, function(x) {
    (is.character(x) || is.factor(x)) && (length(unique(x)) > nrow(df) / 2)
  }, TRUE)
  df <- df[, !too_many_levels]
  # On all remaining columns, run R2 analysis vs. all principal components.
  vapply(
    X = df,
    FUN = h_pca_var_rsquared,
    pca = pca,
    FUN.VALUE = rep(0.5, ncol(pca))
  )
}

# correlate-HermesDataPca ----

#' Correlation of Principal Components with Sample Variables
#'
#' @description `r lifecycle::badge("stable")`
#'
#' This `correlate()` method analyses the correlations (in R2 values) between all sample variables
#' in a [`AnyHermesData`] object and the principal components of the samples.
#'
#' A corresponding `autoplot()` method then can visualize the results in a heatmap.
#'
#' @rdname pca_cor_samplevar
#' @aliases pca_cor_samplevar
#'
#' @param object (`HermesDataPca`)\cr input. It can be generated using [calc_pca()] function
#'   on [`AnyHermesData`].
#' @param data (`AnyHermesData`)\cr input that was used originally for the PCA.
#'
#' @return A [`HermesDataPcaCor`] object with R2 values for all sample variables.
#'
#' @seealso [h_pca_df_r2_matrix()] which is used internally for the details.
#'
#' @export
#'
#' @examples
#' object <- hermes_data %>%
#'   add_quality_flags() %>%
#'   filter() %>%
#'   normalize()
#'
#' # Perform PCA and then correlate the prinicipal components with the sample variables.
#' object_pca <- calc_pca(object)
#' result <- correlate(object_pca, object)
setMethod(
  f = "correlate",
  signature = c(object = "HermesDataPca"),
  definition = function(object, data) {
    pca <- object$x
    assert_that(
      is_hermes_data(data),
      is(object, "HermesDataPca"),
      identical(rownames(pca), colnames(data))
    )
    df <- as.data.frame(colData(data))
    r2_matrix <- h_pca_df_r2_matrix(pca, df)
    .HermesDataPcaCor(r2_matrix)
  }
)

# HermesDataPcaCor ----

#' @rdname pca_cor_samplevar
#' @aliases HermesDataPcaCor
#' @exportClass HermesDataPcaCor
.HermesDataPcaCor <- setClass( # nolint
  Class = "HermesDataPcaCor",
  contains = "matrix"
)

# autoplot-HermesDataPcaCor ----

#' @describeIn pca_cor_samplevar This plot method uses the [ComplexHeatmap::Heatmap()] function
#'   to visualize a [`HermesDataPcaCor`] object.
#'
#' @param cor_colors (`function`)\cr color scale function for the correlation values in the heatmap,
#'   produced by [circlize::colorRamp2()].
#' @param ... other arguments to be passed to [ComplexHeatmap::Heatmap()].
#'
#' @export
#'
#' @examples
#'
#' # Visualize the correlations in a heatmap.
#' autoplot(result)
#'
#' # We can also choose to not reorder the columns.
#' autoplot(result, cluster_columns = FALSE)
#'
#' # We can also choose break-points for color customization.
#' autoplot(
#'   result,
#'   cor_colors = circlize::colorRamp2(
#'     c(-0.5, -0.25, 0, 0.25, 0.5, 0.75, 1),
#'     c("blue", "green", "purple", "yellow", "orange", "red", "brown")
#'   )
#' )
setMethod(
  f = "autoplot",
  signature = c(object = "HermesDataPcaCor"),
  definition = function(object,
                        cor_colors = circlize::colorRamp2(
                          c(-1, 0, 1),
                          c("blue", "white", "red")
                        ),
                        ...) {
    mat <- as(object, "matrix")
    ComplexHeatmap::Heatmap(
      matrix = t(mat),
      col = cor_colors,
      name = "R2",
      ...
    )
  }
)