R/plotSCEHeatmap.R
a9233cb6
 #' Plot heatmap of using data stored in SingleCellExperiment Object
caf80c57
 #' @rdname plotSCEHeatmap
a9233cb6
 #' @param inSCE \linkS4class{SingleCellExperiment} inherited object.
 #' @param useAssay character. A string indicating the assay name that
caf80c57
 #' provides the expression level to plot. Only for \code{plotSCEHeatmap}.
 #' @param useReducedDim character. A string indicating the reducedDim name that
 #' provides the expression level to plot. Only for \code{plotSCEDimReduceHeatmap}.
cea7aaa7
 #' @param doLog Logical scalar. Whether to do \code{log(assay + 1)}
 #' transformation on the assay indicated by \code{useAssay}. Default
 #' \code{FALSE}.
a9233cb6
 #' @param featureIndex A vector that can subset the input SCE object by rows
 #' (features). Alternatively, it can be a vector identifying features in
 #' another feature list indicated by \code{featureIndexBy}. Default \code{NULL}.
 #' @param cellIndex A vector that can subset the input SCE object by columns
 #' (cells). Alternatively, it can be a vector identifying cells in another
 #' cell list indicated by \code{featureIndexBy}. Default \code{NULL}.
f0f8cb90
 #' @param scale Whether to perform z-score scaling on each row. Default
 #' \code{TRUE}.
 #' @param trim A 2-element numeric vector. Values outside of this range will be
 #' trimmed to their nearst bound. Default \code{c(-2, 2)}
a9233cb6
 #' @param featureIndexBy A single character specifying a column name of
 #' \code{rowData(inSCE)}, or a vector of the same length as \code{nrow(inSCE)},
7d46c6eb
 #' where we search for the non-rowname feature indices. Not applicable for
 #' \code{plotSCEDimReduceHeatmap}. Default \code{"rownames"}.
a9233cb6
 #' @param cellIndexBy A single character specifying a column name of
 #' \code{colData(inSCE)}, or a vector of the same length as \code{ncol(inSCE)},
 #' where we search for the non-rowname cell indices. Default \code{"rownames"}.
f0f8cb90
 #' @param rowDataName character. The column name(s) in \code{rowData} that need
7d46c6eb
 #' to be added to the annotation. Not applicable for
caf80c57
 #' \code{plotSCEDimReduceHeatmap}. Default \code{NULL}.
f0f8cb90
 #' @param colDataName character. The column name(s) in \code{colData} that need
 #' to be added to the annotation. Default \code{NULL}.
ef608c1e
 #' @param aggregateRow Feature variable for aggregating the heatmap by row. Can
 #' be a vector or a \code{rowData} column name for feature variable. Multiple
 #' variables are allowed. Default \code{NULL}.
 #' @param aggregateCol Cell variable for aggregating the heatmap by column. Can
 #' be a vector or a \code{colData} column name for cell variable. Multiple
 #' variables are allowed. Default \code{NULL}.
a9233cb6
 #' @param featureAnnotations \code{data.frame}, with \code{rownames} containing
 #' all the features going to be plotted. Character columns should be factors.
 #' Default \code{NULL}.
 #' @param cellAnnotations \code{data.frame}, with \code{rownames} containing
 #' all the cells going to be plotted. Character columns should be factors.
 #' Default \code{NULL}.
 #' @param featureAnnotationColor A named list. Customized color settings for
 #' feature labeling. Should match the entries in the \code{featureAnnotations}
 #' or \code{rowDataName}. For each entry, there should be a list/vector of
 #' colors named with categories. Default \code{NULL}.
 #' @param cellAnnotationColor A named list. Customized color settings for
 #' cell labeling. Should match the entries in the \code{cellAnnotations} or
 #' \code{colDataName}. For each entry, there should be a list/vector of colors
 #' named with categories. Default \code{NULL}.
ef608c1e
 #' @param palette Choose from \code{"ggplot"}, \code{"celda"} or \code{"random"}
 #' to generate unique category colors.
