Browse code

Make celda_C a useable model from celda(). Fixes #5.

Josh optimized celda_C in the same fashion as celda_CG, as well as
rearranged things so it can be toggled as the desired model from
the celda() 'frontend' function.

Sean Corbett authored on 07/04/2017 03:56:23
Showing 2 changed files

... ...
@@ -19,15 +19,13 @@ Imports:
19 19
     grDevices,
20 20
     graphics
21 21
 Suggests:
22
-	testthat,
23
-	knitr,
24
-	roxygen2,
22
+    testthat,
23
+    knitr,
24
+    roxygen2,
25 25
     rmarkdown
26
-VignetteBuilder:
27
-	knitr
26
+VignetteBuilder: knitr
28 27
 License: MIT
29 28
 Encoding: UTF-8
30 29
 LazyData: true
31
-RoxygenNote: 6.0.1
32
-BugReports:
33
-	https://github.com/definitelysean/celda/issues
30
+RoxygenNote: 5.0.1
31
+BugReports: https://github.com/definitelysean/celda/issues
34 32
old mode 100644
35 33
new mode 100755
... ...
@@ -1,32 +1,46 @@
1
-calcGibbsProb = function(ix, r, s, z, k, a, b) {
2
-  phi <- c()
3
-  for(j in 1:k) {
4
-  z[ix] <- j
5
-  n <- t(rowsum(t(r), group=z, reorder=TRUE))
6
-  phi <- c(phi, ll.phi.gibbs(n, b))    
7
-  }  
8
-  
9
-  z <- factor(z, levels=1:k)
10
-  m <- table(z[-ix], s[-ix])  
11
-  s.ix <- colnames(m) == s[ix]
12
-  
13
-  final <- log(m[,s.ix] + a) + phi
14
-  return(final)
15
-}
16
-
17
-
18
-generateCells = function(S=10, C.Range=c(10, 100), N.Range=c(100,5000), 
19
-                         G=5000, k=5, a=1, b=0.1) {
20
-  
21
-  phi <- gtools::rdirichlet(k, rep(b, G))
22
-  theta <- gtools::rdirichlet(S, rep(a, k))
1
+# -----------------------------------
2
+# Variable description
3
+# -----------------------------------
4
+# C = Cell
5
+# S or s = Sample
6
+# G = Gene
7
+# CP = Cell population
8
+# n = counts of transcripts
9
+# m = counts of cells
10
+# K = Total number of cell populations
11
+# nM = Number of cells
12
+# nG = Number of genes
13
+# nS = Number of samples
14
+
15
+# -----------------------------------
16
+# Count matrices descriptions
17
+# -----------------------------------
18
+
19
+# All n.* variables contain counts of transcripts
20
+# n.CP.by.TS = Number of counts in each Cellular Population per Transcriptional State
21
+# n.TS.by.C = Number of counts in each Transcriptional State per Cell 
22
+# n.CP.by.G = Number of counts in each Cellular Population per Gene
23
+# n.by.G = Number of counts per gene (i.e. rowSums)
24
+# n.by.TS = Number of counts per Transcriptional State
25
+
26
+## All m.* variables contain counts of cells
27
+# m.CP.by.S = Number of cells in each Cellular Population per Sample
28
+
29
+# nG.by.TS = Number of genes in each Transcriptional State
30
+
31
+#' @export
32
+simulateCells.celda_C = function(S=10, C.Range=c(10, 100), N.Range=c(100,5000), 
33
+                         G=500, K=5, alpha=1, beta=1) {
34
+  
35
+  phi <- gtools::rdirichlet(K, rep(beta, G))
36
+  theta <- gtools::rdirichlet(S, rep(alpha, K))
23 37
   
24 38
   ## Select the number of cells per sample
25 39
   nC <- sample(C.Range[1]:C.Range[2], size=S, replace=TRUE)  
26 40
   cell.sample <- rep(1:S, nC)
27 41
   
28 42
   ## Select state of the cells  
29
-  cell.state <- unlist(lapply(1:S, function(i) sample(1:k, size=nC[i], prob=theta[i,], replace=TRUE)))
43
+  cell.state <- unlist(lapply(1:S, function(i) sample(1:K, size=nC[i], prob=theta[i,], replace=TRUE)))
30 44
   
31 45
   ## Select number of transcripts per cell
32 46
   nN <- sample(N.Range[1]:N.Range[2], size=length(cell.sample), replace=TRUE)
... ...
@@ -34,80 +48,89 @@ generateCells = function(S=10, C.Range=c(10, 100), N.Range=c(100,5000),
34 48
   ## Select transcript distribution for each cell
35 49
   cell.counts <- sapply(1:length(cell.sample), function(i) rmultinom(1, size=nN[i], prob=phi[cell.state[i],]))
36 50
   
37
-  return(list(z=cell.state, counts=cell.counts, sample=cell.sample, k=k, a=a, b=b))
51
+  return(list(z=cell.state, counts=cell.counts, sample=cell.sample, K=K, alpha=alpha, beta=beta))
38 52
 }
39 53
 
40
-
41
-
42
-celda_C = function(counts, sample, k, a=1, b=0.1, max.iter=25, min.cell=5, 
43
-                   seed=12345, best=TRUE, kick=TRUE, converge=1e-5) {
54
+#' @export
55
+celda_C = function(counts, sample.label, K, alpha=1, beta=1, max.iter=25, min.cell=5, 
56
+                   seed=12345, best=TRUE, kick=TRUE) {
57
+  
58
+  if(is.factor(sample.label)) {
59
+    s = as.numeric(sample.label)
60
+  }
61
+  else {
62
+    s = as.numeric(as.factor(sample.label))
63
+  }  
44 64
   
45 65
   set.seed(seed)
46
-  require(entropy)
47 66
   cat(date(), "... Starting Gibbs sampling\n")
48 67
   
49
-  co = counts
50
-  s = sample
51
-  
52
-  z = sample(1:k, ncol(co), replace=TRUE)
68
+  z = sample(1:K, ncol(counts), replace=TRUE)
53 69
   z.all = z
54
-  ll = calcLL(counts=co, s=s, z=z, k=k, alpha=a, beta=b)
70
+  z.stability = c(NA)
71
+  z.probs = matrix(NA, nrow=ncol(counts), ncol=K)
55 72
   
56
-  z.probs = matrix(NA, nrow=ncol(co), ncol=k)
57
-    
73
+  ## Calculate counts one time up front
74
+  m.CP.by.S = table(factor(z, levels=1:K), s)
75
+  n.CP.by.G = rowsum(t(counts), group=z, reorder=TRUE)
76
+
77
+  ll = cC.calcLL(m.CP.by.S=m.CP.by.S, n.CP.by.G=n.CP.by.G, s=s, K=K, alpha=alpha, beta=beta)
78
+
58 79
   iter = 1
59 80
   continue = TRUE
60 81
   while(iter <= max.iter & continue == TRUE) {
61 82
     
62
-    ## Determine if any clusters are below the minimum threshold 
63
-    ## and if a kick needs to be performed
64
-    z.ta = table(factor(z, levels=1:k))
65
-    if(min(z.ta) < min.cell & kick==TRUE) {
66 83
 
67
-      all.k.to.kick = which(z.ta < min.cell)
68
-      
69
-      for(j in all.k.to.kick) { 
70
-        all.k.to.test = which(z.ta > 2*min.cell)
71
-        z = kick.z(co, s=s, z=z, k=k, k.to.kick=j, k.to.test=all.k.to.test, 
72
-                   min.cell=min.cell, a=a, b=b)
73
-        z.ta = table(factor(z, levels=1:k))
74
-      }
75
-      
76
-    }
77
-    
78 84
     ## Begin process of Gibbs sampling for each cell
79
-    ix = sample(1:ncol(co))
85
+    ix = sample(1:ncol(counts))
80 86
     for(i in ix) {
81
-      probs = calcGibbsProb(i, r=co, s=s, z=z, k=k, a=a, b=b)
82 87
       
83
-      z[i] = sample.ll(probs)
84
-      z.probs[i,] = probs
85
-    }
88
+      if(sum(z == z[i]) > 1) {
89
+        
90
+        ## Subtract current cell counts from matrices
91
+        m.CP.by.S[z[i],s[i]] = m.CP.by.S[z[i],s[i]] - 1
92
+        n.CP.by.G[z[i],] = n.CP.by.G[z[i],] - counts[,i]
93
+        
94
+        ## Calculate probabilities for each state
95
+        ## Calculate probabilities for each state
96
+        probs = rep(NA, K)
97
+        for(j in 1:K) {
98
+          temp.n.CP.by.G = n.CP.by.G
99
+          temp.n.CP.by.G[j,] = temp.n.CP.by.G[j,] + counts[,i]
100
+          probs[j] = cC.calcGibbsProbZ(m.CP.by.S=m.CP.by.S[j,s[i]], n.CP.by.G=temp.n.CP.by.G, alpha=alpha, beta=beta)
101
+        }  
102
+
103
+        ## Sample next state and add back counts
104
+        z[i] = sample.ll(probs)
105
+        m.CP.by.S[z[i],s[i]] = m.CP.by.S[z[i],s[i]] + 1
106
+        n.CP.by.G[z[i],] = n.CP.by.G[z[i],] + counts[,i]
107
+      
108
+      } else {
109
+        probs = rep(0, K)
110
+        probs[z[i]] = 1
111
+      }
112
+    }  
113
+    #z.probs[i,] = probs
86 114
 
87
-    ## Save Z history
115
+    ## Save history
88 116
     z.all = cbind(z.all, z)
89
-    
117
+
118
+    ## Normalize Z and Y marginal probabilties and calculate stability
119
+    z.probs = normalizeLogProbs(z.probs)
120
+    z.stability = c(z.stability, stability(z.probs))
121
+
90 122
     ## Calculate complete likelihood
91
-    temp.ll = calcLL(counts=co, s=s, z=z, k=k, alpha=a, beta=b)
123
+    temp.ll = cC.calcLL(m.CP.by.S=m.CP.by.S, n.CP.by.G=n.CP.by.G, s=s, K=K, alpha=alpha, beta=beta)
124
+    if((best == TRUE & all(temp.ll > ll)) | iter == 1) {
125
+      z.probs.final = z.probs
126
+    }
92 127
     ll = c(ll, temp.ll)
93
-
94
-    cat(date(), "... Completed iteration:", iter, "| logLik:", temp.ll, "\n")
95 128
     
96
-    ## Normalize Z probabilties and test for convergence
97
-    z.probs = exp(sweep(z.probs, 1, apply(z.probs, 1, max), "-"))
98
-    z.probs = sweep(z.probs, 1, rowSums(z.probs), "/")
99
-    f = function(v) sort(v, decreasing=TRUE)[2]
100
-    z.probs.second = max(apply(z.probs, 1, f))
101
-    z.ta = table(z)
102
-    if (z.probs.second < converge & (min(z.ta) >= min.cell | kick==FALSE)) {
103
-      continue = FALSE
104
-      cat("Maximum probability of a cell changing its state is ", z.probs.second, ". Exiting at iteration ", iter, ".", sep="")
105
-    }
129
+    cat(date(), "... Completed iteration:", iter, "| logLik:", temp.ll, "\n")
106 130
     
107 131
     iter = iter + 1    
108 132
   }
109 133
   
110
-  
111 134
   if (best == TRUE) {
112 135
     ix = which.max(ll)
113 136
     z.final = z.all[,ix]
... ...
@@ -122,40 +145,43 @@ celda_C = function(counts, sample, k, a=1, b=0.1, max.iter=25, min.cell=5,
122 145
 }
123 146
 
124 147
 
125
-
126
-ll.phi.gibbs = function(n, beta) {
127
-  ng = nrow(n)
128
-  nk = ncol(n)
148
+cC.calcGibbsProbZ = function(m.CP.by.S, n.CP.by.G, alpha, beta) {
149
+  
150
+  ## Calculate for "Theta" component
151
+  theta.ll = log(m.CP.by.S + alpha)
152
+  
153
+  ## Calculate for "Phi" component
154
+  b = sum(lgamma(n.CP.by.G + beta))
155
+  d = -sum(lgamma(rowSums(n.CP.by.G + beta)))
129 156
   
130
-  b = sum(lgamma(n+beta))
131
-  d = -sum(lgamma(colSums(n + beta)))
157
+  phi.ll = b + d
132 158
   
133
-  ll = b + d
134
-  return(ll)
159
+  final = theta.ll + phi.ll 
160
+  return(final)
135 161
 }
136 162
 
137
-
138
-calcLL = function(counts, s, z, k, alpha, beta) {
139
-  m = table(z, s)
140
-  nk = nrow(m)
141
-  ns = ncol(m)
163
+#' @export
164
+cC.calcLLFromVariables = function(counts, s, z, K, alpha, beta) {
142 165
   
143
-  a = ns*lgamma(nk*alpha)
144
-  b = sum(lgamma(m+alpha))
145
-  c = -ns*nk*lgamma(alpha)
146
-  d = -sum(lgamma(apply(m + alpha, 2, sum)))
166
+  ## Calculate for "Theta" component
167
+  m.CP.by.S = table(z, s)
168
+  nS = length(unique(s))
169
+  
170
+  a = nS * lgamma(K*alpha)
171
+  b = sum(lgamma(m.CP.by.S + alpha))
172
+  c = -nS * K * lgamma(alpha)
173
+  d = -sum(lgamma(colSums(m.CP.by.S + alpha)))
147 174
   
148 175
   theta.ll = a + b + c + d
149 176
  
150
- 
151
-  n = sapply(1:k, function(i) apply(counts[,z == i,drop=FALSE], 1, sum))
152
-  ng = nrow(n)
153
-  nk = ncol(n)
177
+  ## Calculate for "Phi" component
178
+  n.CP.by.G = rowsum(t(counts), group=z, reorder=TRUE)
179
+  nG = ncol(n.CP.by.G)
154 180
   
155
-  a = nk*lgamma(ng*beta)
156
-  b = sum(lgamma(n+beta))
157
-  c = -nk*ng*lgamma(beta)
158
-  d = -sum(lgamma(apply(n + beta, 2, sum)))
181
+  a = K * lgamma(nG * beta)
182
+  b = sum(lgamma(n.CP.by.G + beta))
183
+  c = -K * nG * lgamma(beta)
184
+  d = -sum(lgamma(rowSums(n.CP.by.G + beta)))
159 185
   
160 186
   phi.ll = a + b + c + d
161 187
 
... ...
@@ -163,71 +189,29 @@ calcLL = function(counts, s, z, k, alpha, beta) {
163 189
   return(final)
164 190
 }
165 191
 
166
-
167
-
168
-
169
-cosineDist <- function(x){
170
-  x = t(x)
171
-  y = as.dist(1 - x%*%t(x)/(sqrt(rowSums(x^2) %*% t(rowSums(x^2))))) 
172
-  return(y)
173
-}
174
-
175
-kick.z = function(counts, s, z, k, k.to.kick, k.to.test, min.cell=5, a, b) {
176
-  require(cluster)
177
-  cat(date(), "... Cluster", k.to.kick, "has fewer than", min.cell, "cells. Performing kick by ")
192
+cC.calcLL = function(m.CP.by.S, n.CP.by.G, s, z, K, alpha, beta) {
178 193
   
179
-  counts.norm = sweep(counts, 2, colSums(counts), "/")
180
-  z.kick = matrix(z, ncol=length(k.to.test), nrow=length(z))
194
+  ## Calculate for "Theta" component
195
+  nS = length(unique(s))
181 196
   
182
-  ## Randomly assign clusters to cells with cluster to kick
183
-  z.k.to.kick = sample(1:k, size=sum(z == k.to.kick), replace=TRUE)
184
-  z.kick[z==k.to.kick,] = z.k.to.kick
197
+  a = nS * lgamma(K * alpha)
198
+  b = sum(lgamma(m.CP.by.S + alpha))
199
+  c = -nS * K * lgamma(alpha)
200
+  d = -sum(lgamma(colSums(m.CP.by.S + alpha)))
185 201
   
186
-  ## Loop through each cluster, split, and determine logLik
187
-  k.kick.ll = rep(NA, length(k.to.test))
188
-  for(i in 1:length(k.to.test)) {
189
-    k.dist = cosineDist(counts.norm[,z==k.to.test[i]])/2
190
-    k.pam = pam(x=k.dist, k=2)$clustering
191
-    
192
-    ## If PAM split is too small, perform secondary hclust procedure to split into equal groups
193
-    if(min(table(k.pam)) < min.cell) {
194
-      k.hc = hclust(k.dist, method="ward.D")
195
-      
196
-      ## Get maximum sample size of each subcluster
197
-      k.hc.size = sapply(1:length(k.hc$height), function(i) max(table(cutree(k.hc, h=k.hc$height[i]))))
198
-      
199
-      ## Find the height of the dendrogram that best splits the samples in half
200
-      sample.size = round(length(k.hc$order)/ 2)
201
-      k.hc.select = which.min(abs(k.hc.size - sample.size))
202
-      k.hc.cut = cutree(k.hc, h=k.hc$height[k.hc.select])
203
-      k.hc.cluster = which.max(table(k.hc.cut))
204
-      
205
-      k.hc.final = ifelse(k.hc.cut == k.hc.cluster, k.to.test[i], k.to.kick)
206
-      
207
-      ix = (z == k.to.test[i])
208
-      z.kick[ix,i] = k.hc.final
209
-      
210
-    } else {
211
-      
212
-      k.pam.final = ifelse(k.pam == 1, k.to.test[i], k.to.kick)
213
-      ix = (z == k.to.test[i])
214
-      z.kick[ix,i] = k.pam.final
215
-      
216
-    }
217
-    k.kick.ll[i] = calcLL(counts=counts, s=s, z=z.kick[,i], k=k, alpha=a, beta=b)
218
-  }
219
-
220
-  k.to.test.select = sample.ll(k.kick.ll)
202
+  theta.ll = a + b + c + d
221 203
   
222
-  cat("splitting Cluster", k.to.test[k.to.test.select], "\n")
223
-  return(z.kick[,k.to.test.select])
224
-}
225
-
226
-
227
-sample.ll = function(ll.probs) {
228
-  probs.sub = exp(ll.probs - max(ll.probs))
229
-  probs.norm = probs.sub / sum(probs.sub)
230
-  probs.select = sample(1:length(ll.probs), size=1, prob=probs.norm)
231
-  return(probs.select)
204
+  ## Calculate for "Phi" component
205
+  nG = ncol(n.CP.by.G)
206
+  
207
+  a = K * lgamma(nG * beta)
208
+  b = sum(lgamma(n.CP.by.G + beta))
209
+  c = -K * nG * lgamma(beta)
210
+  d = -sum(lgamma(rowSums(n.CP.by.G + beta)))
211
+  
212
+  phi.ll = a + b + c + d
213
+  
214
+  final = theta.ll + phi.ll
215
+  return(final)
232 216
 }
233 217