Browse code

BROKEN initial implementation of K/L plots. #30

Sean Corbett authored on 16/05/2017 00:27:49
Showing 3 changed files

... ...
@@ -10,6 +10,7 @@ Depends:
10 10
     R (>= 3.2.2)
11 11
 Imports:
12 12
     gtools,
13
+    ggplot2,
13 14
     entropy,
14 15
     RColorBrewer,
15 16
     pheatmap,
... ...
@@ -118,4 +118,15 @@ getL = function(celda.mod) {
118 118
 #' @export 
119 119
 celda_heatmap <- function(celda.mod, counts, ...) {
120 120
   UseMethod("celda_heatmap", celda.mod)
121
+}
122
+
123
+
124
+#' Visualize various performance metrics as a function of K / L to aid parameter choice.
125
+#' 
126
+#' @param celda.list A celda_list object as returned from *celda()*
127
+#' @param metric Which performance metric to visualize. One of ("perplexity", "harmonic", "loglik"). "perplexity" calculates the inverse of the geometric mean of the log likelihoods from each iteration of Gibbs sampling. "harmonic" calculates the marginal likelihood has the harmonic mean of the likelihoods. "loglik" plots the highest log-likelihood during Gibbs iteration.
128
+#' @return A ggplot object containing the requested plot(s)
129
+#' @export
130
+visualize_performance <- function(celda.list, method, ...) {
131
+  UseMethod("visualize_performance", celda.list)
121 132
 }
122 133
\ No newline at end of file
... ...
@@ -26,3 +26,43 @@ calculate_perplexity = function(completeLogLik) {
26 26
   perplexity = exp(Rmpfr::mean(mpfr_log_lik))^-1
27 27
   return(perplexity)
28 28
 }
29
+
30
+
31
+#' Visualize various performance metrics as a function of K / L to aid parameter choice.
32
+#' 
33
+#' @param celda.list A celda_list object as returned from *celda()*
34
+#' @param metric Which performance metric to visualize. One of ("perplexity", "harmonic", "loglik"). "perplexity" calculates the inverse of the geometric mean of the log likelihoods from each iteration of Gibbs sampling. "harmonic" calculates the marginal likelihood has the harmonic mean of the likelihoods. "loglik" plots the highest log-likelihood during Gibbs iteration.
35
+#' @return A ggplot object containing the requested plot(s), or a list of ggplots if the provided celda_list contains celda_CG models.
36
+#' @export
37
+visualize_performance = function(celda.list, method="perplexity") {
38
+  # TODO use celda_list getter
39
+  log.likelihoods = lapply(celda.list$res.list,
40
+                           function(mod) { return(mod$completeLogLik) })
41
+    
42
+  if (method == "perplexity") {
43
+    metric = lapply(log.likelihoods, calculate_perplexity)
44
+    metric = new("mpfr", unlist(metric))
45
+  } else if (method == "harmonic") {
46
+    metric = lapply(log.likelihoods, calculate_marginal_likelihood)
47
+    metric = new("mpfr", unlist(metric))
48
+  } else if (method == "loglik") {
49
+    # TODO use celda_list getter
50
+    metric = lapply(log.likelihoods, max)
51
+  } else stop("Invalid method specified")
52
+  
53
+  
54
+  #TODO ggplot table building below is vulnerable to error 
55
+  #     if the user modifies the celda_list at all..
56
+  if (celda.list$content.type == "celda_C") {
57
+    Ks = lapply(celda.list$res.list, function(mod) { mod$K })
58
+    plot.df = data.frame(K=as.factor(unlist(Ks)), metric=as.numeric(metric))
59
+    ggplot2::ggplot(plot.df, aes(x=K, y=metric)) + geom_point() +
60
+      xlab("K") + ylab(method)
61
+  }
62
+  
63
+}
64
+
65
+
66
+
67
+
68
+