R/celda_CG.R
9f8e453d
 #' @title Cell and feature clustering with Celda
5a6c5283
 #' @description Clusters the rows and columns of a count matrix containing
 #'  single-cell data into L modules and K subpopulations, respectively.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells.
 #' @param sampleLabel Vector or factor. Denotes the sample label for each cell
 #'  (column) in the count matrix.
bf9aff71
 #' @param K Integer. Number of cell populations.
 #' @param L Integer. Number of feature modules.
5a6c5283
 #' @param alpha Numeric. Concentration parameter for Theta. Adds a pseudocount
 #'  to each cell population in each sample. Default 1.
 #' @param beta Numeric. Concentration parameter for Phi. Adds a pseudocount to
 #'  each feature module in each cell population. Default 1.
 #' @param delta Numeric. Concentration parameter for Psi. Adds a pseudocount to
 #'  each feature in each module. Default 1.
 #' @param gamma Numeric. Concentration parameter for Eta. Adds a pseudocount to
 #'  the number of features in each module. Default 1.
 #' @param algorithm String. Algorithm to use for clustering cell subpopulations.
 #'  One of 'EM' or 'Gibbs'. The EM algorithm for cell clustering is faster,
 #'  especially for larger numbers of cells. However, more chains may be required
 #'  to ensure a good solution is found. Default 'EM'.
 #' @param stopIter Integer. Number of iterations without improvement in the log
 #'  likelihood to stop inference. Default 10.
 #' @param maxIter Integer. Maximum number of iterations of Gibbs sampling to
 #'  perform. Default 200.
 #' @param splitOnIter Integer. On every `splitOnIter` iteration, a heuristic
 #'  will be applied to determine if a cell population or feature module should
 #'  be reassigned and another cell population or feature module should be split
 #'  into two clusters. To disable splitting, set to -1. Default 10.
 #' @param splitOnLast Integer. After `stopIter` iterations have been
 #'  performed without improvement, a heuristic will be applied to determine if
 #'  a cell population or feature module should be reassigned and another cell
 #'  population or feature module should be split into two clusters. If a split
 #'  occurs, then 'stopIter' will be reset. Default TRUE.
7365c4e1
 #' @param seed Integer. Passed to \link[withr]{with_seed}. For reproducibility,
 #'  a default value of 12345 is used. If NULL, no calls to
 #'  \link[withr]{with_seed} are made.
bf9aff71
 #' @param nchains Integer. Number of random cluster initializations. Default 3.
5a6c5283
 #' @param zInitialize Chararacter. One of 'random', 'split', or 'predefined'.
 #'  With 'random', cells are randomly assigned to a populations. With 'split',
 #'  cells will be split into sqrt(K) populations and then each popluation will
 #'  be subsequently split into another sqrt(K) populations. With 'predefined',
 #'  values in `zInit` will be used to initialize `z`. Default 'split'.
 #' @param yInitialize Chararacter. One of 'random', 'split', or 'predefined'.
 #'  With 'random', features are randomly assigned to a modules. With 'split',
 #'  features will be split into sqrt(L) modules and then each module will be
 #'  subsequently split into another sqrt(L) modules. With 'predefined', values
 #'  in `yInit` will be used to initialize `y`. Default 'split'.
 #' @param zInit Integer vector. Sets initial starting values of z. If NULL,
 #'  starting values for each cell will be randomly sampled from 1:K. 'zInit'
 #'  can only be used when `initialize' = 'random'`. Default NULL.
 #' @param yInit Integer vector. Sets initial starting values of y. If NULL,
 #'  starting values for each feature will be randomly sampled from 1:L.
 #'  'yInit' can only be used when `initialize = 'random'`. Default NULL.
 #' @param countChecksum Character. An MD5 checksum for the `counts` matrix.
 #'  Default NULL.
 #' @param logfile Character. Messages will be redirected to a file named
 #'  `logfile`. If NULL, messages will be printed to stdout.  Default NULL.
bf9aff71
 #' @param verbose Logical. Whether to print log messages. Default TRUE.
5a6c5283
 #' @return An object of class `celda_CG` with the cell populations clusters
 #'  stored in in `z` and feature module clusters stored in `y`.
 #' @seealso `celda_G()` for feature clustering and `celda_C()` for clustering
 #'  cells. `celdaGridSearch()` can be used to run multiple values of K/L and
 #'  multiple chains in parallel.
ec7b91a6
 #' @examples
a49fff03
 #' data(celdaCGSim)
5a6c5283
 #' celdaMod <- celda_CG(celdaCGSim$counts,
 #'     K = celdaCGSim$K,
 #'     L = celdaCGSim$L,
 #'     sampleLabel = celdaCGSim$sampleLabel,
 #'     nchains = 1)
3de9437b
 #' @import Rcpp RcppEigen
2905885c
 #' @rawNamespace import(gridExtra, except = c(combine))
af260337
 #' @export
5a6c5283
 celda_CG <- function(counts,
     sampleLabel = NULL,
     K,
     L,
     alpha = 1,
     beta = 1,
     delta = 1,
     gamma = 1,
     algorithm = c("EM", "Gibbs"),
     stopIter = 10,
     maxIter = 200,
     splitOnIter = 10,
     splitOnLast = TRUE,
c358fce6
     seed = 12345,
5a6c5283
     nchains = 3,
     zInitialize = c("split", "random", "predefined"),
     yInitialize = c("split", "random", "predefined"),
     countChecksum = NULL,
     zInit = NULL,
     yInit = NULL,
     logfile = NULL,
     verbose = TRUE) {
 
     .validateCounts(counts)
c358fce6
     if (is.null(seed)) {
         res <- .celda_CG(
5a6c5283
             counts,
             sampleLabel,
             K,
             L,
             alpha,
             beta,
             delta,
             gamma,
             algorithm,
             stopIter,
             maxIter,
             splitOnIter,
             splitOnLast,
             nchains,
             zInitialize,
             yInitialize,
             countChecksum,
             zInit,
             yInit,
             logfile,
             verbose,
c358fce6
             reorder = TRUE)
     } else {
         with_seed(seed,
             res <- .celda_CG(
                 counts,
                 sampleLabel,
                 K,
                 L,
                 alpha,
                 beta,
                 delta,
                 gamma,
                 algorithm,
                 stopIter,
                 maxIter,
                 splitOnIter,
                 splitOnLast,
                 nchains,
                 zInitialize,
                 yInitialize,
                 countChecksum,
                 zInit,
                 yInit,
                 logfile,
                 verbose,
                 reorder = TRUE))
     }
 
     return(res)
215a91c9
 }
 
