Browse code

Remove marginal likelihood (harmonic mean estimator) performance metric. Fixes #139.

Former-commit-id: 800d3c83f3537457349b47f68ece7cf9cdc842ec

Sean Corbett authored on 19/09/2017 17:51:46
Showing 15 changed files

... ...
@@ -29,7 +29,8 @@ Imports:
29 29
     cluster,
30 30
     digest,
31 31
     gridExtra,
32
-    methods
32
+    methods,
33
+    lintr
33 34
 Suggests:
34 35
     testthat,
35 36
     knitr,
... ...
@@ -26,7 +26,6 @@ export(calculate_loglik_from_variables)
26 26
 export(calculate_loglik_from_variables.celda_C)
27 27
 export(calculate_loglik_from_variables.celda_CG)
28 28
 export(calculate_loglik_from_variables.celda_G)
29
-export(calculate_marginal_likelihood)
30 29
 export(calculate_perplexity)
31 30
 export(celda)
32 31
 export(celda_C)
... ...
@@ -405,7 +405,7 @@ celda_heatmap.celda_C = function(celda.mod, counts, ...) {
405 405
 
406 406
 #' visualize_model_performance for celda Cell clustering function
407 407
 #' @param celda.list A celda_list object returned from celda()
408
-#' @param method One of "perplexity", "harmonic", or "loglik"
408
+#' @param method One of "perplexity", "loglik"
409 409
 #' @param title Title for the plot
410 410
 #' @param log Currently not working for celda_C objects
411 411
 #' @import Rmpfr
... ...
@@ -423,7 +423,7 @@ visualize_model_performance.celda_C = function(celda.list, method="perplexity",
423 423
   
424 424
   # These methods return Rmpfr numbers that are extremely small and can't be 
425 425
   # plotted, so log 'em first
426
-  if (method %in% c("perplexity", "harmonic")) {
426
+  if (method %in% c("perplexity")) {
427 427
     performance.metric = lapply(performance.metric, log)
428 428
     performance.metric = methods::new("mpfr", unlist(performance.metric))
429 429
     performance.metric = as.numeric(performance.metric)
... ...
@@ -669,9 +669,9 @@ celda_heatmap.celda_CG = function(celda.mod, counts, ...) {
669 669
 #' L (cell clusters K), plot the performance of each number of cell
670 670
 #' clusters K (gene clusters L).
671 671
 #' @param celda.list A list of celda_CG objects returned from celda function
672
-#' @param method One of "perplexity", "harmonic", or "loglik"
672
+#' @param method One of "perplexity" or "loglik"
673 673
 #' @param title Title for the visualize_model_performance
674
-#' @param log Set log to TRUE to visualize the log(perplexity) of Celda_CG objects. Does not work for "harmonic" metric
674
+#' @param log Set log to TRUE to visualize the log(perplexity) of Celda_CG objects.
675 675
 #' @import Rmpfr
676 676
 #' @export
677 677
 visualize_model_performance.celda_CG = function(celda.list, method="perplexity",
... ...
@@ -696,11 +696,7 @@ visualize_model_performance.celda_CG = function(celda.list, method="perplexity",
696 696
       performance.metric = new("mpfr", unlist(performance.metric))
697 697
     }
698 698
     y.lab = paste0("Log(",method,")")
699
-  } else if (method == "harmonic") {
700
-    performance.metric = lapply(performance.metric, log)
701
-    performance.metric = new("mpfr", unlist(performance.metric))
702
-    y.lab = paste0("Log(",method,")")
703
-  }
699
+  } 
704 700
   
705 701
   performance.metric = as.numeric(performance.metric)
706 702
 
... ...
@@ -761,7 +757,7 @@ visualize_model_performance.celda_CG = function(celda.list, method="perplexity",
761 757
 #' run for each combination of K/L (cell/gene).
762 758
 #' 
763 759
 #' @param celda.list A list of celda_CG objects returned from celda function
764
-#' @param method One of "perplexity", "harmonic", or "loglik", passed through to calculate_performance_metric()
760
+#' @param method One of "perplexity" or "loglik", passed through to calculate_performance_metric()
765 761
 #' @param title The plot title
766 762
 #' @import Rmpfr 
767 763
 #' @export
... ...
@@ -783,7 +779,7 @@ render_interactive_kl_plot = function(celda.list,  method="perplexity",
783 779
   # The performance metric methods return Rmpfr numbers that are extremely small and can't be 
784 780
   # plotted via ggplot2, so log 'em first. 
785 781
   # TODO: celda_list getter that calculates these metrics.
786
-  if (method %in% c("perplexity", "harmonic")) {
782
+  if (method %in% c("perplexity")) {
787 783
     performance.metric = lapply(performance.metric, log)
788 784
     performance.metric = new("mpfr", unlist(performance.metric))
789 785
     performance.metric = as.numeric(performance.metric)
... ...
@@ -803,7 +799,7 @@ render_interactive_kl_plot = function(celda.list,  method="perplexity",
803 799
   
804 800
   # TODO: Return plot or nah?
805 801
   method.label = method 
806
-  if (method %in% c("perplexity, harmonic")) {
802
+  if (method %in% c("perplexity")) {
807 803
      method.label = paste("log(", method, ")", sep="")
808 804
   }
809 805
   k.l.plot = ggplot2::ggplot(figure.df, aes(x=key, y=metric, label=key)) +
... ...
@@ -824,7 +820,7 @@ validate_kl_plot_parameters = function(celda.list, method) {
824 820
     stop("celda.list argument must be of class 'celda_list'")
825 821
  } else if (celda.list$content.type != "celda_CG") {
826 822
     stop("celda.list must be a 'celda.list' of 'celda_CG' objects")
827
- } else if (!(method %in% c("perplexity","harmonic","loglik"))) {
828
-    stop("Invalid method, 'method' has to be either 'perplexity', 'harmonic', or 'loglik'")
823
+ } else if (!(method %in% c("perplexity", "loglik"))) {
824
+    stop("Invalid method, 'method' has to be either 'perplexity' or 'loglik'")
829 825
  } 
830 826
 }
... ...
@@ -481,7 +481,7 @@ celda_heatmap.celda_G = function(celda.mod, counts, ...) {
481 481
 # TODO DRYer implementation in concert with celda_C
482 482
 #' visualize_model_performance for the celda Gene function
483 483
 #' @param celda.list A celda_list object returned from celda()
484
-#' @param method One of "perplexity", "harmonic", or "loglik"
484
+#' @param method One of "perplexity" or "loglik"
485 485
 #' @param title Title for the plot
486 486
 #' @param log Currently not working for celda.G objects
487 487
 #' @import Rmpfr
... ...
@@ -499,7 +499,7 @@ visualize_model_performance.celda_G = function(celda.list, method="perplexity",
499 499
   
500 500
   # These methods return Rmpfr numbers that are extremely small and can't be 
501 501
   # plotted, so log 'em first
502
-  if (method %in% c("perplexity", "harmonic")) {
502
+  if (method %in% c("perplexity")) {
503 503
     performance.metric = lapply(performance.metric, log)
504 504
     performance.metric = methods::new("mpfr", unlist(performance.metric))
505 505
     performance.metric = as.numeric(performance.metric)
... ...
@@ -10,7 +10,7 @@
10 10
 #' @param K The K parameter for the desired model in the results list
11 11
 #' @param L The L parameter for the desired model in the results list
12 12
 #' @param chain The desired chain for the specified model
13
-#' @param best Method for choosing best chain automatically. Options are c("perplexity", "harmonic", "loglik"). See documentation for chooseBestModel for details. Overrides chain parameter if provided.
13
+#' @param best Method for choosing best chain automatically. Options are c("perplexity", "loglik"). See documentation for chooseBestModel for details. Overrides chain parameter if provided.
14 14
 #' @return A celda model object matching the provided parameters (of class "celda_C", "celda_G", "celda_CG" accordingly), or NA if one is not found.
15 15
 #' @export
16 16
 getModel = function(celda.list, K=NULL, L=NULL, chain=1, best=NULL) {
... ...
@@ -90,7 +90,7 @@ validate_get_model_params = function(celda.list, K, L, chain, best) {
90 90
 #' Determine the best chain among a set of celda_* objects with
91 91
 #' otherwise uniform K/L choices.
92 92
 #' @param celda.mods A list of celda class objects (celda_C, celda_CG, celda_G)
93
-#' @param method How to choose the best chain. Choices are c("perplexity", "harmonic", "loglik"). Defaults to perplexity. "perplexity" calculates each chain's perplexity as the inverse of the geometric mean, per the original LDA description. "harmonic" calculates each chain's marginal likelihood as the harmonic mean of each iteration of Gibbs sampling's log likelihoods. "loglik" chooses the chain which reached the maximal log likelihood during Gibbs sampling.
93
+#' @param method How to choose the best chain. Choices are c("perplexity", "loglik"). Defaults to perplexity. "perplexity" calculates each chain's perplexity as the inverse of the geometric mean, per the original LDA description. "loglik" chooses the chain which reached the maximal log likelihood during Gibbs sampling.
94 94
 chooseBestChain = function(celda.mods, method="perplexity") {
95 95
   # We want to get the *most negative* perplexity, as opposed to the *least* negative
96 96
   # for the other metrics...
... ...
@@ -101,10 +101,6 @@ chooseBestChain = function(celda.mods, method="perplexity") {
101 101
     return(celda.mods[[best]])
102 102
   } 
103 103
   
104
-  else if (method == "harmonic"){
105
-    metrics = lapply(celda.mods, function(mod) { calculate_perplexity(mod$completeLogLik) })
106
-    metrics = new("mpfr", unlist(metrics))
107
-   }
108 104
   else if (method == "loglik"){
109 105
     metrics = lapply(celda.mods, function(mod) { max(mod$completeLogLik) })
110 106
     metrics = unlist(metrics)
... ...
@@ -131,4 +127,4 @@ search_res_list = function(celda_list, K=NULL, L=NULL) {
131 127
     stop("K/L parameter(s) requested did not appear for any model in the celda_list. Did you modify the run.params?")
132 128
   }
133 129
   return(requested.chain)
134
-}
135 130
\ No newline at end of file
131
+}
... ...
@@ -1,19 +1,3 @@
1
-#' Calculate the marginal likelihood from a single celda model
2
-#' 
3
-#' Marginal likelihood is estimated as the harmonic mean of the 
4
-#' (non-log) likelihood over all iterations of Gibbs sampling.
5
-#' 
6
-#' @param completeLogLik The complete Gibbs sampling history of log-likelihoods for a single celda chain
7
-#' @return The estimated marginal likelihood as an mpfr number
8
-#' @export
9
-calculate_marginal_likelihood = function(completeLogLik) {
10
-  mpfr_log_lik = Rmpfr::mpfr(completeLogLik, 512)
11
-  complete_likelihood = exp(mpfr_log_lik)
12
-  marginal_likelihood = (Rmpfr::mean((1/complete_likelihood)))^-1
13
-  return(marginal_likelihood)
14
-}
15
-
16
-
17 1
 #' Calculate the perplexity from a single celda chain
18 2
 #' 
19 3
 #' Perplexity is defined as the inverse of the geometric mean of the log-likelihoods over all 
... ...
@@ -34,10 +18,8 @@ calculate_perplexity = function(completeLogLik, log = FALSE) {
34 18
 
35 19
 # Convenience function to calculate performance metrics by specifying a method. 
36 20
 calculate_performance_metric = function(log.likelihoods, method="perplexity", log = FALSE) {
37
-    if (method == "perplexity") {
21
+  if (method == "perplexity") {
38 22
     metric = calculate_perplexity(log.likelihoods, log)
39
-  } else if (method == "harmonic") {
40
-    metric = calculate_marginal_likelihood(log.likelihoods)
41 23
   } else if (method == "loglik") {
42 24
      metric = max(log.likelihoods)
43 25
   } else stop("Invalid method specified")
44 26
deleted file mode 100644
... ...
@@ -1,18 +0,0 @@
1
-% Generated by roxygen2: do not edit by hand
2
-% Please edit documentation in R/model_performance.R
3
-\name{calculate_marginal_likelihood}
4
-\alias{calculate_marginal_likelihood}
5
-\title{Calculate the marginal likelihood from a single celda model}
6
-\usage{
7
-calculate_marginal_likelihood(completeLogLik)
8
-}
9
-\arguments{
10
-\item{completeLogLik}{The complete Gibbs sampling history of log-likelihoods for a single celda chain}
11
-}
12
-\value{
13
-The estimated marginal likelihood as an mpfr number
14
-}
15
-\description{
16
-Marginal likelihood is estimated as the harmonic mean of the 
17
-(non-log) likelihood over all iterations of Gibbs sampling.
18
-}
... ...
@@ -10,7 +10,7 @@ chooseBestChain(celda.mods, method = "perplexity")
10 10
 \arguments{
11 11
 \item{celda.mods}{A list of celda class objects (celda_C, celda_CG, celda_G)}
12 12
 
13
-\item{method}{How to choose the best chain. Choices are c("perplexity", "harmonic", "loglik"). Defaults to perplexity. "perplexity" calculates each chain's perplexity as the inverse of the geometric mean, per the original LDA description. "harmonic" calculates each chain's marginal likelihood as the harmonic mean of each iteration of Gibbs sampling's log likelihoods. "loglik" chooses the chain which reached the maximal log likelihood during Gibbs sampling.}
13
+\item{method}{How to choose the best chain. Choices are c("perplexity", "loglik"). Defaults to perplexity. "perplexity" calculates each chain's perplexity as the inverse of the geometric mean, per the original LDA description. "loglik" chooses the chain which reached the maximal log likelihood during Gibbs sampling.}
14 14
 }
15 15
 \description{
16 16
 Determine the best chain among a set of celda_* objects with
... ...
@@ -17,7 +17,7 @@ getModel(celda.list, K = NULL, L = NULL, chain = 1, best = NULL)
17 17
 
18 18
 \item{chain}{The desired chain for the specified model}
19 19
 
20
-\item{best}{Method for choosing best chain automatically. Options are c("perplexity", "harmonic", "loglik"). See documentation for chooseBestModel for details. Overrides chain parameter if provided.}
20
+\item{best}{Method for choosing best chain automatically. Options are c("perplexity", "loglik"). See documentation for chooseBestModel for details. Overrides chain parameter if provided.}
21 21
 }
22 22
 \value{
23 23
 A celda model object matching the provided parameters (of class "celda_C", "celda_G", "celda_CG" accordingly), or NA if one is not found.
... ...
@@ -10,7 +10,7 @@ render_interactive_kl_plot(celda.list, method = "perplexity",
10 10
 \arguments{
11 11
 \item{celda.list}{A list of celda_CG objects returned from celda function}
12 12
 
13
-\item{method}{One of "perplexity", "harmonic", or "loglik", passed through to calculate_performance_metric()}
13
+\item{method}{One of "perplexity" or "loglik", passed through to calculate_performance_metric()}
14 14
 
15 15
 \item{title}{The plot title}
16 16
 }
... ...
@@ -7,6 +7,8 @@
7 7
 simulateCells(model, ...)
8 8
 }
9 9
 \arguments{
10
+\item{model}{The celda generative model to use (one of celda_C, celda_G, celda_CG)}
11
+
10 12
 \item{S}{Total number of samples (celda_C, celda_CG)}
11 13
 
12 14
 \item{C}{The number of cells (celda_G)}
... ...
@@ -11,7 +11,7 @@
11 11
 \arguments{
12 12
 \item{celda.list}{A celda_list object returned from celda()}
13 13
 
14
-\item{method}{One of "perplexity", "harmonic", or "loglik"}
14
+\item{method}{One of "perplexity", "loglik"}
15 15
 
16 16
 \item{title}{Title for the plot}
17 17
 
... ...
@@ -11,11 +11,11 @@
11 11
 \arguments{
12 12
 \item{celda.list}{A list of celda_CG objects returned from celda function}
13 13
 
14
-\item{method}{One of "perplexity", "harmonic", or "loglik"}
14
+\item{method}{One of "perplexity" or "loglik"}
15 15
 
16 16
 \item{title}{Title for the visualize_model_performance}
17 17
 
18
-\item{log}{Set log to TRUE to visualize the log(perplexity) of Celda_CG objects. Does not work for "harmonic" metric}
18
+\item{log}{Set log to TRUE to visualize the log(perplexity) of Celda_CG objects.}
19 19
 }
20 20
 \description{
21 21
 Plot the performance of a list of celda_CG models returned 
... ...
@@ -11,7 +11,7 @@
11 11
 \arguments{
12 12
 \item{celda.list}{A celda_list object returned from celda()}
13 13
 
14
-\item{method}{One of "perplexity", "harmonic", or "loglik"}
14
+\item{method}{One of "perplexity" or "loglik"}
15 15
 
16 16
 \item{title}{Title for the plot}
17 17