R/moduleHeatmap.R
25d7d3d5
 #' @title Heatmap for featureModules
a6113941
 #' @description Renders a heatmap for selected \code{featureModule}. Cells are
25d7d3d5
 #'  ordered from those with the lowest probability of the module on the left to
1d563e1f
 #'  the highest probability on the right. Features are ordered from those
 #'  with the highest probability in the module
a6113941
 #'  on the top to the lowest probability on the bottom. Use of
 #'  \link[multipanelfigure]{save_multi_panel_figure} is recommended for
 #'  outputting figures in various formats.
1ac452ae
 #' @param x A numeric \link{matrix} of counts or a
 #'  \linkS4class{SingleCellExperiment}
 #'  with the matrix located in the assay slot under \code{useAssay}.
 #'  Rows represent features and columns represent cells.
82cc7a6a
 #' @param useAssay A string specifying which \link{assay}
1ac452ae
 #'  slot to use if \code{x} is a
 #'  \linkS4class{SingleCellExperiment} object. Default "counts".
82cc7a6a
 #' @param altExpName The name for the \link{altExp} slot
898ab9e0
 #'  to use. Default "featureSubset".
25d7d3d5
 #' @param featureModule Integer Vector. The featureModule(s) to display.
1d563e1f
 #'  Multiple modules can be included in a vector. Default \code{NULL} which
 #'  plots all module heatmaps.
1a9938ee
 #' @param col Passed to \link[ComplexHeatmap]{Heatmap}. Set color boundaries
 #'  and colors.
25d7d3d5
 #' @param topCells Integer. Number of cells with the highest and lowest
1d563e1f
 #'  probabilities for each module to include in the heatmap. For example, if
 #'  \code{topCells = 50}, the 50 cells with the lowest probabilities and
 #'  the 50 cells
 #'  with the highest probabilities for each featureModule will be included. If
25d7d3d5
 #'  NULL, all cells will be plotted. Default 100.
1d563e1f
 #' @param topFeatures Integer. Plot `topFeatures` features with the highest
 #'  probabilities in the module heatmap for each featureModule. If \code{NULL},
 #'  plot all features in the module. Default \code{NULL}.
fa7bb072
 #' @param normalizedCounts Integer matrix. Rows represent features and columns
1d563e1f
 #'  represent cells. If you have a normalized matrix result from
 #'  \link{normalizeCounts}, you can pass through the result here to
 #'  skip the normalization step in this function. Make sure the colnames and
 #'  rownames match the object in x. This matrix should
 #'  correspond to one generated from this count matrix
 #'  \code{assay(altExp(x, altExpName), i = useAssay)}. If \code{NA},
 #'  normalization will be carried out in the following form
 #'  \code{normalizeCounts(assay(altExp(x, altExpName), i = useAssay),
 #'  normalize = "proportion", transformationFun = sqrt)}.
 #'  Use of this parameter is particularly useful for plotting many
 #'  module heatmaps, where normalizing the counts matrix repeatedly would
 #'  be too time consuming. Default NA.
 #' @param normalize Character. Passed to \link{normalizeCounts} if
 #'  \code{normalizedCounts} is \code{NA}.
 #'  Divides counts by the library sizes for each cell. One of 'proportion',
 #'  'cpm', 'median', or 'mean'. 'proportion' uses the total counts for each
 #'  cell as the library size. 'cpm' divides the library size of each cell by
 #'  one million to produce counts per million. 'median' divides the library
 #'  size of each cell by the median library size across all cells. 'mean'
 #'  divides the library size of each cell by the mean library size across all
 #'  cells. Default "proportion".
 #' @param transformationFun Function. Passed to \link{normalizeCounts} if
 #'  \code{normalizedCounts} is \code{NA}. Applys a transformation such as
 #'  \link{sqrt}, \link{log}, \link{log2}, \link{log10}, or \link{log1p}.
 #'  If NULL, no transformation will be applied. Occurs after normalization.
 #'  Default \link{sqrt}.
 #' @param scaleRow Function. Which function to use to scale each individual
