... | ... |
@@ -118,4 +118,15 @@ getL = function(celda.mod) { |
118 | 118 |
#' @export |
119 | 119 |
celda_heatmap <- function(celda.mod, counts, ...) { |
120 | 120 |
UseMethod("celda_heatmap", celda.mod) |
121 |
+} |
|
122 |
+ |
|
123 |
+ |
|
124 |
+#' Visualize various performance metrics as a function of K / L to aid parameter choice. |
|
125 |
+#' |
|
126 |
+#' @param celda.list A celda_list object as returned from *celda()* |
|
127 |
+#' @param metric Which performance metric to visualize. One of ("perplexity", "harmonic", "loglik"). "perplexity" calculates the inverse of the geometric mean of the log likelihoods from each iteration of Gibbs sampling. "harmonic" calculates the marginal likelihood has the harmonic mean of the likelihoods. "loglik" plots the highest log-likelihood during Gibbs iteration. |
|
128 |
+#' @return A ggplot object containing the requested plot(s) |
|
129 |
+#' @export |
|
130 |
+visualize_performance <- function(celda.list, method, ...) { |
|
131 |
+ UseMethod("visualize_performance", celda.list) |
|
121 | 132 |
} |
122 | 133 |
\ No newline at end of file |
... | ... |
@@ -26,3 +26,43 @@ calculate_perplexity = function(completeLogLik) { |
26 | 26 |
perplexity = exp(Rmpfr::mean(mpfr_log_lik))^-1 |
27 | 27 |
return(perplexity) |
28 | 28 |
} |
29 |
+ |
|
30 |
+ |
|
31 |
+#' Visualize various performance metrics as a function of K / L to aid parameter choice. |
|
32 |
+#' |
|
33 |
+#' @param celda.list A celda_list object as returned from *celda()* |
|
34 |
+#' @param metric Which performance metric to visualize. One of ("perplexity", "harmonic", "loglik"). "perplexity" calculates the inverse of the geometric mean of the log likelihoods from each iteration of Gibbs sampling. "harmonic" calculates the marginal likelihood has the harmonic mean of the likelihoods. "loglik" plots the highest log-likelihood during Gibbs iteration. |
|
35 |
+#' @return A ggplot object containing the requested plot(s), or a list of ggplots if the provided celda_list contains celda_CG models. |
|
36 |
+#' @export |
|
37 |
+visualize_performance = function(celda.list, method="perplexity") { |
|
38 |
+ # TODO use celda_list getter |
|
39 |
+ log.likelihoods = lapply(celda.list$res.list, |
|
40 |
+ function(mod) { return(mod$completeLogLik) }) |
|
41 |
+ |
|
42 |
+ if (method == "perplexity") { |
|
43 |
+ metric = lapply(log.likelihoods, calculate_perplexity) |
|
44 |
+ metric = new("mpfr", unlist(metric)) |
|
45 |
+ } else if (method == "harmonic") { |
|
46 |
+ metric = lapply(log.likelihoods, calculate_marginal_likelihood) |
|
47 |
+ metric = new("mpfr", unlist(metric)) |
|
48 |
+ } else if (method == "loglik") { |
|
49 |
+ # TODO use celda_list getter |
|
50 |
+ metric = lapply(log.likelihoods, max) |
|
51 |
+ } else stop("Invalid method specified") |
|
52 |
+ |
|
53 |
+ |
|
54 |
+ #TODO ggplot table building below is vulnerable to error |
|
55 |
+ # if the user modifies the celda_list at all.. |
|
56 |
+ if (celda.list$content.type == "celda_C") { |
|
57 |
+ Ks = lapply(celda.list$res.list, function(mod) { mod$K }) |
|
58 |
+ plot.df = data.frame(K=as.factor(unlist(Ks)), metric=as.numeric(metric)) |
|
59 |
+ ggplot2::ggplot(plot.df, aes(x=K, y=metric)) + geom_point() + |
|
60 |
+ xlab("K") + ylab(method) |
|
61 |
+ } |
|
62 |
+ |
|
63 |
+} |
|
64 |
+ |
|
65 |
+ |
|
66 |
+ |
|
67 |
+ |
|
68 |
+ |