Browse code

- performancePlot no longer ignores orderingList. - Default random forest function now seamlessly works when outcome is a Surv object so classifier doesn't need to be changed from default setting if input data is survival.

Dario Strbenac authored on 25/11/2022 11:40:05
Showing 3 changed files

... ...
@@ -3,8 +3,8 @@ Type: Package
3 3
 Title: A framework for cross-validated classification problems, with
4 4
        applications to differential variability and differential
5 5
        distribution testing
6
-Version: 3.3.4
7
-Date: 2022-11-22
6
+Version: 3.3.5
7
+Date: 2022-11-25
8 8
 Authors@R:
9 9
     c(
10 10
     person(given = "Dario", family = "Strbenac", email = "dario.strbenac@sydney.edu.au", role = c("aut", "cre")),
... ...
@@ -22,13 +22,20 @@ randomForestPredictInterface <- function(forest, measurementsTest, ..., returnTy
22 22
   if(verbose == 3)
23 23
     message("Predicting using random forest.")  
24 24
   measurementsTest <- as.data.frame(measurementsTest)
25
-  classPredictions <- predict(forest, measurementsTest)$predictions
26
-  classScores <- predict(forest, measurementsTest, predict.all = TRUE)[[1]]
27
-  classScores <- t(apply(classScores, 1, function(sampleRow) table(factor(classes[sampleRow], levels = classes)) / forest$forest$num.trees))
28
-  rownames(classScores) <- names(classPredictions) <- rownames(measurementsTest)
29
-  switch(returnType, class = classPredictions,
30
-         score = classScores,
31
-         both = data.frame(class = classPredictions, classScores, check.names = FALSE))
25
+  
26
+  predictions <- predict(forest, measurementsTest)
27
+  if(predictions$treetype == "Classification")
28
+  {
29
+    classPredictions <- predictions$predictions
30
+    classScores <- predict(forest, measurementsTest, predict.all = TRUE)[[1]]
31
+    classScores <- t(apply(classScores, 1, function(sampleRow) table(factor(classes[sampleRow], levels = classes)) / forest$forest$num.trees))
32
+    rownames(classScores) <- names(classPredictions) <- rownames(measurementsTest)
33
+    switch(returnType, class = classPredictions,
34
+           score = classScores,
35
+           both = data.frame(class = classPredictions, classScores, check.names = FALSE))
36
+  } else { # It is "Survival".
37
+      rowSums(predictions$survival)
38
+  }
32 39
 }
33 40
 
34 41
 ################################################################################
... ...
@@ -150,6 +150,7 @@ setMethod("performancePlot", "list",
150 150
                     }, results, 1:length(results), SIMPLIFY = FALSE))
151 151
   
152 152
   plotData <- plotData[, !duplicated(colnames(plotData))]
153
+  if(length(orderingList) > 0) plotData <- .addUserLevels(plotData, orderingList)
153 154
 
154 155
   # Fill in any missing variables needed for ggplot2 code.
155 156
   if("fillColour" %in% names(characteristicsList))