R/celdaGridSearch.R
cab49138
 #' @title Run Celda in parallel with multiple parameters
4b1d5604
 #' @description Run Celda with different combinations of parameters and
 #'  multiple chains in parallel. The variable `availableModels` contains the
 #'  potential models that can be utilized. Different parameters to be tested
 #'  should be stored in a list and passed to the argument `paramsTest`. Fixed
 #'  parameters to be used in all models, such as `sampleLabel`, can be passed
 #'  as a list to the argument `paramsFixed`. When `verbose = TRUE`, output
 #'  from each chain will be sent to a log file but not be displayed in stdout.
 #' @param counts Integer matrix. Rows represent features and columns represent
 #'  cells.
 #' @param model Celda model. Options available in `celda::availableModels`.
 #' @param paramsTest List. A list denoting the combinations of parameters to
 #'  run in a celda model. For example, `list(K = seq(5, 10), L = seq(15, 20))`
 #'  will run all combinations of K from 5 to 10 and L from 15 to 20 in model
 #'  `celda_CG()`.
 #' @param paramsFixed List. A list denoting additional parameters to use in
 #'  each celda model. Default NULL.
 #' @param maxIter Integer. Maximum number of iterations of sampling to
 #'  perform. Default 200.
 #' @param nchains Integer. Number of random cluster initializations. Default 3.
 #' @param cores Integer. The number of cores to use for parallel estimation of
 #'  chains. Default 1.
 #' @param bestOnly Logical. Whether to return only the chain with the highest
 #'  log likelihood per combination of parameters or return all chains. Default
 #'  TRUE.
 #' @param perplexity Logical. Whether to calculate perplexity for each model.
 #'  If FALSE, then perplexity can be calculated later with
 #'  `resamplePerplexity()`. Default TRUE.
3cb45224
 #' @param verbose Logical. Whether to print log messages during celda chain
 #'  execution. Default TRUE.
 #' @param logfilePrefix Character. Prefix for log files from worker threads
 #'  and main process. Default "Celda".
4b1d5604
 #' @return Object of class `celdaList`, which contains results for all model
 #'  parameter combinations and summaries of the run parameters
 #' @seealso `celda_G()` for feature clustering, `celda_C()` for clustering of
 #'  cells, and `celda_CG()` for simultaneous clustering of features and cells.
 #'  `subsetCeldaList()` can subset the `celdaList` object. `selectBestModel()`
 #'  can get the best model for each combination of parameters.
a3823d60
 #' @import foreach
d7196f24
 #' @importFrom doParallel registerDoParallel
 #' @importFrom methods is
ff43602d
 #' @examples
 #' \donttest{
 #' data(celdaCGSim)
 #' #' ## Run various combinations of parameters with 'celdaGridSearch'
 #' celdaCGGridSearchRes <- celdaGridSearch(celdaCGSim$counts,
 #'     model = "celda_CG",
 #'     paramsTest = list(K = seq(4, 6), L = seq(9, 11)),
 #'     paramsFixed = list(sampleLabel = celdaCGSim$sampleLabel),
 #'     bestOnly = TRUE,
 #'     nchains = 1,
 #'     cores = 2)
 #' }
2aef131e
 #' @export
