R/celdaProbabilityMap.R
f3b143d0
 #' @title Probability map for a celda model
 #' @description Renders probability and relative expression heatmaps to
 #'  visualize the relationship between features and cell populations (or cell
 #'  populations and samples).
 #' @param sce A \link[SingleCellExperiment]{SingleCellExperiment} object
 #'  returned by \link{celda_C}, \link{celda_G}, or \link{celda_CG}.
82cc7a6a
 #' @param useAssay A string specifying which \link{assay}
f3b143d0
 #'  slot to use. Default "counts".
82cc7a6a
 #' @param altExpName The name for the \link{altExp} slot
898ab9e0
 #'  to use. Default "featureSubset".
f3b143d0
 #' @param level Character. One of "cellPopulation" or "Sample".
 #'  "cellPopulation" will display the absolute probabilities and relative
 #'  normalized expression of each module in each cell population.
 #'  \strong{\code{level = "cellPopulation"} only works for celda_CG \code{sce}
 #'  objects}. "sample" will display the absolute probabilities and relative
 #'  normalized abundance of each cell population in each sample. Default
 #'  "cellPopulation".
d5d18fd3
 #' @param ncols The number of colors (>1) to be in the color palette of
16389352
 #'  the absolute probability heatmap.
 #' @param col2 Passed to \code{col} argument of \link[ComplexHeatmap]{Heatmap}.
 #'  Set color boundaries and colors for the relative expression heatmap.
 #' @param title1 Passed to \code{column_title} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Figure title for the absolute probability
 #'  heatmap.
 #' @param title2 Passed to \code{column_title} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Figure title for the relative expression
 #'  heatmap.
 #' @param showColumnNames Passed to \code{show_column_names} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Show column names.
 #' @param showRowNames Passed to \code{show_row_names} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Show row names.
 #' @param rowNamesgp Passed to \code{row_names_gp} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Set row name font.
 #' @param colNamesgp Passed to \code{column_names_gp} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Set column name font.
 #' @param clusterRows Passed to \code{cluster_rows} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Cluster rows.
 #' @param clusterColumns Passed to \code{cluster_columns} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Cluster columns.
 #' @param showHeatmapLegend Passed to \code{show_heatmap_legend} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Show heatmap legend.
 #' @param heatmapLegendParam Passed to \code{heatmap_legend_param} argument of
 #'  \link[ComplexHeatmap]{Heatmap}. Heatmap legend parameters.
 #' @param ... Additional parameters passed to \link[ComplexHeatmap]{Heatmap}.
f3b143d0
 #' @seealso \link{celda_C} for clustering cells. \link{celda_CG} for
 #'  clustering features and cells
16389352
 #' @return A \link[ComplexHeatmap]{HeatmapList} object containing 2
 #'  \link[ComplexHeatmap]{Heatmap-class} objects
f3b143d0
 #' @export
 setGeneric("celdaProbabilityMap",
e7453220
     function(sce,
         useAssay = "counts",
         altExpName = "featureSubset",
         level = c("cellPopulation", "sample"),
         ncols = 100,
         col2 = circlize::colorRamp2(c(-2, 0, 2),
             c("#1E90FF", "#FFFFFF", "#CD2626")),
         title1 = "Absolute probability",
         title2 = "Relative expression",
         showColumnNames = TRUE,
         showRowNames = TRUE,
         rowNamesgp = grid::gpar(fontsize = 8),
         colNamesgp = grid::gpar(fontsize = 12),
         clusterRows = FALSE,
         clusterColumns = FALSE,
         showHeatmapLegend = TRUE,
         heatmapLegendParam = list(title = NULL,
             legend_height = grid::unit(6, "cm")),
         ...) {
 
f3b143d0
         standardGeneric("celdaProbabilityMap")
     })
 
 
 #' @rdname celdaProbabilityMap
 #' @importFrom RColorBrewer brewer.pal
 #' @importFrom grDevices colorRampPalette
f8826e1e
 #' @examples
 #' data(sceCeldaCG)
 #' celdaProbabilityMap(sceCeldaCG)
