Browse code

- Random forest's survival mode now calculates risk values by multiplying sum of survival probabilities by negative one. - performancePlot and samplesMetricMap now use a variable named metric to consistently refer to the metric. - samplesMetricMap now has default value of "auto" for metric. Also, the metric will be calculated rather then stopping with an error, if the metric hasn't been calculated previously using calcCVperformance. - samplesMetricMap default metric is accuracy, to be consistent with performancePlot.

Dario Strbenac authored on 27/11/2022 10:45:33
Showing 7 changed files

... ...
@@ -3,8 +3,8 @@ Type: Package
3 3
 Title: A framework for cross-validated classification problems, with
4 4
        applications to differential variability and differential
5 5
        distribution testing
6
-Version: 3.2.3
7
-Date: 2022-11-25
6
+Version: 3.2.4
7
+Date: 2022-11-27
8 8
 Author: Dario Strbenac, Ellis Patrick, Sourish Iyengar, Harry Robertson, Andy Tran, John Ormerod, Graham Mann, Jean Yang
9 9
 Maintainer: Dario Strbenac <dario.strbenac@sydney.edu.au>
10 10
 VignetteBuilder: knitr
... ...
@@ -27,7 +27,7 @@ Description: The software formalises a framework for classification in R.
27 27
              may be developed by the user, by creating an interface to the framework.
28 28
 License: GPL-3
29 29
 Packaged: 2014-10-18 11:16:55 UTC; dario
30
-RoxygenNote: 7.2.1
30
+RoxygenNote: 7.2.2
31 31
 NeedsCompilation: yes
32 32
 Collate:
33 33
     'ROCplot.R'
... ...
@@ -34,7 +34,7 @@ randomForestPredictInterface <- function(forest, measurementsTest, ..., returnTy
34 34
            score = classScores,
35 35
            both = data.frame(class = classPredictions, classScores, check.names = FALSE))
36 36
   } else { # It is "Survival".
37
-      rowSums(predictions$survival)
37
+      -1 * rowSums(predictions$survival) # Make it a risk score.
38 38
   }
39 39
 }
40 40
 
... ...
@@ -15,7 +15,7 @@
15 15
 #' \code{characteristicsList['x']} to aggregate to a single number by taking
16 16
 #' the mean. This is particularly meaningful when the cross-validation is
17 17
 #' leave-k-out, when k is small.
18
-#' @param performanceName Default: \code{"auto"}. The name of the
18
+#' @param metric Default: \code{"auto"}. The name of the
19 19
 #' performance measure or "auto". If the results are classification then
20 20
 #' balanced accuracy will be displayed. Otherwise, the results would be survival risk
21 21
 #' predictions and then C-index will be displayed. This is one of the names printed
... ...
@@ -81,7 +81,7 @@
81 81
 #'                             list(function(oracle){}), NULL, predicted, actual)
82 82
 #'   result2 <- calcCVperformance(result2, "Macro F1")
83 83
 #'   
84
-#'   performancePlot(list(result1, result2), performanceName = "Macro F1",
84
+#'   performancePlot(list(result1, result2), metric = "Macro F1",
85 85
 #'                   title = "Comparison")
86 86
 #' 
87 87
 #' @importFrom rlang sym