5a6c5283
 .celda_CG <- function(counts,
     sampleLabel = NULL,
     K,
     L,
     alpha = 1,
     beta = 1,
     delta = 1,
     gamma = 1,
     algorithm = c("EM", "Gibbs"),
     stopIter = 10,
     maxIter = 200,
     splitOnIter = 10,
     splitOnLast = TRUE,
     nchains = 3,
     zInitialize = c("split", "random", "predefined"),
     yInitialize = c("split", "random", "predefined"),
     countChecksum = NULL,
     zInit = NULL,
     yInit = NULL,
     logfile = NULL,
     verbose = TRUE,
     reorder = TRUE) {
 
d2e8abdb
     .logMessages(paste(rep("-", 50), collapse = ""),
5a6c5283
         logfile = logfile,
         append = FALSE,
         verbose = verbose)
 
d2e8abdb
     .logMessages("Starting Celda_CG: Clustering cells and genes.",
5a6c5283
         logfile = logfile,
         append = TRUE,
         verbose = verbose)
 
d2e8abdb
     .logMessages(paste(rep("-", 50), collapse = ""),
5a6c5283
         logfile = logfile,
         append = TRUE,
         verbose = verbose)
 
     startTime <- Sys.time()
 
     counts <- .processCounts(counts)
     if (is.null(countChecksum)) {
         countChecksum <- .createCountChecksum(counts)
bf9aff71
     }
 
5a6c5283
     sampleLabel <- .processSampleLabels(sampleLabel, ncol(counts))
     s <- as.integer(sampleLabel)
 
     algorithm <- match.arg(algorithm)
     algorithmFun <- ifelse(algorithm == "Gibbs",
         ".cCCalcGibbsProbZ",
af4c3cb8
         ".cCCalcEMProbZ")
5a6c5283
     zInitialize <- match.arg(zInitialize)
     yInitialize <- match.arg(yInitialize)
 
75664a2f
     allChains <- seq(nchains)
5a6c5283
 
     # Pre-compute lgamma values
     lggamma <- lgamma(seq(0, nrow(counts) + L) + gamma)
     lgdelta <- c(NA, lgamma((seq(nrow(counts) + L) * delta)))
 
     bestResult <- NULL
75664a2f
     for (i in allChains) {
5a6c5283
         ## Initialize cluster labels
d2e8abdb
         .logMessages(date(),
5a6c5283
             ".. Initializing 'z' in chain",
             i,
             "with",
75664a2f
             paste0("'", zInitialize, "' "),
5a6c5283
             logfile = logfile,
             append = TRUE,
             verbose = verbose)
 
d2e8abdb
         .logMessages(date(),
5a6c5283
             ".. Initializing 'y' in chain",
             i,
             "with",
75664a2f
             paste0("'", yInitialize, "' "),
5a6c5283
             logfile = logfile,
             append = TRUE,
             verbose = verbose)
 
         if (zInitialize == "predefined") {
             if (is.null(zInit)) {
                 stop("'zInit' needs to specified when initilize.z == 'given'.")
             }
25256e1e
             z <- .initializeCluster(K,
5a6c5283
                 ncol(counts),
                 initial = zInit,
75664a2f
                 fixed = NULL)
5a6c5283
         } else if (zInitialize == "split") {
25256e1e
             z <- .initializeSplitZ(
5a6c5283
                 counts,
                 K = K,
                 alpha = alpha,
75664a2f
                 beta = beta)
5a6c5283
         } else {
25256e1e
             z <- .initializeCluster(K,
5a6c5283
                 ncol(counts),
                 initial = NULL,
75664a2f
                 fixed = NULL)
5a6c5283
         }
 
         if (yInitialize == "predefined") {
             if (is.null(yInit)) {
                 stop("'yInit' needs to specified when initilize.y == 'given'.")
             }
25256e1e
             y <- .initializeCluster(L,
5a6c5283
                     nrow(counts),
                     initial = yInit,
75664a2f
                     fixed = NULL)
5a6c5283
         } else if (yInitialize == "split") {
25256e1e
             y <- .initializeSplitY(counts,
5a6c5283
                     L,
                     beta = beta,
                     delta = delta,
75664a2f
                     gamma = gamma)
5a6c5283
         } else {
25256e1e
             y <- .initializeCluster(L,
5a6c5283
                     nrow(counts),
                     initial = NULL,
75664a2f
                     fixed = NULL)
5a6c5283
         }
 
         zBest <- z
         yBest <- y
 
         ## Calculate counts one time up front
         p <- .cCGDecomposeCounts(counts, s, z, y, K, L)
         mCPByS <- p$mCPByS
         nTSByC <- p$nTSByC
         nTSByCP <- p$nTSByCP
         nCP <- p$nCP
         nByG <- p$nByG
         nByC <- p$nByC
         nByTS <- p$nByTS
         nGByTS <- p$nGByTS
         nGByCP <- p$nGByCP
         nM <- p$nM
         nG <- p$nG
         nS <- p$nS
         rm(p)
 
         ll <- .cCGCalcLL(K = K,
             L = L,
             mCPByS = mCPByS,
             nTSByCP = nTSByCP,
             nByG = nByG,
             nByTS = nByTS,
             nGByTS = nGByTS,
             nS = nS,
             nG = nG,
             alpha = alpha,
             beta = beta,
             delta = delta,
             gamma = gamma)
 
         iter <- 1L
         numIterWithoutImprovement <- 0L
         doCellSplit <- TRUE
         doGeneSplit <- TRUE
         while (iter <= maxIter & numIterWithoutImprovement <= stopIter) {
             ## Gibbs sampling for each gene
             lgbeta <- lgamma(seq(0, max(nCP)) + beta)
             nextY <- .cGCalcGibbsProbY(counts = nGByCP,
                     nTSByC = nTSByCP,
                     nByTS = nByTS,
                     nGByTS = nGByTS,
                     nByG = nByG,
                     y = y,
                     L = L,
                     nG = nG,
                     beta = beta,
                     delta = delta,
                     gamma = gamma,
                     lgbeta = lgbeta,
                     lggamma = lggamma,
                     lgdelta = lgdelta)
             nTSByCP <- nextY$nTSByC
             nGByTS <- nextY$nGByTS
             nByTS <- nextY$nByTS
2e877ffe
             nTSByC <- .rowSumByGroupChange(counts, nTSByC, nextY$y, y, L)
5a6c5283
             y <- nextY$y
 
             ## Gibbs or EM sampling for each cell
             nextZ <- do.call(algorithmFun, list(counts = nTSByC,
                 mCPByS = mCPByS,
                 nGByCP = nTSByCP,
                 nCP = nCP,
                 nByC = nByC,
                 z = z,
                 s = s,
                 K = K,
                 nG = L,
                 nM = nM,
                 alpha = alpha,
                 beta = beta))
             mCPByS <- nextZ$mCPByS
             nTSByCP <- nextZ$nGByCP
             nCP <- nextZ$nCP
             nGByCP <- .colSumByGroupChange(counts, nGByCP, nextZ$z, z, K)
             z <- nextZ$z
 
             ## Perform split on i-th iteration defined by splitOnIter
             tempLl <- .cCGCalcLL(K = K,
                 L = L,
                 mCPByS = mCPByS,
                 nTSByCP = nTSByCP,
                 nByG = nByG,
                 nByTS = nByTS,
                 nGByTS = nGByTS,
                 nS = nS,
                 nG = nG,
                 alpha = alpha,
                 beta = beta,
                 delta = delta,
                 gamma = gamma)
 
             if (L > 2 & iter != maxIter &
                 (((numIterWithoutImprovement == stopIter &
                     !all(tempLl > ll)) & isTRUE(splitOnLast)) |
                         (splitOnIter > 0 & iter %% splitOnIter == 0 &
                             isTRUE(doGeneSplit)))) {
d2e8abdb
                 .logMessages(date(),
5a6c5283
                     " .... Determining if any gene clusters should be split.",
                     logfile = logfile,
                     append = TRUE,
                     sep = "",
                     verbose = verbose)
                 res <- .cCGSplitY(counts,
                         y,
                         mCPByS,
                         nGByCP,
                         nTSByC,
                         nTSByCP,
                         nByG,
                         nByTS,
                         nGByTS,
                         nCP,
                         s,
                         z,
                         K,
                         L,
                         nS,
                         nG,
                         alpha,
                         beta,
                         delta,
                         gamma,
                         yProb = t(nextY$probs),
                         maxClustersToTry = max(L / 2, 10),
                         minCell = 3)
d2e8abdb
                 .logMessages(res$message,
5a6c5283
                     logfile = logfile,
                     append = TRUE,
                     verbose = verbose)
 
                 # Reset convergence counter if a split occured
                 if (!isTRUE(all.equal(y, res$y))) {
                     numIterWithoutImprovement <- 1L
                     doGeneSplit <- TRUE
                 } else {
                     doGeneSplit <- FALSE
                 }
 
                 ## Re-calculate variables
                 y <- res$y
                 nTSByCP <- res$nTSByCP
                 nByTS <- res$nByTS
                 nGByTS <- res$nGByTS
                 nTSByC <- .rowSumByGroup(counts, group = y, L = L)
             }
bf9aff71
 
b6cf56ae
             if (K > 2 & iter != maxIter &
5a6c5283
                 (((numIterWithoutImprovement == stopIter &
                     !all(tempLl > ll)) & isTRUE(splitOnLast)) |
                         (splitOnIter > 0 & iter %% splitOnIter == 0 &
                             isTRUE(doCellSplit)))) {
d2e8abdb
                 .logMessages(date(),
5a6c5283
                     " .... Determining if any cell clusters should be split.",
                     logfile = logfile,
                     append = TRUE,
                     sep = "",
                     verbose = verbose)
                 res <- .cCGSplitZ(counts,
                         mCPByS,
                         nTSByC,
                         nTSByCP,
                         nByG,
                         nByTS,
                         nGByTS,
                         nCP,
                         s,
                         z,
                         K,
                         L,
                         nS,
                         nG,
                         alpha,
                         beta,
                         delta,
                         gamma,
                         zProb = t(nextZ$probs),
                         maxClustersToTry = K,
                         minCell = 3)
d2e8abdb
                 .logMessages(res$message,
5a6c5283
                     logfile = logfile,
                     append = TRUE,
                     verbose = verbose)
 
                 # Reset convergence counter if a split occured
                 if (!isTRUE(all.equal(z, res$z))) {
                     numIterWithoutImprovement <- 0L
                     doCellSplit <- TRUE
                 } else {
                     doCellSplit <- FALSE
                 }
 
                 ## Re-calculate variables
                 z <- res$z
                 mCPByS <- res$mCPByS
                 nTSByCP <- res$nTSByCP
                 nCP <- res$nCP
                 nGByCP <- .colSumByGroup(counts, group = z, K = K)
             }
 
             ## Calculate complete likelihood
             tempLl <- .cCGCalcLL(K = K,
                     L = L,
                     mCPByS = mCPByS,
                     nTSByCP = nTSByCP,
                     nByG = nByG,
                     nByTS = nByTS,
                     nGByTS = nGByTS,
                     nS = nS,
                     nG = nG,
                     alpha = alpha,
                     beta = beta,
                     delta = delta,
                     gamma = gamma)
             if ((all(tempLl > ll)) | iter == 1) {
                 zBest <- z
                 yBest <- y
                 llBest <- tempLl
                 numIterWithoutImprovement <- 1L
             } else {
                 numIterWithoutImprovement <- numIterWithoutImprovement + 1L
             }
             ll <- c(ll, tempLl)
 
d2e8abdb
             .logMessages(date(),
5a6c5283
                 " .... Completed iteration: ",
                 iter,
                 " | logLik: ",
                 tempLl,
                 logfile = logfile,
                 append = TRUE,
                 sep = "",
                 verbose = verbose)
             iter <- iter + 1L
         }
 
         names <- list(row = rownames(counts),
             column = colnames(counts),
             sample = levels(sampleLabel))
 
         result <- list(z = zBest,
             y = yBest,
             completeLogLik = ll,
             finalLogLik = llBest,
             K = K,
             L = L,
             alpha = alpha,
             beta = beta,
             delta = delta,
             gamma = gamma,
             sampleLabel = sampleLabel,
             names = names,
             countChecksum = countChecksum)
 
         class(result) <- "celda_CG"
 
         if (is.null(bestResult) ||
                 result$finalLogLik > bestResult$finalLogLik) {
             bestResult <- result
         }
 
d2e8abdb
         .logMessages(date(),
5a6c5283
             ".. Finished chain",
             i,
             logfile = logfile,
             append = TRUE,
             verbose = verbose)
     }
af260337
 
5a6c5283
     ## Peform reordering on final Z and Y assigments:
     bestResult <- methods::new("celda_CG",
         clusters = list(z = zBest, y = yBest),
af4c3cb8
         params = list(K = as.integer(K),
             L = as.integer(L),
5a6c5283
             alpha = alpha,
             beta = beta,
             delta = delta,
             gamma = gamma,
             countChecksum = countChecksum),
         completeLogLik = ll,
         finalLogLik = llBest,
         sampleLabel = sampleLabel,
         names = names)
     if (isTRUE(reorder)) {
         bestResult <- .reorderCeldaCG(counts = counts, res = bestResult)
     }
af260337
 
5a6c5283
     endTime <- Sys.time()
d2e8abdb
     .logMessages(paste(rep("-", 50), collapse = ""),
5a6c5283
         logfile = logfile,
         append = TRUE,
         verbose = verbose)
d2e8abdb
     .logMessages("Completed Celda_CG. Total time:",
5a6c5283
         format(difftime(endTime, startTime)),
         logfile = logfile,
         append = TRUE,
         verbose = verbose)
58380509
     .logMessages(paste(rep("-", 50), collapse = ""),
5a6c5283
         logfile = logfile,
         append = TRUE,
         verbose = verbose)
 
     return(bestResult)
 }