f3b143d0
 #' @export
 setMethod("celdaProbabilityMap", signature(sce = "SingleCellExperiment"),
16389352
     function(sce,
         useAssay = "counts",
         altExpName = "featureSubset",
         level = c("cellPopulation", "sample"),
         ncols = 100,
         col2 = circlize::colorRamp2(c(-2, 0, 2),
             c("#1E90FF", "#FFFFFF", "#CD2626")),
         title1 = "Absolute probability",
         title2 = "Relative expression",
         showColumnNames = TRUE,
         showRowNames = TRUE,
         rowNamesgp = grid::gpar(fontsize = 8),
         colNamesgp = grid::gpar(fontsize = 12),
         clusterRows = FALSE,
         clusterColumns = FALSE,
         showHeatmapLegend = TRUE,
         heatmapLegendParam = list(title = NULL,
1b630e59
             legend_height = grid::unit(6, "cm")),
         ...) {
898ab9e0
 
f3b143d0
         level <- match.arg(level)
cd12e7ce
         if (celdaModel(sce, altExpName = altExpName) == "celda_C") {
f3b143d0
             if (level == "cellPopulation") {
                 warning("'level' has been set to 'sample'")
             }
48e12245
             pm <- .celdaProbabilityMapC(sce = sce,
16389352
                 useAssay = useAssay,
48e12245
                 altExpName = altExpName,
16389352
                 level = "sample",
                 ncols = ncols,
                 col2 = col2,
                 title1 = title1,
                 title2 = title2,
                 showColumnNames = showColumnNames,
                 showRowNames = showRowNames,
                 rowNamesgp = rowNamesgp,
                 colNamesgp = colNamesgp,
                 clusterRows = clusterRows,
                 clusterColumns = clusterColumns,
                 showHeatmapLegend = showHeatmapLegend,
1b630e59
                 heatmapLegendParam = heatmapLegendParam,
                 ...)
cd12e7ce
         } else if (celdaModel(sce, altExpName = altExpName) == "celda_CG") {
48e12245
             pm <- .celdaProbabilityMapCG(sce = sce,
16389352
                 useAssay = useAssay,
48e12245
                 altExpName = altExpName,
16389352
                 level = level,
                 ncols = ncols,
                 col2 = col2,
                 title1 = title1,
                 title2 = title2,
                 showColumnNames = showColumnNames,
                 showRowNames = showRowNames,
                 rowNamesgp = rowNamesgp,
                 colNamesgp = colNamesgp,
                 clusterRows = clusterRows,
                 clusterColumns = clusterColumns,
                 showHeatmapLegend = showHeatmapLegend,
1b630e59
                 heatmapLegendParam = heatmapLegendParam,
                 ...)
f3b143d0
         } else {
898ab9e0
             stop("S4Vectors::metadata(altExp(sce,",
                 " altExpName))$celda_parameters$model must be",
f3b143d0
                 " one of 'celda_C', or 'celda_CG'!")
         }
         return(pm)
     }
 )
 
 
16389352
 .celdaProbabilityMapC <- function(sce,
     useAssay,
48e12245
     altExpName,
16389352
     level,
     ncols,
     col2,
     title1,
     title2,
     showColumnNames,
     showRowNames,
     rowNamesgp,
     colNamesgp,
     clusterRows,
     clusterColumns,
     showHeatmapLegend,
1b630e59
     heatmapLegendParam,
     ...) {
16389352
 
48e12245
     altExp <- SingleCellExperiment::altExp(sce, altExpName)
f3b143d0
 
898ab9e0
     zInclude <- which(tabulate(SummarizedExperiment::colData(
48e12245
         altExp)$celda_cell_cluster,
         S4Vectors::metadata(altExp)$celda_parameters$K) > 0)
f3b143d0
 
e4060cc7
     factorized <- factorizeMatrix(x = sce, useAssay = useAssay,
898ab9e0
         type = "proportion")
f3b143d0
 
     samp <- factorized$proportions$sample[zInclude, , drop = FALSE]
16389352
     col1 <- grDevices::colorRampPalette(c("white",
f3b143d0
         "blue",
16389352
         "midnightblue",
         "springgreen4",
f3b143d0
         "yellowgreen",
         "yellow",
         "orange",
         "red"))(100)
16389352
     breaks <- seq(0, 1, length.out = length(col1))
 
     g1 <- ComplexHeatmap::Heatmap(matrix = samp,
         col = circlize::colorRamp2(breaks, col1),
         column_title = title1,
         show_column_names = showColumnNames,
         show_row_names = showRowNames,
         row_names_gp = rowNamesgp,
         column_names_gp = colNamesgp,
         cluster_rows = clusterRows,
         cluster_columns = clusterColumns,
         show_heatmap_legend = showHeatmapLegend,
1b630e59
         heatmap_legend_param = heatmapLegendParam,
         ...)
f3b143d0
 
     if (ncol(samp) > 1) {
         sampNorm <- normalizeCounts(samp,
             normalize = "proportion",
             transformationFun = sqrt,
             scaleFun = base::scale)
16389352
 
         g2 <- ComplexHeatmap::Heatmap(matrix = sampNorm,
             col = col2,
             column_title = title2,
             show_column_names = showColumnNames,
             show_row_names = showRowNames,
             row_names_gp = rowNamesgp,
             column_names_gp = colNamesgp,
             cluster_rows = clusterRows,
             cluster_columns = clusterColumns,
             show_heatmap_legend = showHeatmapLegend,
1b630e59
             heatmap_legend_param = heatmapLegendParam,
             ...)
16389352
         return(g1 + g2)
f3b143d0
     } else {
16389352
         return(g1)
f3b143d0
     }
 }
 
 