a9233cb6
 #' @param rowSplitBy character. Do semi-heatmap based on the grouping of
 #' this(these) annotation(s). Should exist in either \code{rowDataName} or
 #' \code{names(featureAnnotations)}. Default \code{NULL}.
 #' @param colSplitBy character. Do semi-heatmap based on the grouping of
 #' this(these) annotation(s). Should exist in either \code{colDataName} or
 #' \code{names(cellAnnotations)}. Default \code{NULL}.
 #' @param rowLabel Use a logical for whether to display all the feature names,
 #' a single character to display a column of \code{rowData(inSCE)} annotation,
 #' a vector of the same length as full/subset \code{nrow(inSCE)} to display
 #' customized info. Default \code{FALSE}.
 #' @param colLabel Use a logical for whether to display all the cell names, a
 #' single character to display a column of \code{colData(inSCE)} annotation,
 #' a vector of the same length as full/subset \code{ncol(inSCE)} to display
 #' customized info. Default \code{FALSE}.
f0f8cb90
 #' @param rowLabelSize A number for the font size of feature names. Default
 #' \code{8}
 #' @param colLabelSize A number for the font size of cell names. Default
 #' \code{8}
a9233cb6
 #' @param rowDend Whether to display row dendrogram. Default \code{TRUE}.
 #' @param colDend Whether to display column dendrogram. Default \code{TRUE}.
baa95e32
 #' @param title The main title of the whole plot. Default \code{NULL}.
a9233cb6
 #' @param rowTitle The subtitle for the rows. Default \code{"Genes"}.
 #' @param colTitle The subtitle for the columns. Default \code{"Cells"}.
 #' @param rowGap A numeric value or a \code{\link[grid]{unit}} object. For the
 #' gap size between rows of the splitted heatmap. Default
 #' \code{grid::unit(0, 'mm')}.
 #' @param colGap A numeric value or a \code{\link[grid]{unit}} object. For the
 #' gap size between columns of the splitted heatmap. Default
 #' \code{grid::unit(0, 'mm')}.
 #' @param border A logical scalar. Whether to show the border of the heatmap or
 #' splitted heatmaps. Default \code{TRUE}.
 #' @param colorScheme function. A function that generates color code by giving
 #' a value. Can be generated by \code{\link[circlize]{colorRamp2}}.
 #' Default \code{NULL}.
 #' @param ... Other arguments passed to \code{\link[ComplexHeatmap]{Heatmap}}.
0ccddee1
 #' @examples
 #' data(scExample, package = "singleCellTK")
c06bc73b
 #' plotSCEHeatmap(sce[1:3,1:3], useAssay = "counts")
7d46c6eb
 #' @return A \code{\link[ggplot2]{ggplot}} object.
a9233cb6
 #' @export
 #' @author Yichen Wang
