#' @title Generate marker decision tree from single-cell clustering output #' @description Create a decision tree that identifies gene markers for given #' cell populations. The algorithm uses a decision tree procedure to generate #' a set of rules for each cell cluster defined by single-cell clustering. #' Splits are determined by one of two metrics at each split: a one-off metric #' to determine rules for identifying clusters by a single feature, and a #' balanced metric to determine rules for identifying sets of similar clusters. #' @param x A numeric \link{matrix} of counts or a #' \linkS4class{SingleCellExperiment} #' with the matrix located in the assay slot under \code{useAssay}. #' Rows represent features and columns represent cells. #' @param useAssay A string specifying which \link{assay} #' slot to use if \code{x} is a #' \link[SingleCellExperiment]{SingleCellExperiment} object. Default "counts". #' @param altExpName The name for the \link{altExp} slot #' to use. Default "featureSubset". #' @param class Vector of cell cluster labels. #' @param oneoffMetric A character string. What one-off metric to run, either #' `modified F1` or `pairwise AUC`. Default is 'modified F1'. #' @param metaclusters List where each element is a metacluster (e.g. known #' cell type) and all the clusters within that metacluster (e.g. subtypes). #' @param featureLabels Vector of feature assignments, e.g. which cluster #' does each gene belong to? Useful when using clusters of features #' (e.g. gene modules or Seurat PCs) and user wishes to expand tree results #' to individual features (e.g. score individual genes within marker gene #' modules). #' @param counts Numeric counts matrix. Useful when using clusters #' of features (e.g. gene modules) and user wishes to expand tree results to #' individual features (e.g. score individual genes within marker gene #' modules). Row names should be individual feature names. Ignored if #' \code{x} is a \linkS4class{SingleCellExperiment} object. #' @param celda A \emph{celda_CG} or \emph{celda_C} object. #' Counts matrix has to be provided as well. #' @param seurat A seurat object. Note that the seurat functions #' \emph{RunPCA} and \emph{FindClusters} must have been run on the object. #' @param threshold Numeric between 0 and 1. The threshold for the oneoff #' metric. Smaller values will result in more one-off splits. Default is 0.90. #' @param reuseFeatures Logical. Whether or not a feature can be used more than #' once on the same cluster. Default is TRUE. #' @param altSplit Logical. Whether or not to force a marker for clusters that #' are solely defined by the absence of markers. Default is TRUE. #' @param consecutiveOneoff Logical. Whether or not to allow one-off splits at #' consecutive brances. Default is FALSE. #' @param autoMetaclusters Logical. Whether to identify metaclusters prior to #' creating the tree based on the distance between clusters in a UMAP #' dimensionality reduction projection. A metacluster is simply a large #' cluster that includes several clusters within it. Default is TRUE. #' @param seed Numeric. Seed used to enable reproducible UMAP results #' for identifying metaclusters. Default is 12345. #' @param ... Ignored. Placeholder to prevent check warning. #' @return A named list with six elements: #' \itemize{ #' \item rules - A named list with one data frame for every label. Each #' data frame has five columns and gives the set of rules for disinguishing #' each label. #' \itemize{ #' \item feature - Marker feature, e.g. marker gene name. #' \item direction - Relationship to feature value. -1 if cluster is #' down-regulated for this feature, 1 if cluster is up-regulated. #' \item stat - The performance value returned by the splitting metric for #' this split. #' \item statUsed - Which performance metric was used. "Split" if information #' gain and "One-off" if one-off. #' \item level - The level of the tree at which is rule was defined. 1 is the #' level of the first split of the tree. #' \item metacluster - Optional. If metaclusters were used, the metacluster #' this rule is applied to. #' } #' \item dendro - A dendrogram object of the decision tree output. Plot with #' plotMarkerDendro() #' \item classLabels - A vector of the class labels used in the model, i.e. #' cell cluster labels. #' \item metaclusterLabels - A vector of the metacluster labels #' used in the model #' \item prediction - A character vector of label of predictions of the #' training data using the final model. "MISSING" if label prediction was #' ambiguous. #' \item performance - A named list denoting the training performance of the #' model: #' \itemize{ #' \item accuracy - (number correct/number of samples) for the whole set of #' samples. #' \item balAcc - mean sensitivity across all clusters #' \item meanPrecision - mean precision across all clusters #' \item correct - the number of correct predictions of each cluster #' \item sizes - the number of actual counts of each cluster #' \item sensitivity - the sensitivity of the prediciton of each cluster #' \item precision - the precision of the prediciton of each cluster #' } #' } #' @examples #' \dontrun{ #' # Generate simulated single-cell dataset using celda #' sim_counts <- simulateCells("celda_CG", K = 4, L = 10, G = 100) #' #' # Celda clustering into 5 clusters & 10 modules #' cm <- celda_CG(sim_counts, K = 5, L = 10, verbose = FALSE) #' #' # Get features matrix and cluster assignments #' factorized <- factorizeMatrix(cm) #' features <- factorized$proportions$cell #' class <- celdaClusters(cm) #' #' # Generate Decision Tree #' DecTree <- findMarkersTree(features, class) #' #' # Plot dendrogram #' plotMarkerDendro(DecTree) #' } #' @export setGeneric("findMarkersTree", function(x, ...) { standardGeneric("findMarkersTree")}) #' @rdname findMarkersTree #' @export setMethod("findMarkersTree", signature(x = "SingleCellExperiment"), function(x, useAssay = "counts", altExpName = "featureSubset", class, oneoffMetric = c("modified F1", "pairwise AUC"), metaclusters, featureLabels, counts, seurat, threshold = 0.90, reuseFeatures = FALSE, altSplit = TRUE, consecutiveOneoff = FALSE, autoMetaclusters = TRUE, seed = 12345) { altExp <- SingleCellExperiment::altExp(x, altExpName) if ("celda_parameters" %in% names(S4Vectors::metadata(altExp))) { counts <- SummarizedExperiment::assay(altExp, i = useAssay) # factorize matrix (proportion of each module in each cell) features <- factorizeMatrix(x, useAssay = useAssay, altExpName = altExpName)$proportions$cell # get class labels class <- celdaClusters(x, altExpName = altExpName) # get feature labels featureLabels <- paste0("L", celdaModules(x, altExpName = altExpName)) } else if (methods::hasArg(seurat)) { # get counts matrix from seurat object counts <- as.matrix(seurat@assays$RNA@data) # get class labels class <- as.character(Seurat::Idents(seurat)) # get feature labels featureLabels <- unlist(apply( seurat@reductions$pca@feature.loadings, 1, function(x) { return(names(x)[which(x == max(x))]) } )) # sum counts for each PC in each cell features <- matrix( unlist(lapply(unique(featureLabels), function(pc) { colSums(counts[featureLabels == pc, ]) })), ncol = length(class), byrow = TRUE, dimnames = list(unique(featureLabels), colnames(counts)) ) # normalize column-wise (i.e. convert counts to proportions) features <- apply(features, 2, function(x) { x / sum(x) }) } if (ncol(features) != length(class)) { stop("Number of columns of features must equal length of class") } if (any(is.na(class))) { stop("NA class values") } if (any(is.na(features))) { stop("NA feature values") } # Match the oneoffMetric argument oneoffMetric <- match.arg(oneoffMetric) branchPoints <- .findMarkersTree(features = features, class = class, oneoffMetric = oneoffMetric, metaclusters = metaclusters, featureLabels = featureLabels, counts = counts, seurat = seurat, threshold = threshold, reuseFeatures = reuseFeatures, altSplit = altSplit, consecutiveOneoff = consecutiveOneoff, autoMetaclusters = autoMetaclusters, seed = seed) return(branchPoints) } ) #' @rdname findMarkersTree #' @export setMethod("findMarkersTree", signature(x = "matrix"), function(x, class, oneoffMetric = c("modified F1", "pairwise AUC"), metaclusters, featureLabels, counts, celda, seurat, threshold = 0.90, reuseFeatures = FALSE, altSplit = TRUE, consecutiveOneoff = FALSE, autoMetaclusters = TRUE, seed = 12345) { features <- x if (methods::hasArg(celda)) { # check that counts matrix is provided if (!methods::hasArg(counts)) { stop("Please provide counts matrix in addition to", " celda object.") } # factorize matrix (proportion of each module in each cell) features <- factorizeMatrix(counts, celda)$proportions$cell # get class labels class <- celdaClusters(celda)$z # get feature labels featureLabels <- paste0("L", celdaClusters(celda)$y) } else if (methods::hasArg(seurat)) { # get counts matrix from seurat object counts <- as.matrix(seurat@assays$RNA@data) # get class labels class <- as.character(Seurat::Idents(seurat)) # get feature labels featureLabels <- unlist(apply( seurat@reductions$pca@feature.loadings, 1, function(x) { return(names(x)[which(x == max(x))]) } )) # sum counts for each PC in each cell features <- matrix( unlist(lapply(unique(featureLabels), function(pc) { colSums(counts[featureLabels == pc, ]) })), ncol = length(class), byrow = TRUE, dimnames = list(unique(featureLabels), colnames(counts)) ) # normalize column-wise (i.e. convert counts to proportions) features <- apply(features, 2, function(x) { x / sum(x) }) } if (ncol(features) != length(class)) { stop("Number of columns of features must equal length of class") } if (any(is.na(class))) { stop("NA class values") } if (any(is.na(features))) { stop("NA feature values") } # Match the oneoffMetric argument oneoffMetric <- match.arg(oneoffMetric) branchPoints <- .findMarkersTree(features = features, class = class, oneoffMetric = oneoffMetric, metaclusters = metaclusters, featureLabels = featureLabels, counts = counts, seurat = seurat, threshold = threshold, reuseFeatures = reuseFeatures, altSplit = altSplit, consecutiveOneoff = consecutiveOneoff, autoMetaclusters = autoMetaclusters, seed = seed) return(branchPoints) } ) .findMarkersTree <- function(features, class, oneoffMetric, metaclusters, featureLabels, counts, seurat, threshold, reuseFeatures, altSplit, consecutiveOneoff, autoMetaclusters, seed) { # Transpose features features <- t(features) # If no detailed cell types are provided or to be identified if (!methods::hasArg(metaclusters) & (!autoMetaclusters)) { message("Building tree...") # Set class to factor class <- as.factor(class) # Generate list of tree levels tree <- .generateTreeList( features, class, oneoffMetric, threshold, reuseFeatures, consecutiveOneoff ) # Add alternative node for the solely down-regulated leaf if (altSplit) { tree <- .addAlternativeSplit(tree, features, class) } message("Computing performance metrics...") # Format tree output for plotting and generate summary statistics DTsummary <- .summarizeTree(tree, features, class) # Remove confusing 'value' column DTsummary$rules <- lapply(DTsummary$rules, function(x) { x["value"] <- NULL x }) # Add column to each rules table which specifies its class DTsummary$rules <- mapply(cbind, "class" = as.character(names(DTsummary$rules)), DTsummary$rules, SIMPLIFY = FALSE ) # Generate table for each branch point in the tree DTsummary$branchPoints <- .createBranchPoints(DTsummary$rules) # Add class labels to output DTsummary$classLabels <- class return(DTsummary) } else { # If metaclusters are provided or to be identified # consecutive one-offs break the code(tricky to find 1st balanced split) if (consecutiveOneoff) { stop( "Cannot use metaclusters if consecutive one-offs are allowed.", " Please set the consecutiveOneoff parameter to FALSE." ) } # Check if need to identify metaclusters if (autoMetaclusters & !methods::hasArg(metaclusters)) { message("Identifying metaclusters...") # if seurat object then use seurat's UMAP parameters if (methods::hasArg(seurat)) { suppressMessages(seurat <- Seurat::RunUMAP( seurat, dims = seq(ncol(seurat@reductions$pca@feature.loadings)) )) umap <- seurat@reductions$umap@cell.embeddings } else { if (is.null(seed)) { umap <- uwot::umap( t(sqrt(t(features))), n_neighbors = 15, min_dist = 0.01, spread = 1, n_sgd_threads = 1 ) } else { withr::with_seed( seed, umap <- uwot::umap( t(sqrt(t(features))), n_neighbors = 15, min_dist = 0.01, spread = 1, n_sgd_threads = 1 ) ) } } # dbscan to find metaclusters dbscan <- dbscan::dbscan(umap, eps = 1) # place each population in the correct metacluster mapping <- unlist(lapply( sort(as.integer( unique(class) )), function(population) { # get indexes of occurences of this population indexes <- which(class == population) # get corresponding metaclusters metaIndices <- dbscan$cluster[indexes] # return corresponding metacluster with majority vote return(names(sort(table( metaIndices ), decreasing = TRUE)[1])) } )) # create list which will contain subtypes of each metacluster metaclusters <- vector(mode = "list") # fill in list of populations for each metacluster for (i in unique(mapping)) { metaclusters[[i]] <- sort(as.integer(unique(class)))[which(mapping == i)] } names(metaclusters) <- paste0("M", unique(mapping)) message(paste("Identified", length(metaclusters), "metaclusters")) } # Check that cell types match class labels if (mean(unlist(metaclusters) %in% unique(class)) != 1) { stop( "Provided cell types do not match class labels. ", "Please check the 'metaclusters' argument." ) } # Create vector with metacluster labels metaclusterLabels <- class for (i in names(metaclusters)) { metaclusterLabels[metaclusterLabels %in% metaclusters[[i]]] <- i } # Rename metaclusters with just one cluster oneCluster <- names(metaclusters[lengths(metaclusters) == 1]) if (length(oneCluster) > 0) { oneClusterIndices <- which(metaclusterLabels %in% oneCluster) metaclusterLabels[oneClusterIndices] <- paste0( metaclusterLabels[oneClusterIndices], "(", class[oneClusterIndices], ")" ) names(metaclusters[lengths(metaclusters) == 1]) <- paste0( names(metaclusters[lengths(metaclusters) == 1]), "(", unlist(metaclusters[lengths(metaclusters) == 1]), ")" ) } # create temporary variables for top-level tree tmpThreshold <- threshold # create list to store split off classes at each threshold markerThreshold <- list() # Create top-level tree # while there is still a balanced split at the top-level while (TRUE) { # create tree message("Building top-level tree across all metaclusters...") tree <- .generateTreeList( features, as.factor(metaclusterLabels), oneoffMetric, tmpThreshold, reuseFeatures, consecutiveOneoff ) # Add alternative node for the solely down-regulated leaf tree <- .addAlternativeSplit( tree, features, as.factor(metaclusterLabels) ) # store clusters with markers at current threshold topLevel <- tree[[1]][[1]] if (topLevel$statUsed == "One-off") { markerThreshold[[as.character(tmpThreshold)]] <- unlist(lapply( topLevel[seq(length(topLevel) - 3)], function(marker) { return(marker$group1Consensus) } )) } # if no more balanced split if (length(tree) == 1) { # if all clusters have positive markers if (length(tree[[1]][[1]]) == (length(metaclusters) + 3)) { break } else { # decrease threshold by 10% tmpThreshold <- tmpThreshold * 0.9 message("Decreasing classifier threshold to ", tmpThreshold) next } } # still balanced split else { # get up-regulated clusters at first balanced split upClass <- tree[[2]][[1]][[1]]$group1Consensus # if only 2 clusters at the balanced split then merge them if ((length(upClass) == 1) && (length(tree[[2]][[1]][[1]]$group2Consensus) == 1)) { upClass <- c(upClass, tree[[2]][[1]][[1]]$group2Consensus) } # update metacluster label of each cell tmpMeta <- metaclusterLabels tmpMeta[tmpMeta %in% upClass] <- paste(upClass, sep = "", collapse = "+") # create top-level tree again tmpTree <- .generateTreeList( features, as.factor(tmpMeta), oneoffMetric, tmpThreshold, reuseFeatures, consecutiveOneoff ) # Add alternative node for the solely down-regulated leaf tmpTree <- .addAlternativeSplit( tmpTree, features, as.factor(tmpMeta) ) # if new tree still has balanced split/no markers for some if ((length(tmpTree) > 1) || (length(tree[[1]][[1]]) != (length(metaclusters) + 3))) { # decrease threshold by 10% tmpThreshold <- tmpThreshold * 0.9 message("Decreasing classifier threshold to ", tmpThreshold) } else { # set final metacluster labels to new set of clusters metaclusterLabels <- tmpMeta # set final tree to current tree tree <- tmpTree ## update 'metaclusters' (list of metaclusters) # get celda clusters in these metaclusters newMetacluster <- unlist(metaclusters[upClass]) # remove old metaclusters metaclusters[upClass] <- NULL # add new metacluster to list of metaclusters metaclusters[paste(upClass, sep = "", collapse = "+")] <- list(unname(newMetacluster)) break } } } # re-format output finalTree <- tree tree <- list(rules = .mapClass2features( finalTree, features, as.factor(metaclusterLabels), topLevelMeta = TRUE )$rules) # keep markers at first threshold they reached only markersToRemove <- c() for (thresh in names(markerThreshold)) { thresholdClasses <- markerThreshold[[thresh]] for (cl in thresholdClasses) { curRules <- tree$rules[[cl]] lowMarkerIndices <- which(curRules$direction == 1 & curRules$stat < as.numeric(thresh)) if (length(lowMarkerIndices) > 0 & length(which(curRules$direction == 1)) > 1) { markersToRemove <- c( markersToRemove, curRules[lowMarkerIndices, "feature"] ) } } } tree$rules <- lapply(tree$rules, function(rules) { return(rules[!rules$feature %in% markersToRemove, ]) }) # store final set of top-level markers topLevelMarkers <- unlist(lapply(tree$rules, function(cluster) { markers <- cluster[cluster$direction == 1, "feature"] return(paste(markers, collapse = ";")) })) # create tree dendrogram tree$dendro <- .convertToDendrogram(finalTree, as.factor(metaclusterLabels), splitNames = topLevelMarkers ) # add metacluster label to rules table for (metacluster in names(tree$rules)) { tree$rules[[metacluster]]$metacluster <- metacluster } # Store tree's dendrogram in a separate variable dendro <- tree$dendro # Find which metaclusters have more than one cluster largeMetaclusters <- names(metaclusters[lengths(metaclusters) > 1]) # Update subtype labels for large metaclusters subtypeLabels <- metaclusterLabels subtypeLabels[subtypeLabels %in% largeMetaclusters] <- paste0( subtypeLabels[subtypeLabels %in% largeMetaclusters], "(", class[subtypeLabels %in% largeMetaclusters], ")" ) # Update metaclusters list for (metacluster in names(metaclusters)) { subtypes <- metaclusters[metacluster] subtypes <- lapply(subtypes, function(subtype) { paste0(metacluster, "(", subtype, ")") }) metaclusters[metacluster] <- subtypes } # Create separate trees for each cell type with more than one cluster newTrees <- lapply(largeMetaclusters, function(metacluster) { # Print current status message("Building tree for metacluster ", metacluster) # Remove used features featUse <- colnames(features) if (!reuseFeatures) { tmpRules <- tree$rules[[metacluster]] featUse <- featUse[!featUse %in% tmpRules[tmpRules$direction == 1, "feature"]] } # Create new tree newTree <- .generateTreeList( features[metaclusterLabels == metacluster, featUse], as.factor(subtypeLabels[metaclusterLabels == metacluster]), oneoffMetric, threshold, reuseFeatures, consecutiveOneoff ) # Add alternative node for the solely down-regulated leaf if (altSplit) { newTree <- .addAlternativeSplit( newTree, features[metaclusterLabels == metacluster, featUse], as.factor(subtypeLabels[metaclusterLabels == metacluster]) ) } newTree <- list( rules = .mapClass2features( newTree, features[metaclusterLabels == metacluster, ], as.factor(subtypeLabels[metaclusterLabels == metacluster]) )$rules, dendro = .convertToDendrogram( newTree, as.factor(subtypeLabels[metaclusterLabels == metacluster]) ) ) # Adjust 'rules' table for new tree newTree$rules <- lapply(newTree$rules, function(rules) { rules$level <- rules$level + max(tree$rules[[metacluster]]$level) rules$metacluster <- metacluster rules <- rbind(tree$rules[[metacluster]], rules) }) return(newTree) }) names(newTrees) <- largeMetaclusters # Fix max depth in original tree if (length(newTrees) > 0) { maxDepth <- max(unlist(lapply(newTrees, function(newTree) { lapply(newTree$rules, function(ruleDF) { ruleDF$level }) }))) addDepth <- maxDepth - attributes(dendro)$height dendro <- stats::dendrapply(dendro, function(node, addDepth) { if (attributes(node)$height > 1) { attributes(node)$height <- attributes(node)$height + addDepth + 1 } return(node) }, addDepth) } # Find indices of cell type nodes in tree indices <- lapply( largeMetaclusters, function(metacluster) { # Initialize sub trees, indices string, and flag dendSub <- dendro index <- "" flag <- TRUE while (flag) { # Get the edge with the class of interest whEdge <- which(unlist( lapply( dendSub, function(edge) { metacluster %in% attributes(edge)$classLabels } ) )) # Add this as a string index <- paste0(index, "[[", whEdge, "]]") # Move to this branch dendSub <- eval(parse(text = paste0("dendro", index))) # Is this the only class in that branch flag <- length(attributes(dendSub)$classLabels) > 1 } return(index) } ) names(indices) <- largeMetaclusters # Add each cell type tree for (metacluster in largeMetaclusters) { # Get current tree metaclusterDendro <- newTrees[[metacluster]]$dendro # Adjust labels, member count, and midpoint of nodes dendro <- stats::dendrapply(dendro, function(node) { # Check if in right branch if (metacluster %in% as.character(attributes(node)$classLabels)) { # Replace cell type label with subtype labels labels <- attributes(node)$classLabels labels <- as.character(labels) labels <- labels[labels != metacluster] labels <- c(labels, unique(subtypeLabels) [grep(metacluster, unique(subtypeLabels))]) attributes(node)$classLabels <- labels # Assign new member count for this branch attributes(node)$members <- length(attributes(node)$classLabels) # Assign new midpoint for this branch attributes(node)$midpoint <- (attributes(node)$members - 1) / 2 } return(node) }) # Replace label at new tree's branch point branchPointAttr <- attributes(eval(parse(text = paste0( "dendro", indices[[metacluster]] )))) branchPointLabel <- branchPointAttr$label branchPointStatUsed <- branchPointAttr$statUsed if (!is.null(branchPointLabel)) { attributes(metaclusterDendro)$label <- branchPointLabel attributes(metaclusterDendro)$statUsed <- branchPointStatUsed } # Fix height indLoc <- gregexpr("\\[\\[", indices[[metacluster]])[[1]] indLoc <- indLoc[length(indLoc)] parentIndexString <- substr( indices[[metacluster]], 0, indLoc - 1 ) parentHeight <- attributes(eval(parse( text = paste0("dendro", parentIndexString) )))$height metaclusterHeight <- attributes(metaclusterDendro)$height metaclusterDendro <- stats::dendrapply( metaclusterDendro, function(node, parentHeight, metaclusterHeight) { if (attributes(node)$height > 1) { attributes(node)$height <- parentHeight - 1 - (metaclusterHeight - attributes(node)$height) } return(node) }, parentHeight, metaclusterHeight ) # Add new tree to original tree eval(parse(text = paste0( "dendro", indices[[metacluster]], " <- metaclusterDendro" ))) # Append new tree's 'rules' tables to original tree tree$rules <- append(tree$rules, newTrees[[metacluster]]$rules, after = which(names(tree$rules) == metacluster) ) # Remove old tree's rules tree$rules <- tree$rules[-which(names(tree$rules) == metacluster)] } # Set final tree dendro tree$dendro <- dendro # Get performance statistics message("Computing performance statistics...") perfList <- .getPerformance( tree$rules, features, as.factor(subtypeLabels) ) tree$prediction <- perfList$prediction tree$performance <- perfList$performance # Remove confusing 'value' column tree$rules <- lapply(tree$rules, function(x) { x["value"] <- NULL x }) # add column to each rules table which specifies its class tree$rules <- mapply(cbind, "class" = as.character(names(tree$rules)), tree$rules, SIMPLIFY = FALSE ) # create branch points table branchPoints <- .createBranchPoints(tree$rules, largeMetaclusters, metaclusters) # collapse all rules tables into one large table collapsed <- do.call("rbind", tree$rules) # get top-level rules topLevelRules <- collapsed[collapsed$level == 1, ] # add 'class' column topLevelRules$class <- topLevelRules$metacluster # add to branch point list branchPoints[["top_level"]] <- topLevelRules # check if need to expand features to gene-level if (methods::hasArg(featureLabels) && methods::hasArg(counts)) { message("Computing scores for individual genes...") # make sure feature labels match those in the tree if (!all(unique(collapsed$feature) %in% unique(featureLabels))) { m <- "Provided feature labels don't match those in count matrix." stop(m) } # iterate over branch points branchPoints <- lapply(branchPoints, function(branch) { # iterate over unique features featAUC <- lapply( unique(branch$feature), .getGeneAUC, branch, subtypeLabels, metaclusterLabels, featureLabels, counts ) # update branch table after merging genes data return(do.call("rbind", featAUC)) }) # simplify top-level in rules tables to only up-regulated markers tree$rules <- lapply(tree$rules, function(rule) { return(rule[-intersect( which(rule$level == 1), which(rule$direction == (-1)) ), ]) }) ## add gene-level info to rules tables # collapse branch points tables into one collapsedBranches <- do.call("rbind", branchPoints) collapsedBranches$class <- as.character(collapsedBranches$class) # loop over rules tables and get relevant info tree$rules <- lapply(tree$rules, function(class) { # initialize table to return toReturn <- data.frame(NULL) # loop over rows of this class for (i in seq(nrow(class))) { # extract relevant genes from branch points tables genesAUC <- collapsedBranches[collapsedBranches$feature == class$feature[i] & collapsedBranches$level == class$level[i] & collapsedBranches$class == class$class[i], ] # don't forget top-level if (class$level[i] == 1) { genesAUC <- collapsedBranches[collapsedBranches$feature == class$feature[i] & collapsedBranches$level == class$level[i] & collapsedBranches$class == class$metacluster[i], ] } # merge table toReturn <- rbind(toReturn, genesAUC) } return(toReturn) }) # remove table row names tree$rules <- lapply(tree$rules, function(t) { rownames(t) <- NULL return(t) }) # add feature labels to output tree$featureLabels <- featureLabels } # simplify top-level branch point to save memory branchPoints$top_level <- branchPoints$top_level[branchPoints$top_level$direction == 1, ] branchPoints$top_level <- branchPoints$top_level[!duplicated(branchPoints$top_level), ] # remove branch points row names branchPoints <- lapply(branchPoints, function(br) { rownames(br) <- NULL return(br) }) # adjust subtype labels branchPoints <- lapply(branchPoints, function(br) { br$class <- as.character(br$class) br$class[grepl("\\(.*\\)", br$class)] <- regmatches( br$class[grepl("\\(.*\\)", br$class)], regexpr( pattern = "(?<=\\().*?(?=\\)$)", br$class[grepl("\\(.*\\)", br$class)], perl = TRUE ) ) br$metacluster <- as.character(br$metacluster) br$metacluster[grepl("\\(.*\\)", br$metacluster)] <- gsub( "\\(.*\\)", "", br$metacluster[grepl("\\(.*\\)", br$metacluster)] ) return(br) }) # adjust subtype labels tree$rules <- suppressWarnings(lapply(tree$rules, function(r) { r$class <- as.character(r$class) r$class[grepl("\\(.*\\)", r$class)] <- regmatches( r$class[grepl("\\(.*\\)", r$class)], regexpr( pattern = "(?<=\\().*?(?=\\)$)", r$class[grepl("\\(.*\\)", r$class)], perl = TRUE ) ) r$metacluster[grepl("\\(.*\\)", r$metacluster)] <- gsub( "\\(.*\\)", "", r$metacluster[grepl("\\(.*\\)", r$metacluster)] ) return(r) })) # add to tree tree$branchPoints <- branchPoints # return class labels tree$classLabels <- regmatches( subtypeLabels, regexpr( pattern = "(?<=\\().*?(?=\\)$)", subtypeLabels, perl = TRUE ) ) tree$metaclusterLabels <- metaclusterLabels tree$metaclusterLabels[grepl("\\(.*\\)", metaclusterLabels)] <- gsub( "\\(.*\\)", "", metaclusterLabels[grepl("\\(.*\\)", metaclusterLabels)] ) # Final return return(tree) } } # helper function to create table for each branch point in the tree .createBranchPoints <- function(rules, largeMetaclusters, metaclusters) { # First step differs if metaclusters were used if (methods::hasArg(metaclusters) && (length(largeMetaclusters) > 0)) { # iterate over metaclusters and add the rules for each level branchPoints <- lapply(largeMetaclusters, function(metacluster) { # get names of subtypes subtypes <- metaclusters[[metacluster]] # collapse rules tables of subtypes subtypeRules <- do.call("rbind", rules[subtypes]) # get rules at each level levels <- lapply(seq(2, max(subtypeRules$level)), function(level) { return(subtypeRules[subtypeRules$level == level, ]) }) names(levels) <- paste0( metacluster, "_level_", seq(max(subtypeRules$level) - 1) ) return(levels) }) branchPoints <- unlist(branchPoints, recursive = FALSE) } else { # collapse all rules into one table collapsed <- do.call("rbind", rules) # subset rules at each level branchPoints <- lapply(seq(max(collapsed$level)), function(level) { return(collapsed[collapsed$level == level, ]) }) names(branchPoints) <- paste0("level_", seq(max(collapsed$level))) } # split each level into its branch points branchPoints <- lapply(branchPoints, function(level) { # check if need to split firstFeat <- level$feature[1] firstStat <- level$stat[1] if (setequal( level[ level$feature == firstFeat & level$stat == firstStat, "class" ], unique(level$class) )) { return(level) } # initialize lists for new tables bSplits <- NA oSplits <- NA # get balanced split rows by themselves balS <- level[level$statUsed == "Split", ] # return table for each unique value of 'stat' if (nrow(balS) > 0) { # get unique splits (based on stat) unS <- unique(balS$stat) # return table for each unique split bSplits <- lapply(unS, function(s) { balS[balS$stat == s, ] }) } # get one-off rows by themselves oneS <- level[level$statUsed == "One-off", ] if (nrow(oneS) > 0) { # check if need to split firstFeat <- oneS$feature[1] if (setequal( oneS[oneS$feature == firstFeat, "class"], unique(oneS$class) )) { oSplits <- oneS } # get class groups for each marker markers <- oneS[oneS$direction == 1, "feature"] groups <- unique(unlist(lapply(markers, function(m) { return(paste(as.character(oneS[oneS$feature == m, "class"]), collapse = " " )) }))) # return table for each class group oSplits <- lapply(groups, function(x) { gr <- unlist(strsplit(x, split = " ")) oneS[as.character(oneS$class) %in% gr, ] }) } # rename new tables if (is.list(bSplits)) { names(bSplits) <- paste0( "split_", LETTERS[seq(length(bSplits), 1)] ) } if (is.list(oSplits)) { names(oSplits) <- paste0( "one-off_", LETTERS[seq(length(oSplits), 1)] ) } # return 2 sets of table toReturn <- list(oSplits, bSplits) toReturn <- toReturn[!is.na(toReturn)] toReturn <- unlist(toReturn, recursive = FALSE) return(toReturn) }) # adjust for new tables branchPoints <- lapply(branchPoints, function(br) { if (inherits(br, "list")) { return(br) } else { return(list(br)) } }) branchPoints <- unlist(branchPoints, recursive = FALSE) # replace dots in names of new branches with underscores names(branchPoints) <- gsub( pattern = "\\.([^\\.]*)$", replacement = "_\\1", names(branchPoints) ) return(branchPoints) } # helper function to get AUC for individual genes within feature .getGeneAUC <- function(marker, table, subtypeLabels, metaclusterLabels, featureLabels, counts) { # get up-regulated & down-regulated classes for this feature upClass <- as.character(table[table$feature == marker & table$direction == 1, "class"]) downClasses <- as.character(table[table$feature == marker & table$direction == (-1), "class"]) # subset counts matrix if (table$level[1] > 1) { subCounts <- counts[, which(subtypeLabels %in% c(upClass, downClasses))] } else { subCounts <- counts[, which(metaclusterLabels %in% c(upClass, downClasses))] } # subset class labels if (table$level[1] > 1) { subLabels <- subtypeLabels[which(subtypeLabels %in% c(upClass, downClasses))] } else { subLabels <- metaclusterLabels[which(metaclusterLabels %in% c(upClass, downClasses))] } # set label to 0 if not class of interest subLabels <- as.numeric(subLabels %in% upClass) # get individual features within this marker markers <- rownames(counts)[which(featureLabels == marker)] # get one-vs-all AUC for each gene auc <- unlist(lapply(markers, function(markerGene) { as.numeric(pROC::auc( pROC::roc( subLabels, subCounts[markerGene, ], direction = "<", quiet = TRUE ) )) })) names(auc) <- markers # sort by AUC auc <- sort(auc, decreasing = TRUE) # create table for this marker featTable <- table[table$feature == marker, ] featTable <- featTable[rep(seq_len(nrow(featTable)), each = length(auc)), ] featTable$gene <- rep(names(auc), length(c(upClass, downClasses))) featTable$geneAUC <- rep(auc, length(c(upClass, downClasses))) # return table for merging with main table return(featTable) } # This function generates the decision tree by recursively separating classes. .generateTreeList <- function(features, class, oneoffMetric, threshold, reuseFeatures, consecutiveOneoff = FALSE) { # Initialize Tree treeLevel <- tree <- list() # Initialize the first split treeLevel[[1]] <- list() # Generate the first split at the first level treeLevel[[1]] <- .wrapSplitHybrid( features, class, threshold, oneoffMetric ) # Add set of features used at this split treeLevel[[1]]$fUsed <- unlist(lapply( treeLevel[[1]][names(treeLevel[[1]]) != "statUsed"], function(X) { X$featureName } )) # Initialize split directions treeLevel[[1]]$dirs <- 1 # Add split list as first level tree[[1]] <- treeLevel # Initialize tree depth mDepth <- 1 # Build tree until all leafs are of a single cluster while (length(unlist(treeLevel)) > 0) { # Create list of branches on this level outList <- lapply(treeLevel, function(split, features, class) { # Check for consecutive oneoff tryOneoff <- TRUE if (!consecutiveOneoff & split$statUsed == "One-off") { tryOneoff <- FALSE } # If length(split == 4) than this split is binary node if (length(split) == 4 & length(split[[1]]$group1Consensus) > 1) { # Create branch from this split. branch1 <- .wrapBranchHybrid( split[[1]]$group1, features, class, split$fUsed, threshold, reuseFeatures, oneoffMetric, tryOneoff ) if (!is.null(branch1)) { # Add feature to list of used features. branch1$fUsed <- c(split$fUsed, unlist(lapply( branch1[names(branch1) != "statUsed"], function(X) { X$featureName } ))) # Add the split direction (always 1 when splitting group 1) branch1$dirs <- c(split$dirs, 1) } } else { branch1 <- NULL } # If length(split == 4) than this split is binary node if (length(split) == 4 & length(split[[1]]$group2Consensus) > 1) { # Create branch from this split branch2 <- .wrapBranchHybrid( split[[1]]$group2, features, class, split$fUsed, threshold, reuseFeatures, oneoffMetric, tryOneoff ) if (!is.null(branch2)) { # Add feature to list of used features. branch2$fUsed <- c(split$fUsed, unlist(lapply( branch2[names(branch2) != "statUsed"], function(X) { X$featureName } ))) # Add the split direction (always 2 when splitting group 2) branch2$dirs <- c(split$dirs, 2) } # If length(split > 4) than this split is more than 2 edges # In this case group 1 will always denote leaves. } else if (length(split) > 4) { # Get samples that are never in group 1 in this split group1Samples <- unique(unlist(lapply( split[!names(split) %in% c("statUsed", "fUsed", "dirs")], function(X) { X$group1 } ))) group2Samples <- unique(unlist(lapply( split[!names(split) %in% c("statUsed", "fUsed", "dirs")], function(X) { X$group2 } ))) group2Samples <- group2Samples[!group2Samples %in% group1Samples] # Check that there is still more than one class group2Classes <- levels(droplevels(class[rownames(features) %in% group2Samples])) if (length(group2Classes) > 1) { # Create branch from this split branch2 <- .wrapBranchHybrid( group2Samples, features, class, split$fUsed, threshold, reuseFeatures, oneoffMetric, tryOneoff ) if (!is.null(branch2)) { # Add multiple features branch2$fUsed <- c(split$fUsed, unlist(lapply( branch2[names(branch2) != "statUsed"], function(X) { X$featureName } ))) # Instead of 2, this direction is 1 + the num. splits branch2$dirs <- c( split$dirs, sum(!names(split) %in% c("statUsed", "fUsed", "dirs")) + 1 ) } } else { branch2 <- NULL } } else { branch2 <- NULL } # Combine these branches outBranch <- list(branch1, branch2) # Only keep non-null branches outBranch <- outBranch[!unlist(lapply(outBranch, is.null))] if (length(outBranch) > 0) { return(outBranch) } else { return(NULL) } }, features, class) # Unlist outList so is one list per 'treeLevel' treeLevel <- unlist(outList, recursive = FALSE) # Increase tree depth mDepth <- mDepth + 1 # Add this level to the tree tree[[mDepth]] <- treeLevel } return(tree) } # Wrapper to subset the feature and class set for each split .wrapBranchHybrid <- function(groups, features, class, fUsed, threshold = 0.95, reuseFeatures = FALSE, oneoffMetric, tryOneoff) { # Subset for branch to run split gKeep <- rownames(features) %in% groups # Remove used features? if (reuseFeatures) { fSub <- features[gKeep, ] } else { fSub <- features[gKeep, !colnames(features) %in% fUsed, drop = FALSE] } # Drop levels (class that are no longer in) cSub <- droplevels(class[gKeep]) # If multiple columns in fSub run split, else return null if (ncol(fSub) > 1) { return(.wrapSplitHybrid(fSub, cSub, threshold, oneoffMetric, tryOneoff)) } else { return(NULL) } } # Wrapper function to perform split metrics .wrapSplitHybrid <- function(features, class, threshold = 0.95, oneoffMetric, tryOneoff = TRUE) { # Get best one-2-one splits ## Use modified f1 or pairwise auc? if (tryOneoff) { if (oneoffMetric == "modified F1") { splitMetric <- .splitMetricModF1 } else { splitMetric <- .splitMetricPairwiseAUC } splitStats <- .splitMetricRecursive(features, class, splitMetric = splitMetric ) splitStats <- splitStats[splitStats >= threshold] statUsed <- "One-off" } else { splitStats <- integer(0) } # If no one-2-one split meets threshold, run semi-supervised clustering if (length(splitStats) == 0) { splitMetric <- .splitMetricIGpIGd splitStats <- .splitMetricRecursive(features, class, splitMetric = splitMetric )[1] # Use top statUsed <- "Split" } # Get split for best gene splitList <- lapply( names(splitStats), .getSplit, splitStats, features, class, splitMetric ) # Combine feature rules when same group1 class arises if (length(splitList) > 1) { group1Vec <- unlist(lapply(splitList, function(X) { X$group1Consensus }), recursive = FALSE) splitList <- lapply( unique(group1Vec), function(group1, splitList, group1Vec) { # Get subset with same group1 splitListSub <- splitList[group1Vec == group1] # Get feature, value, and stat for these splitFeature <- unlist(lapply( splitListSub, function(X) { X$featureName } )) splitValue <- unlist(lapply( splitListSub, function(X) { X$value } )) splitStat <- unlist(lapply( splitListSub, function(X) { X$stat } )) # Create a single object and add these splitSingle <- splitListSub[[1]] splitSingle$featureName <- splitFeature splitSingle$value <- splitValue splitSingle$stat <- splitStat return(splitSingle) }, splitList, group1Vec ) } names(splitList) <- unlist(lapply( splitList, function(X) { paste(X$featureName, collapse = ";") } )) # Add statUsed splitList$statUsed <- statUsed return(splitList) } # Recursively run split metric on every feature .splitMetricRecursive <- function(features, class, splitMetric) { splitStats <- vapply(colnames(features), function(feat, features, class, splitMetric) { splitMetric(feat, class, features, rPerf = TRUE) }, features, class, splitMetric, FUN.VALUE = double(1) ) names(splitStats) <- colnames(features) splitStats <- sort(splitStats, decreasing = TRUE) return(splitStats) } # Run pairwise AUC metirc on single feature .splitMetricPairwiseAUC <- function(feat, class, features, rPerf = FALSE) { # Get current feature currentFeature <- features[, feat] # Get unique classes classUnique <- sort(unique(class)) # Do one-to-all to determine top cluster # For each class K1 determine best AUC auc1toAll <- vapply(classUnique, function(k1, class, currentFeature) { # Set value to k1 classK1 <- as.numeric(class == k1) # Get AUC value aucK1 <- pROC::auc(pROC::roc( classK1, currentFeature, direction = "<", quiet = TRUE )) # Return return(aucK1) }, class, currentFeature, FUN.VALUE = double(1)) # Get class with best AUC (Class with generally highest values) classMax <- as.character(classUnique[which.max(auc1toAll)]) # Get other classes classRest <- as.character(classUnique[classUnique != classMax]) # for each second cluster k2 aucFram <- as.data.frame(do.call( rbind, lapply( classRest, function(k2, k1, class, currentFeature) { # keep cells in k1 or k2 only obsKeep <- class %in% c(k1, k2) currentFeatureSubset <- currentFeature[obsKeep] # update cluster assignments currentClusters <- class[obsKeep] # label cells whether they belong to k1 (0 or 1) currentLabels <- as.integer(currentClusters == k1) # get AUC value for this feat-cluster pair rocK2 <- pROC::roc(currentLabels, currentFeatureSubset, direction = "<", quiet = TRUE ) aucK2 <- rocK2$auc coordK2 <- pROC::coords(rocK2, "best", ret = "threshold", transpose = TRUE)[1] # Concatenate vectors statK2 <- c(threshold = coordK2, auc = aucK2) return(statK2) }, classMax, class, currentFeature ) )) # Get Min Value aucMin <- min(aucFram$auc) # Get indices where this AUC occurs aucMinIndices <- which(aucFram$auc == aucMin) # Use maximum value if there are ties aucValue <- max(aucFram$threshold) # Return performance or value? if (rPerf) { return(aucMin) } else { return(aucValue) } } # Run modified F1 metric on single feature .splitMetricModF1 <- function(feat, class, features, rPerf = FALSE) { # Get number of samples len <- length(class) # Get Values featValues <- features[, feat] # Get order of values ord <- order(featValues, decreasing = TRUE) # Get sorted class and values featValuesSort <- featValues[ord] classSort <- class[ord] # Keep splits of the data where the class changes keep <- c( classSort[seq(1, (len - 1))] != classSort[seq(2, (len))] & featValuesSort[seq(1, (len - 1))] != featValuesSort[seq(2, (len))], FALSE ) # Create data.matrix X <- stats::model.matrix(~ 0 + classSort) # Get cumulative sums sRCounts <- apply(X, 2, cumsum) # Keep only values where the class changes sRCounts <- sRCounts[keep, , drop = FALSE] featValuesKeep <- featValuesSort[keep] # Number of each class Xsum <- colSums(X) # Remove impossible splits (No class has > 50% of there samples on one side) sRProbs <- sRCounts %*% diag(Xsum^-1) sKeepPossible <- rowSums(sRProbs >= 0.5) > 0 & rowSums(sRProbs < 0.5) > 0 # Remove anything after a full prob (Doesn't always happen) maxCheck <- min(c(which(apply(sRProbs, 1, max) == 1), nrow(sRProbs))) sKeepCheck <- seq(1, nrow(sRProbs)) %in% seq(1, maxCheck) # Combine logical vectors sKeep <- sKeepPossible & sKeepCheck if (sum(sKeep) > 0) { # Remove these if they exist sRCounts <- sRCounts[sKeep, , drop = FALSE] featValuesKeep <- featValuesKeep[sKeep] # Get left counts sLCounts <- t(Xsum - t(sRCounts)) # Calculate the harmonic mean of Sens, Prec, and Worst Alt Sens statModF1 <- vapply(seq(nrow(sRCounts)), function(i, Xsum, sRCounts, sLCounts) { # Right Side sRRowSens <- sRCounts[i, ] / Xsum # Right sensitivities sRRowPrec <- sRCounts[i, ] / sum(sRCounts[i, ]) # Right prec sRRowF1 <- 2 * (sRRowSens * sRRowPrec) / (sRRowSens + sRRowPrec) sRRowF1[is.nan(sRRowF1)] <- 0 # Get right F1 bestF1Ind <- which.max(sRRowF1) # Which is the best? bestSens <- sRRowSens[bestF1Ind] # The corresponding sensitivity bestPrec <- sRRowPrec[bestF1Ind] # The corresponding precision # Left Side sLRowSens <- sLCounts[i, ] / Xsum # Get left sensitivities worstSens <- min(sLRowSens[-bestF1Ind]) # Get the worst # Get harmonic mean of best sens, best prec, and worst sens HMout <- (3 * bestSens * bestPrec * worstSens) / (bestSens * bestPrec + bestPrec * worstSens + bestSens * worstSens) return(HMout) }, Xsum, sRCounts, sLCounts, FUN.VALUE = double(1) ) # Get Max Value ModF1Max <- max(statModF1) # Get indices where this value occurs (use minimum row) ModF1Index <- which.max(statModF1) # Get value at this point ValueCeiling <- featValuesKeep[ModF1Index] ValueWhich <- which(featValuesSort == ValueCeiling) ModF1Value <- mean(c( featValuesSort[ValueWhich], featValuesSort[ValueWhich + 1] )) } else { ModF1Max <- 0 ModF1Value <- NA } if (rPerf) { return(ModF1Max) } else { return(ModF1Value) } } # Run Information Gain (probability + density) on a single feature .splitMetricIGpIGd <- function(feat, class, features, rPerf = FALSE) { # Get number of samples len <- length(class) # Get Values featValues <- features[, feat] # Get order of values ord <- order(featValues, decreasing = TRUE) # Get sorted class and values featValuesSort <- featValues[ord] classSort <- class[ord] # Keep splits of the data where the class changes keep <- c( classSort[seq(1, (len - 1))] != classSort[seq(2, (len))] & featValuesSort[seq(1, (len - 1))] != featValuesSort[seq(2, (len))], FALSE ) # Create data.matrix X <- stats::model.matrix(~ 0 + classSort) # Get cumulative sums sRCounts <- apply(X, 2, cumsum) # Keep only values where the class changes sRCounts <- sRCounts[keep, , drop = FALSE] featValuesKeep <- featValuesSort[keep] # Number of each class Xsum <- colSums(X) # Remove impossible splits sRProbs <- sRCounts %*% diag(Xsum^-1) sKeep <- rowSums(sRProbs >= 0.5) > 0 & rowSums(sRProbs < 0.5) > 0 if (sum(sKeep) > 0) { # Remove these if they exist sRCounts <- sRCounts[sKeep, , drop = FALSE] featValuesKeep <- featValuesKeep[sKeep] # Get left counts sLCounts <- t(Xsum - t(sRCounts)) # Multiply them to get probabilities sRProbs <- t(t(sRCounts) %*% diag(rowSums(sRCounts)^-1, nrow = nrow(sRCounts))) sLProbs <- t(t(sLCounts) %*% diag(rowSums(sLCounts)^-1, nrow = nrow(sLCounts))) # Multiply them by there log sRTrans <- sRProbs * log(sRProbs) sRTrans[is.na(sRTrans)] <- 0 sLTrans <- sLProbs * log(sLProbs) sLTrans[is.na(sLTrans)] <- 0 # Get entropies HSR <- -rowSums(sRTrans) HSL <- -rowSums(sLTrans) # Get overall probabilities and entropy nProbs <- colSums(X) / len HS <- -sum(nProbs * log(nProbs)) # Get split proporions sProps <- rowSums(sRCounts) / nrow(X) # Get information gain (Probability) IGprobs <- HS - (sProps * HSR + (1 - sProps) * HSL) IGprobs[is.nan(IGprobs)] <- 0 IGprobsQuantile <- IGprobs / max(IGprobs) IGprobsQuantile[is.nan(IGprobsQuantile)] <- 0 # Get proportions at each split classProps <- sRCounts %*% diag(Xsum^-1) classSplit <- classProps >= 0.5 # Initialize information gain density vector splitIGdensQuantile <- rep(0, nrow(classSplit)) # Get unique splits of the data classSplitUnique <- unique(classSplit) classSplitUnique <- classSplitUnique[!rowSums(classSplitUnique) %in% c(0, ncol(classSplitUnique)), , drop = FALSE] # Get density information gain if (nrow(classSplitUnique) > 0) { # Get log(determinant of full matrix) DET <- .psdet(stats::cov(features)) # Information gain of every observation IGdens <- apply( classSplitUnique, 1, .infoGainDensity, X, features, DET ) names(IGdens) <- apply( classSplitUnique * 1, 1, function(X) { paste(X, collapse = "") } ) IGdens[is.nan(IGdens) | IGdens < 0] <- 0 IGdensQuantile <- IGdens / max(IGdens) IGdensQuantile[is.nan(IGdensQuantile)] <- 0 # Get ID of each class split splitsIDs <- apply( classSplit * 1, 1, function(x) { paste(x, collapse = "") } ) # Append information gain density vector for (ID in names(IGdens)) { splitIGdensQuantile[splitsIDs == ID] <- IGdensQuantile[ID] } } # Add this to the other matrix IG <- IGprobsQuantile + splitIGdensQuantile # Get IG(probabilty) of maximum value IGreturn <- IGprobs[which.max(IG)[1]] # Get maximum value maxVal <- featValuesKeep[which.max(IG)] wMax <- max(which(featValuesSort == maxVal)) IGvalue <- mean(c(featValuesSort[wMax], featValuesSort[wMax + 1])) } else { IGreturn <- 0 IGvalue <- NA } # Report maximum ID or value at maximum IG if (rPerf) { return(IGreturn) } else { return(IGvalue) } } # Function to find pseudo-determinant .psdet <- function(x) { if (sum(is.na(x)) == 0) { svalues <- zapsmall(svd(x)$d) sum(log(svalues[svalues > 0])) } else { 0 } } # Function to calculate density information gain .infoGainDensity <- function(splitVector, X, features, DET) { # Get Subsets of the feature matrix sRFeat <- features[as.logical(rowSums(X[, splitVector, drop = FALSE])), , drop = FALSE ] sLFeat <- features[as.logical(rowSums(X[, !splitVector, drop = FALSE])), , drop = FALSE ] # Get pseudo-determinant of covariance matrices DETR <- .psdet(stats::cov(sRFeat)) DETL <- .psdet(stats::cov(sLFeat)) # Get relative sizes sJ <- nrow(features) sJR <- nrow(sRFeat) sJL <- nrow(sLFeat) IUout <- 0.5 * (DET - (sJR / sJ * DETR + sJL / sJ * DETL)) return(IUout) } # Wrapper function for getting split statistics .getSplit <- function(feat, splitStats, features, class, splitMetric) { stat <- splitStats[feat] splitVal <- splitMetric(feat, class, features, rPerf = FALSE) featValues <- features[, feat] # Get classes split to one node node1Class <- class[featValues > splitVal] # Get proportion of each class at each node group1Prop <- table(node1Class) / table(class) group2Prop <- 1 - group1Prop # Get class consensus group1Consensus <- names(group1Prop)[group1Prop >= 0.5] group2Consensus <- names(group1Prop)[group1Prop < 0.5] # Get group samples group1 <- rownames(features)[class %in% group1Consensus] group2 <- rownames(features)[class %in% group2Consensus] # Get class vector group1Class <- droplevels(class[class %in% group1Consensus]) group2Class <- droplevels(class[class %in% group2Consensus]) return( list( featureName = feat, value = splitVal, stat = stat, group1 = group1, group1Class = group1Class, group1Consensus = group1Consensus, group1Prop = c(group1Prop), group2 = group2, group2Class = group2Class, group2Consensus = group2Consensus, group2Prop = c(group2Prop) ) ) } # Function to annotate alternate split of a soley downregulated terminal nodes .addAlternativeSplit <- function(tree, features, class) { # Unlist decsision tree DecTree <- unlist(tree, recursive = FALSE) # Get leaves groupList <- lapply(DecTree, function(split) { # Remove directions split <- split[!names(split) %in% c("statUsed", "fUsed", "dirs")] # Get groups group1 <- unique(unlist(lapply( split, function(node) { node$group1Consensus } ))) group2 <- unique(unlist(lapply( split, function(node) { node$group2Consensus } ))) return(list( group1 = group1, group2 = group2 )) }) # Get vector of each group group1Vec <- unique(unlist(lapply(groupList, function(g) { g$group1 }))) group2Vec <- unique(unlist(lapply(groupList, function(g) { g$group2 }))) # Get group that is never up-regulated group2only <- group2Vec[!group2Vec %in% group1Vec] # Check whether there are solely downregulated splits AltSplitInd <- which(unlist(lapply(groupList, function(g, group2only) { group2only %in% g$group2 }, group2only))) if (length(AltSplitInd) > 0) { AltDec <- max(which(unlist( lapply(groupList, function(g, group2only) { group2only %in% g$group2 }, group2only) ))) # Get split downSplit <- DecTree[[AltDec]] downNode <- downSplit[[1]] # Get classes to rerun branchClasses <- names(downNode$group1Prop) # Get samples from these classes and features from this cluster sampKeep <- class %in% branchClasses featKeep <- !colnames(features) %in% downSplit$fUsed # Subset class and features cSub <- droplevels(class[sampKeep]) fSub <- features[sampKeep, featKeep, drop = FALSE] # Get best alternative split altStats <- do.call( rbind, lapply( colnames(fSub), function(feat, splitMetric, features, class, cInt) { Val <- splitMetric(feat, cSub, fSub, rPerf = FALSE) # Get node1 classes node1Class <- class[features[, feat] > Val] # Get sensitivity/precision/altSens Sens <- sum(node1Class == cInt) / sum(class == cInt) Prec <- mean(node1Class == cInt) # Get Sensitivity of Alternate Classes AltClasses <- unique(class)[unique(class) != cInt] AltSizes <- vapply(AltClasses, function(cAlt, class) { sum(class == cAlt) }, class, FUN.VALUE = double(1) ) AltWrong <- vapply(AltClasses, function(cAlt, node1Class) { sum(node1Class == cAlt) }, node1Class, FUN.VALUE = double(1) ) AltSens <- min(1 - (AltWrong / AltSizes)) # Get harmonic mean HM <- (3 * Sens * Prec * AltSens) / (Sens * Prec + Prec * AltSens + Sens * AltSens) HM[is.nan(HM)] <- 0 # Return return(data.frame( feat = feat, val = Val, stat = HM, stringsAsFactors = FALSE )) }, .splitMetricModF1, fSub, cSub, group2only ) ) altStats <- altStats[order(altStats$stat, decreasing = TRUE), ] # Get alternative splits splitStats <- altStats$stat[1] names(splitStats) <- altStats$feat[1] altSplit <- .getSplit( altStats$feat[1], splitStats, fSub, cSub, .splitMetricModF1 ) # Check that this split out the group2 of interest if (length(altSplit$group1Consensus) == 1) { # Add it to split downSplit[[length(downSplit) + 1]] <- altSplit names(downSplit)[length(downSplit)] <- altStats$feat[1] downSplit <- downSplit[c( which(!names(downSplit) %in% c("statUsed", "fUsed", "dirs")), which(names(downSplit) %in% c("statUsed", "fUsed", "dirs")) )] # Get index of split to add it to branchLengths <- unlist(lapply(tree, length)) branchCum <- cumsum(branchLengths) wBranch <- min(which(branchCum >= AltDec)) if (wBranch == 1) { wSplit <- 1 } else { wSplit <- which(seq( (branchCum[(wBranch - 1)] + 1), branchCum[wBranch] ) == AltDec) } # Add it to decision tree tree[[wBranch]][[wSplit]] <- downSplit } else { cat( "No non-ambiguous rule to separate", group2only, "from", branchClasses, ". No alternative split added." ) } } else { # print("No solely down-regulated cluster to add alternative split.") } return(tree) } #' @title Gets cluster estimates using rules generated by #' `celda::findMarkersTree` #' @description Get decisions for a matrix of features. Estimate cell #' cluster membership using feature matrix input. #' @param rules List object. The `rules` element from `findMarkersTree` #' output. Returns NA if cluster estimation was ambiguous. #' @param features A L(features) by N(samples) numeric matrix. #' @return A character vector of label predicitions. getDecisions <- function(rules, features) { features <- t(features) votes <- apply(features, 1, .predictClass, rules) return(votes) } # Function to predict class from list of rules .predictClass <- function(samp, rules) { # Initilize possible classes and level classes <- names(rules) level <- 1 # Set maximum levele possible to prevent infinity run maxLevel <- max(unlist(lapply(rules, function(ruleSet) { ruleSet$level }))) while (length(classes) > 1 & level <= maxLevel) { # Get possible classes clLogical <- unlist(lapply(classes, function(cl, rules, level, samp) { # Get the rules for this class ruleClass <- rules[[cl]] # Get the rules for this level ruleClass <- ruleClass[ruleClass$level == level, , drop = FALSE] # Subset class for the features at this level ruleClass$sample <- samp[ruleClass$feature] # For multiple direction == 1, use one with the top stat if (sum(ruleClass$direction == 1) > 1) { ruleClass <- ruleClass[order(ruleClass$direction, decreasing = TRUE ), ] ruleClass <- ruleClass[c( which.max(ruleClass$stat[ruleClass$direction == 1]), which(ruleClass$direction == -1) ), , drop = FALSE] } # Check for followed rules ruleClass$check <- ruleClass$sample >= ruleClass$value ruleClass$check[ruleClass$direction == -1] <- !ruleClass$check[ruleClass$direction == -1] # Check that all rules were followed ruleFollowed <- mean(ruleClass$check & ruleClass$direction == 1) > 0 | mean(ruleClass$check) == 1 return(ruleFollowed) }, rules, level, samp)) # Subset possible classes classes <- classes[clLogical] # Add level level <- level + 1 } # Return if only one class selected if (length(classes) == 1) { return(classes) } else { return(NA) } } # Function to summarize and format tree list output by .generateTreeList .summarizeTree <- function(tree, features, class) { # Format tree into dendrogram object dendro <- .convertToDendrogram(tree, class) # Map classes to features class2features <- .mapClass2features(tree, features, class) # Get performance of the tree on training samples perfList <- .getPerformance(class2features$rules, features, class) return( list( rules = class2features$rules, dendro = dendro, prediction = perfList$prediction, performance = perfList$performance ) ) } # Function to reformat raw tree ouput to a dendrogram .convertToDendrogram <- function(tree, class, splitNames = NULL) { # Unlist decision tree (one element for each split) DecTree <- unlist(tree, recursive = FALSE) if (is.null(splitNames)) { # Name split by gene and threshold splitNames <- lapply(DecTree, function(split) { # Remove non-split elements dirs <- paste0(split$dirs, collapse = "_") split <- split[!names(split) %in% c("statUsed", "fUsed", "dirs")] # Get set of features and values for each featuresplits <- lapply(split, function(node) { nodeFeature <- node$featureName nodeStrings <- paste(nodeFeature, collapse = ";") }) # Get split directions names(featuresplits) <- paste(dirs, seq(length(featuresplits)), sep = "_" ) return(featuresplits) }) splitNames <- unlist(splitNames) names(splitNames) <- sub("1_", "", names(splitNames)) } else { names(splitNames) <- seq(length(DecTree[[1]]) - 3) } # Get Stat Used statUsed <- unlist(lapply(DecTree, function(split) { split$statUsed })) statRep <- unlist(lapply( DecTree, function(split) { length(split[!names(split) %in% c("statUsed", "fUsed", "dirs")]) } )) statUsed <- unlist(lapply( seq(length(statUsed)), function(i) { rep(statUsed[i], statRep[i]) } )) names(statUsed) <- names(splitNames) # Create Matrix of results mat <- matrix(0, nrow = length(DecTree), ncol = length(unique(class))) colnames(mat) <- unique(class) for (i in seq(1, length(DecTree))) { # If only one split than ezpz split <- DecTree[[i]] split <- split[!names(split) %in% c("statUsed", "fUsed", "dirs")] if (length(split) == 1) { mat[i, split[[1]]$group1Consensus] <- 1 mat[i, split[[1]]$group2Consensus] <- 2 # Otherwise we need to assign > 2 splits for different higher groups } else { # Get classes in group 1 group1classUnique <- unique(lapply( split, function(X) { X$group1Consensus } )) group1classVec <- unlist(group1classUnique) # Get classes always in group 2 group2classUnique <- unique(unlist(lapply( split, function(X) { X$group2Consensus } ))) group2classUnique <- group2classUnique[!group2classUnique %in% group1classVec] # Assign for (j in seq(length(group1classUnique))) { mat[i, group1classUnique[[j]]] <- j } mat[i, group2classUnique] <- j + 1 } } ## Collapse matrix to get set of direction to include in dendrogram matCollapse <- sort(apply( mat, 2, function(x) { paste(x[x != 0], collapse = "_") } )) matUnique <- unique(matCollapse) # Get branchlist bList <- c() j <- 1 for (i in seq(max(ncharX(matUnique)))) { sLength <- matUnique[ncharX(matUnique) >= i] sLength <- unique(subUnderscore(sLength, i)) for (k in sLength) { bList[j] <- k j <- j + 1 } } # Initialize dendrogram list val <- max(ncharX(matUnique)) + 1 dendro <- list() attributes(dendro) <- list( members = length(matCollapse), classLabels = unique(class), height = val, midpoint = (length(matCollapse) - 1) / 2, label = NULL, name = NULL ) for (i in bList) { # Add element iSplit <- unlist(strsplit(i, "_")) iPaste <- paste0( "dendro", paste(paste0("[[", iSplit, "]]"), collapse = "") ) eval(parse( text = paste0(iPaste, "<-list()") )) # Add attributes classLabels <- names(matCollapse[subUnderscore( matCollapse, ncharX(i) ) == i]) members <- length(classLabels) # Add height, set to one if leaf height <- val - ncharX(i) # Check that this isn't a terminal split if (members == 1) { height <- 1 } # Add labels and stat used if (i %in% names(splitNames)) { lab <- splitNames[i] statUsedI <- statUsed[i] } else { lab <- NULL statUsedI <- NULL } att <- list( members = members, classLabels = classLabels, edgetext = lab, height = height, midpoint = (members - 1) / 2, label = lab, statUsed = statUsedI, name = i ) eval(parse(text = paste0("attributes(", iPaste, ") <- att"))) # Add leaves leaves <- matCollapse[matCollapse == i] if (length(leaves) > 0) { for (l in seq(1, length(leaves))) { # Add element lPaste <- paste0(iPaste, "[[", l, "]]") eval(parse(text = paste0(lPaste, "<-list()"))) # Add attributes members <- 1 leaf <- names(leaves)[l] height <- 0 att <- list( members = members, classLabels = leaf, height = height, label = leaf, leaf = TRUE, name = i ) eval(parse(text = paste0("attributes(", lPaste, ") <- att"))) } } } class(dendro) <- "dendrogram" return(dendro) } # Function to calculate the number of non-underscore characters in a string ncharX <- function(x) { unlist(lapply(strsplit(x, "_"), length)) } # Function to subset a string of characters seperated by underscores subUnderscore <- function(x, n) { unlist(lapply( strsplit(x, "_"), function(y) { paste(y[seq(n)], collapse = "_") } )) } # Function to calculate performance statistics .getPerformance <- function(rules, features, class) { # Get classification accuracy, balanced accurecy, and per class sensitivity ## Get predictions votes <- getDecisions(rules, t(features)) votes[is.na(votes)] <- "MISSING" ## Calculate accuracy statistics and per class sensitivity class <- as.character(class) acc <- mean(votes == as.character(class)) classCorrect <- vapply(unique(class), function(x) { sum(votes == x & class == x) }, FUN.VALUE = double(1) ) classCount <- c(table(class))[unique(class)] sens <- classCorrect / classCount ## Calculate balanced accuracy balacc <- mean(sens) ## Calculate per class and mean precision voteCount <- c(table(votes))[unique(class)] prec <- classCorrect / voteCount meanPrecision <- mean(prec) ## Add performance metrics performance <- list( accuracy = acc, balAcc = balacc, meanPrecision = meanPrecision, correct = classCorrect, sizes = classCount, sensitivity = sens, precision = prec ) return(list( prediction = votes, performance = performance )) } # Create rules of classes and features sequences .mapClass2features <- function(tree, features, class, topLevelMeta = FALSE) { # Get class to feature indices class2featuresIndices <- do.call(rbind, lapply( seq(length(tree)), function(i) { treeLevel <- tree[[i]] c2fsub <- as.data.frame(do.call(rbind, lapply( treeLevel, function(split) { # Keep track of stat used for rule list statUsed <- split$statUsed # Keep only split information split <- split[!names(split) %in% c("statUsed", "fUsed", "dirs")] # Create data frame of split rules edgeFram <- do.call(rbind, lapply(split, function(edge) { # Create data.frame of groups, split-dirs, feature IDs groups <- c(edge$group1Consensus, edge$group2Consensus) sdir <- c( rep(1, length(edge$group1Consensus)), rep(-1, length(edge$group2Consensus)) ) feat <- edge$featureName val <- edge$value stat <- edge$stat data.frame( class = rep(groups, length(feat)), feature = rep(feat, each = length(groups)), direction = rep(sdir, length(feat)), value = rep(val, each = length(groups)), stat = rep(stat, each = length(groups)), stringsAsFactors = FALSE ) })) # Add stat used edgeFram$statUsed <- statUsed return(edgeFram) } ))) c2fsub$level <- i return(c2fsub) } )) rownames(class2featuresIndices) <- NULL # Generate list of rules for each class if (topLevelMeta) { orderedClass <- unique(class2featuresIndices[ class2featuresIndices$direction == 1, "class" ]) } else { orderedClass <- levels(class) } rules <- lapply(orderedClass, function(cl, class2featuresIndices) { class2featuresIndices[ class2featuresIndices$class == cl, colnames(class2featuresIndices) != "class" ] }, class2featuresIndices) names(rules) <- orderedClass return(list(rules = rules)) } #' @title Plots dendrogram of \emph{findMarkersTree} output #' @description Generates a dendrogram of the rules and performance #' (optional) of the decision tree generated by findMarkersTree(). #' @param tree List object. The output of findMarkersTree() #' @param classLabel A character value. The name of a specific label to draw #' the path and rules. If NULL (default), the tree for all clusters is shown. #' @param addSensPrec Logical. Print training sensitivities and precisions #' for each cluster below leaf label? Default is FALSE. #' @param maxFeaturePrint Numeric value. Maximum number of markers to print #' at a given split. Default is 4. #' @param leafSize Numeric value. Size of text below each leaf. Default is 24. #' @param boxSize Numeric value. Size of rule labels. Default is 7. #' @param boxColor Character value. Color of rule labels. Default is black. #' @examples #' \dontrun{ #' # Generate simulated single-cell dataset using celda #' sim_counts <- celda::simulateCells("celda_CG", K = 4, L = 10, G = 100) #' #' # Celda clustering into 5 clusters & 10 modules #' cm <- celda_CG(sim_counts$counts, K = 5, L = 10, verbose = FALSE) #' #' # Get features matrix and cluster assignments #' factorized <- factorizeMatrix(sim_counts$counts, cm) #' features <- factorized$proportions$cell #' class <- celdaClusters(cm) #' #' # Generate Decision Tree #' DecTree <- findMarkersTree(features, class, threshold = 1) #' #' # Plot dendrogram #' plotMarkerDendro(DecTree) #' } #' @return A ggplot2 object #' @export plotMarkerDendro <- function(tree, classLabel = NULL, addSensPrec = FALSE, maxFeaturePrint = 4, leafSize = 10, boxSize = 2, boxColor = "black") { # Get necessary elements dendro <- tree$dendro # Get performance information (training or CV based) if (addSensPrec) { performance <- tree$performance # Create vector of per class performance perfVec <- paste0( "Sens. ", format(round(performance$sensitivity, 2), nsmall = 2), "\n Prec. ", format(round(performance$precision, 2), nsmall = 2) ) names(perfVec) <- names(performance$sensitivity) } # Get dendrogram segments dendSegs <- ggdendro::dendro_data(dendro, type = "rectangle")$segments # Get necessary coordinates to add labels to # These will have y > 1 dendSegs <- unique(dendSegs[dendSegs$y > 1, c("x", "y", "yend", "xend")]) # Labeled splits will be vertical (x != xend) or # Length 0 (x == xend & y == yend) dendSegsAlt <- dendSegs[ dendSegs$x != dendSegs$xend | (dendSegs$x == dendSegs$xend & dendSegs$y == dendSegs$yend), c("x", "xend", "y") ] colnames(dendSegsAlt)[1] <- "xalt" # Label names will be at nodes, these will # Occur at the end of segments segs <- as.data.frame(dendextend::get_nodes_xy(dendro)) colnames(segs) <- c("xend", "yend") # Add labels to nodes segs$label <- gsub(";", "\n", dendextend::get_nodes_attr(dendro, "label")) # Subset for max segs$label <- sapply(segs$label, function(lab, maxFeaturePrint) { loc <- gregexpr("\n", lab)[[1]][maxFeaturePrint] if (!is.na(loc)) { lab <- substr(lab, 1, loc - 1) } return(lab) }, maxFeaturePrint) segs$statUsed <- dendextend::get_nodes_attr(dendro, "statUsed") # If highlighting a class label, remove non-class specific rules if (!is.null(classLabel)) { if (!classLabel %in% names(tree$rules)) { stop("classLabel not a valid class ID.") } dendro <- .highlightClassLabel(dendro, classLabel) keepLabel <- dendextend::get_nodes_attr(dendro, "keepLabel") keepLabel[is.na(keepLabel)] <- FALSE segs$label[!keepLabel] <- NA } # Remove non-labelled nodes & # leaf nodes (yend == 0) segs <- segs[!is.na(segs$label) & segs$yend != 0, ] # Merge to full set of coordinates dendSegsLabelled <- merge(dendSegs, segs) # Remove duplicated labels dendSegsLabelled <- dendSegsLabelled[order(dendSegsLabelled$y, decreasing = TRUE ), ] dendSegsLabelled <- dendSegsLabelled[!duplicated(dendSegsLabelled[ , c( "xend", "x", "yend", "label", "statUsed" ) ]), ] # Merge with alternative x-coordinates for alternative split dendSegsLabelled <- merge(dendSegsLabelled, dendSegsAlt) # Order by height and coordinates dendSegsLabelled <- dendSegsLabelled[order(dendSegsLabelled$x), ] # Find information gain splits igSplits <- dendSegsLabelled$statUsed == "Split" & !duplicated(dendSegsLabelled[, c("xalt", "y")]) # Set xend for IG splits dendSegsLabelled$xend[igSplits] <- dendSegsLabelled$xalt[igSplits] # Set y for non-IG splits dendSegsLabelled$y[!igSplits] <- dendSegsLabelled$y[!igSplits] - 0.2 # Get index of leaf labels leafLabels <- dendextend::get_leaves_attr(dendro, "label") # Adjust leaf labels if there are metacluster labels if (!is.null(tree$metaclusterLabels)) { leafLabels <- regmatches( leafLabels, regexpr( pattern = "(?<=\\().*?(?=\\)$)", leafLabels, perl = TRUE ) ) } # Add sensitivity and precision measurements if (addSensPrec) { leafLabels <- paste(leafLabels, perfVec, sep = "\n") leafAngle <- 0 leafHJust <- 0.5 leafVJust <- -1 } else { leafAngle <- 90 leafHJust <- 1 leafVJust <- 0.5 } # Create plot of dendrogram suppressMessages( dendroP <- ggdendro::ggdendrogram(dendro) + ggplot2::geom_label( data = dendSegsLabelled, ggplot2::aes( x = dendSegsLabelled$xend, y = dendSegsLabelled$y, label = dendSegsLabelled$label ), size = boxSize, label.size = 1, fontface = "bold", vjust = 1, nudge_y = 0.1, color = boxColor ) + ggplot2::theme_bw() + ggplot2::scale_x_reverse( breaks = seq(length(leafLabels)), label = leafLabels ) + ggplot2::scale_y_continuous(expand = c(0, 0)) + ggplot2::theme( panel.grid.major.y = ggplot2::element_blank(), legend.position = "none", panel.grid.minor.y = ggplot2::element_blank(), panel.grid.minor.x = ggplot2::element_blank(), panel.grid.major.x = ggplot2::element_blank(), panel.border = ggplot2::element_blank(), axis.title = ggplot2::element_blank(), axis.ticks = ggplot2::element_blank(), axis.text.x = ggplot2::element_text( hjust = leafHJust, angle = leafAngle, size = leafSize, family = "Palatino", face = "bold", vjust = leafVJust ), axis.text.y = ggplot2::element_blank() ) ) # Check if need to add metacluster labels if (!is.null(tree$metaclusterLabels)) { # store metacluster labels to add newLabels <- unique(tree$branchPoints$top_level$metacluster) # adjust labels for metaclusters of size one newLabels <- unlist(lapply(newLabels, function(curMeta) { if (substr(curMeta, nchar(curMeta), nchar(curMeta)) == ")") { return(gsub( pattern = "\\(.*\\)$", replacement = "", x = curMeta )) } else { return(curMeta) } })) # Create table for metacluster labels metaclusterText <- dendSegsLabelled[ dendSegsLabelled$y == max(dendSegsLabelled$y), c("xend", "y", "label") ] metaclusterText$label <- newLabels # Add metacluster labels to top of plot dendroP <- dendroP + ggplot2::geom_text( data = metaclusterText, ggplot2::aes( x = metaclusterText$xend, y = metaclusterText$y, label = metaclusterText$label, fontface = 2 ), angle = 90, nudge_y = 0.5, family = "Palatino", size = leafSize / 3 ) # adjust coordinates of plot to show labels dendroP <- dendroP + ggplot2::coord_cartesian( ylim = c( 0, max(dendSegsLabelled$y + 1) ) ) } # Increase line width slightly for aesthetic purposes dendroP$layers[[2]]$aes_params$size <- 1.3 return(dendroP) } # Function to reformat the dendrogram to draw path to a specific class .highlightClassLabel <- function(dendro, classLabel) { # Reorder dendrogram flag <- TRUE bIndexString <- "" # Get branch branch <- eval(parse(text = paste0("dendro", bIndexString))) while (flag) { # Get attributes att <- attributes(branch) # Get split with the label of interest labList <- lapply(branch, function(split) { attributes(split)$classLabels }) wSplit <- which(unlist(lapply( labList, function(vec) { classLabel %in% vec } ))) # Keep labels for this branch branch <- lapply(branch, function(edge) { attributes(edge)$keepLabel <- TRUE return(edge) }) # Make a dendrogram class again class(branch) <- "dendrogram" attributes(branch) <- att # Add branch to dendro eval(parse(text = paste0("dendro", bIndexString, "<- branch"))) # Create new bIndexString bIndexString <- paste0(bIndexString, "[[", wSplit, "]]") # Get branch branch <- eval(parse(text = paste0("dendro", bIndexString))) # Add flag flag <- attributes(branch)$members > 1 } return(dendro) } #' @title Generate heatmap for a marker decision tree #' @description Creates heatmap for a specified branch point in a marker tree. #' @param tree A decision tree returned from \link{findMarkersTree} function. #' @param counts Numeric matrix. Gene-by-cell counts matrix. #' @param branchPoint Character. Name of branch point to plot heatmap for. #' Name should match those in \emph{tree$branchPoints}. #' @param featureLabels List of feature cluster assignments. Length should #' be equal to number of rows in counts matrix, and formatting should match #' that used in \emph{findMarkersTree()}. Required when using clusters #' of features and not previously provided to \emph{findMarkersTree()} #' @param topFeatures Integer. Number of genes to plot per marker module. #' Genes are sorted based on their AUC for their respective cluster. #' Default is 10. #' @param silent Logical. Whether to avoid plotting heatmap to screen. #' Default is FALSE. #' @return A heatmap visualizing the counts matrix for the cells and genes at #' the specified branch point. #' @examples #' \dontrun{ #' # Generate simulated single-cell dataset using celda #' sim_counts <- simulateCells("celda_CG", K = 4, L = 10, G = 100) #' #' # Celda clustering into 5 clusters & 10 modules #' cm <- celda_CG(sim_counts, K = 5, L = 10, verbose = FALSE) #' #' # Get features matrix and cluster assignments #' factorized <- factorizeMatrix(cm) #' features <- factorized$proportions$cell #' class <- celdaClusters(cm) #' #' # Generate Decision Tree #' DecTree <- findMarkersTree(features, class, threshold = 1) #' #' # Plot example heatmap #' plotMarkerHeatmap(DecTree, assay(sim_counts), #' branchPoint = "top_level", #' featureLabels = paste0("L", celdaModules(cm))) #' } #' @export plotMarkerHeatmap <- function(tree, counts, branchPoint, featureLabels, topFeatures = 10, silent = FALSE) { # get branch point to plot branch <- tree$branchPoints[[branchPoint]] # check that user entered valid branch point name if (is.null(branch)) { stop( "Invalid branch point.", " Branch point name should match one of those in tree$branchPoints." ) } # convert counts matrix to matrix (e.g. from dgCMatrix) counts <- as.matrix(counts) # get marker features marker <- unique(branch$feature) # add feature labels if ("featureLabels" %in% names(tree)) { featureLabels <- tree$featureLabels } # check that feature labels are provided if (missing(featureLabels) & !("featureLabels" %in% names(tree)) & (sum(marker %in% rownames(counts)) != length(marker))) { stop("Please provide feature labels, i.e. gene cluster labels") } else { if (missing(featureLabels) & !("featureLabels" %in% names(tree)) & (sum(marker %in% rownames(counts)) == length(marker))) { featureLabels == rownames(counts) } } # make sure feature labels match the table if (!all(branch$feature %in% featureLabels)) { stop( "Provided feature labels don't match those in the tree.", " Please check the feature names in the tree's rules' table." ) } # if top-level in metaclusters tree if (branchPoint == "top_level") { # get unique metaclusters metaclusters <- unique(branch$metacluster) # list which will contain final set of genes for heatmap whichFeatures <- c() # loop over unique metaclusters for (meta in metaclusters) { # subset table curMeta <- branch[branch$metacluster == meta, ] # if we have gene-level info in the tree if ("gene" %in% names(branch)) { # sort by gene AUC score curMeta <- curMeta[order(curMeta$geneAUC, decreasing = TRUE), ] # get genes genes <- unique(curMeta$gene) # keep top N features genes <- utils::head(genes, topFeatures) # get gene indices markerGenes <- which(rownames(counts) %in% genes) # get features with non-zero variance to avoid clustering error markerGenes <- .removeZeroVariance( counts, cells = which( tree$metaclusterLabels %in% unique(curMeta$metacluster) ), markers = markerGenes ) # add to list of features whichFeatures <- c(whichFeatures, markerGenes) } else { # current markers curMarker <- unique(curMeta$feature) # get marker gene indices markerGenes <- which(featureLabels %in% curMarker) # get features with non-zero variance to avoid error markerGenes <- .removeZeroVariance( counts, cells = which( tree$metaclusterLabels %in% unique(curMeta$metacluster) ), markers = markerGenes ) # add to list of features whichFeatures <- c(whichFeatures, markerGenes) } } # order the metaclusters by size colOrder <- data.frame( groupName = names(sort( table(tree$metaclusterLabels), decreasing = TRUE )), groupIndex = seq_along(unique(tree$metaclusterLabels)) ) # order the markers for metaclusters allMarkers <- stats::setNames(as.list(colOrder$groupName), colOrder$groupName) allMarkers <- lapply(allMarkers, function(x) { unique(branch[branch$metacluster == x, "feature"]) }) rowOrder <- data.frame( groupName = unlist(allMarkers), groupIndex = seq_along(unlist(allMarkers)) ) toRemove <- which(!rowOrder$groupName %in% featureLabels[whichFeatures]) if (length(toRemove) > 0) { rowOrder <- rowOrder[-toRemove, ] } # sort cells according to metacluster size x <- tree$metaclusterLabels y <- colOrder$groupName sortedCells <- seq(ncol(counts))[order(match(x, y))] # create heatmap with only the markers return( plotHeatmap( counts = counts, z = tree$metaclusterLabels, y = featureLabels, featureIx = whichFeatures, cellIx = sortedCells, showNamesFeature = TRUE, main = "Top-level", silent = silent, treeheightFeature = 0, colGroupOrder = colOrder, rowGroupOrder = rowOrder, treeheightCell = 0 ) ) } # if balanced split if (branch$statUsed[1] == "Split") { # keep entries for balanced split only (in case of alt. split) split <- branch$feature[1] branch <- branch[branch$feature == split, ] # get up-regulated and down-regulated classes upClasses <- unique(branch[branch$direction == 1, "class"]) downClasses <- unique(branch[branch$direction == (-1), "class"]) # re-order cells to keep up and down separate on the heatmap reorderedCells <- c( (which(tree$classLabels %in% upClasses) [order(tree$classLabels[tree$classLabels %in% upClasses])]), (which(tree$classLabels %in% downClasses) [order(tree$classLabels[tree$classLabels %in% downClasses])]) ) # cell annotation based on split cellAnno <- data.frame( split = rep("Down-regulated", ncol(counts)), stringsAsFactors = FALSE ) cellAnno$split[which(tree$classLabels %in% upClasses)] <- "Up-regulated" rownames(cellAnno) <- colnames(counts) # if we have gene-level info in the tree if (("gene" %in% names(branch))) { # get genes genes <- unique(branch$gene) # keep top N features genes <- utils::head(genes, topFeatures) # get gene indices whichFeatures <- which(rownames(counts) %in% genes) # get features with non-zero variance to avoid error whichFeatures <- .removeZeroVariance(counts, cells = which(tree$classLabels %in% unique(branch$class)), markers = whichFeatures ) # create heatmap with only the split feature and split classes return( plotHeatmap( counts = counts, z = tree$classLabels, y = featureLabels, featureIx = whichFeatures, cellIx = reorderedCells, clusterCell = FALSE, showNamesFeature = TRUE, main = branchPoint, silent = silent, treeheightFeature = 0, treeheightCell = 0, annotationCell = cellAnno ) ) } else { # get features with non-zero variance to avoid error whichFeatures <- .removeZeroVariance( counts, cells = reorderedCells, markers = which(featureLabels == branch$feature[1]) ) # create heatmap with only the split feature and split classes return( plotHeatmap( counts = counts, z = tree$classLabels, y = featureLabels, featureIx = whichFeatures, cellIx = reorderedCells, clusterCell = FALSE, showNamesFeature = TRUE, main = branchPoint, silent = silent, treeheightFeature = 0, treeheightCell = 0, annotationCell = cellAnno ) ) } } # if one-off split if (branch$statUsed[1] == "One-off") { # get unique classes classes <- unique(branch$class) # list which will contain final set of genes for heatmap whichFeatures <- c() # loop over unique classes for (class in classes) { # subset table curClass <- branch[branch$class == class & branch$direction == 1, ] # if we have gene-level info in the tree if (("gene" %in% names(branch))) { # get genes genes <- unique(curClass$gene) # keep top N features genes <- utils::head(genes, topFeatures) # get gene indices markerGenes <- which(rownames(counts) %in% genes) # get features with non-zero variance to avoid error markerGenes <- .removeZeroVariance( counts, cells = which(tree$classLabels %in% unique(curClass$class)), markers = markerGenes ) # add to list of features whichFeatures <- c(whichFeatures, markerGenes) } else { # get features with non-zero variance to avoid error markerGenes <- .removeZeroVariance( counts, cells = which(tree$classLabels %in% unique(curClass$class)), markers = which(featureLabels %in% unique(curClass$feature)) ) # add to list of features whichFeatures <- c(whichFeatures, markerGenes) } } # order the clusters such that up-regulated come first colOrder <- data.frame( groupName = unique(branch[ order(branch$direction, decreasing = TRUE), "class" ]), groupIndex = seq_along(unique(branch$class)) ) # order the markers for clusters allMarkers <- stats::setNames(as.list(colOrder$groupName), colOrder$groupName) allMarkers <- lapply(allMarkers, function(x) { unique(branch[branch$class == x & branch$direction == 1, "feature"]) }) rowOrder <- data.frame( groupName = unlist(allMarkers), groupIndex = seq_along(unlist(allMarkers)) ) toRemove <- which(!rowOrder$groupName %in% featureLabels[whichFeatures]) if (length(toRemove) > 0) { rowOrder <- rowOrder[-toRemove, ] } # sort cells according to metacluster size x <- tree$classLabels # [tree$classLabels %in% unique(branch$class)] y <- colOrder$groupName sortedCells <- seq(ncol(counts))[order(match(x, y))] sortedCells <- sortedCells[seq(sum(tree$classLabels %in% classes))] # create heatmap with only the split features and split classes return( plotHeatmap( counts = counts, z = tree$classLabels, y = featureLabels, featureIx = whichFeatures, cellIx = sortedCells, showNamesFeature = TRUE, main = branchPoint, silent = silent, treeheightFeature = 0, colGroupOrder = colOrder, rowGroupOrder = rowOrder, treeheightCell = 0 ) ) } } # helper function to identify zero-variance genes in a counts matrix .removeZeroVariance <- function(counts, cells, markers) { # subset counts matrix counts <- counts[, cells] # scale rows counts <- t(scale(t(counts))) # get indices of genes which have NA zeroVarianceGenes <- which(!stats::complete.cases(counts)) # find overlap between zero-variance genes and marker genes zeroVarianceMarkers <- intersect(zeroVarianceGenes, markers) # return indices of marker genes without zero-variance if (length(zeroVarianceMarkers) > 0) { return(markers[-which(markers %in% zeroVarianceMarkers)]) } else { return(markers) } }