R/model_performance.R
4d9072bf
 #' @title Calculate and visualize perplexity of all models in a celdaList, with
 #'  count resampling
 #' @description Calculates the perplexity of each model's cluster assignments
 #'  given the provided countMatrix, as well as resamplings of that count
 #'  matrix, providing a distribution of perplexities and a better sense of the
 #'  quality of a given K/L choice.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells. This matrix should be the same as the one used to generate
 #'  `celda.mod`.
 #' @param celdaList Object of class 'celdaList'.
 #' @param resample Integer. The number of times to resample the counts matrix
 #'  for evaluating perplexity. Default 5.
f6a2aa74
 #' @param seed Integer. Passed to \link[withr]{with_seed}. For reproducibility,
 #'  a default value of 12345 is used. If NULL, no calls to
 #'  \link[withr]{with_seed} are made.
4d9072bf
 #' @return celdaList. Returns the provided `celdaList` with a `perplexity`
 #'  property, detailing the perplexity of all K/L combinations that appeared in
 #'  the celdaList's models.
78293bb4
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGGridSearchRes)
ca5fb59d
 #' celdaCGGridSearchRes <- resamplePerplexity(
 #'   celdaCGSim$counts,
 #'   celdaCGGridSearchRes)
 #' plotGridSearchPerplexity(celdaCGGridSearchRes)
2d20a24f
 #' @export
4d9072bf
 resamplePerplexity <- function(counts,
     celdaList,
f6a2aa74
     resample = 5,
     seed = 12345) {
 
     if (is.null(seed)) {
         res <- .resamplePerplexity(counts = counts,
             celdaList = celdaList,
             resample = resample)
     } else {
         with_seed(seed,
             res <- .resamplePerplexity(counts = counts,
                 celdaList = celdaList,
                 resample = resample))
     }
6331ab07
 
f6a2aa74
     return(res)
 }
 
27fcd64f
 
f6a2aa74
 .resamplePerplexity <- function(counts,
     celdaList,
     resample = 5) {
6331ab07
 
4d9072bf
     if (!methods::is(celdaList, "celdaList")) {
         stop("celdaList parameter was not of class celdaList.")
2d20a24f
     }
4d9072bf
     if (!isTRUE(is.numeric(resample))) {
         stop("Provided resample parameter was not numeric.")
     }
6331ab07
 
85379044
     perpRes <- matrix(NA, nrow = length(resList(celdaList)), ncol = resample)
ca5fb59d
     for (j in seq(resample)) {
4d9072bf
         newCounts <- .resampleCountMatrix(counts)
783d0985
         for (i in seq(length(resList(celdaList)))) {
             perpRes[i, j] <- perplexity(counts, resList(celdaList)[[i]],
22ad839c
                 newCounts)
4d9072bf
         }
     }
1eaf423b
     celdaList@perplexity <- perpRes
6331ab07
 
ca5fb59d
     ## Add mean perplexity to runParams
     perpMean <- apply(perpRes, 1, mean)
25e3eecb
     celdaList@runParams$mean_perplexity <- perpMean
6331ab07
 
4d9072bf
     return(celdaList)
2d20a24f
 }
 
 
4d9072bf
 #' @title Visualize perplexity of a list of celda models
 #' @description Visualize perplexity of every model in a celdaList, by unique
 #'  K/L combinations
 #' @param celdaList Object of class 'celdaList'.
ca5fb59d
 #' @param sep Numeric. Breaks in the x axis of the resulting plo.t.
4d9072bf
 #' @return A ggplot plot object showing perplexity as a function of clustering
 #'  parameters.
79ea9d73
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGGridSearchRes)
2fff897f
 #' ## Run various combinations of parameters with 'celdaGridSearch'
ca5fb59d
 #' celdaCGGridSearchRes <- resamplePerplexity(
 #'   celdaCGSim$counts,
 #'   celdaCGGridSearchRes)
 #' plotGridSearchPerplexity(celdaCGGridSearchRes)
2d20a24f
 #' @export
4d9072bf
 plotGridSearchPerplexity <- function(celdaList, sep = 1) {
804d499b
     do.call(paste0("plotGridSearchPerplexity",
783d0985
         as.character(class(resList(celdaList)[[1]]))),
4d9072bf
         args = list(celdaList, sep))
2d20a24f
 }
 
23cd7e9f
 
4d9072bf
 #' @title Plot perplexity as a function of K and L from celda_CG models
 #' @description This function plots perplexity as a function of the cell/gene
 #'  (K/L) clusters as generated by celdaGridSearch().
 #' @param celdaList Object of class 'celdaList'.
956fe4d6
 #' @param sep Numeric. Breaks in the x axis of the resulting plot.
4d9072bf
 #' @return A ggplot plot object showing perplexity as a function of clustering
 #'  parameters.
7d8786af
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGGridSearchRes)
ca5fb59d
 #' celdaCGGridSearchRes <- resamplePerplexity(
 #'   celdaCGSim$counts,
 #'   celdaCGGridSearchRes
4d9072bf
 #' )
ca5fb59d
 #' plotGridSearchPerplexity(celdaCGGridSearchRes)