... ...
@@ -99,7 +99,7 @@ setMethod("performancePlot", "ClassifyResult", function(results, ...) {
99 99
 #' @rdname performancePlot
100 100
 #' @export
101 101
 setMethod("performancePlot", "list",
102
-          function(results, performanceName = "auto",
102
+          function(results, metric = "auto",
103 103
                    characteristicsList = list(x = "auto"), aggregate = character(), coloursList = list(), orderingList = list(),
104 104
                    densityStyle = c("box", "violin"), yLimits = NULL, fontSizes = c(24, 16, 12, 12), title = NULL,
105 105
                    margin = grid::unit(c(1, 1, 1, 1), "lines"), rotate90 = FALSE, showLegend = TRUE)
... ...
@@ -118,34 +118,34 @@ setMethod("performancePlot", "list",
118 118
     else
119 119
       stop("No characteristic is present for all results but must be.")
120 120
   }
121
-  if(performanceName == "auto")
122
-      performanceName <- ifelse("risk" %in% colnames(results[[1]]@predictions), "C-index", "Balanced Accuracy")
121
+  if(metric == "auto")
122
+      metric <- ifelse("risk" %in% colnames(results[[1]]@predictions), "C-index", "Balanced Accuracy")
123 123
             
124 124
   ggplot2::theme_set(ggplot2::theme_classic() + ggplot2::theme(panel.border = ggplot2::element_rect(fill = NA)))
125
-  performanceNames <- unlist(lapply(results, function(result)
125
+  metrics <- unlist(lapply(results, function(result)
126 126
     if(!is.null(result@performance)) names(result@performance)))
127
-  namesCounts <- table(performanceNames)
127
+  namesCounts <- table(metrics)
128 128
   commonNames <- names(namesCounts)[namesCounts == length(results)]
129
-  if(!performanceName %in% commonNames)
129
+  if(!metric %in% commonNames)
130 130
   {
131
-    warning(paste(performanceName, "not found in all elements of results. Calculating it now."))
132
-    results <- lapply(results, function(result) calcCVperformance(result, performanceName))
131
+    warning(paste(metric, "not found in all elements of results. Calculating it now."))
132
+    results <- lapply(results, function(result) calcCVperformance(result, metric))
133 133
   }
134 134
   
135
-  ifelse(performanceName == "Matthews Correlation Coefficient", baseline <- 0, baseline <- 0.5)
135
+  ifelse(metric == "Matthews Correlation Coefficient", baseline <- 0, baseline <- 0.5)
136 136
  
137 137
   plotData <- do.call(rbind, mapply(function(result, index)
138 138
                     {
139
-                      if(!performanceName %in% names(result@performance))
140
-                        stop(performanceName, " not calculated for element ", index, " of results list.")
139
+                      if(!metric %in% names(result@performance))
140
+                        stop(metric, " not calculated for element ", index, " of results list.")
141 141
                       row <- result@characteristics[, "characteristic"] == characteristicsList[["x"]] 
142 142
                       if(any(row) && result@characteristics[row, "value"] %in% aggregate)
143
-                        performance <- mean(result@performance[[performanceName]])
143
+                        performance <- mean(result@performance[[metric]])
144 144
                       else
145
-                        performance <- result@performance[[performanceName]]
145
+                        performance <- result@performance[[metric]]
146 146
                       rows <- match(unlist(characteristicsList), result@characteristics[, "characteristic"])
147 147
                       summaryTable <- data.frame(as.list(result@characteristics[rows, "value"]), performance)
148
-                      colnames(summaryTable) <- c(characteristicsList, performanceName)
148
+                      colnames(summaryTable) <- c(characteristicsList, metric)
149 149
                       summaryTable
150 150
                     }, results, 1:length(results), SIMPLIFY = FALSE))
151 151
   
... ...
@@ -182,12 +182,12 @@ setMethod("performancePlot", "list",
182 182
   if(any(analysisGroupSizes > 1))
183 183
   {
184 184
     multiPlotData <- do.call(rbind, analysisGrouped[analysisGroupSizes > 1])
185
-    performancePlot <- performancePlot + densityStyle(data = multiPlotData, ggplot2::aes(x = !!characteristicsList[['x']], y = !!(rlang::sym(performanceName)), fill = !!fillVariable, colour = !!lineVariable))
185
+    performancePlot <- performancePlot + densityStyle(data = multiPlotData, ggplot2::aes(x = !!characteristicsList[['x']], y = !!(rlang::sym(metric)), fill = !!fillVariable, colour = !!lineVariable))
186 186
   }
187 187
   if(any(analysisGroupSizes == 1))
188 188
   {
189 189
     singlePlotData <- do.call(rbind, analysisGrouped[analysisGroupSizes == 1])
190
-    performancePlot <- performancePlot + ggplot2::geom_bar(data = singlePlotData, stat = "identity", ggplot2::aes(x = !!characteristicsList[['x']], y = !!(rlang::sym(performanceName)), fill = !!fillVariable, colour = !!lineVariable))
190
+    performancePlot <- performancePlot + ggplot2::geom_bar(data = singlePlotData, stat = "identity", ggplot2::aes(x = !!characteristicsList[['x']], y = !!(rlang::sym(metric)), fill = !!fillVariable, colour = !!lineVariable))
191 191
   }
192 192
   
193 193
   if(!is.null(yLimits)) yLimits = c(0, 1)
... ...
@@ -15,7 +15,12 @@
15 15
 #' same length as the number of columns that \code{results} has.
16 16
 #' @param comparison Default: "auto". The aspect of the experimental
17 17
 #' design to compare. Can be any characteristic that all results share.
18
-#' @param metric Default: "Sample Error". The sample-wise metric to plot.
18
+#' @param metric Default: \code{"auto"}. The name of the
19
+#' performance measure or "auto". If the results are classification then
20
+#' sample accuracy will be displayed. Otherwise, the results would be survival risk
21
+#' predictions and then a sample C-index will be displayed. Valid values are \code{"Sample Error"},
22
+#' \code{"Sample Error"} or \code{"Sample C-index"}. If the metric is not stored in the
23
+#' results list, the performance metric will be calculated automatically.
19 24
 #' @param featureValues If not NULL, can be a named factor or named numeric
20 25
 #' vector specifying some variable of interest to plot above the heatmap.
21 26
 #' @param featureName A label describing the information in
... ...
@@ -66,8 +71,8 @@
66 71
 #'                             value = c("Example", "Bartlett Test", "Differential Variability", "2 Permutations, 2 Folds")),
67 72
 #'                             LETTERS[1:10], features, list(1:100), list(sample(10, 10)),
68 73
 #'                             list(function(oracle){}), NULL, predicted, actual)
69
-#'   result1 <- calcCVperformance(result1, "Sample Error")
70
-#'   result2 <- calcCVperformance(result2, "Sample Error")
74
+#'   result1 <- calcCVperformance(result1)
75
+#'   result2 <- calcCVperformance(result2)
71 76
 #'   groups <- factor(rep(c("Male", "Female"), length.out = 10))
72 77
 #'   names(groups) <- LETTERS[1:10]
73 78
 #'   cholesterol <- c(4.0, 5.5, 3.9, 4.9, 5.7, 7.1, 7.9, 8.0, 8.5, 7.2)
... ...
@@ -95,10 +100,10 @@ setMethod("samplesMetricMap", "ClassifyResult", function(results, ...) {
95 100
 setMethod("samplesMetricMap", "list", 
96 101
           function(results,
97 102
                    comparison = "auto",
98
-                   metric = c("Sample Error", "Sample Accuracy", "Sample C-index"),
103
+                   metric = "auto",
99 104
                    featureValues = NULL, featureName = NULL,
100
-                   metricColours = list(c("#3F48CC", "#6F75D8", "#9FA3E5", "#CFD1F2", "#FFFFFF"),
101
-                                        c("#880015", "#A53F4F", "#C37F8A", "#E1BFC4", "#FFFFFF")),
105
+                   metricColours = list(c("#FFFFFF", "#CFD1F2", "#9FA3E5", "#6F75D8", "#3F48CC"),
106
+                                        c("#FFFFFF", "#E1BFC4", "#C37F8A", "#A53F4F", "#880015")),
102 107
                    classColours = c("#3F48CC", "#880015"), groupColours = c("darkgreen", "yellow2"),
103 108
                    fontSizes = c(24, 16, 12, 12, 12),
104 109
                    mapHeight = 4, title = switch(metric, `Sample Error` = "Error Comparison", `Sample Accuracy` = "Accuracy Comparison", `Sample C-index` = "Risk Score Comparison"),
... ...
@@ -125,6 +130,17 @@ setMethod("samplesMetricMap", "list",
125 130
       stop("No characteristic is present for all results but must be.")
126 131
     }
127 132
   }
133
+  isSurvival <- "risk" %in% colnames(results[[1]]@predictions)
134
+  validMetrics <- c("Sample Error", "Sample Accuracy", "Sample C-index")
135
+  if(metric == "auto")
136
+    metric <- ifelse(isSurvival, "Sample C-index", "Sample Accuracy")
137
+  else
138
+    if(!metric %in% validMetrics) stop("metric must be one of ", validMetrics, " but is ", metric, '.')   
139
+  if(isSurvival && is.list(metricColours)) metricColours <- metricColours[[1]]
140
+  metricText <- gsub("Sample ", '', metric) # For legend labelling.
141
+  if(showXtickLabels == FALSE && xAxisLabel == "Sample Name") xAxisLabel <- "Sample"
142
+     
143
+  
128 144
   resultsWithComparison <- sum(sapply(results, function(result) any(result@characteristics[, "characteristic"] == comparison)))
129 145
   if(resultsWithComparison < length(results))
130 146
     stop("Not all results have comparison characteristic ", comparison, ' but need to.')
... ...
@@ -135,13 +151,17 @@ setMethod("samplesMetricMap", "list",
135 151
     compareFactor <- sapply(results, function(result) {
136 152
                      useRow <- result@characteristics[, "characteristic"] == comparison
137 153
                      result@characteristics[useRow, "value"]
138
-                    })  
139
-  metric <- match.arg(metric)
140
-  metricText <- switch(metric, `Sample Error` = "Error", `Sample Accuracy` = "Accuracy", `Sample C-index` = "C-index")
154
+                    })
141 155
   
142
-  allCalculated <- all(sapply(results, function(result) metric %in% names(performance(result))))
143
-  if(!allCalculated)
144
-    stop("One or more classification results lack the calculated sample-specific metric.")
156
+  metrics <- unlist(lapply(results, function(result)
157
+    if(!is.null(result@performance)) names(result@performance)))
158
+  namesCounts <- table(metrics)
159
+  commonNames <- names(namesCounts)[namesCounts == length(results)]
160
+  if(!metric %in% commonNames)
161
+  {
162
+    warning(paste(metric, "not found in all elements of results. Calculating it now."))
163
+    results <- lapply(results, function(result) calcCVperformance(result, metric))
164
+  }
145 165
   if(!is.null(featureValues) && is.null(featureName))
146 166
     stop("featureValues is specified by featureNames isn't. Specify both.")
147 167
   if(!is.null(featureValues) && is.null(names(featureValues)))
... ...
@@ -10,7 +10,7 @@
10 10
 
11 11
 \S4method{performancePlot}{list}(
12 12
   results,
13
-  performanceName = "auto",
13
+  metric = "auto",
14 14
   characteristicsList = list(x = "auto"),
15 15
   aggregate = character(),
16 16
   coloursList = list(),
... ...
@@ -29,7 +29,7 @@
29 29
 
30 30
 \item{...}{Not used by end user.}
31 31
 
32
-\item{performanceName}{Default: \code{"auto"}. The name of the
32
+\item{metric}{Default: \code{"auto"}. The name of the
33 33
 performance measure or "auto". If the results are classification then
34 34
 balanced accuracy will be displayed. Otherwise, the results would be survival risk
35 35
 predictions and then C-index will be displayed. This is one of the names printed
... ...
@@ -122,7 +122,7 @@ calculated, and a barchart is plotted.
122 122
                             list(function(oracle){}), NULL, predicted, actual)
123 123
   result2 <- calcCVperformance(result2, "Macro F1")
124 124
   
125
-  performancePlot(list(result1, result2), performanceName = "Macro F1",
125
+  performancePlot(list(result1, result2), metric = "Macro F1",
126 126
                   title = "Comparison")
127 127
 
128 128
 }
... ...
@@ -12,11 +12,11 @@
12 12
 \S4method{samplesMetricMap}{list}(
13 13
   results,
14 14
   comparison = "auto",
15
-  metric = c("Sample Error", "Sample Accuracy", "Sample C-index"),
15
+  metric = "auto",
16 16
   featureValues = NULL,
17 17
   featureName = NULL,
18
-  metricColours = list(c("#3F48CC", "#6F75D8", "#9FA3E5", "#CFD1F2", "#FFFFFF"),
19
-    c("#880015", "#A53F4F", "#C37F8A", "#E1BFC4", "#FFFFFF")),
18
+  metricColours = list(c("#FFFFFF", "#CFD1F2", "#9FA3E5", "#6F75D8", "#3F48CC"),
19
+    c("#FFFFFF", "#E1BFC4", "#C37F8A", "#A53F4F", "#880015")),
20 20
   classColours = c("#3F48CC", "#880015"),
21 21
   groupColours = c("darkgreen", "yellow2"),
22 22
   fontSizes = c(24, 16, 12, 12, 12),
... ...
@@ -64,7 +64,12 @@ list-packaging but used by the main \code{list} method.}
64 64
 \item{comparison}{Default: "auto". The aspect of the experimental
65 65
 design to compare. Can be any characteristic that all results share.}
66 66
 
67
-\item{metric}{Default: "Sample Error". The sample-wise metric to plot.}
67
+\item{metric}{Default: \code{"auto"}. The name of the
68
+performance measure or "auto". If the results are classification then
69
+sample accuracy will be displayed. Otherwise, the results would be survival risk
70
+predictions and then a sample C-index will be displayed. Valid values are \code{"Sample Error"},
71
+\code{"Sample Error"} or \code{"Sample C-index"}. If the metric is not stored in the
72
+results list, the performance metric will be calculated automatically.}
68 73
 
69 74
 \item{featureValues}{If not NULL, can be a named factor or named numeric
70 75
 vector specifying some variable of interest to plot above the heatmap.}
... ...
@@ -142,8 +147,8 @@ values will be discretised to.
142 147
                             value = c("Example", "Bartlett Test", "Differential Variability", "2 Permutations, 2 Folds")),
143 148
                             LETTERS[1:10], features, list(1:100), list(sample(10, 10)),
144 149
                             list(function(oracle){}), NULL, predicted, actual)
145
-  result1 <- calcCVperformance(result1, "Sample Error")
146
-  result2 <- calcCVperformance(result2, "Sample Error")
150
+  result1 <- calcCVperformance(result1)
151
+  result2 <- calcCVperformance(result2)
147 152
   groups <- factor(rep(c("Male", "Female"), length.out = 10))
148 153
   names(groups) <- LETTERS[1:10]
149 154
   cholesterol <- c(4.0, 5.5, 3.9, 4.9, 5.7, 7.1, 7.9, 8.0, 8.5, 7.2)
... ...
@@ -258,14 +258,11 @@ DDresults
258 258
 
259 259
 The naive Bayes kernel classifier by default uses the vertical distance between class densities but it can instead use the horizontal distance to the nearest non-zero density cross-over point to confidently classify samples in the tails of the densities.
260 260
 
261
-Now, the classification error for each sample is also calculated for both the differential means and differential distribution classifiers and both *ClassifyResult* objects generated so far are plotted with *samplesMetricMap*.
261
+The per-sample classification accuracy is automatically calculated for both the differential means and differential distribution classifiers and plotted with *samplesMetricMap*.
262 262
 
263 263
 ```{r, fig.width = 10, fig.height = 7}
264
-DMresults <- calcCVperformance(DMresults, "Sample Error")
265
-DDresults <- calcCVperformance(DDresults, "Sample Error")
266 264
 resultsList <- list(Abundance = DMresults, Distribution = DDresults)
267
-samplesMetricMap(resultsList, metric = "Sample Error", xAxisLabel = "Sample",
268
-                              showXtickLabels = FALSE)
265
+samplesMetricMap(resultsList, showXtickLabels = FALSE)
269 266
 ```
270 267
 
271 268
 The benefit of this plot is that it allows the easy identification of samples which are hard to classify and could be explained by considering additional information about them. Differential distribution class prediction appears to be biased to the majority class (No Asthma).