Browse code

- Typo in train.DataFrame variable fixed. - .predict for DLDA made renamed to DLDA to enable easy dispatch by predict method. - randomForest wrapper now uses ranger as the underlying package instead of randomForest.

Dario Strbenac authored on 14/10/2022 00:30:12
Showing 9 changed files

... ...
@@ -3,8 +3,8 @@ Type: Package
3 3
 Title: A framework for cross-validated classification problems, with
4 4
        applications to differential variability and differential
5 5
        distribution testing
6
-Version: 3.1.20
7
-Date: 2022-10-07
6
+Version: 3.1.21
7
+Date: 2022-10-14
8 8
 Author: Dario Strbenac, Ellis Patrick, John Ormerod, Graham Mann, Jean Yang
9 9
 Maintainer: Dario Strbenac <dario.strbenac@sydney.edu.au>
10 10
 VignetteBuilder: knitr
... ...
@@ -912,18 +912,32 @@ train.data.frame <- function(x, outcomeTrain, ...)
912 912
 #' @rdname crossValidate
913 913
 #' @param assayIDs A character vector for assays to train with. Special value \code{"all"}
914 914
 #' uses all assays in the input object.
915
+#' @param performanceType Performance metric to optimise if classifier has any tuning parameters.
915 916
 #' @method train DataFrame
916 917
 #' @export