af260337
 
 
63a280d7
 #' @title Simulate cells from the celda_CG model
5a6c5283
 #' @description Generates a simulated counts matrix, cell subpopulation
 #'  clusters, sample labels, and feature module clusters according to the
 #'  generative process of the celda_CG model.
 #' @param model Character. Options available in `celda::availableModels`.
bf9aff71
 #' @param S Integer. Number of samples to simulate. Default 5.
5a6c5283
 #' @param CRange Integer vector. A vector of length 2 that specifies the lower
 #'  and upper bounds of the number of cells to be generated in each sample.
 #'  Default c(50, 100).
 #' @param NRange Integer vector. A vector of length 2 that specifies the lower
 #'  and upper bounds of the number of counts generated for each cell. Default
 #'  c(500, 1000).
48cd350b
 #' @param G Integer. The total number of features to be simulated. Default 100.
bf9aff71
 #' @param K Integer. Number of cell populations. Default 5.
48cd350b
 #' @param L Integer. Number of feature modules. Default 10.
5a6c5283
 #' @param alpha Numeric. Concentration parameter for Theta. Adds a pseudocount
 #'  to each cell population in each sample. Default 1.
 #' @param beta Numeric. Concentration parameter for Phi. Adds a pseudocount to
 #'  each feature module in each cell population. Default 1.
 #' @param gamma Numeric. Concentration parameter for Eta. Adds a pseudocount to
 #'  the number of features in each module. Default 5.
 #' @param delta Numeric. Concentration parameter for Psi. Adds a pseudocount to
 #'  each feature in each module. Default 1.
7365c4e1
 #' @param seed Integer. Passed to \link[withr]{with_seed}. For reproducibility,
 #'  a default value of 12345 is used. If NULL, no calls to
 #'  \link[withr]{with_seed} are made.
bf9aff71
 #' @param ... Additional parameters.
5a6c5283
 #' @return List. Contains the simulated matrix `counts`, cell population
 #'  clusters `z`, feature module clusters `y`, sample assignments `sampleLabel`,
 #'  and input parameters.
 #' @seealso `celda_C()` for simulating cell subpopulations and `celda_G()` for
 #'  simulating feature modules.
2f6fec9b
 #' @examples
30fa8113
 #' celdaCGSim <- simulateCells(model = "celda_CG")
af260337
 #' @export