4b1d5604
 celdaGridSearch <- function(counts,
     model,
     paramsTest,
     paramsFixed = NULL,
     maxIter = 200,
     nchains = 3,
     cores = 1,
     bestOnly = TRUE,
     perplexity = TRUE,
     verbose = TRUE,
     logfilePrefix = "Celda") {
 
     ## Check parameters
2e877ffe
     .validateCounts(counts)
 
     modelParams <- as.list(formals(model))
     if (!all(names(paramsTest) %in% names(modelParams))) {
         badParams <- setdiff(names(paramsTest), names(modelParams))
         stop("The following elements in 'paramsTest' are not arguments of '",
             model,
             "': ",
             paste(badParams, collapse = ","))
     }
4b1d5604
 
     if (!is.null(paramsFixed) &&
             !all(names(paramsFixed) %in% names(modelParams))) {
         badParams <- setdiff(names(paramsFixed), names(modelParams))
         stop("The following elements in 'paramsFixed' are not arguments",
             " of '",
             model,
             "': ",
             paste(badParams, collapse = ","))
     }
 
     modelParamsRequired <- setdiff(names(modelParams[modelParams == ""]),
         "counts")
 
     if (!all(modelParamsRequired %in% c(names(paramsTest),
         names(paramsFixed)))) {
         missing.params <- setdiff(modelParamsRequired,
             c(names(paramsTest), names(paramsFixed)))
         stop("The following arguments are not in 'paramsTest' or 'paramsFixed'",
             " but are required for '",
             model,
             "': ",
             paste(missing.params, collapse = ","))
     }
 
     if (any(c("z.init", "y.init", "sampleLabel") %in% names(paramsTest))) {
         stop("Setting parameters such as 'z.init', 'y.init', and 'sampleLabel'",
             " in 'paramsTest' is not currently supported.")
     }
 
     if (any(c("nchains") %in% names(paramsTest))) {
         warning("Parameter 'nchains' should not be used within the paramsTest",
             " list")
         paramsTest[["nchains"]] <- NULL
     }
 
     # Set up parameter combinations for each individual chain
     runParams <- base::expand.grid(c(chain = list(seq_len(nchains)),
         paramsTest))
     runParams <- cbind(index = seq_len(nrow(runParams)), runParams)
 
d2e8abdb
     .logMessages(paste(rep("-", 50), collapse = ""),
4b1d5604
         logfile = NULL,
         append = FALSE,
         verbose = verbose)
 
d2e8abdb
     .logMessages("Starting celdaGridSearch with",
4b1d5604
         model,
         logfile = NULL,
         append = TRUE,
         verbose = verbose)
 
d2e8abdb
     .logMessages("Number of cores:",
4b1d5604
         cores,
         logfile = NULL,
         append = TRUE,
         verbose = verbose)
 
d2e8abdb
     .logMessages(paste(rep("-", 50), collapse = ""),
4b1d5604
         logfile = NULL,
         append = TRUE,
         verbose = verbose)
 
     startTime <- Sys.time()
 
     # An MD5 checksum of the count matrix. Passed to models so
     # later on, we can check on celda_* model objects which
     # count matrix was used.
2e877ffe
     counts <- .processCounts(counts)
     countChecksum <- .createCountChecksum(counts)
4b1d5604
 
     ## Use DoParallel to loop through each combination of parameters
     cl <- parallel::makeCluster(cores)
     doParallel::registerDoParallel(cl)
     i <- NULL # Setting visible binding for R CMD CHECK
     resList <- foreach(i = seq_len(nrow(runParams)),
         .export = model,
         .combine = c,
         .multicombine = TRUE) %dopar% {
 
             ## Set up chain parameter list
             current.run <- c(runParams[i, ])
             chainParams <- list()
             for (j in names(paramsTest)) {
                 chainParams[[j]] <- current.run[[j]]
             }
             chainParams$counts <- counts
             chainParams$maxIter <- maxIter
             chainParams$nchain <- 1
             chainParams$countChecksum <- countChecksum
             chainParams$verbose <- verbose
             chainParams$logfile <- paste0(logfilePrefix,
                 "_",
                 paste(paste(
                     colnames(runParams), runParams[i, ], sep = "-"
                 ), collapse = "_"),
75664a2f
                 "_log.txt")
4b1d5604
 
             ## Run model
             res <- do.call(model, c(chainParams, paramsFixed))
             return(list(res))
         }
     parallel::stopCluster(cl)
 
     logliks <- vapply(resList, function(mod) {
06b0c870
         bestLogLikelihood(mod)
4b1d5604
     }, double(1))
     runParams <- cbind(runParams, logLikelihood = logliks)
 
     celdaRes <- methods::new(
         "celdaList",
         runParams = runParams,
         resList = resList,
         countChecksum = countChecksum
     )
 
     if (isTRUE(bestOnly)) {
b41b5de0
         celdaRes <- selectBestModel(celdaRes, asList = TRUE)
4b1d5604
     }
 
     if (isTRUE(perplexity)) {
d2e8abdb
         .logMessages(
4b1d5604
             date(),
             ".. Calculating perplexity",
             append = TRUE,
             verbose = verbose,
             logfile = NULL
         )
         celdaRes <- resamplePerplexity(counts, celdaRes)
f17387ff
     }
4b1d5604
 
     endTime <- Sys.time()
d2e8abdb
     .logMessages(paste(rep("-", 50), collapse = ""),
4b1d5604
         logfile = NULL,
         append = TRUE,
         verbose = verbose)
d2e8abdb
     .logMessages("Completed celdaGridSearch. Total time:",
4b1d5604
         format(difftime(endTime, startTime)),
         logfile = NULL,
         append = TRUE,
         verbose = verbose)
d2e8abdb
     .logMessages(paste(rep("-", 50), collapse = ""),
4b1d5604
         logfile = NULL,
         append = TRUE,
         verbose = verbose)
 
     return(celdaRes)
2aef131e
 }
ac78a723
 
 
6dd3b624
 ################################################################################
4b1d5604
 # Methods for manipulating celdaList objects                                  #
6dd3b624
 ################################################################################
4b1d5604
 #' @title Subset celdaList object from celdaGridSearch
 #' @description Select a subset of models from a `celdaList` object generated
 #'  by `celdaGridSearch()` that match the criteria in the argument `params`.
 #' @param celdaList celdaList Object of class `celdaList`. An object
 #'  containing celda models returned from `celdaGridSearch`.
 #' @param params List. List of parameters used to subset celdaList.
 #' @return A new `celdaList` object containing all models matching the
 #'  provided criteria in `params`. If only one item in the `celdaList` matches
 #'  the given criteria, the matching model will be returned directly instead of
 #'  a `celdaList` object.
 #' @seealso `celdaGridSearch()` can run Celda with multiple parameters and
 #'  chains in parallel. `selectBestModel()` can get the best model for each
 #'  combination of parameters.
6dd3b624
 #' @examples
a49fff03
 #' data(celdaCGGridSearchRes)