2d20a24f
 #' @export
804d499b
 plotGridSearchPerplexitycelda_CG <- function(celdaList, sep) {
06b0c870
     if (!all(c("K", "L") %in% colnames(runParams(celdaList)))) {
         stop("runParams(celdaList) needs K and L columns.")
4d9072bf
     }
06b0c870
     if (is.null(celdaPerplexity(celdaList))) {
22ad839c
         stop("No perplexity measurements available. First run",
             " 'resamplePerplexity' with celdaList object.")
4d9072bf
     }
27fcd64f
 
06b0c870
     ix1 <- rep(seq(nrow(celdaPerplexity(celdaList))),
         each = ncol(celdaPerplexity(celdaList)))
85379044
     ix2 <- rep(seq(ncol(celdaPerplexity(celdaList))),
         nrow(celdaPerplexity(celdaList)))
06b0c870
     df <- data.frame(runParams(celdaList)[ix1, ],
         perplexity = celdaPerplexity(celdaList)[cbind(ix1, ix2)])
4d9072bf
     df$K <- as.factor(df$K)
     df$L <- as.factor(df$L)
27fcd64f
 
ca5fb59d
     lMeansByK <- stats::aggregate(df$perplexity, by = list(df$K, df$L),
4d9072bf
         FUN = mean)
     colnames(lMeansByK) <- c("K", "L", "mean_perplexity")
     lMeansByK$K <- as.factor(lMeansByK$K)
     lMeansByK$L <- as.factor(lMeansByK$L)
27fcd64f
 
4d9072bf
     if (nlevels(df$K) > 1) {
ca5fb59d
         plot <- ggplot2::ggplot(df,
22ad839c
             ggplot2::aes_string(x = "K", y = "perplexity")) +
             ggplot2::geom_jitter(height = 0, width = 0.1,
                 ggplot2::aes_string(color = "L")) +
             ggplot2::scale_color_discrete(name = "L") +
             ggplot2::geom_path(data = lMeansByK, ggplot2::aes_string(x = "K",
                 y = "mean_perplexity", group = "L", color = "L")) +
             ggplot2::ylab("Perplexity") +
             ggplot2::xlab("K") +
             ggplot2::scale_x_discrete(breaks = seq(
85379044
                 min(runParams(celdaList)$K),
                 max(runParams(celdaList)$K), sep)) +
22ad839c
             ggplot2::theme_bw()
4d9072bf
     } else {
         plot <-
22ad839c
             ggplot2::ggplot(df,
                 ggplot2::aes_string(x = "L", y = "perplexity")) +
             ggplot2::geom_jitter(height = 0, width = 0.1,
                 ggplot2::aes_string(color = "K")) +
             ggplot2::scale_color_discrete(name = "K") +
             ggplot2::geom_path(data = lMeansByK,
                 ggplot2::aes_string(x = "L", y = "mean_perplexity", group = "K",
                     color = "K")) +
             ggplot2::ylab("Perplexity") +
             ggplot2::xlab("L") +
06b0c870
             ggplot2::scale_x_discrete(breaks = seq(min(runParams(celdaList)$L),
                 max(runParams(celdaList)$L), sep)) +
22ad839c
             ggplot2::theme_bw()
4d9072bf
     }
27fcd64f
 
4d9072bf
     return(plot)
2d20a24f
 }
 
 
4d9072bf
 #' @title Plot perplexity as a function of K from celda_C models
 #' @description Plots perplexity as a function of the cell (K) clusters as
 #'   generated by celdaGridSearch().
 #' @param celdaList Object of class 'celdaList'.
956fe4d6
 #' @param sep Numeric. Breaks in the x axis of the resulting plot.
4d9072bf
 #' @return A ggplot plot object showing perplexity as a function of clustering
 #'   parameters.
7d8786af
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGGridSearchRes)
ca5fb59d
 #' celdaCGGridSearchRes <- resamplePerplexity(
 #'   celdaCGSim$counts,
 #'   celdaCGGridSearchRes
4d9072bf
 #' )
ca5fb59d
 #' plotGridSearchPerplexity(celdaCGGridSearchRes)
2d20a24f
 #' @export