804d499b
 simulateCellscelda_CG <- function(model,
5a6c5283
     S = 5,
     CRange = c(50, 100),
     NRange = c(500, 1000),
     G = 100,
     K = 5,
     L = 10,
     alpha = 1,
     beta = 1,
     gamma = 5,
     delta = 1,
c358fce6
     seed = 12345,
     ...) {
 
     if (is.null(seed)) {
         res <- .simulateCellscelda_CG(model = model,
             S = S,
             CRange = CRange,
             NRange = NRange,
             G = G,
             K = K,
             L = L,
             alpha = alpha,
             beta = beta,
             gamma = gamma,
             delta = delta,
             ...)
     } else {
         with_seed(seed,
             res <- .simulateCellscelda_CG(model = model,
                 S = S,
                 CRange = CRange,
                 NRange = NRange,
                 G = G,
                 K = K,
                 L = L,
                 alpha = alpha,
                 beta = beta,
                 gamma = gamma,
                 delta = delta,
                 ...))
     }
     return(res)
 }
 
 
 .simulateCellscelda_CG <- function(model = model,
     S = S,
     CRange = CRange,
     NRange = NRange,
     G = G,
     K = K,
     L = L,
     alpha = alpha,
     beta = beta,
     gamma = gamma,
     delta = delta,
5a6c5283
     ...) {
 
     ## Number of cells per sample
     nC <- sample(seq(CRange[1], CRange[2]), size = S, replace = TRUE)
     nCSum <- sum(nC)
     cellSampleLabel <- rep(seq(S), nC)
 
     ## Select number of transcripts per cell
     nN <- sample(seq(NRange[1], NRange[2]),
         size = length(cellSampleLabel),
         replace = TRUE)
 
     ## Generate cell population distribution for each sample
     theta <- t(.rdirichlet(S, rep(alpha, K)))
 
     ## Assign cells to cellular subpopulations
     z <- unlist(lapply(seq(S), function(i) {
         sample(seq(K),
             size = nC[i],
             prob = theta[, i],
             replace = TRUE)
     }))
 
     ## Generate transcriptional state distribution for each cell subpopulation
     phi <- .rdirichlet(K, rep(beta, L))
 
     ## Assign genes to gene modules
     eta <- .rdirichlet(1, rep(gamma, L))
     y <- sample(seq(L),
         size = G,
         prob = eta,
         replace = TRUE)
     if (length(table(y)) < L) {
         warning("Some gene modules did not receive any genes after sampling.",
             " Try increasing G and/or making gamma larger.")
         L <- length(table(y))
         y <- as.integer(as.factor(y))
     }
 
     psi <- matrix(0, nrow = G, ncol = L)
     for (i in seq(L)) {
         ind <- y == i
         psi[ind, i] <- .rdirichlet(1, rep(delta, sum(ind)))
     }
 
     ## Select transcript distribution for each cell
     cellCounts <- matrix(0, nrow = G, ncol = nCSum)
     for (i in seq(nCSum)) {
af4c3cb8
         transcriptionalStateDist <- as.integer(stats::rmultinom(1,
5a6c5283
             size = nN[i], prob = phi[z[i], ]))
         for (j in seq(L)) {
             if (transcriptionalStateDist[j] > 0) {
                 cellCounts[, i] <- cellCounts[, i] + stats::rmultinom(1,
                     size = transcriptionalStateDist[j], prob = psi[, j])
             }
         }
     }
 
     ## Ensure that there are no all-0 rows in the counts matrix, which violates
     ## a celda modeling
     ## constraint (columns are guarnteed at least one count):
     zeroRowIdx <- which(rowSums(cellCounts) == 0)
     if (length(zeroRowIdx > 0)) {
         cellCounts <- cellCounts[-zeroRowIdx, ]
         y <- y[-zeroRowIdx]
bf9aff71
     }
5a6c5283
 
     ## Assign gene/cell/sample names
     rownames(cellCounts) <- paste0("Gene_", seq(nrow(cellCounts)))
     colnames(cellCounts) <- paste0("Cell_", seq(ncol(cellCounts)))
     cellSampleLabel <- paste0("Sample_", seq(S))[cellSampleLabel]
     cellSampleLabel <- factor(cellSampleLabel,
         levels = paste0("Sample_", seq(S)))
 
     ## Peform reordering on final Z and Y assigments:
     cellCounts <- .processCounts(cellCounts)
     names <- list(row = rownames(cellCounts),
         column = colnames(cellCounts),
         sample = unique(cellSampleLabel))
     countChecksum <- .createCountChecksum(cellCounts)
     result <- methods::new("celda_CG",
         clusters = list(z = z, y = y),
af4c3cb8
         params = list(K = as.integer(K),
             L = as.integer(L),
5a6c5283
             alpha = alpha,
             beta = beta,
             delta = delta,
             gamma = gamma,
             countChecksum = countChecksum),
         sampleLabel = cellSampleLabel,
         names = names
     )
 
     result <- .reorderCeldaCG(counts = cellCounts, res = result)
 
06b0c870
     return(list(z = clusters(result)$z,
         y = clusters(result)$y,
5a6c5283
         sampleLabel = cellSampleLabel,
         counts = cellCounts,
         K = K,
         L = L,
         CRange = CRange,
         NRange = NRange,
         S = S,
         alpha = alpha,
         beta = beta,
         gamma = gamma,
75664a2f
         delta = delta)
5a6c5283
     )
af260337
 }
 
 
367fa0a7
 #' @title Matrix factorization for results from celda_CG
5a6c5283
 #' @description Generates factorized matrices showing the contribution of each
 #'  feature in each module, each module in each cell and/or cell population,
 #'  and each cell population in each sample.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells. This matrix should be the same as the one used to generate
 #'  `celdaMod`.
 #' @param celdaMod Celda model. Options are "celda_C" or "celda_CG". Celda
 #'  object of class "celda_CG".
 #' @param type Character vector. A vector containing one or more of "counts",
 #'  "proportion", or "posterior". "counts" returns the raw number of counts for
 #'  each factorized matrix. "proportions" returns the normalized probabilities
 #'  for each factorized matrix, which are calculated by dividing the raw counts
 #'  in each factorized matrix by the total counts in each column. "posterior"
 #'  returns the posterior estimates. Default
 #'  `c("counts", "proportion", "posterior")`.
 #' @return A list with elements for `counts`, `proportions`, or `posterior`
 #'  probabilities. Each element will be a list containing factorized matrices
 #'  for `module`, `cellPopulation`, and `sample`. Additionally, the
 #'  contribution of each module in each individual cell will be included in the
 #'  `cell` element of `counts` and `proportions` elements.
367fa0a7
 #' @seealso `celda_CG()` for clustering features and cells
bf9aff71
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGMod)
5a6c5283
 #' factorizedMatrices <- factorizeMatrix(celdaCGSim$counts,
 #'     celdaCGMod, "posterior")
bf9aff71
 #' @export
5a6c5283
 setMethod("factorizeMatrix", signature(celdaMod = "celda_CG"),
     function(counts,
         celdaMod,
         type = c("counts", "proportion", "posterior")) {
 
         counts <- .processCounts(counts)
         compareCountMatrix(counts, celdaMod)
 
06b0c870
         K <- params(celdaMod)$K
         L <- params(celdaMod)$L
         z <- clusters(celdaMod)$z
         y <- clusters(celdaMod)$y
         alpha <- params(celdaMod)$alpha
         beta <- params(celdaMod)$beta
         delta <- params(celdaMod)$delta
         gamma <- params(celdaMod)$gamma
         sampleLabel <- sampleLabel(celdaMod)
5a6c5283
         s <- as.integer(sampleLabel)
 
         ## Calculate counts one time up front
         p <- .cCGDecomposeCounts(counts, s, z, y, K, L)
         nS <- p$nS
         nG <- p$nG
         nM <- p$nM
         mCPByS <- p$mCPByS
         nTSByC <- p$nTSByC
         nTSByCP <- p$nTSByCP
         nByG <- p$nByG
         nByTS <- p$nByTS
         nGByTS <- p$nGByTS
         nGByTS[nGByTS == 0] <- 1
 
         nGByTS <- matrix(0, nrow = length(y), ncol = L)
         nGByTS[cbind(seq(nG), y)] <- p$nByG
 
         LNames <- paste0("L", seq(L))
         KNames <- paste0("K", seq(K))
06b0c870
         colnames(nTSByC) <- matrixNames(celdaMod)$column
5a6c5283
         rownames(nTSByC) <- LNames
         colnames(nGByTS) <- LNames
06b0c870
         rownames(nGByTS) <- matrixNames(celdaMod)$row
5a6c5283
         rownames(mCPByS) <- KNames
06b0c870
         colnames(mCPByS) <- matrixNames(celdaMod)$sample
5a6c5283
         colnames(nTSByCP) <- KNames
         rownames(nTSByCP) <- LNames
 
         countsList <- c()
         propList <- c()
         postList <- c()
         res <- list()
 
         if (any("counts" %in% type)) {
             countsList <- list(sample = mCPByS,
                 cellPopulation = nTSByCP,
                 cell = nTSByC,
                 module = nGByTS,
                 geneDistribution = nGByTS)
             res <- c(res, list(counts = countsList))
         }
 
         if (any("proportion" %in% type)) {
             ## Need to avoid normalizing cell/gene states with zero cells/genes
             uniqueZ <- sort(unique(z))
             tempNTSByCP <- nTSByCP
             tempNTSByCP[, uniqueZ] <- normalizeCounts(tempNTSByCP[, uniqueZ],
                 normalize = "proportion")
 
             uniqueY <- sort(unique(y))
             tempNGByTS <- nGByTS
             tempNGByTS[, uniqueY] <- normalizeCounts(tempNGByTS[, uniqueY],
                 normalize = "proportion")
             tempNGByTS <- nGByTS / sum(nGByTS)
 
             propList <- list(sample = normalizeCounts(mCPByS,
                 normalize = "proportion"),
                 cellPopulation = tempNTSByCP,
                 cell = normalizeCounts(nTSByC, normalize = "proportion"),
                 module = tempNGByTS,
                 geneDistribution = tempNGByTS)
             res <- c(res, list(proportions = propList))
         }
 
         if (any("posterior" %in% type)) {
             gs <- nGByTS
             gs[cbind(seq(nG), y)] <- gs[cbind(seq(nG), y)] + delta
             gs <- normalizeCounts(gs, normalize = "proportion")
             tempNGByTS <- (nGByTS + gamma) / sum(nGByTS + gamma)
 
             postList <- list(sample = normalizeCounts(mCPByS + alpha,
                 normalize = "proportion"),
                 cellPopulation = normalizeCounts(nTSByCP + beta,
                     normalize = "proportion"),
                 module = gs,
                 geneDistribution = tempNGByTS)
             res <- c(res, posterior = list(postList))
         }
 
         return(res)
     })