ef608c1e
 #' @importFrom scuttle aggregateAcrossCells aggregateAcrossFeatures
 #' @importFrom SingleCellExperiment SingleCellExperiment
 #' @importFrom SummarizedExperiment colData assayNames<-
 plotSCEHeatmap <- function(inSCE, useAssay = 'logcounts', useReducedDim = NULL,
                            doLog = FALSE, featureIndex = NULL, cellIndex = NULL,
                            scale = TRUE, trim = c(-2, 2),
                            featureIndexBy = 'rownames',
                            cellIndexBy = 'rownames',
                            rowDataName = NULL, colDataName = NULL,
                            aggregateRow = NULL, aggregateCol = NULL,
                            featureAnnotations = NULL, cellAnnotations = NULL,
                            featureAnnotationColor = NULL,
                            cellAnnotationColor = NULL,
                            palette = c("ggplot", "celda", "random"),
                            rowSplitBy = NULL, colSplitBy = NULL,
                            rowLabel = FALSE, colLabel = FALSE,
                            rowLabelSize = 6, colLabelSize = 6,
                            rowDend = TRUE, colDend = TRUE,
                            title = NULL, rowTitle = 'Features',
                            colTitle = 'Cells',
                            rowGap = grid::unit(0, 'mm'),
                            colGap = grid::unit(0, 'mm'),
                            border = FALSE, colorScheme = NULL, ...){
     palette <- match.arg(palette)
84664fd7
     # STAGE 1: Create clean SCE object with only needed information ####
ef608c1e
     ## .selectSCEMatrix, .manageCellVar and .manageFeatureVar perform checks
     useMat <- .selectSCEMatrix(inSCE, useAssay = useAssay,
                                useReducedDim = useReducedDim,
                                returnMatrix = TRUE, cellAsCol = TRUE)
     useAssay <- useMat$names$useAssay
     useReducedDim <- useMat$names$useReducedDim
     useData <- ifelse(!is.null(useAssay), useAssay, useReducedDim)
     ### cell annotation
     colDataName <- unique(c(colDataName, aggregateCol))
     colDataAnns <- lapply(colDataName, function(x) .manageCellVar(inSCE, x))
     if (length(colDataName) > 0)
         colDataAnns <- data.frame(colDataAnns, row.names = colnames(inSCE))
     else
         colDataAnns <- data.frame(row.names = colnames(inSCE))
     colnames(colDataAnns) <- colDataName
     cellAnnotations <- .mergeAnnotationDF(colDataAnns, cellAnnotations)
     if (!is.null(colSplitBy) &&
         any(!colSplitBy %in% colnames(cellAnnotations)))
         stop('Specified `colSplitBy` variables not found.')
     if (isTRUE(colLabel)) {
         colLabelName <- colnames(inSCE)
     } else if (isFALSE(colLabel)) {
         colLabelName <- NULL
     } else {
         colLabelName <- .manageCellVar(inSCE, colLabel)
         colLabel <- TRUE
     }
     ### feature annotation
     rowDataAnns <- data.frame(row.names = rownames(useMat$mat))
     if (!is.null(useAssay)) {
         # When using reducedDim, no rowData can be applied
         rowDataName <- unique(c(rowDataName, aggregateRow))
         rowDataAnns <- lapply(rowDataName, function(x) .manageFeatureVar(inSCE, x))
         if (length(rowDataName) > 0)
             rowDataAnns <- data.frame(rowDataAnns, row.names = rownames(inSCE))
         else
             rowDataAnns <- data.frame(row.names = rownames(inSCE))
         colnames(rowDataAnns) <- rowDataName
     }
     # But customized featureAnnotations should work
     featureAnnotations <- .mergeAnnotationDF(rowDataAnns, featureAnnotations)
     if (!is.null(rowSplitBy) &&
         any(!rowSplitBy %in% colnames(featureAnnotations)))
         stop('Specified `rowSplitBy` variables not found.')
     if (isTRUE(rowLabel)) {
         rowLabelName <- rownames(useMat$mat)
     } else if (isFALSE(rowLabel)) {
         rowLabelName <- NULL
     } else {
         if (!is.null(useAssay)) {
             rowLabelName <- .manageFeatureVar(inSCE, rowLabel)
             rowLabel <- TRUE
         } else {
             # Using customized rowLabel for reducedDim
             if (length(rowLabel) != nrow(useMat$mat))
                 stop("Length of `rowLabel` does not match nrow of specified ",
                      "`useReducedDim`")
             rowLabelName <- rowLabel
             rowLabel <- TRUE
a9233cb6
         }
     }
ef608c1e
     ### create SCE object
     SCE <- SingleCellExperiment(assay = list(useMat$mat),
                                 colData = cellAnnotations,
                                 rowData = featureAnnotations)
     assayNames(SCE) <- useData
     # STAGE 2: Subset SCE object as needed ####
84664fd7
     # Manage cell subsetting
a9233cb6
     if(is.null(cellIndex)){
ef608c1e
         cellIndex <- seq(ncol(SCE))
84664fd7
     } else if (is.character(cellIndex)) {
ef608c1e
         # cellIndexBy not necessarily included in new "SCE"
84664fd7
         cellIndex <- retrieveSCEIndex(inSCE, cellIndex, axis = "col",
                                       by = cellIndexBy)
     } else if (is.logical(cellIndex)) {
         if (length(cellIndex) != ncol(inSCE)) {
             stop("Logical index length does not match ncol(inSCE)")
a9233cb6
         }
         cellIndex <- which(cellIndex)
     }
ef608c1e
     # Manage feature subsetting
     if(is.null(featureIndex)){
         featureIndex <- seq(nrow(SCE))
     } else if (is.character(featureIndex)) {
         if (!is.null(useAssay))
             featureIndex <- retrieveSCEIndex(inSCE, featureIndex, axis = "row",
                                              by = featureIndexBy)
         else
             # When using reducedDim, can only go with "PC" names
             # or customized "by"
             featureIndex <- retrieveSCEIndex(SCE, featureIndex, axis = "row",
                                              by = featureIndexBy)
     } else if (is.logical(featureIndex)) {
         if (length(featureIndex) != nrow(SCE)) {
             stop("Logical index length does not match nrow(inSCE)")
a9233cb6
         }
ef608c1e
         featureIndex <- which(featureIndex)
a9233cb6
     }
ef608c1e
     colnames(SCE) <- colLabelName
     rownames(SCE) <- rowLabelName
     SCE <- SCE[featureIndex, cellIndex]
     ### Scaling should be done before aggregating
     if (isTRUE(doLog)) assay(SCE) <- log1p(assay(SCE))
     if (isTRUE(scale)) assay(SCE) <- as.matrix(computeZScore(assay(SCE)))
     if (!is.null(trim)) assay(SCE) <- trimCounts(assay(SCE), trim)
     # STAGE 3: Aggregate As needed ####
     if (!is.null(aggregateCol)) {
         # TODO: whether to also aggregate numeric variable that users want
         # Might need to use "coldata.merge" in aggregate function
         colIDS <- colData(SCE)[, aggregateCol]
         origRowData <- rowData(SCE)
         SCE <- aggregateAcrossCells(SCE, ids = colIDS,
                                     use.assay.type = useData,
                                     store.number = NULL, statistics = "mean")
         # TODO: `aggregateAcrossCells` produce duplicated variables in colData
         # and unwanted "ncell" variable even if I set `store.number = NULL`.
         colData(SCE) <- colData(SCE)[,aggregateCol,drop=FALSE]
         newColnames <- do.call(paste, c(colData(SCE), list(sep = "_")))
         colnames(SCE) <- newColnames
         rowData(SCE) <- origRowData
     }
     if (!is.null(aggregateRow)) {
         # `aggregateAcrossFeatures` doesn't work by with multi-var
         # Remake one single variable vector
         rowIDS <- rowData(SCE)[, aggregateRow, drop = FALSE]
         rowIDS <- do.call(paste, c(rowIDS, list(sep = "_")))
         origColData <- colData(SCE)
         SCE <- aggregateAcrossFeatures(SCE, ids = rowIDS, average = TRUE,
                                        use.assay.type = useData)
         colData(SCE) <- origColData
     }
     # STAGE 4: Other minor preparation for plotting ####
     mat <- assay(SCE)
     if (is.null(colorScheme)) {
         if (!is.null(trim))
a9233cb6
             colorScheme <- circlize::colorRamp2(c(trim[1], 0, trim[2]),
                                                 c('blue', 'white', 'red'))
ef608c1e
         else
a9233cb6
             colorScheme <- circlize::colorRamp2(c(min(mat),
                                                   (max(mat) + min(mat))/2,
                                                   max(mat)),
                                                 c('blue', 'white', 'red'))
     } else {
ef608c1e
         if (!is.function(colorScheme))
a9233cb6
             stop('`colorScheme` must be a function generated by ',
                  'circlize::colorRamp2')
         breaks <- attr(colorScheme, 'breaks')
ef608c1e
         if (breaks[1] != min(trim) || breaks[length(breaks)] != max(trim))
a9233cb6
             stop('Breaks of `colorScheme` do not match with `trim`.')
     }
ef608c1e
     ### Generate HeatmapAnnotation object
     ca <- NULL
     cellAnnotationColor <- .heatmapAnnColor(SCE, slot = "colData",
                                             custom = cellAnnotationColor,
                                             palette = palette)
     if(dim(cellAnnotations)[2] > 0)
         ca <- ComplexHeatmap::HeatmapAnnotation(df = colData(SCE),
a9233cb6
                                                 col = cellAnnotationColor)
ef608c1e
     ra <- NULL
     featureAnnotationColor <- .heatmapAnnColor(SCE, slot = "rowData",
                                                custom = featureAnnotationColor,
                                                palette = palette)
     if(ncol(rowData(SCE)) > 0)
         ra <- ComplexHeatmap::rowAnnotation(df = rowData(SCE),
                                             col = featureAnnotationColor)
     ### Set split variable
     cs <- NULL
     if (!is.null(colSplitBy)) cs <- colData(SCE)[colSplitBy]
     rs <- NULL
     if (!is.null(rowSplitBy)) rs <- rowData(SCE)[rowSplitBy]
     ###
     if (!is.null(colGap)) {
         if (!inherits(colGap, "unit"))
             stop("`colGap` has to be 'unit' object. Try `grid::unit(", colGap,
                  ", 'mm')`.")
     }
     else colGap <- grid::unit(0, 'mm')
     if (!is.null(rowGap)) {
         if (!inherits(rowGap, "unit"))
             stop("`rowGap` has to be 'unit' object. Try `grid::unit(", rowGap,
                  ", 'mm')`.")
     }
     else rowGap <- grid::unit(0, 'mm')
 
     if (!is.null(useAssay)) name <- useAssay
     else name <- useReducedDim
     hm <- ComplexHeatmap::Heatmap(mat, name = name, left_annotation = ra,
a9233cb6
                                   top_annotation = ca, col = colorScheme,
                                   row_split = rs, column_split = cs,
                                   row_title = rowTitle, column_title = colTitle,
                                   show_row_names = rowLabel,
cd268d2f
                                   row_names_gp = grid::gpar(fontsize = rowLabelSize),
a9233cb6
                                   show_row_dend = rowDend,
                                   show_column_names = colLabel,
cd268d2f
                                   column_names_gp = grid::gpar(fontsize = colLabelSize),
a9233cb6
                                   show_column_dend = colDend,
                                   row_gap = rowGap, column_gap = colGap,
                                   border = border,
                                   ...)
f7ec414a
     return(hm)
a9233cb6
 }