917
-train.DataFrame <- function(x, outcomeTrain, classifier = "randomForest", multiViewMethod = "none", assayIDs = "all", ...) # ... for prepareData.
918
+train.DataFrame <- function(x, outcomeTrain, classifier = "randomForest", performanceType = "auto",
919
+                            multiViewMethod = "none", assayIDs = "all", ...) # ... for prepareData.
918 920
                    {
919 921
               prepArgs <- list(x, outcomeTrain)
920 922
               extraInputs <- list(...)
921 923
               prepExtras <- numeric()
922 924
               if(length(extraInputs) > 0)
923
-                prepExtras <- which(names(extrasInputs) %in% .ClassifyRenvir[["prepareDataFormals"]])
925
+                prepExtras <- which(names(extraInputs) %in% .ClassifyRenvir[["prepareDataFormals"]])
924 926
               if(length(prepExtras) > 0)
925 927
                 prepArgs <- append(prepArgs, extraInputs[prepExtras])
926 928
               measurementsAndOutcome <- do.call(prepareData, prepArgs)
929
+              
930
+              # Ensure performance type is one of the ones that can be calculated by the package.
931
+              if(!performanceType %in% c("auto", .ClassifyRenvir[["performanceTypes"]]))
932
+                stop(paste("performanceType must be one of", paste(c("auto", .ClassifyRenvir[["performanceTypes"]]), collapse = ", "), "but is", performanceType))
933
+              
934
+              if(performanceType == "auto")
935
+              {
936
+                if(is.character(outcomeTrain) && (length(outcomeTrain) == 1 || length(outcomeTrain) == nrow(x)) || is.factor(outcomeTrain))
937
+                  performanceType <- "Balanced Accuracy"
938
+                else performanceType <- "C-index"
939
+              }
940
+              
927 941
               measurements <- measurementsAndOutcome[["measurements"]]
928 942
               outcomeTrain <- measurementsAndOutcome[["outcome"]]
929 943
               
... ...
@@ -939,11 +953,13 @@ train.DataFrame <- function(x, outcomeTrain, classifier = "randomForest", multiV
939 953
                           # Loop over assays
940 954
                           sapply(classifier[[assayIndex]], function(classifierForAssay) {
941 955
                               # Loop over classifiers
956
+                              
942 957
                                   measurementsUse <- measurements
943 958
                                   if(assayIndex != 1) measurementsUse <- measurements[, mcols(measurements)[, "assay"] == assayIndex, drop = FALSE]
944 959
                                   
945 960
                                   classifierParams <- .classifierKeywordToParams(classifierForAssay)
946
-                                  classifierParams$trainParams@tuneParams <- c(classifierParams$trainParams@tuneParams, performanceType = performanceType)
961
+                                  if(!is.null(classifierParams$trainParams@tuneParams))
962
+                                    classifierParams$trainParams@tuneParams <- c(classifierParams$trainParams@tuneParams, performanceType = performanceType)
947 963
                                   modellingParams <- ModellingParams(balancing = "none", selectParams = NULL,
948 964
                                                                trainParams = classifierParams$trainParams, predictParams = classifierParams$predictParams)
949 965
                                   
... ...
@@ -1019,7 +1035,6 @@ train.DataFrame <- function(x, outcomeTrain, classifier = "randomForest", multiV
1019 1035
 #' @rdname crossValidate
1020 1036
 #' @method train list
1021 1037
 #' @export
1022
-# Each of the first four variables are named lists with names of assays.
1023 1038
 train.list <- function(x, outcomeTrain, ...)
1024 1039
               {
1025 1040
                 # Check data type is valid
... ...
@@ -1062,13 +1077,13 @@ train.MultiAssayExperiment <- function(x, outcomeColumns, ...)
1062 1077
               extraInputs <- list(...)
1063 1078
               prepExtras <- trainExtras <- numeric()
1064 1079
               if(length(extraInputs) > 0)
1065
-                prepExtras <- which(names(extrasInputs) %in% .ClassifyRenvir[["prepareDataFormals"]])
1080
+                prepExtras <- which(names(extraInputs) %in% .ClassifyRenvir[["prepareDataFormals"]])
1066 1081
               if(length(prepExtras) > 0)
1067 1082
                 prepArgs <- append(prepArgs, extraInputs[prepExtras])
1068 1083
               measurementsAndOutcome <- do.call(prepareData, prepArgs)
1069 1084
               trainArgs <- list(measurementsAndOutcome[["measurements"]], measurementsAndOutcome[["outcome"]])
1070 1085
               if(length(extraInputs) > 0)
1071
-                trainExtras <- which(!names(extrasInputs) %in% .ClassifyRenvir[["prepareDataFormals"]])
1086
+                trainExtras <- which(!names(extraInputs) %in% .ClassifyRenvir[["prepareDataFormals"]])
1072 1087
               if(length(trainExtras) > 0)
1073 1088
                 trainArgs <- append(trainArgs, extraInputs[trainExtras])
1074 1089
               do.call(train, trainArgs)
... ...
@@ -1101,10 +1116,8 @@ predict.trainedByClassifyR <- function(object, newData, ...)
1101 1116
               # Some classifiers dangerously use positional matching rather than column name matching.
1102 1117
               # newData columns are sorted so that the right column ordering is guaranteed.
1103 1118
             } else {stop("'newData' is not one of the valid data types. It is of type ", class(newData), '.')}
1104
-  if(is(object, "ClassifyResult"))
1105
-  {
1106
-    object@modellingParams@predictParams@predictor(object@finalModel[[1]], newData)
1107
-  } else if (is(object, "listOfModels")) { # Object is itself a trained model and it is assumed that a predict method is defined for it.
1108
-    mapply(function(model, assay) predict(model, assay), object, newData, SIMPLIFY = FALSE)
1109
-  } else predict(object, newData)
1119
+
1120
+    if (is(object, "listOfModels")) 
1121
+         mapply(function(model, assay) predict(model, assay), object, newData, SIMPLIFY = FALSE)
1122
+    else predict(object, newData) # Object is itself a trained model and it is assumed that a predict method is defined for it.
1110 1123
 }
... ...
@@ -29,7 +29,7 @@ DLDApredictInterface <- function(model, measurementsTest, returnType = c("both",
29 29
     message("Predicting classes using trained DLDA classifier.")
30 30
   
31 31
   #predict(model, as.matrix(test))
32
-  predictions <- .predict(model, as.matrix(measurementsTest)) # Copy located in utilities.R.
32
+  predictions <- predict(model, as.matrix(measurementsTest)) # Copy located in utilities.R.
33 33
 
34 34
   switch(returnType, class = predictions[["class"]], # Factor vector.
35 35
          score = predictions[["posterior"]][, model[["groups"]]], # Numeric matrix.
... ...
@@ -1,27 +1,28 @@
1
-# An Interface for randomForest Package's randomForest Function
2
-randomForestTrainInterface <- function(measurementsTrain, classesTrain, mTryProportion = 0.5, ..., verbose = 3)
1
+# An Interface for ranger Package's randomForest Function
2
+randomForestTrainInterface <- function(measurementsTrain, outcomeTrain, mTryProportion = 0.5, ..., verbose = 3)
3 3
 {
4
-  if(!requireNamespace("randomForest", quietly = TRUE))
5
-    stop("The package 'randomForest' could not be found. Please install it.")
4
+  if(!requireNamespace("ranger", quietly = TRUE))
5
+    stop("The package 'ranger' could not be found. Please install it.")
6 6
   if(verbose == 3)
7
-    message("Fitting random forest classifier to training data and making predictions on test
8
-            data.")
7
+    message("Fitting random forest classifier to training data.")
9 8
   mtry <- round(mTryProportion * ncol(measurementsTrain)) # Number of features to try.
10 9
       
11 10
   # Convert to base data.frame as randomForest doesn't understand DataFrame.
12
-  randomForest::randomForest(as(measurementsTrain, "data.frame"), classesTrain, mtry = mtry, keep.forest = TRUE, ...)
11
+  ranger::ranger(x = as(measurementsTrain, "data.frame"), y = outcomeTrain, mtry = mtry, importance = "impurity_corrected", ...)
13 12
 }
14 13
 attr(randomForestTrainInterface, "name") <- "randomForestTrainInterface"
15 14
     
16
-# forest is of class randomForest
15
+# forest is of class ranger
17 16
 randomForestPredictInterface <- function(forest, measurementsTest, ..., returnType = c("both", "class", "score"), verbose = 3)
18 17
 {
19 18
   returnType <- match.arg(returnType)
19
+  classes <- forest$forest$levels
20 20
   if(verbose == 3)
21 21
     message("Predicting using random forest.")  
22 22
   measurementsTest <- as.data.frame(measurementsTest)
23
-  classPredictions <- predict(forest, measurementsTest)
24
-  classScores <- predict(forest, measurementsTest, type = "vote")[, forest[["classes"]], drop = FALSE]
23
+  classPredictions <- predict(forest, measurementsTest)$predictions
24
+  classScores <- predict(forest, measurementsTest, predict.all = TRUE)[[1]]
25
+  classScores <- t(apply(classScores, 1, function(sampleRow) table(factor(classes[sampleRow], levels = classes)) / forest$forest$num.trees))
25 26
   switch(returnType, class = classPredictions,
26 27
          score = classScores,
27 28
          both = data.frame(class = classPredictions, classScores, check.names = FALSE))
... ...
@@ -35,7 +36,7 @@ randomForestPredictInterface <- function(forest, measurementsTest, ..., returnTy
35 36
 
36 37
 forestFeatures <- function(forest)
37 38
                   {
38
-                    rankedFeaturesIndices <- order(randomForest::importance(forest), decreasing = TRUE)
39
-                    selectedFeaturesIndices <- randomForest::varUsed(forest, count = FALSE)
39
+                    rankedFeaturesIndices <- order(ranger::importance(forest), decreasing = TRUE)
40
+                    selectedFeaturesIndices <- which(ranger::importance(forest) > 0)
40 41
                     list(rankedFeaturesIndices, selectedFeaturesIndices)
41 42
                   }
42 43
\ No newline at end of file
... ...
@@ -1,5 +1,5 @@
1 1
 # An Interface for xgboost Package's xgboost Function
2
-extremeGradientBoostingTrainInterface <- function(measurementsTrain, outcomesTrain, mTryProportion = 0.5, nrounds = 10, ..., verbose = 3)
2
+extremeGradientBoostingTrainInterface <- function(measurementsTrain, outcomeTrain, mTryProportion = 0.5, nrounds = 10, ..., verbose = 3)
3 3
 {
4 4
   if(!requireNamespace("xgboost", quietly = TRUE))
5 5
     stop("The package 'xgboost' could not be found. Please install it.")
... ...
@@ -12,22 +12,22 @@ extremeGradientBoostingTrainInterface <- function(measurementsTrain, outcomesTra
12 12
   
13 13
   isClassification <- FALSE
14 14
   numClasses <- NULL
15
-  if(is(outcomesTrain, "Surv")) # xgboost only knows about numeric vectors.
15
+  if(is(outcomeTrain, "Surv")) # xgboost only knows about numeric vectors.
16 16
   {
17
-    time <- outcomesTrain[, "time"]
18
-    event <- as.numeric(outcomesTrain[, "status"])
17
+    time <- outcomeTrain[, "time"]
18
+    event <- as.numeric(outcomeTrain[, "status"])
19 19
     if(max(event) == 2) event <- event - 1
20
-    outcomesTrain <- time * ifelse(event == 1, 1, -1) # Negative for censoring.
20
+    outcomeTrain <- time * ifelse(event == 1, 1, -1) # Negative for censoring.
21 21
     objective <- "survival:cox"
22 22
   } else { # Classification task.
23 23
     isClassification <- TRUE
24
-    classes <- levels(outcomesTrain)
24
+    classes <- levels(outcomeTrain)
25 25
     numClasses <- length(classes)
26 26
     objective <- "multi:softprob"
27
-    outcomesTrain <- as.numeric(outcomesTrain) - 1 # Classes are represented as 0, 1, 2, ...
27
+    outcomeTrain <- as.numeric(outcomeTrain) - 1 # Classes are represented as 0, 1, 2, ...
28 28
   }
29 29
   
30
-  trained <- xgboost::xgboost(measurementsTrain, outcomesTrain, objective = objective, nrounds = nrounds,
30
+  trained <- xgboost::xgboost(measurementsTrain, outcomeTrain, objective = objective, nrounds = nrounds,
31 31
                               num_class = numClasses, colsample_bynode = mTryProportion, verbose = 0, ...)
32 32
   if(isClassification)
33 33
   {
... ...
@@ -79,4 +79,4 @@ XGBfeatures <- function(booster)
79 79
                     rankedFeaturesIndices <- order(gains, decreasing = TRUE)
80 80
                     selectedFeaturesIndices <- indicesUsed
81 81
                     list(rankedFeaturesIndices, selectedFeaturesIndices)
82
-                  }
83 82
\ No newline at end of file
83
+                  }
... ...
@@ -1,6 +1,6 @@
1 1
 # Random Forest
2 2
 RFparams <- function() {
3
-    trainParams <- TrainParams(randomForestTrainInterface, tuneParams = list(mTryProportion = c(0.25, 0.33, 0.50, 0.66, 0.75, 1.00), ntree = seq(100, 500, 100)),
3
+    trainParams <- TrainParams(randomForestTrainInterface, tuneParams = list(mTryProportion = c(0.25, 0.33, 0.50, 0.66, 0.75, 1.00)),
4 4
                                getFeatures = forestFeatures)
5 5
     predictParams <- PredictParams(randomForestPredictInterface)
6 6
     
... ...
@@ -167,7 +167,7 @@
167 167
           result <- runTest(measurementsTrain, outcomeTrain, measurementsTrain, outcomeTrain,
168 168
                             crossValParams = NULL, modellingParams = modellingParams,
169 169
                             verbose = verbose, .iteration = "internal")
170
-          
170
+
171 171
           predictions <- result[["predictions"]]
172 172
           # Classifiers will use a column "class" and survival models will use a column "risk".
173 173
           if(class(predictions) == "data.frame")
... ...
@@ -222,6 +222,7 @@
222 222
                             measurementsTrain, outcomeTrain,
223 223
                             crossValParams = NULL, modellingParams,
224 224
                             verbose = verbose, .iteration = "internal")
225
+
225 226
           predictions <- result[["predictions"]]
226 227
           if(class(predictions) == "data.frame")
227 228
             predictedOutcome <- predictions[, "class"]
... ...
@@ -275,7 +276,7 @@
275 276
         result <- runTest(measurementsTrain, outcomeTrain, measurementsTrain, outcomeTrain,
276 277
                           crossValParams = NULL, modellingParams,
277 278
                           verbose = verbose, .iteration = "internal")
278
-        
279
+
279 280
         predictions <- result[["predictions"]]
280 281
         if(class(predictions) == "data.frame")
281 282
           predictedOutcome <- predictions[, colnames(predictions) %in% c("class", "risk")]
... ...
@@ -580,7 +581,8 @@
580 581
   obj
581 582
 }
582 583
 
583
-.predict <- function(object, newdata, ...) { # Remove once sparsediscrim is reinstated to CRAN.
584
+#' @method predict dlda
585
+predict.dlda <- function(object, newdata, ...) { # Remove once sparsediscrim is reinstated to CRAN.
584 586
   if (!inherits(object, "dlda"))  {
585 587
     stop("object not of class 'dlda'")
586 588
   }
... ...
@@ -111,6 +111,7 @@ crossValidate(measurements, outcome, ...)
111 111
   x,
112 112
   outcomeTrain,
113 113
   classifier = "randomForest",
114
+  performanceType = "auto",
114 115
   multiViewMethod = "none",
115 116
   assayIDs = "all",
116 117
   ...
... ...
@@ -143,8 +144,7 @@ and performing multiview classification, the respective classification methods w
143 144
 
144 145
 \item{selectionOptimisation}{A character of "Resubstitution", "Nested CV" or "none" specifying the approach used to optimise \code{nFeatures}.}
145 146
 
146
-\item{performanceType}{Default: \code{"auto"}. If \code{"auto"}, then balanced accuracy for classification or C-index for survival. Any one of the
147
-options described in \code{\link{calcPerformance}} may otherwise be specified.}
147
+\item{performanceType}{Performance metric to optimise if classifier has any tuning parameters.}
148 148
 
149 149
 \item{classifier}{A character vector of classification methods to compare. If a named character vector with names corresponding to different assays, 
150 150
 and performing multiview classification, the respective classification methods will be used on each assay.}
... ...
@@ -31,8 +31,7 @@ and performing multiview classification, the respective classification methods w
31 31
 
32 32
 \item{selectionOptimisation}{A character of "Resubstitution", "Nested CV" or "none" specifying the approach used to optimise \code{nFeatures}.}
33 33
 
34
-\item{performanceType}{Default: \code{"auto"}. If \code{"auto"}, then balanced accuracy for classification or C-index for survival. Any one of the
35
-options described in \code{\link{calcPerformance}} may otherwise be specified.}
34
+\item{performanceType}{Performance metric to optimise if classifier has any tuning parameters.}
36 35
 
37 36
 \item{classifier}{A character vector of classification methods to compare. If a named character vector with names corresponding to different assays, 
38 37
 and performing multiview classification, the respective classification methods will be used on each assay.}