af260337
 
 # Calculate the loglikelihood for the celda_CG model
5a6c5283
 .cCGCalcLL <- function(K,
     L,
     mCPByS,
     nTSByCP,
     nByG,
     nByTS,
     nGByTS,
     nS,
     nG,
     alpha,
     beta,
     delta,
     gamma) {
 
     nG <- sum(nGByTS)
 
     ## Calculate for "Theta" component
     a <- nS * lgamma(K * alpha)
     b <- sum(lgamma(mCPByS + alpha))
     c <- -nS * K * lgamma(alpha)
     d <- -sum(lgamma(colSums(mCPByS + alpha)))
 
     thetaLl <- a + b + c + d
 
     ## Calculate for "Phi" component
     a <- K * lgamma(L * beta)
     b <- sum(lgamma(nTSByCP + beta))
     c <- -K * L * lgamma(beta)
     d <- -sum(lgamma(colSums(nTSByCP + beta)))
 
     phiLl <- a + b + c + d
 
     ## Calculate for "Psi" component
     a <- sum(lgamma(nGByTS * delta))
     b <- sum(lgamma(nByG + delta))
     c <- -nG * lgamma(delta)
     d <- -sum(lgamma(nByTS + (nGByTS * delta)))
 
     psiLl <- a + b + c + d
 
     ## Calculate for "Eta" side
     a <- lgamma(L * gamma)
     b <- sum(lgamma(nGByTS + gamma))
     c <- -L * lgamma(gamma)
     d <- -lgamma(sum(nGByTS + gamma))
 
     etaLl <- a + b + c + d
 
2e877ffe
     final <- thetaLl + phiLl + psiLl + etaLl
5a6c5283
     return(final)
af260337
 }
 
 
a82d1665
 #' @title Calculate Celda_CG log likelihood
5a6c5283
 #' @description Calculates the log likelihood for user-provided cell population
 #'  and feature module clusters using the `celda_CG()` model.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells.
 #' @param sampleLabel Vector or factor. Denotes the sample label for each cell
 #'  (column) in the count matrix.
bf9aff71
 #' @param z Numeric vector. Denotes cell population labels.
 #' @param y Numeric vector. Denotes feature module labels.
 #' @param K Integer. Number of cell populations.
 #' @param L Integer. Number of feature modules.
5a6c5283
 #' @param alpha Numeric. Concentration parameter for Theta. Adds a pseudocount
 #'  to each cell population in each sample. Default 1.
 #' @param beta Numeric. Concentration parameter for Phi. Adds a pseudocount to
 #'  each feature module in each cell population. Default 1.
 #' @param delta Numeric. Concentration parameter for Psi. Adds a pseudocount to
 #'  each feature in each module. Default 1.
 #' @param gamma Numeric. Concentration parameter for Eta. Adds a pseudocount to
 #'  the number of features in each module. Default 1.
a82d1665
 #' @return The log likelihood for the given cluster assignments
 #' @seealso `celda_CG()` for clustering features and cells
6839068a
 #' @examples
a49fff03
 #' data(celdaCGSim)
804d499b
 #' loglik <- logLikelihoodcelda_CG(celdaCGSim$counts,
5a6c5283
 #'     sampleLabel = celdaCGSim$sampleLabel,
 #'     z = celdaCGSim$z,
 #'     y = celdaCGSim$y,
 #'     K = celdaCGSim$K,
 #'     L = celdaCGSim$L,
 #'     alpha = celdaCGSim$alpha,
 #'     beta = celdaCGSim$beta,
 #'     gamma = celdaCGSim$gamma,
 #'     delta = celdaCGSim$delta)
bf9aff71
 #'
5a6c5283
 #' loglik <- logLikelihood(celdaCGSim$counts,
 #'     model = "celda_CG",
 #'     sampleLabel = celdaCGSim$sampleLabel,
 #'     z = celdaCGSim$z,
 #'     y = celdaCGSim$y,
 #'     K = celdaCGSim$K,
 #'     L = celdaCGSim$L,
 #'     alpha = celdaCGSim$alpha,
 #'     beta = celdaCGSim$beta,
 #'     gamma = celdaCGSim$gamma,
 #'     delta = celdaCGSim$delta)
af260337
 #' @export
804d499b
 logLikelihoodcelda_CG <- function(counts,
5a6c5283
     sampleLabel,
     z,
     y,
     K,
     L,
     alpha,
     beta,
     delta,
     gamma) {
 
     if (sum(z > K) > 0) {
         stop("An entry in z contains a value greater than the provided K.")
     }
     if (sum(y > L) > 0) {
         stop("An entry in y contains a value greater than the provided L.")
     }
 
     sampleLabel <- .processSampleLabels(sampleLabel, ncol(counts))
     s <- as.integer(sampleLabel)
     p <- .cCGDecomposeCounts(counts, s, z, y, K, L)
     final <- .cCGCalcLL(K = K,
             L = L,
             mCPByS = p$mCPByS,
             nTSByCP = p$nTSByCP,
             nByG = p$nByG,
             nByTS = p$nByTS,
             nGByTS = p$nGByTS,
             nS = p$nS,
             nG = p$nG,
             alpha = alpha,
             beta = beta,
             delta = delta,
             gamma = gamma)
     return(final)
af260337
 }
 
 
5a6c5283
 # Takes raw counts matrix and converts it to a series of matrices needed for
 # log likelihood calculation
 # @param counts Integer matrix. Rows represent features and columns represent
 # cells.
 # @param s Integer vector. Contains the sample label for each cell (column) in
 # the count matrix.
bf9aff71
 # @param z Numeric vector. Denotes cell population labels.
 # @param y Numeric vector. Denotes feature module labels.
 # @param K Integer. Number of cell populations.
 # @param L Integer. Number of feature modules.
5a6c5283
 .cCGDecomposeCounts <- function(counts, s, z, y, K, L) {
     nS <- length(unique(s))
     mCPByS <- matrix(as.integer(table(factor(z, levels = seq(K)), s)),
         ncol = nS)
     nTSByC <- .rowSumByGroup(counts, group = y, L = L)
     nTSByCP <- .colSumByGroup(nTSByC, group = z, K = K)
     nCP <- as.integer(colSums(nTSByCP))
     nByG <- as.integer(rowSums(counts))
     nByC <- as.integer(colSums(counts))
     nByTS <- as.integer(.rowSumByGroup(matrix(nByG, ncol = 1),
         group = y, L = L))
     nGByTS <- tabulate(y, L) + 1 ## Add pseudogene to each module
     nGByCP <- .colSumByGroup(counts, group = z, K = K)
 
     nG <- nrow(counts)
     nM <- ncol(counts)
 
     return(list(mCPByS = mCPByS,
             nTSByC = nTSByC,
             nTSByCP = nTSByCP,
             nCP = nCP,
             nByG = nByG,
             nByC = nByC,
             nByTS = nByTS,
             nGByTS = nGByTS,
             nGByCP = nGByCP,
             nM = nM,
             nG = nG,
             nS = nS))
bf9aff71
 }
af260337
 
 
5a6c5283
 #' @title Conditional probabilities for cells and features from a Celda_CG
 #'  model
 #' @description Calculates the conditional probability of each cell belonging
 #'  to each subpopulation given all other cell cluster assignments as well as
 #'  each feature belonging to each module given all other feature cluster
 #'  assignments in a `celda_CG()` result.
 #' @param celdaMod Celda object of class `celda_CG`.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells. This matrix should be the same as the one used to generate
 #'  `celdaMod`.
 #' @param log Logical. If FALSE, then the normalized conditional probabilities
 #'  will be returned. If TRUE, then the unnormalized log probabilities will be
 #'  returned. Default FALSE.
af260337
 #' @param ... Additional parameters.
5a6c5283
 #' @return A list containging a matrix for the conditional cell and feature
 #'  cluster probabilities.