25d7d3d5
 #'  row. Set to NULL to disable. Occurs after normalization and log
1d563e1f
 #'  transformation. For example, \link{scale} will Z-score transform each row.
 #'  Default \link{scale}.
25d7d3d5
 #' @param showFeaturenames Logical. Wheter feature names should be displayed.
 #'  Default TRUE.
1d563e1f
 #' @param trim Numeric vector. Vector of length two that specifies the lower
 #'  and upper bounds for plotting the data. This threshold is applied
 #'  after row scaling. Set to NULL to disable. Default c(-2,2).
 #' @param rowFontSize Integer. Font size for genes.
 #' @param showHeatmapLegend Passed to \link[ComplexHeatmap]{Heatmap}. Show
 #'  legend for expression levels.
 #' @param showTopAnnotationLegend Passed to
 #'  \link[ComplexHeatmap]{HeatmapAnnotation}. Show legend for cell annotation.
 #' @param showTopAnnotationName Passed to
 #'  \link[ComplexHeatmap]{HeatmapAnnotation}. Show heatmap top annotation name.
 #' @param showLeftAnnotationLegend Passed to
a6113941
 #'  \link[ComplexHeatmap]{HeatmapAnnotation}. Show legend for feature module
 #'  annotation.
 #' @param topAnnotationHeight Passed to
1d563e1f
 #'  \link[ComplexHeatmap]{HeatmapAnnotation}. Column annotation height.
 #'  \link[ComplexHeatmap]{rowAnnotation}. Show legend for module annotation.
a6113941
 #' @param showLeftAnnotation Show left annotation. Default \code{FALSE}.
1d563e1f
 #' @param showLeftAnnotationName Passed to
a6113941
 #'  \link[ComplexHeatmap]{rowAnnotation}. Show heatmap left annotation name.
 #' @param leftAnnotationWidth Passed to
1d563e1f
 #'  \link[ComplexHeatmap]{rowAnnotation}. Row annotation width.
 #' @param width Passed to \link[multipanelfigure]{multi_panel_figure}. The
 #'  width of the output figure.
 #' @param height Passed to \link[multipanelfigure]{multi_panel_figure}. The
 #'  height of the output figure.
 #' @param unit Passed to \link[multipanelfigure]{multi_panel_figure}. Single
 #'  character object defining the unit of all dimensions defined.
 #' @param ModuleLabel Must be
 #'  vector of the same length as \code{length(unique(celdaModules(x)))} or
1a9938ee
 #'  \code{length(unique(celdaClusters(x)$y))}. Set to \code{""} to disable.
1d563e1f
 #' @param labelJust Passed to \link[multipanelfigure]{fill_panel}.
 #'  Justification for the label within the interpanel spacing grob to the
 #'  top-left of the panel content grob.
1a9938ee
 #' @param ... Additional parameters passed to \link[ComplexHeatmap]{Heatmap}.
1d563e1f
 #' @return A \link[multipanelfigure]{multi_panel_figure} object.
1ac452ae
 #' @importFrom methods .hasSlot
16389352
 #' @importFrom multipanelfigure multi_panel_figure
1ac452ae
 #' @export
2a0e9541
 setGeneric("moduleHeatmap", function(x, ...) {
     standardGeneric("moduleHeatmap")})
1ac452ae
 
 
 #' @rdname moduleHeatmap
1079f925
 #' @examples
5a5ff531
 #' data(sceCeldaCG)
1a9938ee
 #' moduleHeatmap(sceCeldaCG, width = 250, height = 250)
25d7d3d5
 #' @export