4b1d5604
 #' resK5L10 <- subsetCeldaList(celdaCGGridSearchRes, params = list(K = 5,
 #'     L = 10))
6dd3b624
 #' @export
4b1d5604
 subsetCeldaList <- function(celdaList, params) {
     if (!methods::is(celdaList, "celdaList")) {
         stop("celdaList parameter was not of class celdaList.")
     }
 
     ## Check for bad parameter names
06b0c870
     if (!all(names(params) %in% colnames(runParams(celdaList)))) {
         badParams <- setdiff(names(params), colnames(runParams(celdaList)))
4b1d5604
         stop("The following elements in 'params' are not columns in runParams",
             " (celdaList) ",
             paste(badParams, collapse = ","))
     }
 
     ## Subset 'runParams' based on items in 'params'
06b0c870
     newRunParams <- runParams(celdaList)
4b1d5604
     for (i in names(params)) {
         newRunParams <-
             subset(newRunParams, newRunParams[, i] %in% params[[i]])
 
         if (nrow(newRunParams) == 0) {
             stop("No runs matched the criteria given in 'params'. Check",
b6cf56ae
                 " 'runParams(celdaList)' for complete list of parameters used",
                 " to generate 'celdaList'.")
4b1d5604
         }
     }
 
     ## Get index of selected models, subset celdaList, and return
06b0c870
     ix <- match(newRunParams$index, runParams(celdaList)$index)
4b1d5604
     if (length(ix) == 1) {
06b0c870
         return(resList(celdaList)[[ix]])
4b1d5604
     } else {
1eaf423b
         celdaList@runParams <- as.data.frame(newRunParams)
         celdaList@resList <- resList(celdaList)[ix]
4b1d5604
         return(celdaList)
2f59dbe0
     }
d63cdbf4
 }
 
4b1d5604
 
cab49138
 #' @title Select best chain within each combination of parameters
4b1d5604
 #' @description Select the chain with the best log likelihood for each
 #'  combination of tested parameters from a `celdaList` object gererated by
 #'  `celdaGridSearch()`.
 #' @param celdaList Object of class `celdaList`. An object containing celda
 #'  models returned from `celdaGridSearch()`.
b41b5de0
 #' @param asList `TRUE` or `FALSE`. Whether to return the best model as a
 #'  `celdaList` object or not. If `FALSE`, return the best model as a
 #'  corresponding `celda_C`, `celda_G` or `celda_CG` object.
4b1d5604
 #' @return A new `celdaList` object containing one model with the best log
 #'  likelihood for each set of parameters. If only one set of parameters is in
 #'  the `celdaList`, the best model will be returned directly instead of a
 #'  `celdaList` object.
 #' @seealso `celdaGridSearch()` can run Celda with multiple parameters and
 #'  chains in parallel. `subsetCeldaList()` can subset the `celdaList` object.
6dd3b624
 #' @examples
a49fff03
 #' data(celdaCGGridSearchRes)
4b1d5604
 #' ## Returns same result as running celdaGridSearch with "bestOnly = TRUE"
 #' cgsBest <- selectBestModel(celdaCGGridSearchRes)
d7196f24
 #' @importFrom data.table as.data.table
6dd3b624
 #' @export
b41b5de0
 selectBestModel <- function(celdaList, asList = FALSE) {
4b1d5604
     if (!methods::is(celdaList, "celdaList"))
         stop("celdaList parameter was not of class celdaList.")
 
     logLikelihood <- NULL
06b0c870
     group <- setdiff(colnames(runParams(celdaList)),
4b1d5604
         c("index", "chain", "logLikelihood", "mean_perplexity"))
06b0c870
     dt <- data.table::as.data.table(runParams(celdaList))
4b1d5604
     newRunParams <- as.data.frame(dt[, .SD[which.max(logLikelihood)],
         by = group])
06b0c870
     newRunParams <- newRunParams[, colnames(runParams(celdaList))]
ea5b392d
 
06b0c870
     ix <- match(newRunParams$index, runParams(celdaList)$index)
b41b5de0
     if (nrow(newRunParams) == 1 & !asList) {
06b0c870
         return(resList(celdaList)[[ix]])
4b1d5604
     } else {
1eaf423b
         celdaList@runParams <- as.data.frame(newRunParams)
         celdaList@resList <- resList(celdaList)[ix]
4b1d5604
         return(celdaList)
     }
 }
ea5b392d
 
 
4b1d5604
 #' @title Celda models
 #' @description List of available Celda models with correpsonding descriptions.
d1ae35f0
 #' @export
db5762c5
 #' @examples
 #' celda()
d1154f0b
 #' @return None
4b1d5604
 celda <- function() {
     message("celda_C: Clusters the columns of a count matrix containing",
         " single-cell data into K subpopulations.")
     message("celda_G: Clusters the rows of a count matrix containing",
         " single-cell data into L modules.")
     message("celda_CG: Clusters the rows and columns of a count matrix",
         " containing single-cell data into L modules and K subpopulations,",
         " respectively.")
     message("celdaGridSearch: Run Celda with different combinations of",
         " parameters and multiple chains in parallel.")
ac78a723
 }