a82d1665
 #' @seealso `celda_CG()` for clustering features and cells
24363d24
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGMod)
5a6c5283
 #' clusterProb <- clusterProbability(celdaCGSim$counts, celdaCGMod)
af260337
 #' @export
5a6c5283
 setMethod("clusterProbability", signature(celdaMod = "celda_CG"),
     function(counts, celdaMod, log = FALSE, ...) {
06b0c870
         s <- as.integer(sampleLabel(celdaMod))
         z <- clusters(celdaMod)$z
         K <- params(celdaMod)$K
         y <- clusters(celdaMod)$y
         L <- params(celdaMod)$L
         alpha <- params(celdaMod)$alpha
         delta <- params(celdaMod)$delta
         beta <- params(celdaMod)$beta
         gamma <- params(celdaMod)$gamma
5a6c5283
 
         p <- .cCGDecomposeCounts(counts, s, z, y, K, L)
         lgbeta <- lgamma(seq(0, max(p$nCP)) + beta)
         lggamma <- lgamma(seq(0, nrow(counts) + L) + gamma)
         lgdelta <- c(NA, lgamma((seq(nrow(counts) + L) * delta)))
 
         nextZ <- .cCCalcGibbsProbZ(counts = p$nTSByC,
             mCPByS = p$mCPByS,
             nGByCP = p$nTSByCP,
             nCP = p$nCP,
             nByC = p$nByC,
             z = z,
             s = s,
             K = K,
             nG = L,
             nM = p$nM,
             alpha = alpha,
             beta = beta,
             doSample = FALSE)
         zProb <- t(nextZ$probs)
 
         ## Gibbs sampling for each gene
         nextY <- .cGCalcGibbsProbY(counts = p$nGByCP,
             nTSByC = p$nTSByCP,
             nByTS = p$nByTS,
             nGByTS = p$nGByTS,
             nByG = p$nByG,
             y = y,
             L = L,
             nG = p$nG,
             lgbeta = lgbeta,
             lgdelta = lgdelta,
             lggamma = lggamma,
             delta = delta,
             doSample = FALSE)
 
         yProb <- t(nextY$probs)
 
         if (!isTRUE(log)) {
2e877ffe
             zProb <- .normalizeLogProbs(zProb)
             yProb <- .normalizeLogProbs(yProb)
5a6c5283
         }
 
         return(list(zProbability = zProb, yProbability = yProb))
     })
af260337
 
 
91de8c6d
 #' @title Calculate the perplexity on new data with a celda_CG model
5a6c5283
 #' @description Perplexity is a statistical measure of how well a probability
 #'  model can predict new data. Lower perplexity indicates a better model.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells. This matrix should be the same as the one used to generate
 #'  `celdaMod`.
 #' @param celdaMod Celda object of class "celda_C", "celda_G" or "celda_CG".
 #' @param newCounts A new counts matrix used to calculate perplexity. If NULL,
 #'  perplexity will be calculated for the 'counts' matrix. Default NULL.
91f4ba11
 #' @return Numeric. The perplexity for the provided count data and model.
91de8c6d
 #' @seealso `celda_CG()` for clustering features and cells
91f4ba11
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGMod)
5a6c5283
 #' perplexity <- perplexity(celdaCGSim$counts, celdaCGMod)
d7196f24
 #' @importFrom matrixStats logSumExp
af260337
 #' @export
5a6c5283
 setMethod("perplexity", signature(celdaMod = "celda_CG"),
     function(counts, celdaMod, newCounts = NULL) {
 
         if (!("celda_CG" %in% class(celdaMod))) {
             stop("The celdaMod provided was not of class celda_CG.")
         }
 
         counts <- .processCounts(counts)
         compareCountMatrix(counts, celdaMod)
 
         if (is.null(newCounts)) {
             newCounts <- counts
         } else {
             newCounts <- .processCounts(newCounts)
         }
         if (nrow(newCounts) != nrow(counts)) {
             stop("newCounts should have the same number of rows as counts.")
         }
 
         factorized <- factorizeMatrix(counts = counts,
             celdaMod = celdaMod,
             type = c("posterior", "counts"))
 
         theta <- log(factorized$posterior$sample)
         phi <- factorized$posterior$cellPopulation
         psi <- factorized$posterior$module
06b0c870
         s <- as.integer(sampleLabel(celdaMod))
5a6c5283
         eta <- factorized$posterior$geneDistribution
         nGByTS <- factorized$counts$geneDistribution
 
         etaProb <- log(eta) * nGByTS
         geneByPopProb <- log(psi %*% phi)
         innerLogProb <- eigenMatMultInt(geneByPopProb, newCounts) + theta[, s]
         # innerLogProb = (t(geneByPopProb) %*% newCounts) + theta[, s]
 
         log.px <- sum(apply(innerLogProb, 2, matrixStats::logSumExp))
         # + sum(etaProb)
b6cf56ae
         perplexity <- exp(- (log.px / sum(newCounts)))
5a6c5283
         return(perplexity)
     })
 
 
 .reorderCeldaCG <- function(counts, res) {
     # Reorder K
06b0c870
     if (params(res)$K > 2 & isTRUE(length(unique(clusters(res)$z)) > 1)) {
5a6c5283
 
1eaf423b
         res@clusters$z <- as.integer(as.factor(clusters(res)$z))
5a6c5283
         fm <- factorizeMatrix(counts = counts,
             celdaMod = res,
             type = "posterior")
06b0c870
         uniqueZ <- sort(unique(clusters(res)$z))
5a6c5283
         d <- .cosineDist(fm$posterior$cellPopulation[, uniqueZ])
         h <- stats::hclust(d, method = "complete")
 
         res <- recodeClusterZ(res,
             from = h$order,
             to = seq(length(h$order)))
     }
bf9aff71
 
5a6c5283
     # Reorder L
06b0c870
     if (params(res)$L > 2 & isTRUE(length(unique(clusters(res)$y)) > 1)) {
bf9aff71
 
53beedc9
         res@clusters$y <- as.integer(as.factor(clusters(res)$y))
5a6c5283
         fm <- factorizeMatrix(counts = counts,
                 celdaMod = res,
                 type = "posterior")
06b0c870
         uniqueY <- sort(unique(clusters(res)$y))
5a6c5283
         cs <- prop.table(t(fm$posterior$cellPopulation[uniqueY, ]), 2)
         d <- .cosineDist(cs)
         h <- stats::hclust(d, method = "complete")
bf9aff71
 
5a6c5283
         res <- recodeClusterY(res, from = h$order, to = seq(length(h$order)))
     }
     return(res)
af260337
 }
 
 
5533bb39
 #' @title Heatmap for celda_CG
5a6c5283
 #' @description Renders an expression heatmap to visualize `celda_CG()` results.
 #'  The top `nfeatures` for each module will be included in the heatmap.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells. This matrix should be the same as the one used to generate.
 #'  `celdaMod`.
 #' @param celdaMod Celda object of class `celda_CG`.
 #' @param nfeatures Integer. Maximum number of features to select for each
 #'  module. Default 25.
af260337
 #' @param ... Additional parameters.
5a6c5283
 #' @seealso `celda_CG()` for clustering features and cells and `celdaTsne()`
 #'  for generating 2-dimensional coordinates.
bf9aff71
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGMod)
5a6c5283
 #' celdaHeatmap(celdaCGSim$counts, celdaCGMod)
5533bb39
 #' @return A list containing dendrograms and the heatmap grob
af260337
 #' @export
5a6c5283
 setMethod("celdaHeatmap", signature(celdaMod = "celda_CG"),
     function(counts, celdaMod, nfeatures = 25, ...) {
         fm <- factorizeMatrix(counts, celdaMod, type = "proportion")
         top <- celda::topRank(fm$proportions$module, n = nfeatures)
         ix <- unlist(top$index)
         norm <- normalizeCounts(counts,
                 normalize = "proportion",
                 transformationFun = sqrt)
         plotHeatmap(norm[ix, ],
06b0c870
             z = clusters(celdaMod)$z,
             y = clusters(celdaMod)$y[ix],
5a6c5283
             ...)
     })
af260337
 
 
2ced6125
 #' @title tSNE for celda_CG