caf80c57
 
ef608c1e
 .mergeAnnotationDF <- function(origin, external) {
     if (!is.null(external)) {
         external <- external[match(rownames(origin), rownames(external)), ,drop = FALSE]
         origin <- cbind(origin, external)
     }
     return(origin)
 }
 
 .heatmapAnnColor <- function(inSCE, slot = c("colData", "rowData"),
                           custom = NULL, palette = palette) {
     slot <- match.arg(slot)
     if (!is.null(custom) && !is.list(custom))
         stop("'cellAnnotationColor' or 'featureAnnotationColor' must be a list.")
     if (is.null(custom)) custom <- list()
     if (slot == "colData") data <- SummarizedExperiment::colData(inSCE)
     if (slot == "rowData") data <- SummarizedExperiment::rowData(inSCE)
     todoNames <- colnames(data)
     todoNames <- todoNames[!todoNames %in% names(custom)]
     newColor <- lapply(todoNames, function(n) {
         var <- data[[n]]
         if (is.factor(var)) categories <- levels(var)
         else categories <- unique(var)
         colors <- discreteColorPalette(length(categories), palette = palette)
         names(colors) <- categories
         return(colors)
     })
     names(newColor) <- todoNames
     custom <- c(custom, newColor)
     return(custom)
 }
 # Test
 #logcounts(sceBatches) <- log1p(counts(sceBatches))
 #plotSCEHeatmap2(sceBatches, "logcounts",
 #                featureIndex = c("GCG1", "COX11", "INS1", "ND41"),
 #                featureIndexBy = rowData(sceBatches)$feature_name,
 #                cellIndex = c("reads.16087_", "Sample_1073_",
 #                              "reads.29330_", "Sample_801_"),
 #                cellIndexBy = paste0(colnames(sceBatches), "_"),
 #                rowLabel = "feature_name", rowDend = FALSE,
 #                cluster_rows = FALSE, colLabel = TRUE, cluster_columns = FALSE,
 #                colDataName = c("batch", "cell_type"), aggregateCol = c("cell_type", "batch"))
 #sce <-plotSCEHeatmap2(sceBatches, aggregateCol = "batch")
 #plotSCEHeatmap2(sceBatches, aggregateCol = c("cell_type", "batch"))
 #plotFindMarkerHeatmap(sce, log2fcThreshold = 0, minClustExprPerc = 0.4,
 #                      maxCtrlExprPerc = 0.5)
 #plotFindMarkerHeatmap(sce, log2fcThreshold = 0, minClustExprPerc = 0.4,
 #                      maxCtrlExprPerc = 0.5,
 #                      aggregateRow = "marker")
 #plotSCEDimReduceColData(sce, "cluster", "UMAP")
 CellVarColor <- function(inSCE, var,
                             palette = c("ggplot", "random", "celda"),
                             seed = 12345, ...) {
     var <- .manageCellVar(inSCE, var = var)
     palette <- match.arg(palette)
     if (is.factor(var)) uniqVar <- levels(var)
     else uniqVar <- unique(var)
     colors <- discreteColorPalette(length(uniqVar), palette = palette, seed = seed, ...)
     names(colors) <- uniqVar
     return(colors)
caf80c57
 }