1ac452ae
 setMethod("moduleHeatmap",
     signature(x = "SingleCellExperiment"),
     function(x,
         useAssay = "counts",
898ab9e0
         altExpName = "featureSubset",
1d563e1f
         featureModule = NULL,
1a9938ee
         col = circlize::colorRamp2(c(-2, 0, 2),
             c("#1E90FF", "#FFFFFF", "#CD2626")),
1ac452ae
         topCells = 100,
         topFeatures = NULL,
         normalizedCounts = NA,
1d563e1f
         normalize = "proportion",
         transformationFun = sqrt,
1ac452ae
         scaleRow = scale,
1d563e1f
         showFeaturenames = TRUE,
         trim = c(-2, 2),
         rowFontSize = 6,
         showHeatmapLegend = FALSE,
         showTopAnnotationLegend = FALSE,
         showTopAnnotationName = FALSE,
a6113941
         topAnnotationHeight = 1.5,
         showLeftAnnotation = FALSE,
1d563e1f
         showLeftAnnotationLegend = FALSE,
         showLeftAnnotationName = FALSE,
a6113941
         leftAnnotationWidth = 1.5,
1d563e1f
         width = "auto",
         height = "auto",
         unit = "mm",
         ModuleLabel = "auto",
1a9938ee
         labelJust = c("right", "bottom"),
         ...) {
1ac452ae
 
0f0647bf
         altExp <- SingleCellExperiment::altExp(x, altExpName)
898ab9e0
 
         counts <- SummarizedExperiment::assay(altExp, i = useAssay)
886ca284
         if (is.null(colnames(counts))) {
898ab9e0
             stop("colnames(altExp(x, altExpName)) is NULL!",
                 " Please assign column names to x and",
886ca284
                 " try again.")
         }
 
         if (is.null(rownames(counts))) {
898ab9e0
             stop("rownames(altExp(x, altExpName)) is NULL!",
                 " Please assign row names to x and",
886ca284
                 " try again.")
         }
1ac452ae
 
898ab9e0
         if (!(S4Vectors::metadata(altExp)$celda_parameters$model %in%
                 c("celda_G", "celda_CG"))) {
             stop("metadata(altExp(x, altExpName))$",
                 "celda_parameters$model must be 'celda_G' or",
1ac452ae
                 " 'celda_CG'")
         }
 
1d563e1f
         if (is.null(featureModule)) {
1a9938ee
             featureModule <- sort(unique(celdaModules(x)))
1d563e1f
         }
 
         if (is.null(ModuleLabel)) {
             ModuleLabel <- NULL
         } else if (ModuleLabel == "auto") {
1a9938ee
             ModuleLabel <- as.character(featureModule)
         } else if (ModuleLabel == "") {
             ModuleLabel <- rep("", length = length(unique(celdaModules(x,
1d563e1f
                 altExpName = altExpName))))
         } else if (length(ModuleLabel) != length(unique(celdaModules(x,
             altExpName = altExpName)))) {
             stop("Invalid 'ModuleLabel' length!")
         }
 
1ac452ae
         # factorize counts matrix
898ab9e0
         factorizedMatrix <- factorizeMatrix(x,
             useAssay = useAssay,
             altExpName = altExpName,
             type = "proportion")
1d563e1f
         allCellStates <- factorizedMatrix$proportions$cell
 
         if (is.na(normalizedCounts)) {
             normCounts <- normalizeCounts(counts,
                 normalize = normalize,
                 transformationFun = transformationFun)
         } else {
             normCounts <- normalizedCounts
         }
1ac452ae
 
         # take topRank
         if (!is.null(topFeatures) && (is.numeric(topFeatures)) |
                 is.integer(topFeatures)) {
             topRanked <- topRank(
                 matrix = factorizedMatrix$proportions$module,
898ab9e0
                 n = topFeatures)
1ac452ae
         } else {
             topRanked <- topRank(
                 matrix = factorizedMatrix$proportions$module,
898ab9e0
                 n = nrow(factorizedMatrix$proportions$module))
1ac452ae
         }
 
         # filter topRank using featureModule into featureIndices
         featureIndices <- lapply(
             featureModule,
             function(module) {
                 topRanked$index[[module]]
             }
         )
 
1d563e1f
         z <- celdaClusters(x, altExpName = altExpName)
         y <- celdaModules(x, altExpName = altExpName)
1ac452ae
 
1d563e1f
         plts <- vector("list", length = length(featureModule))
1ac452ae
 
1d563e1f
         for (i in seq(length(featureModule))) {
             plts[[i]] <- .plotModuleHeatmap(normCounts = normCounts,
1a9938ee
                 col = col,
1d563e1f
                 allCellStates = allCellStates,
                 featureIndices = featureIndices[[i]],
                 featureModule = featureModule[i],
                 z = z,
                 y = y,
                 topCells = topCells,
                 altExpName = altExpName,
                 scaleRow = scaleRow,
                 showFeaturenames = showFeaturenames,
                 trim = trim,
                 rowFontSize = rowFontSize,
                 showHeatmapLegend = showHeatmapLegend,
                 showTopAnnotationLegend = showTopAnnotationLegend,
                 showTopAnnotationName = showTopAnnotationName,
a6113941
                 topAnnotationHeight = topAnnotationHeight,
                 showLeftAnnotation = showLeftAnnotation,
1d563e1f
                 showLeftAnnotationLegend = showLeftAnnotationLegend,
                 showLeftAnnotationName = showLeftAnnotationName,
a6113941
                 leftAnnotationWidth = leftAnnotationWidth,
1a9938ee
                 unit = unit,
                 ... = ...)
1ac452ae
         }
 
1d563e1f
 
         ncol <- floor(sqrt(length(plts)))
         nrow <- ceiling(length(plts) / ncol)
 
         for (i in seq(length(plts))) {
             plts[[i]] <- grid::grid.grabExpr(ComplexHeatmap::draw(plts[[i]]),
                 wrap.grobs = TRUE)
1ac452ae
         }
 
1d563e1f
         figure <- multipanelfigure::multi_panel_figure(columns = ncol,
             rows = nrow,
             width = width,
             height = height,
             unit = unit)
 
         for (i in seq(length(plts))) {
             if (!is.null(ModuleLabel)) {
                 figure <- suppressMessages(multipanelfigure::fill_panel(figure,
                     plts[[i]], label = ModuleLabel[i], label_just = labelJust))
             } else {
                 figure <- suppressMessages(multipanelfigure::fill_panel(figure,
                     plts[[i]], label_just = labelJust))
1ac452ae
             }
         }
1d563e1f
         suppressWarnings(return(figure))
25d7d3d5
     }
1ac452ae
 )
 
 