5a6c5283
 #' @description Embeds cells in two dimensions using tSNE based on a `celda_CG`
 #'  model. tSNE is run on module probabilities to reduce the number of features
 #'  instead of using PCA. Module probabilities square-root trasformed before
 #'  applying tSNE.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells. This matrix should be the same as the one used to generate
 #'  `celdaMod`.
 #' @param celdaMod Celda object of class `celda_CG`.
 #' @param maxCells Integer. Maximum number of cells to plot. Cells will be
 #'  randomly subsampled if ncol(counts) > maxCells. Larger numbers of cells
 #'  requires more memory. Default 25000.
 #' @param minClusterSize Integer. Do not subsample cell clusters below this
 #'  threshold. Default 100.
 #' @param initialDims Integer. PCA will be used to reduce the dimentionality
 #'  of the dataset. The top 'initialDims' principal components will be used
 #'  for tSNE. Default 20.
 #' @param modules Integer vector. Determines which features modules to use for
 #'  tSNE. If NULL, all modules will be used. Default NULL.
3df37a6d
 #' @param perplexity Numeric. Perplexity parameter for tSNE. Default 20.
5a6c5283
 #' @param maxIter Integer. Maximum number of iterations in tSNE generation.
 #'  Default 2500.
7365c4e1
 #' @param seed Integer. Passed to \link[withr]{with_seed}. For reproducibility,
 #'  a default value of 12345 is used. If NULL, no calls to
 #'  \link[withr]{with_seed} are made.
5a6c5283
 #' @seealso `celda_CG()` for clustering features and cells  and `celdaHeatmap()`
 #'  for displaying expression
3855c560
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGMod)
2e877ffe
 #' tsneRes <- celdaTsne(celdaCGSim$counts, celdaCGMod)
0bd3cd14
 #' @return A two column matrix of t-SNE coordinates
af260337
 #' @export
5a6c5283
 setMethod("celdaTsne", signature(celdaMod = "celda_CG"),
     function(counts,
         celdaMod,
         maxCells = 25000,
         minClusterSize = 100,
         initialDims = 20,
         modules = NULL,
         perplexity = 20,
c358fce6
         maxIter = 2500,
         seed = 12345) {
 
         if (is.null(seed)) {
             res <- .celdaTsneCG(counts = counts,
                 celdaMod = celdaMod,
                 maxCells = maxCells,
                 minClusterSize = minClusterSize,
                 initialDims = initialDims,
                 modules = modules,
                 perplexity = perplexity,
                 maxIter = maxIter)
         } else {
             with_seed(seed,
                 res <- .celdaTsneCG(counts = counts,
                     celdaMod = celdaMod,
                     maxCells = maxCells,
                     minClusterSize = minClusterSize,
                     initialDims = initialDims,
                     modules = modules,
                     perplexity = perplexity,
                     maxIter = maxIter))
         }
5a6c5283
 
c358fce6
         return(res)
5a6c5283
     })
ef11fcec
 
 
c358fce6
 .celdaTsneCG <- function(counts,
     celdaMod,
     maxCells = 25000,
     minClusterSize = 100,
     initialDims = 20,
     modules = NULL,
     perplexity = 20,
     maxIter = 2500) {
 
     preparedCountInfo <- .prepareCountsForDimReductionCeldaCG(counts,
         celdaMod,
         maxCells,
         minClusterSize,
         modules)
     norm <- preparedCountInfo$norm
     res <- .calculateTsne(norm,
         doPca = FALSE,
         perplexity = perplexity,
         maxIter = maxIter,
         initialDims = initialDims)
     final <- matrix(NA, nrow = ncol(counts), ncol = 2)
     final[preparedCountInfo$cellIx, ] <- res
     rownames(final) <- colnames(counts)
     colnames(final) <- c("tsne_1", "tsne_2")
     return(final)
 }
 
 
8b60a355
 #' @title umap for celda_CG
5a6c5283
 #' @description Embeds cells in two dimensions using umap based on a `celda_CG`
 #'  model. umap is run on module probabilities to reduce the number of features
 #'  instead of using PCA. Module probabilities square-root trasformed before
 #'  applying tSNE.
bf9aff71
 #'
5a6c5283
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells. This matrix should be the same as the one used to generate
 #'  `celdaMod`.
 #' @param celdaMod Celda object of class `celda_CG`.
 #' @param maxCells Integer. Maximum number of cells to plot. Cells will be
 #'  randomly subsampled if ncol(counts) > maxCells. Larger numbers of cells
 #'  requires more memory. Default 25000.
 #' @param minClusterSize Integer. Do not subsample cell clusters below this
 #'  threshold. Default 100.
 #' @param modules Integer vector. Determines which features modules to use for
 #'  tSNE. If NULL, all modules will be used. Default NULL.
7365c4e1
 #' @param seed Integer. Passed to \link[withr]{with_seed}. For reproducibility,
 #'  a default value of 12345 is used. If NULL, no calls to
 #'  \link[withr]{with_seed} are made.
5a6c5283
 #' @param umapConfig Object of class `umap.config`. Configures parameters for
 #'  umap. Default `umap::umap.defaults`.
 #' @seealso `celda_CG()` for clustering features and cells and `celdaHeatmap()`
 #'  for displaying expression.
8b60a355
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGMod)
5a6c5283
 #' umapRes <- celdaUmap(celdaCGSim$counts, celdaCGMod)
6e069679
 #' @return A two column matrix of umap coordinates
8b60a355
 #' @export
 setMethod("celdaUmap",
5a6c5283
     signature(celdaMod = "celda_CG"), function(counts,
         celdaMod,
         maxCells = 25000,
         minClusterSize = 100,
         modules = NULL,
c358fce6
         seed = 12345,
5a6c5283
         umapConfig = umap::umap.defaults) {
 
c358fce6
         if (is.null(seed)) {
             res <- .celdaUmapCG(counts = counts,
                 celdaMod = celdaMod,
                 maxCells = maxCells,
                 minClusterSize = minClusterSize,
                 modules = modules,
                 umapConfig = umapConfig)
         } else {
             with_seed(seed,
                 res <- .celdaUmapCG(counts = counts,
                     celdaMod = celdaMod,
                     maxCells = maxCells,
                     minClusterSize = minClusterSize,
                     modules = modules,
                     umapConfig = umapConfig))
         }
 
         return(res)
5a6c5283
     })
 
 
c358fce6
 .celdaUmapCG <- function(counts,
     celdaMod,
     maxCells = 25000,
     minClusterSize = 100,
     modules = NULL,
     umapConfig = umap::umap.defaults) {
 
     preparedCountInfo <- .prepareCountsForDimReductionCeldaCG(counts,
         celdaMod,
         maxCells,
         minClusterSize,
         modules)
     umapRes <- .calculateUmap(preparedCountInfo$norm, umapConfig)
     final <- matrix(NA, nrow = ncol(counts), ncol = 2)
     final[preparedCountInfo$cellIx, ] <- umapRes
     rownames(final) <- colnames(counts)
     colnames(final) <- c("umap_1", "umap_2")
     return(final)
 }
 
 