804d499b
 plotGridSearchPerplexitycelda_C <- function(celdaList, sep) {
06b0c870
     if (!all(c("K") %in% colnames(runParams(celdaList)))) {
ca5fb59d
         stop("runParams(celdaList) needs the column K.")
4d9072bf
     }
06b0c870
     if (is.null(celdaPerplexity(celdaList))) {
ca5fb59d
         stop("No perplexity measurements available. First run",
             " 'resamplePerplexity' with celdaList object.")
4d9072bf
     }
27fcd64f
 
06b0c870
     ix1 <- rep(seq(nrow(celdaPerplexity(celdaList))),
         each = ncol(celdaPerplexity(celdaList)))
85379044
     ix2 <- rep(seq(ncol(celdaPerplexity(celdaList))),
         nrow(celdaPerplexity(celdaList)))
06b0c870
     df <- data.frame(runParams(celdaList)[ix1, ],
         perplexity = celdaPerplexity(celdaList)[cbind(ix1, ix2)])
4d9072bf
     df$K <- as.factor(df$K)
27fcd64f
 
4d9072bf
     meansByK <- stats::aggregate(df$perplexity, by = list(df$K), FUN = mean)
     colnames(meansByK) <- c("K", "mean_perplexity")
     meansByK$K <- as.factor(meansByK$K)
27fcd64f
 
4d9072bf
     plot <-
         ggplot2::ggplot(df, ggplot2::aes_string(x = "K", y = "perplexity")) +
         ggplot2::geom_jitter(height = 0, width = 0.1) +
         ggplot2::geom_path(data = meansByK,
22ad839c
             ggplot2::aes_string(x = "K", y = "mean_perplexity", group = 1)) +
2d20a24f
         ggplot2::ylab("Perplexity") +
         ggplot2::xlab("K") +
06b0c870
         ggplot2::scale_x_discrete(breaks = seq(min(runParams(celdaList)$K),
             max(runParams(celdaList)$K), sep)) +
2d20a24f
         ggplot2::theme_bw()
27fcd64f
 
4d9072bf
     return(plot)
2d20a24f
 }
 
 
4d9072bf
 #' @title Plot perplexity as a function of L from a celda_G model
 #' @description Plots perplexity as a function of the gene (L) clusters as
 #'   generated by celdaGridSearch().
 #' @param celdaList Object of class 'celdaList'.
956fe4d6
 #' @param sep Numeric. Breaks in the x axis of the resulting plot.
4d9072bf
 #' @return A ggplot plot object showing perplexity as a function of clustering
 #'   parameters.
7d8786af
 #' @examples
a49fff03
 #' data(celdaCGSim, celdaCGGridSearchRes)
ca5fb59d
 #' celdaCGGridSearchRes <- resamplePerplexity(
 #'   celdaCGSim$counts,
 #'   celdaCGGridSearchRes)
 #' plotGridSearchPerplexity(celdaCGGridSearchRes)
2d20a24f
 #' @export
804d499b
 plotGridSearchPerplexitycelda_G <- function(celdaList, sep) {
06b0c870
     if (!all(c("L") %in% colnames(runParams(celdaList)))) {
         stop("runParams(celdaList) needs the column L.")
4d9072bf
     }
06b0c870
     if (length(celdaPerplexity(celdaList)) == 0) {
ca5fb59d
         stop("No perplexity measurements available. First run",
             " 'resamplePerplexity' with celdaList object.")
4d9072bf
     }
27fcd64f
 
06b0c870
     ix1 <- rep(seq(nrow(celdaPerplexity(celdaList))),
         each = ncol(celdaPerplexity(celdaList)))
85379044
     ix2 <- rep(seq(ncol(celdaPerplexity(celdaList))),
         nrow(celdaPerplexity(celdaList)))
06b0c870
     df <- data.frame(runParams(celdaList)[ix1, ],
         perplexity = celdaPerplexity(celdaList)[cbind(ix1, ix2)])
4d9072bf
     df$L <- as.factor(df$L)
27fcd64f
 
 
4d9072bf
     meansByL <- stats::aggregate(df$perplexity, by = list(df$L), FUN = mean)
     colnames(meansByL) <- c("L", "mean_perplexity")
     meansByL$L <- as.factor(meansByL$L)
27fcd64f
 
4d9072bf
     plot <-
         ggplot2::ggplot(df, ggplot2::aes_string(x = "L", y = "perplexity")) +
         ggplot2::geom_jitter(height = 0, width = 0.1) +
         ggplot2::geom_path(data = meansByL,
22ad839c
             ggplot2::aes_string(x = "L", y = "mean_perplexity", group = 1)) +
2d20a24f
         ggplot2::ylab("Perplexity") +
         ggplot2::xlab("L") +
06b0c870
         ggplot2::scale_x_discrete(breaks = seq(min(runParams(celdaList)$L),
             max(runParams(celdaList)$L), sep)) +
2d20a24f
         ggplot2::theme_bw()
27fcd64f
 
4d9072bf
     return(plot)
2d20a24f
 }
 
 
 # Resample a counts matrix for evaluating perplexity
4d9072bf
 # Normalizes each column (cell) of a countMatrix by the column sum to
ca5fb59d
 # create a distribution of observing a given number of counts for a given
 # gene in that cell,
2d20a24f
 # then samples across all cells.
ca5fb59d
 # This is primarily used to evaluate the stability of the perplexity for
 # a given K/L combination.
af4c3cb8
 # @param celda.mod A single celda run (usually from the _resList_ property
ca5fb59d
 # of a celdaList).
2d20a24f
 # @return The perplexity for the provided chain as an mpfr number.
4d9072bf
 .resampleCountMatrix <- function(countMatrix) {
     colsums <- colSums(countMatrix)
     prob <- t(t(countMatrix) / colsums)
ca5fb59d
     resample <- vapply(seq(ncol(countMatrix)), function(idx) {
4d9072bf
         stats::rmultinom(n = 1,
             size = colsums[idx],
             prob = prob[, idx])
22ad839c
     }, integer(nrow(countMatrix)))
4d9072bf
     return(resample)
2d20a24f
 }