1d563e1f
 .plotModuleHeatmap <- function(normCounts,
1a9938ee
     col,
1d563e1f
     allCellStates,
     featureIndices,
     featureModule,
     z,
     y,
     topCells,
     altExpName,
     scaleRow,
     showFeaturenames,
     trim,
     rowFontSize,
     showHeatmapLegend,
     showTopAnnotationLegend,
     showTopAnnotationName,
a6113941
     topAnnotationHeight,
     showLeftAnnotation,
1d563e1f
     showLeftAnnotationLegend,
     showLeftAnnotationName,
a6113941
     leftAnnotationWidth,
1a9938ee
     unit,
     ...) {
1ac452ae
 
1d563e1f
     # Determine cell order from factorizedMatrix$proportions$cell
     cellStates <- allCellStates[featureModule, , drop = TRUE]
1ac452ae
 
1d563e1f
     singleModuleOrdered <- order(cellStates, decreasing = TRUE)
1ac452ae
 
1d563e1f
     if (!is.null(topCells)) {
         if (topCells * 2 < ncol(allCellStates)) {
             cellIndices <- c(
                 utils::head(singleModuleOrdered, n = topCells),
                 utils::tail(singleModuleOrdered, n = topCells))
1ac452ae
         } else {
1d563e1f
             cellIndices <- singleModuleOrdered
1ac452ae
         }
1d563e1f
     } else {
         cellIndices <- singleModuleOrdered
     }
1ac452ae
 
1d563e1f
     cellIndices <- rev(cellIndices)
1ac452ae
 
1d563e1f
     # filter counts based on featureIndices
     filteredNormCounts <-
         normCounts[featureIndices, cellIndices, drop = FALSE]
1ac452ae
 
1d563e1f
     filteredNormCounts <-
         filteredNormCounts[rowSums(filteredNormCounts > 0) > 0, ,
             drop = FALSE]
1ac452ae
 
1d563e1f
     geneIx <- match(rownames(filteredNormCounts), rownames(normCounts))
     cellIx <- match(colnames(filteredNormCounts), colnames(normCounts))
1ac452ae
 
1d563e1f
     zToPlot <- z[cellIx]
 
1a9938ee
     uniquezToPlot <- sort(unique(zToPlot))
     ccols <- distinctColors(length(unique(z)))[uniquezToPlot]
     names(ccols) <- uniquezToPlot
1d563e1f
 
     yToPlot <- y[geneIx]
 
1a9938ee
     uniqueyToPlot <- sort(unique(yToPlot))
     rcols <- distinctColors(length(y))[uniqueyToPlot]
     names(rcols) <- uniqueyToPlot
1d563e1f
 
     # scale indivisual rows by scaleRow
     if (!is.null(scaleRow)) {
         if (is.function(scaleRow)) {
             cn <- colnames(filteredNormCounts)
             filteredNormCounts <- t(base::apply(filteredNormCounts,
                 1, scaleRow))
             colnames(filteredNormCounts) <- cn
1ac452ae
         } else {
1d563e1f
             stop("'scaleRow' needs to be of class 'function'")
1ac452ae
         }
1d563e1f
     }
1ac452ae
 
1d563e1f
     if (!is.null(trim)) {
         if (length(trim) != 2) {
             stop(
                 "'trim' should be a 2 element vector specifying the lower",
                 " and upper boundaries"
             )
1ac452ae
         }
1d563e1f
         trim <- sort(trim)
         filteredNormCounts[filteredNormCounts < trim[1]] <- trim[1]
         filteredNormCounts[filteredNormCounts > trim[2]] <- trim[2]
25d7d3d5
     }
1d563e1f
 
a6113941
     if (isTRUE(showLeftAnnotation)) {
         plt <- ComplexHeatmap::Heatmap(matrix = filteredNormCounts,
             col = col,
             show_column_names = FALSE,
             show_row_names = showFeaturenames,
             row_names_gp = grid::gpar(fontsize = rowFontSize),
             cluster_rows = FALSE,
             cluster_columns = FALSE,
             heatmap_legend_param = list(title = "Expression"),
             show_heatmap_legend = showHeatmapLegend,
             top_annotation = ComplexHeatmap::HeatmapAnnotation(
                 cell = factor(zToPlot,
                     levels = stringr::str_sort(unique(zToPlot),
                         numeric = TRUE)),
                 show_legend = showTopAnnotationLegend,
                 show_annotation_name = showTopAnnotationName,
                 col = list(cell = ccols),
                 simple_anno_size = grid::unit(topAnnotationHeight, unit)),
             left_annotation = ComplexHeatmap::rowAnnotation(
                 module = factor(yToPlot,
                     levels = stringr::str_sort(unique(yToPlot),
                         numeric = TRUE)),
                 show_legend = showLeftAnnotationLegend,
                 show_annotation_name = showLeftAnnotationName,
                 col = list(module = rcols),
                 simple_anno_size = grid::unit(leftAnnotationWidth, unit)),
             ...)
     } else {
         plt <- ComplexHeatmap::Heatmap(matrix = filteredNormCounts,
             col = col,
             show_column_names = FALSE,
             show_row_names = showFeaturenames,
             row_names_gp = grid::gpar(fontsize = rowFontSize),
             cluster_rows = FALSE,
             cluster_columns = FALSE,
             heatmap_legend_param = list(title = "Expression"),
             show_heatmap_legend = showHeatmapLegend,
             top_annotation = ComplexHeatmap::HeatmapAnnotation(
                 cell = factor(zToPlot,
                     levels = stringr::str_sort(unique(zToPlot),
                         numeric = TRUE)),
                 show_legend = showTopAnnotationLegend,
                 show_annotation_name = showTopAnnotationName,
                 col = list(cell = ccols),
                 simple_anno_size = grid::unit(topAnnotationHeight, unit)),
             ...)
     }
1d563e1f
     return(plt)
 }