16389352
 .celdaProbabilityMapCG <- function(sce,
     useAssay,
48e12245
     altExpName,
16389352
     level,
     ncols,
     col2,
     title1,
     title2,
     showColumnNames,
     showRowNames,
     rowNamesgp,
     colNamesgp,
     clusterRows,
     clusterColumns,
     showHeatmapLegend,
1b630e59
     heatmapLegendParam,
     ...) {
16389352
 
48e12245
     altExp <- SingleCellExperiment::altExp(sce, altExpName)
f3b143d0
 
e4060cc7
     factorized <- factorizeMatrix(x = sce, useAssay = useAssay,
48e12245
         altExpName = altExpName,
898ab9e0
         type = c("counts", "proportion"))
     zInclude <- which(tabulate(SummarizedExperiment::colData(
48e12245
         altExp)$celda_cell_cluster,
         S4Vectors::metadata(altExp)$celda_parameters$K) > 0)
898ab9e0
     yInclude <- which(tabulate(SummarizedExperiment::rowData(
48e12245
         altExp)$celda_feature_module,
         S4Vectors::metadata(altExp)$celda_parameters$L) > 0)
f3b143d0
 
     if (level == "cellPopulation") {
         pop <- factorized$proportions$cellPopulation[yInclude,
             zInclude,
898ab9e0
             drop = FALSE]
f3b143d0
         popNorm <- normalizeCounts(pop,
             normalize = "proportion",
             transformationFun = sqrt,
898ab9e0
             scaleFun = base::scale)
f3b143d0
 
         percentile9 <- round(stats::quantile(pop, .9), digits = 2) * 100
16389352
         cols11 <- grDevices::colorRampPalette(c("white",
             RColorBrewer::brewer.pal(n = 9, name = "Blues")))(percentile9)
         cols12 <- grDevices::colorRampPalette(c("midnightblue",
             c("springgreen4", "Yellowgreen", "Yellow", "Orange",
                 "Red")))(ncols - percentile9)
         col1 <- c(cols11, cols12)
         breaks <- seq(0, 1, length.out = length(col1))
 
         g1 <- ComplexHeatmap::Heatmap(matrix = pop,
             col = circlize::colorRamp2(breaks, col1),
             column_title = title1,
             show_column_names = showColumnNames,
             show_row_names = showRowNames,
             row_names_gp = rowNamesgp,
             column_names_gp = colNamesgp,
             cluster_rows = clusterRows,
             cluster_columns = clusterColumns,
             show_heatmap_legend = showHeatmapLegend,
1b630e59
             heatmap_legend_param = heatmapLegendParam,
             ...)
16389352
         g2 <- ComplexHeatmap::Heatmap(matrix = popNorm,
             col = col2,
             column_title = title2,
             show_column_names = showColumnNames,
             show_row_names = showRowNames,
             row_names_gp = rowNamesgp,
             column_names_gp = colNamesgp,
             cluster_rows = clusterRows,
             cluster_columns = clusterColumns,
             show_heatmap_legend = showHeatmapLegend,
1b630e59
             heatmap_legend_param = heatmapLegendParam,
             ...)
16389352
         return(g1 + g2)
f3b143d0
     } else {
         samp <- factorized$proportions$sample
16389352
         col1 <- grDevices::colorRampPalette(c(
f3b143d0
             "white",
             "blue",
             "#08306B",
             "#006D2C",
             "yellowgreen",
             "yellow",
             "orange",
             "red"
         ))(100)
16389352
         breaks <- seq(0, 1, length.out = length(col1))
 
         g1 <- ComplexHeatmap::Heatmap(matrix = samp,
             col = circlize::colorRamp2(breaks, col1),
             column_title = title1,
             show_column_names = showColumnNames,
             show_row_names = showRowNames,
             row_names_gp = rowNamesgp,
             column_names_gp = colNamesgp,
             cluster_rows = clusterRows,
             cluster_columns = clusterColumns,
             show_heatmap_legend = showHeatmapLegend,
1b630e59
             heatmap_legend_param = heatmapLegendParam,
             ...)
f3b143d0
 
         if (ncol(samp) > 1) {
             sampNorm <- normalizeCounts(factorized$counts$sample,
                 normalize = "proportion",
                 transformationFun = sqrt,
16389352
                 scaleFun = base::scale)
             g2 <- ComplexHeatmap::Heatmap(matrix = sampNorm,
                 col = col2,
                 column_title = title2,
                 show_column_names = showColumnNames,
                 show_row_names = showRowNames,
                 row_names_gp = rowNamesgp,
                 column_names_gp = colNamesgp,
                 cluster_rows = clusterRows,
                 cluster_columns = clusterColumns,
                 show_heatmap_legend = showHeatmapLegend,
1b630e59
                 heatmap_legend_param = heatmapLegendParam,
                 ...)
16389352
             return(g1 + g2)
f3b143d0
         } else {
16389352
             return(g1 + g2)
f3b143d0
         }
     }
 }