5a6c5283
 .prepareCountsForDimReductionCeldaCG <- function(counts,
     celdaMod,
     maxCells = 25000,
     minClusterSize = 100,
2e877ffe
     modules = NULL) {
5a6c5283
 
     ## Checking if maxCells and minClusterSize will work
     if ((maxCells < ncol(counts)) &
06b0c870
         (maxCells / minClusterSize < params(celdaMod)$K)) {
5a6c5283
 
         stop("Cannot distribute ",
             maxCells,
             " cells among ",
06b0c870
             params(celdaMod)$K,
5a6c5283
             " clusters while maintaining a minumum of ",
             minClusterSize,
             " cells per cluster. Try increasing 'maxCells' or",
             " decreasing 'minClusterSize'.")
8b60a355
     }
bf9aff71
 
5a6c5283
     fm <- factorizeMatrix(counts = counts,
         celdaMod = celdaMod,
         type = "counts")
     modulesToUse <- seq(nrow(fm$counts$cell))
8b60a355
     if (!is.null(modules)) {
5a6c5283
         if (!all(modules %in% modulesToUse)) {
             stop("'modules' must be a vector of numbers between 1 and ",
                 modulesToUse,
                 ".")
         }
         modulesToUse <- modules
8b60a355
     }
bf9aff71
 
5a6c5283
     ## Select a subset of cells to sample if greater than 'maxCells'
     totalCellsToRemove <- ncol(counts) - maxCells
     zInclude <- rep(TRUE, ncol(counts))
 
     if (totalCellsToRemove > 0) {
06b0c870
         zTa <- tabulate(clusters(celdaMod)$z, params(celdaMod)$K)
5a6c5283
 
         ## Number of cells that can be sampled from each cluster without
         ## going below the minimum threshold
         clusterCellsToSample <- zTa - minClusterSize
         clusterCellsToSample[clusterCellsToSample < 0] <- 0
 
         ## Number of cells to sample after exluding smaller clusters
         ## Rounding can cause number to be off by a few, so ceiling is used
         ## with a second round of subtraction
         clusterNToSample <- ceiling((clusterCellsToSample /
             sum(clusterCellsToSample)) * totalCellsToRemove)
         diff <- sum(clusterNToSample) - totalCellsToRemove
         clusterNToSample[which.max(clusterNToSample)] <-
             clusterNToSample[which.max(clusterNToSample)] - diff
 
         ## Perform sampling for each cluster
         for (i in which(clusterNToSample > 0)) {
06b0c870
             zInclude[sample(which(clusters(celdaMod)$z == i),
5a6c5283
                 clusterNToSample[i])] <- FALSE
         }
bf9aff71
     }
5a6c5283
     cellIx <- which(zInclude)
bf9aff71
 
5a6c5283
     norm <- t(normalizeCounts(fm$counts$cell[modulesToUse, cellIx],
         normalize = "proportion",
         transformationFun = sqrt))
     return(list(norm = norm, cellIx = cellIx))
8b60a355
 }
 
ef11fcec
 
2ced6125
 #' @title Probability map for a celda_CG model
5a6c5283
 #' @description Renders probability and relative expression heatmaps to
 #'  visualize the relationship between features and cell populations or cell
 #'  populations and samples.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells. This matrix should be the same as the one used to generate
 #'  `celdaMod`.
 #' @param celdaMod Celda object of class `celda_CG`.
 #' @param level Character. One of 'cellPopulation' or 'sample'.
 #'  'cellPopulation' will display the absolute probabilities and relative
 #'  normalized expression of each module in each cell population. 'sample'
 #'  will display the absolute probabilities and relative normalized abundance
 #'  of each cell population in each sample. Default 'cellPopulation'.
3af3bb76
 #' @param ... Additional parameters.
c1a794b4
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGMod)
5a6c5283
 #' celdaProbabilityMap(celdaCGSim$counts, celdaCGMod)
0bd3cd14
 #' @return A grob containing the specified plots
d7196f24
 #' @importFrom gridExtra grid.arrange
 #' @importFrom RColorBrewer brewer.pal
 #' @importFrom grDevices colorRampPalette
91de8c6d
 #' @seealso `celda_CG()` for clustering features and cells
bf9aff71
 #' @export
5a6c5283
 setMethod("celdaProbabilityMap", signature(celdaMod = "celda_CG"),
     function(counts, celdaMod, level = c("cellPopulation", "sample"), ...) {
 
         counts <- .processCounts(counts)
         compareCountMatrix(counts, celdaMod)
 
         level <- match.arg(level)
         factorized <- factorizeMatrix(celdaMod = celdaMod, counts = counts)
85379044
         zInclude <- which(tabulate(clusters(celdaMod)$z,
             params(celdaMod)$K) > 0)
         yInclude <- which(tabulate(clusters(celdaMod)$y,
             params(celdaMod)$L) > 0)
5a6c5283
 
         if (level == "cellPopulation") {
             pop <- factorized$proportions$cellPopulation[yInclude,
                 zInclude, drop = FALSE]
             popNorm <- normalizeCounts(pop,
                 normalize = "proportion",
                 transformationFun = sqrt,
                 scaleFun = base::scale)
 
             percentile9 <- round(stats::quantile(pop, .9), digits = 2) * 100
             col1 <- grDevices::colorRampPalette(c("#FFFFFF",
                 RColorBrewer::brewer.pal(n = 9, name = "Blues")))(percentile9)
             col2 <- grDevices::colorRampPalette(c("#08306B",
                 c("#006D2C", "Yellowgreen", "Yellow", "Orange",
                     "Red")))(100 - percentile9)
             col <- c(col1, col2)
             breaks <- seq(0, 1, length.out = length(col))
 
             g1 <- plotHeatmap(pop,
                 colorScheme = "sequential",
                 scaleRow = NULL,
                 clusterCell = FALSE,
                 clusterFeature = FALSE,
                 showNamesCell = TRUE,
                 showNamesFeature = TRUE,
                 breaks = breaks,
                 col = col,
                 main = "Absolute Probability",
                 silent = TRUE
             )
             g2 <- plotHeatmap(popNorm,
                 colorScheme = "divergent",
                 clusterCell = FALSE,
                 clusterFeature = FALSE,
                 showNamesCell = TRUE,
                 showNamesFeature = TRUE,
                 main = "Relative Expression",
                 silent = TRUE)
             gridExtra::grid.arrange(g1$gtable, g2$gtable, ncol = 2)
         } else {
             samp <- factorized$proportions$sample
             col <- grDevices::colorRampPalette(c("white",
                 "blue",
                 "#08306B",
                 "#006D2C",
                 "yellowgreen",
                 "yellow",
                 "orange",
                 "red"))(100)
             breaks <- seq(0, 1, length.out = length(col))
             g1 <- plotHeatmap(samp,
                 colorScheme = "sequential",
                 scaleRow = NULL,
                 clusterCell = FALSE,
                 clusterFeature = FALSE,
                 showNamesCell = TRUE,
                 showNamesFeature = TRUE,
                 breaks = breaks,
                 col = col,
                 main = "Absolute Probability",
                 silent = TRUE)
 
             if (ncol(samp) > 1) {
                 sampNorm <- normalizeCounts(factorized$counts$sample,
                     normalize = "proportion",
                     transformationFun = sqrt,
                     scaleFun = base::scale)
                 g2 <- plotHeatmap(sampNorm,
                     colorScheme = "divergent",
                     clusterCell = FALSE,
                     clusterFeature = FALSE,
                     showNamesCell = TRUE,
                     showNamesFeature = TRUE,
                     main = "Relative Abundance",
                     silent = TRUE)
                 gridExtra::grid.arrange(g1$gtable, g2$gtable, ncol = 2)
5a8d8ecc
             } else {
                 gridExtra::grid.arrange(g1$gtable)
             }
5a6c5283
         }
     })
ef11fcec
 
 
91de8c6d
 #' @title Lookup the module of a feature
5a6c5283
 #' @description Finds the module assignments of given features in a `celda_G()`
 #'  model
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells. This matrix should be the same as the one used to generate
 #'  `celdaMod`.
 #' @param celdaMod Model of class `celda_CG`.
 #' @param feature Character vector. The module assignemnts will be found for
 #'  feature names in this vector.
 #' @param exactMatch Logical. Whether an exact match or a partial match using
 #'  `grep()` is used to look up the feature in the rownames of the counts
 #'  matrix. Default TRUE.
2ced6125
 #' @return List. Each element contains the module of the provided feature.
91de8c6d
 #' @seealso `celda_CG()` for clustering features and cells
27448271
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGMod)
5a6c5283
 #' module <- featureModuleLookup(celdaCGSim$counts,
 #'     celdaCGMod,
 #'     c("Gene_1", "Gene_XXX"))
9e57de61
 #' @export
5a6c5283
 setMethod("featureModuleLookup", signature(celdaMod = "celda_CG"),
     function(counts, celdaMod, feature, exactMatch = TRUE) {
         list <- list()
         if (!isTRUE(exactMatch)) {
             featureGrep <- c()
             for (x in seq(length(feature))) {
                 featureGrep <- c(featureGrep, rownames(counts)[grep(feature[x],
                     rownames(counts))])
5a8d8ecc
             }
5a6c5283
             feature <- featureGrep
         }
         for (x in seq(length(feature))) {
             if (feature[x] %in% rownames(counts)) {
06b0c870
                 list[x] <- clusters(celdaMod)$y[which(rownames(counts) ==
5a6c5283
                     feature[x])]
             } else {
                 list[x] <- paste0("No feature was identified matching '",
                         feature[x],
                         "'.")
bf9aff71
             }
5a6c5283
         }
         names(list) <- feature
         return(list)
     })