#' @title Gets cluster estimates using rules generated by #' \link{findMarkersTree} #' @description Get decisions for a matrix of features. Estimate cell #' cluster membership using feature matrix input. #' @param rules List object. The \code{rules} element from #' \code{findMarkersTree} output. Returns NA if cluster estimation was #' ambiguous. #' @param features A L (features) by N (samples) numeric matrix. #' @return A character vector of label predicitions. #' @examples #' \dontrun{ #' library(M3DExampleData) #' counts <- M3DExampleData::Mmus_example_list$data #' # Subset 500 genes for fast clustering #' counts <- as.matrix(counts[seq(1501, 2000), ]) #' # Cluster genes and samples each into 10 modules #' sce <- celda_CG(counts = counts, L = 10, K = 5, verbose = FALSE) #' # Get features matrix and cluster assignments #' factorized <- factorizeMatrix(sce) #' features <- factorized$proportions$cell #' class <- celdaClusters(sce) #' # Generate Decision Tree #' DecTree <- findMarkersTree(features, #' class, #' oneoffMetric = "modified F1", #' threshold = 1, #' consecutiveOneoff = FALSE) #' # Get sample estimates in training data #' getDecisions(DecTree$rules, features) #' } #' @export getDecisions <- function(rules, features) { features <- t(features) votes <- apply(features, 1, .predictClass, rules) return(votes) } # Function to predict class from list of rules .predictClass <- function(samp, rules) { # Initilize possible classes and level classes <- names(rules) level <- 1 # Set maximum levele possible to prevent infinity run maxLevel <- max(unlist(lapply(rules, function(ruleSet) { ruleSet$level }))) while (length(classes) > 1 & level <= maxLevel) { # Get possible classes clLogical <- unlist(lapply(classes, function(cl, rules, level, samp) { # Get the rules for this class ruleClass <- rules[[cl]] # Get the rules for this level ruleClass <- ruleClass[ruleClass$level == level, , drop = FALSE] # Subset class for the features at this level ruleClass$sample <- samp[ruleClass$feature] # For multiple direction == 1, use one with the top stat if (sum(ruleClass$direction == 1) > 1) { ruleClass <- ruleClass[order( ruleClass$direction, decreasing = TRUE ), ] ruleClass <- ruleClass[c( which.max( ruleClass$stat[ruleClass$direction == 1] ), which(ruleClass$direction == -1) ), , drop = FALSE] } # Check for followed rules ruleClass$check <- ruleClass$sample >= ruleClass$value ruleClass$check[ruleClass$direction == -1] <- !ruleClass$check[ ruleClass$direction == -1 ] # Check that all rules were followed ruleFollowed <- mean( ruleClass$check & ruleClass$direction == 1 ) > 0 | mean(ruleClass$check) == 1 return(ruleFollowed) }, rules, level, samp)) # Subset possible classes classes <- classes[clLogical] # Add level level <- level + 1 } # Return if only one class selected if (length(classes) == 1) { return(classes) } else { return(NA) } }