Browse code

- Synchronise random forest survival and performancePlot orderingList fixes.

Dario Strbenac authored on 25/11/2022 13:55:04
Showing 3 changed files

... ...
@@ -3,8 +3,8 @@ Type: Package
3 3
 Title: A framework for cross-validated classification problems, with
4 4
        applications to differential variability and differential
5 5
        distribution testing
6
-Version: 3.2.2
7
-Date: 2022-11-22
6
+Version: 3.2.3
7
+Date: 2022-11-25
8 8
 Author: Dario Strbenac, Ellis Patrick, Sourish Iyengar, Harry Robertson, Andy Tran, John Ormerod, Graham Mann, Jean Yang
9 9
 Maintainer: Dario Strbenac <dario.strbenac@sydney.edu.au>
10 10
 VignetteBuilder: knitr
... ...
@@ -22,13 +22,20 @@ randomForestPredictInterface <- function(forest, measurementsTest, ..., returnTy
22 22
   if(verbose == 3)
23 23
     message("Predicting using random forest.")  
24 24
   measurementsTest <- as.data.frame(measurementsTest)
25
-  classPredictions <- predict(forest, measurementsTest)$predictions
26
-  classScores <- predict(forest, measurementsTest, predict.all = TRUE)[[1]]
27
-  classScores <- t(apply(classScores, 1, function(sampleRow) table(factor(classes[sampleRow], levels = classes)) / forest$forest$num.trees))
28
-  rownames(classScores) <- names(classPredictions) <- rownames(measurementsTest)
29
-  switch(returnType, class = classPredictions,
30
-         score = classScores,
31
-         both = data.frame(class = classPredictions, classScores, check.names = FALSE))
25
+  
26
+  predictions <- predict(forest, measurementsTest)
27
+  if(predictions$treetype == "Classification")
28
+  {
29
+    classPredictions <- predictions$predictions
30
+    classScores <- predict(forest, measurementsTest, predict.all = TRUE)[[1]]
31
+    classScores <- t(apply(classScores, 1, function(sampleRow) table(factor(classes[sampleRow], levels = classes)) / forest$forest$num.trees))
32
+    rownames(classScores) <- names(classPredictions) <- rownames(measurementsTest)
33
+    switch(returnType, class = classPredictions,
34
+           score = classScores,
35
+           both = data.frame(class = classPredictions, classScores, check.names = FALSE))
36
+  } else { # It is "Survival".
37
+      rowSums(predictions$survival)
38
+  }
32 39
 }
33 40
 
34 41
 ################################################################################
... ...
@@ -150,6 +150,7 @@ setMethod("performancePlot", "list",
150 150
                     }, results, 1:length(results), SIMPLIFY = FALSE))
151 151
   
152 152
   plotData <- plotData[, !duplicated(colnames(plotData))]
153
+  if(length(orderingList) > 0) plotData <- .addUserLevels(plotData, orderingList)
153 154
 
154 155
   # Fill in any missing variables needed for ggplot2 code.
155 156
   if("fillColour" %in% names(characteristicsList))