18 | 18 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,529 @@ |
1 |
+#include "GibbsSampler.h" |
|
2 |
+#include "Algorithms.h" |
|
3 |
+ |
|
4 |
+#include <Rcpp.h> |
|
5 |
+ |
|
6 |
+static const float EPSILON = 1.e-10; |
|
7 |
+ |
|
8 |
+GibbsSampler::GibbsSampler(const Rcpp::NumericMatrix &D, |
|
9 |
+const Rcpp::NumericMatrix &S, unsigned nFactor) |
|
10 |
+ : |
|
11 |
+mDMatrix(D), mSMatrix(S), mAPMatrix(D.nrow(), D.ncol()), |
|
12 |
+mAMatrix(D.nrow(), nFactor), mPMatrix(nFactor, D.ncol()), |
|
13 |
+mADomain('A', D.nrow(), nFactor), mPDomain('P', nFactor, D.ncol()), |
|
14 |
+mAMeanMatrix(D.nrow(), nFactor), mAStdMatrix(D.nrow(), nFactor), |
|
15 |
+mPMeanMatrix(nFactor, D.ncol()), mPStdMatrix(nFactor, D.ncol()), |
|
16 |
+mPumpMatrix(D.nrow(), nFactor) |
|
17 |
+{} |
|
18 |
+ |
|
19 |
+GibbsSampler::GibbsSampler(const Rcpp::NumericMatrix &D, |
|
20 |
+const Rcpp::NumericMatrix &S, unsigned nFactor, float alphaA, float alphaP, |
|
21 |
+float maxGibbmassA, float maxGibbmassP, bool singleCellRNASeq, |
|
22 |
+char whichMatrixFixed, const Rcpp::NumericMatrix &FP, PumpThreshold pumpThreshold) |
|
23 |
+ : |
|
24 |
+mDMatrix(D), mSMatrix(S), mAPMatrix(D.nrow(), D.ncol()), |
|
25 |
+mAMatrix(D.nrow(), nFactor), mPMatrix(nFactor, D.ncol()), |
|
26 |
+mADomain('A', D.nrow(), nFactor), mPDomain('P', nFactor, D.ncol()), |
|
27 |
+mAMeanMatrix(D.nrow(), nFactor), mAStdMatrix(D.nrow(), nFactor), |
|
28 |
+mPMeanMatrix(nFactor, D.ncol()), mPStdMatrix(nFactor, D.ncol()), |
|
29 |
+mPumpMatrix(D.nrow(), nFactor), mPumpThreshold(pumpThreshold), mStatUpdates(0), |
|
30 |
+mPumpStatUpdates(0), mMaxGibbsMassA(maxGibbmassA), mMaxGibbsMassP(maxGibbmassP), |
|
31 |
+mAnnealingTemp(1.0), mSingleCellRNASeq(singleCellRNASeq), mNumFixedPatterns(0), |
|
32 |
+mFixedMat(whichMatrixFixed) |
|
33 |
+{ |
|
34 |
+ float meanD = mSingleCellRNASeq ? gaps::algo::nonZeroMean(mDMatrix) |
|
35 |
+ : gaps::algo::mean(mDMatrix); |
|
36 |
+ |
|
37 |
+ mADomain.setAlpha(alphaA); |
|
38 |
+ mADomain.setLambda(alphaA * std::sqrt(nFactor / meanD)); |
|
39 |
+ mPDomain.setAlpha(alphaP); |
|
40 |
+ mPDomain.setLambda(alphaP * std::sqrt(nFactor / meanD)); |
|
41 |
+ |
|
42 |
+ mMaxGibbsMassA /= mADomain.lambda(); |
|
43 |
+ mMaxGibbsMassP /= mPDomain.lambda(); |
|
44 |
+ |
|
45 |
+ if (mFixedMat == 'A') |
|
46 |
+ { |
|
47 |
+ mNumFixedPatterns = FP.ncol(); |
|
48 |
+ ColMatrix temp(FP); |
|
49 |
+ for (unsigned j = 0; j < mNumFixedPatterns; ++j) |
|
50 |
+ { |
|
51 |
+ mAMatrix.getCol(j) = temp.getCol(j) / gaps::algo::sum(temp.getCol(j)); |
|
52 |
+ } |
|
53 |
+ } |
|
54 |
+ else if (mFixedMat == 'P') |
|
55 |
+ { |
|
56 |
+ mNumFixedPatterns = FP.nrow(); |
|
57 |
+ RowMatrix temp(FP); |
|
58 |
+ for (unsigned i = 0; i < mNumFixedPatterns; ++i) |
|
59 |
+ { |
|
60 |
+ mPMatrix.getRow(i) = temp.getRow(i) / gaps::algo::sum(temp.getRow(i)); |
|
61 |
+ } |
|
62 |
+ } |
|
63 |
+ gaps::algo::matrixMultiplication(mAPMatrix, mAMatrix, mPMatrix); |
|
64 |
+} |
|
65 |
+ |
|
66 |
+float GibbsSampler::getGibbsMass(const MatrixChange &change) |
|
67 |
+{ |
|
68 |
+ // check if this change is death (only called in birth/death) |
|
69 |
+ bool death = change.delta1 < 0; |
|
70 |
+ |
|
71 |
+ // get s and su |
|
72 |
+ AlphaParameters alphaParam = gaps::algo::alphaParameters(change, mDMatrix, |
|
73 |
+ mSMatrix, mAMatrix, mPMatrix, mAPMatrix); |
|
74 |
+ |
|
75 |
+ // calculate mean and standard deviation |
|
76 |
+ alphaParam.s *= mAnnealingTemp / 2.0; |
|
77 |
+ alphaParam.su *= mAnnealingTemp / 2.0; |
|
78 |
+ float lambda = change.label == 'A' ? mADomain.lambda() : mPDomain.lambda(); |
|
79 |
+ float mean = (2.0 * alphaParam.su - lambda) / (2.0 * alphaParam.s); |
|
80 |
+ float sd = 1.0 / std::sqrt(2.0 * alphaParam.s); |
|
81 |
+ |
|
82 |
+ // note: is bounded below by zero so have to use inverse sampling! |
|
83 |
+ // based upon algorithm in DistScalarRmath.cc (scalarRandomSample) |
|
84 |
+ float plower = gaps::random::p_norm(0.f, mean, sd); |
|
85 |
+ |
|
86 |
+ // if the likelihood is flat and nonzero, sample strictly from the prior |
|
87 |
+ float newMass = 0.f; |
|
88 |
+ if (plower == 1.f || alphaParam.s < 0.00001f) |
|
89 |
+ { |
|
90 |
+ newMass = death ? std::abs(change.delta1) : 0.f; |
|
91 |
+ } |
|
92 |
+ else if (plower >= 0.99f) // what is this? |
|
93 |
+ { |
|
94 |
+ float tmp1 = gaps::random::d_norm(0.f, mean, sd); |
|
95 |
+ float tmp2 = gaps::random::d_norm(10.f * lambda, mean, sd); |
|
96 |
+ if (tmp1 > EPSILON && std::abs(tmp1 - tmp2) < EPSILON) |
|
97 |
+ { |
|
98 |
+ return death ? 0.0 : change.delta1; |
|
99 |
+ } |
|
100 |
+ } |
|
101 |
+ else |
|
102 |
+ { |
|
103 |
+ newMass = gaps::random::inverseNormSample(plower, 1.f, mean, sd); |
|
104 |
+ } |
|
105 |
+ |
|
106 |
+ newMass = (change.label == 'A' ? std::min(newMass, mMaxGibbsMassA) |
|
107 |
+ : std::min(newMass, mMaxGibbsMassP)); |
|
108 |
+ |
|
109 |
+ return std::max(newMass, 0.f); |
|
110 |
+} |
|
111 |
+ |
|
112 |
+float GibbsSampler::computeDeltaLL(const MatrixChange &change) |
|
113 |
+{ |
|
114 |
+ return gaps::algo::deltaLL(change, mDMatrix, mSMatrix, mAMatrix, |
|
115 |
+ mPMatrix, mAPMatrix); |
|
116 |
+} |
|
117 |
+ |
|
118 |
+void GibbsSampler::update(char matrixLabel, unsigned nUpdates) |
|
119 |
+{ |
|
120 |
+ AtomicSupport &domain(matrixLabel == 'A' ? mADomain : mPDomain); |
|
121 |
+ for (unsigned i = 0; i < nUpdates; ++i) |
|
122 |
+ { |
|
123 |
+ assert(nUpdates - i - domain.size() >= 0); |
|
124 |
+ domain.populateQueue(nUpdates - i - domain.size()); |
|
125 |
+ AtomicProposal proposal = domain.popQueue(); |
|
126 |
+ switch (proposal.type) |
|
127 |
+ { |
|
128 |
+ case 'D': death(domain, proposal); break; |
|
129 |
+ case 'B': birth(domain, proposal); break; |
|
130 |
+ case 'M': move(domain, proposal); break; |
|
131 |
+ case 'E': exchange(domain, proposal); break; |
|
132 |
+ } |
|
133 |
+ } |
|
134 |
+} |
|
135 |
+ |
|
136 |
+uint64_t GibbsSampler::totalNumAtoms(char matrixLabel) const |
|
137 |
+{ |
|
138 |
+ return matrixLabel == 'A' ? mADomain.numAtoms() : mPDomain.numAtoms(); |
|
139 |
+} |
|
140 |
+ |
|
141 |
+float GibbsSampler::chi2() const |
|
142 |
+{ |
|
143 |
+ return 2.f * gaps::algo::loglikelihood(mDMatrix, mSMatrix, mAPMatrix); |
|
144 |
+} |
|
145 |
+ |
|
146 |
+void GibbsSampler::setAnnealingTemp(float temp) |
|
147 |
+{ |
|
148 |
+ mAnnealingTemp = temp; |
|
149 |
+} |
|
150 |
+ |
|
151 |
+void GibbsSampler::evaluateChange(AtomicSupport &domain, |
|
152 |
+const AtomicProposal &proposal, MatrixChange &change, float threshold, |
|
153 |
+bool accept) |
|
154 |
+{ |
|
155 |
+ float delLL = accept ? 0.f : computeDeltaLL(change); |
|
156 |
+ if (accept || delLL * mAnnealingTemp >= threshold) |
|
157 |
+ { |
|
158 |
+ change = domain.acceptProposal(proposal, change); |
|
159 |
+ change.label == 'A' ? mAMatrix.update(change) : mPMatrix.update(change); |
|
160 |
+ updateAPMatrix(change); |
|
161 |
+ } |
|
162 |
+} |
|
163 |
+ |
|
164 |
+// simd? |
|
165 |
+void GibbsSampler::updateAPMatrix_A(unsigned row, unsigned col, float delta) |
|
166 |
+{ |
|
167 |
+ const Vector &APvec = mAPMatrix.getRow(row); |
|
168 |
+ const Vector &Pvec = mPMatrix.getRow(col); |
|
169 |
+ for (unsigned j = 0; j < mAPMatrix.nCol(); ++j) |
|
170 |
+ { |
|
171 |
+ mAPMatrix.set(row, j, APvec[j] + delta * Pvec[j]); |
|
172 |
+ } |
|
173 |
+} |
|
174 |
+ |
|
175 |
+// simd? |
|
176 |
+void GibbsSampler::updateAPMatrix_P(unsigned row, unsigned col, float delta) |
|
177 |
+{ |
|
178 |
+ const Vector &APvec = mAPMatrix.getCol(col); |
|
179 |
+ const Vector &Avec = mAMatrix.getCol(row); |
|
180 |
+ for (unsigned i = 0; i < mAPMatrix.nRow(); ++i) |
|
181 |
+ { |
|
182 |
+ mAPMatrix.set(i, col, APvec[i] + delta * Avec[i]); |
|
183 |
+ } |
|
184 |
+} |
|
185 |
+ |
|
186 |
+void GibbsSampler::updateAPMatrix(const MatrixChange &change) |
|
187 |
+{ |
|
188 |
+ if (change.label == 'A') |
|
189 |
+ { |
|
190 |
+ updateAPMatrix_A(change.row1, change.col1, change.delta1); |
|
191 |
+ if (change.nChanges > 1) |
|
192 |
+ { |
|
193 |
+ updateAPMatrix_A(change.row2, change.col2, change.delta2); |
|
194 |
+ } |
|
195 |
+ } |
|
196 |
+ else |
|
197 |
+ { |
|
198 |
+ updateAPMatrix_P(change.row1, change.col1, change.delta1); |
|
199 |
+ if (change.nChanges > 1) |
|
200 |
+ { |
|
201 |
+ updateAPMatrix_P(change.row2, change.col2, change.delta2); |
|
202 |
+ } |
|
203 |
+ } |
|
204 |
+} |
|
205 |
+ |
|
206 |
+bool GibbsSampler::canUseGibbs(const MatrixChange &ch) |
|
207 |
+{ |
|
208 |
+ bool check1 = (ch.label == 'A' && gaps::algo::isRowZero(mPMatrix, ch.col1)) |
|
209 |
+ || (ch.label == 'P' && gaps::algo::isColZero(mAMatrix, ch.row1)); |
|
210 |
+ |
|
211 |
+ if (ch.nChanges > 1) |
|
212 |
+ { |
|
213 |
+ bool check2 = (ch.label == 'A' && gaps::algo::isRowZero(mPMatrix, ch.col2)) |
|
214 |
+ || (ch.label == 'P' && gaps::algo::isColZero(mAMatrix, ch.row2)); |
|
215 |
+ return !(check1 && check2); |
|
216 |
+ } |
|
217 |
+ return !check1; |
|
218 |
+} |
|
219 |
+ |
|
220 |
+// accept automatically, try to rebirth |
|
221 |
+// TODO consolidate to single proposal |
|
222 |
+void GibbsSampler::death(AtomicSupport &domain, AtomicProposal &prop) |
|
223 |
+{ |
|
224 |
+ // automaticallly accept death |
|
225 |
+ MatrixChange change(prop.label, domain.getRow(prop.pos1), |
|
226 |
+ domain.getCol(prop.pos1), prop.delta1); |
|
227 |
+ evaluateChange(domain, prop, change, 0.f, true); |
|
228 |
+ |
|
229 |
+ // rebirth, label as birth |
|
230 |
+ float newMass = canUseGibbs(change) ? getGibbsMass(change) : 0.f; |
|
231 |
+ prop.delta1 = newMass < EPSILON ? -prop.delta1 : newMass; |
|
232 |
+ change.delta1 = prop.delta1; |
|
233 |
+ |
|
234 |
+ // attempt to accept rebirth |
|
235 |
+ evaluateChange(domain, prop, change, std::log(gaps::random::uniform())); |
|
236 |
+} |
|
237 |
+ |
|
238 |
+void GibbsSampler::birth(AtomicSupport &domain, AtomicProposal &prop) |
|
239 |
+{ |
|
240 |
+ // attempt gibbs |
|
241 |
+ MatrixChange change(prop.label, domain.getRow(prop.pos1), |
|
242 |
+ domain.getCol(prop.pos1), prop.delta1); |
|
243 |
+ prop.delta1 = canUseGibbs(change) ? getGibbsMass(change) : prop.delta1; |
|
244 |
+ change.delta1 = prop.delta1; |
|
245 |
+ |
|
246 |
+ // accept birth |
|
247 |
+ evaluateChange(domain, prop, change, 0.f, true); |
|
248 |
+} |
|
249 |
+ |
|
250 |
+void GibbsSampler::move(AtomicSupport &domain, AtomicProposal &prop) |
|
251 |
+{ |
|
252 |
+ MatrixChange change(prop.label, domain.getRow(prop.pos1), |
|
253 |
+ domain.getCol(prop.pos1), prop.delta1, domain.getRow(prop.pos2), |
|
254 |
+ domain.getCol(prop.pos2), prop.delta2); |
|
255 |
+ if (change.row1 != change.row2 || change.col1 != change.col2) |
|
256 |
+ { |
|
257 |
+ evaluateChange(domain, prop, change, std::log(gaps::random::uniform())); |
|
258 |
+ } |
|
259 |
+} |
|
260 |
+ |
|
261 |
+void GibbsSampler::exchange(AtomicSupport &domain, AtomicProposal &prop) |
|
262 |
+{ |
|
263 |
+ MatrixChange change(prop.label, domain.getRow(prop.pos1), |
|
264 |
+ domain.getCol(prop.pos1), prop.delta1, domain.getRow(prop.pos2), |
|
265 |
+ domain.getCol(prop.pos2), prop.delta2); |
|
266 |
+ if (change.row1 == change.row2 && change.col1 == change.col2) |
|
267 |
+ { |
|
268 |
+ return; |
|
269 |
+ } |
|
270 |
+ |
|
271 |
+ float mass1 = domain.at(prop.pos1); |
|
272 |
+ float mass2 = domain.at(prop.pos2); |
|
273 |
+ float newMass1 = mass1 + prop.delta1; |
|
274 |
+ float newMass2 = mass2 + prop.delta2; |
|
275 |
+ |
|
276 |
+ if (canUseGibbs(change)) |
|
277 |
+ { |
|
278 |
+ AlphaParameters alphaParam = gaps::algo::alphaParameters(change, |
|
279 |
+ mDMatrix, mSMatrix, mAMatrix, mPMatrix, mAPMatrix); |
|
280 |
+ alphaParam.s *= mAnnealingTemp; |
|
281 |
+ alphaParam.su *= mAnnealingTemp; |
|
282 |
+ |
|
283 |
+ if (alphaParam.s > EPSILON) |
|
284 |
+ { |
|
285 |
+ float mean = alphaParam.su / alphaParam.s; |
|
286 |
+ float sd = 1.f / std::sqrt(alphaParam.s); |
|
287 |
+ float plower = gaps::random::p_norm(-mass1, mean, sd); |
|
288 |
+ float pupper = gaps::random::p_norm(mass2, mean, sd); |
|
289 |
+ |
|
290 |
+ if (!(plower > 0.95f || pupper < 0.05f)) |
|
291 |
+ { |
|
292 |
+ float u = gaps::random::uniform(plower, pupper); |
|
293 |
+ prop.delta1 = gaps::random::q_norm(u, mean, sd); |
|
294 |
+ prop.delta1 = std::max(prop.delta1, -mass1); |
|
295 |
+ prop.delta1 = std::min(prop.delta1, mass2); |
|
296 |
+ prop.delta2 = -prop.delta1; |
|
297 |
+ change.delta1 = prop.delta1; |
|
298 |
+ change.delta2 = prop.delta2; |
|
299 |
+ evaluateChange(domain, prop, change, 0.f, true); |
|
300 |
+ return; |
|
301 |
+ } |
|
302 |
+ } |
|
303 |
+ } |
|
304 |
+ |
|
305 |
+ float pnewMass = mass1 > mass2 ? newMass1 : newMass2; |
|
306 |
+ float poldMass = newMass1 > newMass2 ? mass1 : mass2; |
|
307 |
+ |
|
308 |
+ float pnew = gaps::random::d_gamma(pnewMass, 2.0, 1.f / domain.lambda()); |
|
309 |
+ float pold = gaps::random::d_gamma(poldMass, 2.0, 1.f / domain.lambda()); |
|
310 |
+ |
|
311 |
+ if (pold == 0.f && pnew != 0.f) |
|
312 |
+ { |
|
313 |
+ evaluateChange(domain, prop, change, 0.f, true); |
|
314 |
+ } |
|
315 |
+ else |
|
316 |
+ { |
|
317 |
+ float priorLL = (pold == 0.f) ? 0.f : log(pnew / pold); |
|
318 |
+ float rejectProb = std::log(gaps::random::uniform()) - priorLL; |
|
319 |
+ evaluateChange(domain, prop, change, rejectProb); |
|
320 |
+ } |
|
321 |
+} |
|
322 |
+ |
|
323 |
+Rcpp::NumericMatrix GibbsSampler::AMeanRMatrix() const |
|
324 |
+{ |
|
325 |
+ return (mAMeanMatrix / mStatUpdates).rMatrix(); |
|
326 |
+} |
|
327 |
+ |
|
328 |
+Rcpp::NumericMatrix GibbsSampler::AStdRMatrix() const |
|
329 |
+{ |
|
330 |
+ return gaps::algo::computeStdDev(mAStdMatrix, mAMeanMatrix, |
|
331 |
+ mStatUpdates).rMatrix(); |
|
332 |
+} |
|
333 |
+ |
|
334 |
+Rcpp::NumericMatrix GibbsSampler::PMeanRMatrix() const |
|
335 |
+{ |
|
336 |
+ return (mPMeanMatrix / mStatUpdates).rMatrix(); |
|
337 |
+} |
|
338 |
+ |
|
339 |
+Rcpp::NumericMatrix GibbsSampler::PStdRMatrix() const |
|
340 |
+{ |
|
341 |
+ return gaps::algo::computeStdDev(mPStdMatrix, mPMeanMatrix, |
|
342 |
+ mStatUpdates).rMatrix(); |
|
343 |
+} |
|
344 |
+ |
|
345 |
+Rcpp::NumericMatrix GibbsSampler::pumpMatrix() const |
|
346 |
+{ |
|
347 |
+ unsigned denom = mPumpStatUpdates ? mPumpStatUpdates : 1.f; |
|
348 |
+ return (mPumpMatrix / denom).rMatrix(); |
|
349 |
+} |
|
350 |
+ |
|
351 |
+Rcpp::NumericMatrix GibbsSampler::meanPattern() |
|
352 |
+{ |
|
353 |
+ ColMatrix Amean(mAMeanMatrix / (float)mStatUpdates); |
|
354 |
+ RowMatrix Pmean(mPMeanMatrix / (float)mStatUpdates); |
|
355 |
+ ColMatrix mat(mAMatrix.nRow(), mAMatrix.nCol()); |
|
356 |
+ patternMarkers(Amean, Pmean, mat); |
|
357 |
+ return mat.rMatrix(); |
|
358 |
+} |
|
359 |
+ |
|
360 |
+float GibbsSampler::meanChiSq() const |
|
361 |
+{ |
|
362 |
+ ColMatrix Amean = mAMeanMatrix / (float)mStatUpdates; |
|
363 |
+ RowMatrix Pmean = mPMeanMatrix / (float)mStatUpdates; |
|
364 |
+ TwoWayMatrix Mmean(Amean.nRow(), Pmean.nCol()); |
|
365 |
+ gaps::algo::matrixMultiplication(Mmean, Amean, Pmean); |
|
366 |
+ return 2.f * gaps::algo::loglikelihood(mDMatrix, mSMatrix, Mmean); |
|
367 |
+} |
|
368 |
+ |
|
369 |
+void GibbsSampler::updateStatistics() |
|
370 |
+{ |
|
371 |
+ mStatUpdates++; |
|
372 |
+ unsigned nPatterns = mAMatrix.nCol(); |
|
373 |
+ |
|
374 |
+ Vector normVec(nPatterns); |
|
375 |
+ for (unsigned j = 0; j < nPatterns; ++j) |
|
376 |
+ { |
|
377 |
+ normVec[j] = gaps::algo::sum(mPMatrix.getRow(j)); |
|
378 |
+ normVec[j] = normVec[j] == 0 ? 1.f : normVec[j]; |
|
379 |
+ |
|
380 |
+ Vector quot(mPMatrix.getRow(j) / normVec[j]); |
|
381 |
+ mPMeanMatrix.getRow(j) += quot; |
|
382 |
+ mPStdMatrix.getRow(j) += gaps::algo::elementSq(quot); |
|
383 |
+ |
|
384 |
+ Vector prod(mAMatrix.getCol(j) * normVec[j]); |
|
385 |
+ mAMeanMatrix.getCol(j) += prod; |
|
386 |
+ mAStdMatrix.getCol(j) += gaps::algo::elementSq(prod); |
|
387 |
+ } |
|
388 |
+} |
|
389 |
+ |
|
390 |
+void GibbsSampler::updatePumpStatistics() |
|
391 |
+{ |
|
392 |
+ if (mFixedMat != 'A') |
|
393 |
+ { |
|
394 |
+ mPumpStatUpdates++; |
|
395 |
+ patternMarkers(normedAMatrix(), normedPMatrix(), mPumpMatrix); |
|
396 |
+ } |
|
397 |
+} |
|
398 |
+ |
|
399 |
+ColMatrix GibbsSampler::normedAMatrix() const |
|
400 |
+{ |
|
401 |
+ ColMatrix normedA(mAMatrix); |
|
402 |
+ for (unsigned j = 0; j < normedA.nCol(); ++j) |
|
403 |
+ { |
|
404 |
+ float factor = gaps::algo::sum(mPMatrix.getRow(j)); |
|
405 |
+ factor = (factor == 0) ? 1.f : factor; |
|
406 |
+ normedA.getCol(j) = normedA.getCol(j) * factor; |
|
407 |
+ } |
|
408 |
+ return normedA; |
|
409 |
+} |
|
410 |
+ |
|
411 |
+RowMatrix GibbsSampler::normedPMatrix() const |
|
412 |
+{ |
|
413 |
+ RowMatrix normedP(mPMatrix); |
|
414 |
+ for (unsigned i = 0; i < normedP.nRow(); ++i) |
|
415 |
+ { |
|
416 |
+ float factor = gaps::algo::sum(mPMatrix.getRow(i)); |
|
417 |
+ factor = (factor == 0) ? 1.f : factor; |
|
418 |
+ normedP.getRow(i) = normedP.getRow(i) / factor; |
|
419 |
+ } |
|
420 |
+ return normedP; |
|
421 |
+} |
|
422 |
+ |
|
423 |
+static unsigned geneThreshold(const ColMatrix &rankMatrix, unsigned pat) |
|
424 |
+{ |
|
425 |
+ float cutRank = rankMatrix.nRow(); |
|
426 |
+ for (unsigned i = 0; i < rankMatrix.nRow(); ++i) |
|
427 |
+ { |
|
428 |
+ for (unsigned j = 0; j < rankMatrix.nCol(); ++j) |
|
429 |
+ { |
|
430 |
+ if (j != pat && rankMatrix(i,j) <= rankMatrix(i,pat)) |
|
431 |
+ { |
|
432 |
+ cutRank = std::min(cutRank, std::max(0.f, rankMatrix(i,pat)-1)); |
|
433 |
+ } |
|
434 |
+ } |
|
435 |
+ } |
|
436 |
+ return static_cast<unsigned>(cutRank); |
|
437 |
+} |
|
438 |
+ |
|
439 |
+void GibbsSampler::patternMarkers(RowMatrix normedA, RowMatrix normedP, |
|
440 |
+ColMatrix &statMatrix) |
|
441 |
+{ |
|
442 |
+ // helpful notation |
|
443 |
+ unsigned nGenes = normedA.nRow(); |
|
444 |
+ unsigned nPatterns = normedA.nCol(); |
|
445 |
+ |
|
446 |
+ // scale A matrix |
|
447 |
+ for (unsigned j = 0; j < nPatterns; ++j) |
|
448 |
+ { |
|
449 |
+ float scale = gaps::algo::max(normedP.getRow(j)); |
|
450 |
+ for (unsigned i = 0; i < nGenes; ++i) |
|
451 |
+ { |
|
452 |
+ normedA(i,j) *= scale; |
|
453 |
+ } |
|
454 |
+ } |
|
455 |
+ |
|
456 |
+ // compute sstat |
|
457 |
+ TwoWayMatrix sStat(nGenes, nPatterns); |
|
458 |
+ Vector lp(nPatterns), diff(nPatterns); |
|
459 |
+ for (unsigned j = 0; j < nPatterns; ++j) |
|
460 |
+ { |
|
461 |
+ lp[j] = 1.f; |
|
462 |
+ for (unsigned i = 0; i < nGenes; ++i) |
|
463 |
+ { |
|
464 |
+ float geneMax = gaps::algo::max(normedA.getRow(i)); |
|
465 |
+ diff = geneMax > 0.f ? normedA.getRow(i) / geneMax - lp : lp * -1.f; |
|
466 |
+ sStat.set(i, j, std::sqrt(gaps::algo::dot(diff, diff))); |
|
467 |
+ } |
|
468 |
+ lp[j] = 0.f; |
|
469 |
+ } |
|
470 |
+ |
|
471 |
+ // update PUMP matrix |
|
472 |
+ if (mPumpThreshold == PUMP_UNIQUE) |
|
473 |
+ { |
|
474 |
+ for (unsigned i = 0; i < nGenes; ++i) |
|
475 |
+ { |
|
476 |
+ unsigned minNdx = gaps::algo::whichMin(sStat.getRow(i)); |
|
477 |
+ statMatrix(i,minNdx)++; |
|
478 |
+ } |
|
479 |
+ } |
|
480 |
+ else if (mPumpThreshold == PUMP_CUT) |
|
481 |
+ { |
|
482 |
+ ColMatrix rankMatrix(nGenes, nPatterns); |
|
483 |
+ for (unsigned j = 0; j < nPatterns; ++j) |
|
484 |
+ { |
|
485 |
+ rankMatrix.getCol(j) = gaps::algo::rank(sStat.getCol(j)); |
|
486 |
+ } |
|
487 |
+ |
|
488 |
+ for (unsigned j = 0; j < nPatterns; ++j) |
|
489 |
+ { |
|
490 |
+ unsigned cutRank = geneThreshold(rankMatrix, j); |
|
491 |
+ for (unsigned i = 0; i < nGenes; ++i) |
|
492 |
+ { |
|
493 |
+ if (rankMatrix(i,j) <= cutRank) |
|
494 |
+ { |
|
495 |
+ statMatrix(i,j)++; |
|
496 |
+ } |
|
497 |
+ } |
|
498 |
+ } |
|
499 |
+ } |
|
500 |
+} |
|
501 |
+ |
|
502 |
+Archive& operator<<(Archive &ar, GibbsSampler &sampler) |
|
503 |
+{ |
|
504 |
+ ar << sampler.mAMatrix << sampler.mPMatrix << sampler.mADomain |
|
505 |
+ << sampler.mPDomain << sampler.mAMeanMatrix << sampler.mAStdMatrix |
|
506 |
+ << sampler.mPMeanMatrix << sampler.mPStdMatrix << sampler.mStatUpdates |
|
507 |
+ << sampler.mPumpMatrix << sampler.mPumpThreshold << sampler.mPumpStatUpdates |
|
508 |
+ << sampler.mMaxGibbsMassA << sampler.mMaxGibbsMassP |
|
509 |
+ << sampler.mAnnealingTemp << sampler.mSingleCellRNASeq |
|
510 |
+ << sampler.mNumFixedPatterns << sampler.mFixedMat; |
|
511 |
+ |
|
512 |
+ return ar; |
|
513 |
+} |
|
514 |
+ |
|
515 |
+Archive& operator>>(Archive &ar, GibbsSampler &sampler) |
|
516 |
+{ |
|
517 |
+ ar >> sampler.mAMatrix >> sampler.mPMatrix >> sampler.mADomain |
|
518 |
+ >> sampler.mPDomain >> sampler.mAMeanMatrix >> sampler.mAStdMatrix |
|
519 |
+ >> sampler.mPMeanMatrix >> sampler.mPStdMatrix >> sampler.mStatUpdates |
|
520 |
+ >> sampler.mPumpMatrix >> sampler.mPumpThreshold >> sampler.mPumpStatUpdates |
|
521 |
+ >> sampler.mMaxGibbsMassA >> sampler.mMaxGibbsMassP |
|
522 |
+ >> sampler.mAnnealingTemp >> sampler.mSingleCellRNASeq |
|
523 |
+ >> sampler.mNumFixedPatterns >> sampler.mFixedMat; |
|
524 |
+ |
|
525 |
+ gaps::algo::matrixMultiplication(sampler.mAPMatrix, sampler.mAMatrix, |
|
526 |
+ sampler.mPMatrix); |
|
527 |
+ |
|
528 |
+ return ar; |
|
529 |
+} |
|
0 | 530 |
\ No newline at end of file |
1 | 531 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,104 @@ |
1 |
+#ifndef __COGAPS_GIBBSSAMPLER_H__ |
|
2 |
+#define __COGAPS_GIBBSSAMPLER_H__ |
|
3 |
+ |
|
4 |
+#include "AtomicSupport.h" |
|
5 |
+#include "Matrix.h" |
|
6 |
+ |
|
7 |
+#include <vector> |
|
8 |
+ |
|
9 |
+enum PumpThreshold |
|
10 |
+{ |
|
11 |
+ PUMP_UNIQUE=1, |
|
12 |
+ PUMP_CUT=2 |
|
13 |
+}; |
|
14 |
+ |
|
15 |
+class GibbsSampler |
|
16 |
+{ |
|
17 |
+private: |
|
18 |
+#ifdef GAPS_INTERNAL_TESTS |
|
19 |
+public: |
|
20 |
+#endif |
|
21 |
+ |
|
22 |
+ TwoWayMatrix mDMatrix, mSMatrix, mAPMatrix; |
|
23 |
+ |
|
24 |
+ ColMatrix mAMatrix; |
|
25 |
+ RowMatrix mPMatrix; |
|
26 |
+ |
|
27 |
+ AtomicSupport mADomain, mPDomain; |
|
28 |
+ |
|
29 |
+ ColMatrix mAMeanMatrix, mAStdMatrix; |
|
30 |
+ RowMatrix mPMeanMatrix, mPStdMatrix; |
|
31 |
+ unsigned mStatUpdates; |
|
32 |
+ |
|
33 |
+ ColMatrix mPumpMatrix; |
|
34 |
+ PumpThreshold mPumpThreshold; |
|
35 |
+ unsigned mPumpStatUpdates; |
|
36 |
+ |
|
37 |
+ float mMaxGibbsMassA; |
|
38 |
+ float mMaxGibbsMassP; |
|
39 |
+ |
|
40 |
+ float mAnnealingTemp; |
|
41 |
+ |
|
42 |
+ bool mSingleCellRNASeq; |
|
43 |
+ |
|
44 |
+ unsigned mNumFixedPatterns; |
|
45 |
+ char mFixedMat; |
|
46 |
+ |
|
47 |
+ void death(AtomicSupport &domain, AtomicProposal &proposal); |
|
48 |
+ void birth(AtomicSupport &domain, AtomicProposal &proposal); |
|
49 |
+ void move(AtomicSupport &domain, AtomicProposal &prop); |
|
50 |
+ void exchange(AtomicSupport &domain, AtomicProposal &proposal); |
|
51 |
+ |
|
52 |
+ void evaluateChange(AtomicSupport &domain, const AtomicProposal &proposal, |
|
53 |
+ MatrixChange &change, float threshold, bool accept=false); |
|
54 |
+ |
|
55 |
+ float computeDeltaLL(const MatrixChange &change); |
|
56 |
+ |
|
57 |
+ float getGibbsMass(const MatrixChange &change); |
|
58 |
+ |
|
59 |
+ void updateAPMatrix_A(unsigned row, unsigned col, float delta); |
|
60 |
+ void updateAPMatrix_P(unsigned row, unsigned col, float delta); |
|
61 |
+ void updateAPMatrix(const MatrixChange &change); |
|
62 |
+ |
|
63 |
+ bool canUseGibbs(const MatrixChange &ch); |
|
64 |
+ |
|
65 |
+public: |
|
66 |
+ |
|
67 |
+ GibbsSampler(const Rcpp::NumericMatrix &D, const Rcpp::NumericMatrix &S, |
|
68 |
+ unsigned nFactor); |
|
69 |
+ GibbsSampler(const Rcpp::NumericMatrix &D, const Rcpp::NumericMatrix &S, |
|
70 |
+ unsigned nFactor, float alphaA, float alphaP, float maxGibbmassA, |
|
71 |
+ float maxGibbmassP, bool singleCellRNASeq, char whichMatrixFixed, |
|
72 |
+ const Rcpp::NumericMatrix &FP, PumpThreshold pumpThreshold); |
|
73 |
+ |
|
74 |
+ void update(char matrixLabel); |
|
75 |
+ |
|
76 |
+ uint64_t totalNumAtoms(char matrixLabel) const; |
|
77 |
+ void setAnnealingTemp(float temp); |
|
78 |
+ float chi2() const; |
|
79 |
+ |
|
80 |
+ ColMatrix normedAMatrix() const; |
|
81 |
+ RowMatrix normedPMatrix() const; |
|
82 |
+ |
|
83 |
+ unsigned nRow() const {return mDMatrix.nRow();} |
|
84 |
+ unsigned nCol() const {return mDMatrix.nCol();} |
|
85 |
+ unsigned nFactor() const {return mAMatrix.nCol();} |
|
86 |
+ |
|
87 |
+ // statistics |
|
88 |
+ void updateStatistics(); |
|
89 |
+ void updatePumpStatistics(); |
|
90 |
+ Rcpp::NumericMatrix AMeanRMatrix() const; |
|
91 |
+ Rcpp::NumericMatrix AStdRMatrix() const; |
|
92 |
+ Rcpp::NumericMatrix PMeanRMatrix() const; |
|
93 |
+ Rcpp::NumericMatrix PStdRMatrix() const; |
|
94 |
+ Rcpp::NumericMatrix pumpMatrix() const; |
|
95 |
+ Rcpp::NumericMatrix meanPattern(); |
|
96 |
+ void patternMarkers(RowMatrix normedA, RowMatrix normedP, ColMatrix &statMatrix); |
|
97 |
+ float meanChiSq() const; |
|
98 |
+ |
|
99 |
+ // serialization |
|
100 |
+ friend Archive& operator<<(Archive &ar, GibbsSampler &sampler); |
|
101 |
+ friend Archive& operator>>(Archive &ar, GibbsSampler &sampler); |
|
102 |
+}; |
|
103 |
+ |
|
104 |
+#endif |
... | ... |
@@ -94,7 +94,7 @@ bool gaps::algo::isColZero(const ColMatrix &mat, unsigned col) |
94 | 94 |
{ |
95 | 95 |
return gaps::algo::sum(mat.getCol(col)) == 0; |
96 | 96 |
} |
97 |
- |
|
97 |
+/* |
|
98 | 98 |
// horribly slow, don't call often |
99 | 99 |
void gaps::algo::matrixMultiplication(TwoWayMatrix &C, const ColMatrix &A, |
100 | 100 |
const RowMatrix &B) |
... | ... |
@@ -313,4 +313,5 @@ const RowMatrix &P, const TwoWayMatrix &AP) |
313 | 313 |
} |
314 | 314 |
return a1 + a2; |
315 | 315 |
} |
316 |
-} |
|
317 | 316 |
\ No newline at end of file |
317 |
+} |
|
318 |
+*/ |
|
318 | 319 |
\ No newline at end of file |
... | ... |
@@ -52,7 +52,7 @@ namespace algo |
52 | 52 |
// specific matrix algorithms |
53 | 53 |
bool isRowZero(const RowMatrix &mat, unsigned row); |
54 | 54 |
bool isColZero(const ColMatrix &mat, unsigned col); |
55 |
- void matrixMultiplication(TwoWayMatrix &C, const ColMatrix &A, |
|
55 |
+ /*void matrixMultiplication(TwoWayMatrix &C, const ColMatrix &A, |
|
56 | 56 |
const RowMatrix &B); |
57 | 57 |
|
58 | 58 |
// chiSq / 2 |
... | ... |
@@ -67,7 +67,7 @@ namespace algo |
67 | 67 |
// alpha parameters used in exchange and gibbsMass calculation |
68 | 68 |
AlphaParameters alphaParameters(const MatrixChange &ch, |
69 | 69 |
const TwoWayMatrix &D, const TwoWayMatrix &S, const ColMatrix &A, |
70 |
- const RowMatrix &P, const TwoWayMatrix &AP); |
|
70 |
+ const RowMatrix &P, const TwoWayMatrix &AP);*/ |
|
71 | 71 |
} // namespace algo |
72 | 72 |
} // namespace gaps |
73 | 73 |
|
74 | 74 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,92 @@ |
1 |
+#include "AtomicDomain.h" |
|
2 |
+#include "Random.h" |
|
3 |
+ |
|
4 |
+#include <stdint.h> |
|
5 |
+#include <utility> |
|
6 |
+ |
|
7 |
+// O(1) |
|
8 |
+Atom AtomicDomain::front() const |
|
9 |
+{ |
|
10 |
+ return mAtoms[mAtomPositions.begin()->second]; |
|
11 |
+} |
|
12 |
+ |
|
13 |
+// O(1) |
|
14 |
+Atom AtomicDomain::randomAtom() const |
|
15 |
+{ |
|
16 |
+ uint64_t num = gaps::random::uniform64(0, mAtoms.size() - 1); |
|
17 |
+ return mAtoms[num]; |
|
18 |
+} |
|
19 |
+ |
|
20 |
+// O(logN) - keep hash of positions to fix, need this O(1) |
|
21 |
+uint64_t AtomicDomain::randomFreePosition() const |
|
22 |
+{ |
|
23 |
+ uint64_t pos = 0; |
|
24 |
+ do |
|
25 |
+ { |
|
26 |
+ pos = gaps::random::uniform64(); |
|
27 |
+ } while (mAtomPositions.count(pos) > 0); // count is O(logN) |
|
28 |
+ return pos; |
|
29 |
+} |
|
30 |
+ |
|
31 |
+// O(logN) |
|
32 |
+void AtomicDomain::insert(uint64_t pos, float mass) |
|
33 |
+{ |
|
34 |
+ // insert position into map |
|
35 |
+ std::map<uint64_t, uint64_t>::iterator iter, iterLeft, iterRight; |
|
36 |
+ iter = mAtomPositions.insert(std::pair<uint64_t, uint64_t>(pos, mAtoms.size())).first; |
|
37 |
+ iterLeft = iter; |
|
38 |
+ iterRight = iter; |
|
39 |
+ |
|
40 |
+ // find neighbors |
|
41 |
+ Atom atom(pos, mass); |
|
42 |
+ if (iter != mAtomPositions.begin()) |
|
43 |
+ { |
|
44 |
+ --iterLeft; |
|
45 |
+ atom.left = &(mAtoms[iterLeft->second]); |
|
46 |
+ } |
|
47 |
+ if (iter != mAtomPositions.end()) |
|
48 |
+ { |
|
49 |
+ ++iterRight; |
|
50 |
+ atom.right = &(mAtoms[iterRight->second]); |
|
51 |
+ } |
|
52 |
+ |
|
53 |
+ // add atom to vector |
|
54 |
+ mAtoms.push_back(atom); |
|
55 |
+} |
|
56 |
+ |
|
57 |
+// O(logN) |
|
58 |
+void AtomicDomain::erase(uint64_t pos) |
|
59 |
+{ |
|
60 |
+ // get vector index of this atom and erase it |
|
61 |
+ uint64_t index = mAtomPositions.at(pos); |
|
62 |
+ mAtomPositions.erase(pos); |
|
63 |
+ |
|
64 |
+ // update key of object about to be moved (last one doesn't need to move) |
|
65 |
+ if (index < mAtoms.size() - 1) |
|
66 |
+ { |
|
67 |
+ mAtomPositions.erase(mAtoms.back().pos); |
|
68 |
+ mAtomPositions.insert(std::pair<uint64_t, uint64_t>(mAtoms.back().pos, |
|
69 |
+ index)); |
|
70 |
+ } |
|
71 |
+ |
|
72 |
+ // update neighbors |
|
73 |
+ if (mAtoms[index].left) |
|
74 |
+ { |
|
75 |
+ mAtoms[index].left->right = mAtoms[index].right; |
|
76 |
+ } |
|
77 |
+ if (mAtoms[index].right) |
|
78 |
+ { |
|
79 |
+ mAtoms[index].right->left = mAtoms[index].left; |
|
80 |
+ } |
|
81 |
+ |
|
82 |
+ // delete atom from vector in O(1) |
|
83 |
+ mAtoms[index] = mAtoms.back(); |
|
84 |
+ mAtoms.pop_back(); |
|
85 |
+} |
|
86 |
+ |
|
87 |
+// O(logN) |
|
88 |
+void AtomicDomain::updateMass(uint64_t pos, float newMass) |
|
89 |
+{ |
|
90 |
+ uint64_t index = mAtomPositions.at(pos); |
|
91 |
+ mAtoms[index].mass = newMass; |
|
92 |
+} |
... | ... |
@@ -1,24 +1,12 @@ |
1 | 1 |
#ifndef __GAPS_ATOMIC_DOMAIN_H__ |
2 | 2 |
#define __GAPS_ATOMIC_DOMAIN_H__ |
3 | 3 |
|
4 |
-// data structure that holds atoms |
|
5 |
-class AtomicDomain |
|
6 |
-{ |
|
7 |
-private: |
|
8 |
- |
|
9 |
-public: |
|
10 |
- |
|
11 |
- AtomicDomain(); |
|
12 |
- |
|
13 |
- uint64_t randomAtomPosition(); |
|
14 |
- uint64_t randomFreePosition(); |
|
15 |
- |
|
16 |
- float updateMass(uint64_t pos, float delta); |
|
17 |
- |
|
18 |
-}; |
|
19 |
- |
|
20 |
-#endif |
|
4 |
+#include "Archive.h" |
|
21 | 5 |
|
6 |
+#include <stdint.h> |
|
7 |
+#include <cstddef> |
|
8 |
+#include <vector> |
|
9 |
+#include <map> |
|
22 | 10 |
|
23 | 11 |
struct Atom |
24 | 12 |
{ |
... | ... |
@@ -29,32 +17,38 @@ struct Atom |
29 | 17 |
Atom* right; |
30 | 18 |
|
31 | 19 |
Atom(uint64_t p, float m) |
32 |
- : pos(p), mass(m), left(nullptr), right(nullptr) |
|
20 |
+ : pos(p), mass(m), left(NULL), right(NULL) |
|
33 | 21 |
{} |
34 |
-}; |
|
35 | 22 |
|
36 |
-void insertAtom(uint64_t p, float m) |
|
37 |
-{ |
|
38 |
- std::map<uint64_t, Atom>::const_iterator it, left, right; |
|
39 |
- it = mAtoms.insert(std::pair<uint64_t, Atom>(p, Atom(p,m))).first; |
|
40 |
- |
|
41 |
- std::map<uint64_t, Atom>::const_iterator left(it), right(it); |
|
42 |
- if (it != mAtoms.begin()) |
|
23 |
+ bool operator==(const Atom &other) const |
|
43 | 24 |
{ |
44 |
- --left; |
|
25 |
+ return pos == other.pos; |
|
45 | 26 |
} |
46 |
- if (++it != mAtoms.end()) |
|
47 |
- { |
|
48 |
- ++right; |
|
49 |
- } |
|
50 |
- it->left = &left; |
|
51 |
- it->right = &right; |
|
52 |
-} |
|
27 |
+}; |
|
53 | 28 |
|
29 |
+// data structure that holds atoms |
|
30 |
+class AtomicDomain |
|
31 |
+{ |
|
32 |
+private: |
|
54 | 33 |
|
34 |
+ // domain storage |
|
35 |
+ std::vector<Atom> mAtoms; |
|
36 |
+ std::map<uint64_t, uint64_t> mAtomPositions; |
|
55 | 37 |
|
56 |
-void removeAtom(const Atom &atom) |
|
57 |
-{ |
|
58 |
- atom.left.right = atom.right; |
|
59 |
- atom.right.left = atom.left; |
|
60 |
-} |
|
38 |
+public: |
|
39 |
+ |
|
40 |
+ Atom front() const; |
|
41 |
+ Atom randomAtom() const; |
|
42 |
+ uint64_t randomFreePosition() const; |
|
43 |
+ |
|
44 |
+ // modify domain |
|
45 |
+ void insert(uint64_t pos, float mass); |
|
46 |
+ void erase(uint64_t pos); |
|
47 |
+ void updateMass(uint64_t pos, float newMass); |
|
48 |
+ |
|
49 |
+ // serialization |
|
50 |
+ friend Archive& operator<<(Archive &ar, AtomicDomain &domain); |
|
51 |
+ friend Archive& operator>>(Archive &ar, AtomicDomain &domain); |
|
52 |
+}; |
|
53 |
+ |
|
54 |
+#endif |
... | ... |
@@ -37,7 +37,7 @@ static void createCheckpoint(GapsInternalState &state) |
37 | 37 |
std::string fname(checkpointFile); |
38 | 38 |
Archive ar(fname, ARCHIVE_WRITE); |
39 | 39 |
gaps::random::save(ar); |
40 |
- ar << state.sampler.nFactor() << state.nEquil << state.nSample << state; |
|
40 |
+ ar << state.nFactor << state.nEquil << state.nSample << state; |
|
41 | 41 |
ar.close(); |
42 | 42 |
|
43 | 43 |
// display time it took to create checkpoint |
... | ... |
@@ -52,8 +52,12 @@ static void createCheckpoint(GapsInternalState &state) |
52 | 52 |
static void updateSampler(GapsInternalState &state) |
53 | 53 |
{ |
54 | 54 |
state.nUpdatesA += state.nIterA; |
55 |
+ state.ASampler.update(state.nIterA); |
|
56 |
+ state.PSampler.syncAP(state.ASampler.APMatrix()); |
|
57 |
+ |
|
55 | 58 |
state.nUpdatesP += state.nIterP; |
56 |
- state.runner.update(state.nIterA, state.nIterP); |
|
59 |
+ state.PSampler.update(state.nIterP); |
|
60 |
+ state.ASampler.syncAP(state.PSampler.APMatrix()); |
|
57 | 61 |
} |
58 | 62 |
|
59 | 63 |
static void makeCheckpointIfNeeded(GapsInternalState &state) |
... | ... |
@@ -70,9 +74,9 @@ static void makeCheckpointIfNeeded(GapsInternalState &state) |
70 | 74 |
static void storeSamplerInfo(GapsInternalState &state, Vector &atomsA, |
71 | 75 |
Vector &atomsP, Vector &chi2) |
72 | 76 |
{ |
73 |
- chi2[state.iter] = state.sampler.chi2(); |
|
74 |
- atomsA[state.iter] = state.sampler.totalNumAtoms('A'); |
|
75 |
- atomsP[state.iter] = state.sampler.totalNumAtoms('P'); |
|
77 |
+ chi2[state.iter] = state.ASampler.chi2(); |
|
78 |
+ atomsA[state.iter] = state.ASampler.nAtoms(); |
|
79 |
+ atomsP[state.iter] = state.PSampler.nAtoms(); |
|
76 | 80 |
state.nIterA = gaps::random::poisson(std::max(atomsA[state.iter], 10.f)); |
77 | 81 |
state.nIterP = gaps::random::poisson(std::max(atomsP[state.iter], 10.f)); |
78 | 82 |
} |
... | ... |
@@ -83,18 +87,18 @@ unsigned nIterTotal) |
83 | 87 |
if ((state.iter + 1) % state.nOutputs == 0 && state.messages) |
84 | 88 |
{ |
85 | 89 |
Rprintf("%s %d of %d, Atoms:%d(%d) Chi2 = %.2f\n", type.c_str(), |
86 |
- state.iter + 1, nIterTotal, state.sampler.totalNumAtoms('A'), |
|
87 |
- state.sampler.totalNumAtoms('P'), state.sampler.chi2()); |
|
90 |
+ state.iter + 1, nIterTotal, state.ASampler.nAtoms(), |
|
91 |
+ state.PSampler.nAtoms(), state.ASampler.chi2()); |
|
88 | 92 |
} |
89 | 93 |
} |
90 | 94 |
|
91 | 95 |
static void takeSnapshots(GapsInternalState &state) |
92 | 96 |
{ |
93 |
- if (state.nSnapshots && !((state.iter+1)%(state.nSample/state.nSnapshots))) |
|
97 |
+ /*if (state.nSnapshots && !((state.iter+1)%(state.nSample/state.nSnapshots))) |
|
94 | 98 |
{ |
95 | 99 |
state.snapshotsA.push_back(state.sampler.normedAMatrix().rMatrix()); |
96 | 100 |
state.snapshotsP.push_back(state.sampler.normedPMatrix().rMatrix()); |
97 |
- } |
|
101 |
+ }*/ |
|
98 | 102 |
} |
99 | 103 |
|
100 | 104 |
static void runBurnPhase(GapsInternalState &state) |
... | ... |
@@ -103,7 +107,8 @@ static void runBurnPhase(GapsInternalState &state) |
103 | 107 |
{ |
104 | 108 |
makeCheckpointIfNeeded(state); |
105 | 109 |
float temp = ((float)state.iter + 2.f) / ((float)state.nEquil / 2.f); |
106 |
- state.sampler.setAnnealingTemp(std::min(1.f,temp)); |
|
110 |
+ state.ASampler.setAnnealingTemp(std::min(1.f,temp)); |
|
111 |
+ state.PSampler.setAnnealingTemp(std::min(1.f,temp)); |
|
107 | 112 |
updateSampler(state); |
108 | 113 |
displayStatus(state, "Equil: ", state.nEquil); |
109 | 114 |
storeSamplerInfo(state, state.nAtomsAEquil, state.nAtomsPEquil, |
... | ... |
@@ -126,11 +131,11 @@ static void runSampPhase(GapsInternalState &state) |
126 | 131 |
{ |
127 | 132 |
makeCheckpointIfNeeded(state); |
128 | 133 |
updateSampler(state); |
129 |
- state.sampler.updateStatistics(); |
|
130 |
- if (state.nPumpSamples && !((state.iter + 1) % (state.nSample / state.nPumpSamples))) |
|
131 |
- { |
|
132 |
- state.sampler.updatePumpStatistics(); |
|
133 |
- } |
|
134 |
+ //state.sampler.updateStatistics(); |
|
135 |
+ //if (state.nPumpSamples && !((state.iter + 1) % (state.nSample / state.nPumpSamples))) |
|
136 |
+ //{ |
|
137 |
+ // state.sampler.updatePumpStatistics(); |
|
138 |
+ // } |
|
134 | 139 |
takeSnapshots(state); |
135 | 140 |
displayStatus(state, "Samp: ", state.nSample); |
136 | 141 |
storeSamplerInfo(state, state.nAtomsASample, state.nAtomsPSample, |
... | ... |
@@ -167,17 +172,17 @@ static Rcpp::List runCogaps(GapsInternalState &state) |
167 | 172 |
chi2Vec.concat(state.chi2VecSample); |
168 | 173 |
|
169 | 174 |
// print final chi-sq value |
170 |
- float meanChiSq = state.sampler.meanChiSq(); |
|
175 |
+ /*float meanChiSq = state.sampler.meanChiSq(); |
|
171 | 176 |
if (state.messages) |
172 | 177 |
{ |
173 | 178 |
Rprintf("Chi-Squared of Mean: %.2f\n", meanChiSq); |
174 |
- } |
|
179 |
+ }*/ |
|
175 | 180 |
|
176 | 181 |
return Rcpp::List::create( |
177 |
- Rcpp::Named("Amean") = state.sampler.AMeanRMatrix(), |
|
178 |
- Rcpp::Named("Asd") = state.sampler.AStdRMatrix(), |
|
179 |
- Rcpp::Named("Pmean") = state.sampler.PMeanRMatrix(), |
|
180 |
- Rcpp::Named("Psd") = state.sampler.PStdRMatrix(), |
|
182 |
+ //Rcpp::Named("Amean") = state.sampler.AMeanRMatrix(), |
|
183 |
+ //Rcpp::Named("Asd") = state.sampler.AStdRMatrix(), |
|
184 |
+ //Rcpp::Named("Pmean") = state.sampler.PMeanRMatrix(), |
|
185 |
+ //Rcpp::Named("Psd") = state.sampler.PStdRMatrix(), |
|
181 | 186 |
Rcpp::Named("ASnapshots") = Rcpp::wrap(state.snapshotsA), |
182 | 187 |
Rcpp::Named("PSnapshots") = Rcpp::wrap(state.snapshotsP), |
183 | 188 |
Rcpp::Named("atomsAEquil") = state.nAtomsAEquil.rVec(), |
... | ... |
@@ -186,10 +191,10 @@ static Rcpp::List runCogaps(GapsInternalState &state) |
186 | 191 |
Rcpp::Named("atomsPSamp") = state.nAtomsPSample.rVec(), |
187 | 192 |
Rcpp::Named("chiSqValues") = chi2Vec.rVec(), |
188 | 193 |
Rcpp::Named("randSeed") = state.seed, |
189 |
- Rcpp::Named("numUpdates") = state.nUpdatesA + state.nUpdatesP, |
|
190 |
- Rcpp::Named("meanChi2") = meanChiSq, |
|
191 |
- Rcpp::Named("pumpStats") = state.sampler.pumpMatrix(), |
|
192 |
- Rcpp::Named("meanPatternAssignment") = state.sampler.meanPattern() |
|
194 |
+ Rcpp::Named("numUpdates") = state.nUpdatesA + state.nUpdatesP |
|
195 |
+ //Rcpp::Named("meanChi2") = meanChiSq, |
|
196 |
+ //Rcpp::Named("pumpStats") = state.sampler.pumpMatrix(), |
|
197 |
+ //Rcpp::Named("meanPatternAssignment") = state.sampler.meanPattern() |
|
193 | 198 |
); |
194 | 199 |
} |
195 | 200 |
|
... | ... |
@@ -215,8 +220,8 @@ const std::string &cptFile, unsigned pumpThreshold, unsigned nPumpSamples) |
215 | 220 |
// create internal state from parameters and run from there |
216 | 221 |
GapsInternalState state(D, S, nFactor, nEquil, nEquilCool, nSample, |
217 | 222 |
nOutputs, nSnapshots, alphaA, alphaP, maxGibbmassA, maxGibbmassP, seed, |
218 |
- messages, singleCellRNASeq, whichMatrixFixed, FP, checkpointInterval, |
|
219 |
- static_cast<PumpThreshold>(pumpThreshold), nPumpSamples); |
|
223 |
+ messages, singleCellRNASeq, whichMatrixFixed, FP, checkpointInterval); |
|
224 |
+ //static_cast<PumpThreshold>(pumpThreshold), nPumpSamples); |
|
220 | 225 |
checkpointFile = cptFile; |
221 | 226 |
return runCogaps(state); |
222 | 227 |
} |
223 | 228 |
deleted file mode 100644 |
... | ... |
@@ -1,81 +0,0 @@ |
1 |
-#ifndef __GAPS_GAPS_RUNNER_H__ |
|
2 |
-#define __GAPS_GAPS_RUNNER_H__ |
|
3 |
- |
|
4 |
-// holds the data and dispatches the top level jobs |
|
5 |
-class GapsRunner |
|
6 |
-{ |
|
7 |
-private: |
|
8 |
- |
|
9 |
- // Amplitude and Pattern matrices |
|
10 |
- ColMatrix mAMatrix; |
|
11 |
- RowMatrix mPMatrix; |
|
12 |
- |
|
13 |
- // used when updating A matrix |
|
14 |
- RowMatrix mDMatrix; |
|
15 |
- RowMatrix mSMatrix; |
|
16 |
- RowMatrix mAPMatrix_A; |
|
17 |
- |
|
18 |
- // used when upating P matrix |
|
19 |
- ColMatrix mDMatrix; |
|
20 |
- ColMatrix mSMatrix; |
|
21 |
- ColMatrix mAPMatrix_P; |
|
22 |
- |
|
23 |
- // gibbs sampler |
|
24 |
- AmplitudeGibbsSampler mAGibbsSampler; |
|
25 |
- PatternGibbsSampler mPGibbsSampler; |
|
26 |
- |
|
27 |
- // proposal queue |
|
28 |
- ProposalQueue mAQueue; |
|
29 |
- ProposalQueue mPQueue; |
|
30 |
- |
|
31 |
- // atomic domain |
|
32 |
- AtomicDomain mADomain; |
|
33 |
- AtomicDomina mPDomain; |
|
34 |
- |
|
35 |
- // number of cores available for jobs |
|
36 |
- unsigned mNumCores; |
|
37 |
- |
|
38 |
-public: |
|
39 |
- |
|
40 |
- GapsRunner() {} |
|
41 |
- |
|
42 |
- void run(unsigned nASteps, unsigned nPSteps) |
|
43 |
- { |
|
44 |
- update(mADomain, mAQueue, mAGibbsSampler, nASteps); |
|
45 |
- mAPMatrix_P = mAPMatrix_A; |
|
46 |
- |
|
47 |
- update(mPDomain, mPQueue, mPGibbsSampler, nPSteps); |
|
48 |
- mAPMatrix_A = mAPMatrix_P; |
|
49 |
- } |
|
50 |
- |
|
51 |
- // Performance Metrics |
|
52 |
- // 1) % of cores used in each iteration |
|
53 |
- // 2) given fixed nCores, how does speed get better with matrix size, |
|
54 |
- // i.e. when does the overhead of parallelization start paying off |
|
55 |
- // 3) % of program spent in parallel portion, i.e. not in populate queue |
|
56 |
- void update(AtomicDomain domain, ProposalQueue queue, GibbsSampler sampler, |
|
57 |
- unsigned nUpdates) |
|
58 |
- { |
|
59 |
- unsigned n = 0; |
|
60 |
- while (n < nSteps) |
|
61 |
- { |
|
62 |
- // want this to be as quick as possible - otherwise there would be |
|
63 |
- // a large speed up to making this run concurrently along with the |
|
64 |
- // processProposal jobs, but that is much, much more complicated |
|
65 |
- // to implement |
|
66 |
- assert(nSteps - (queue.size() + n) >= 0); |
|
67 |
- queue.populate(domain, nSteps - (queue.size() + n)) |
|
68 |
- |
|
69 |
- unsigned nJobs = std::min(queue.size(), mNumCores); |
|
70 |
- for (unsigned i = 0; i < nJobs; ++i) // can be run in parallel |
|
71 |
- { |
|
72 |
- sampler.processProposal(domain, queue[i]); |
|
73 |
- } |
|
74 |
- queue.clear(nJobs); |
|
75 |
- n += nJobs; |
|
76 |
- assert(n <= nSteps); |
|
77 |
- } |
|
78 |
- } |
|
79 |
-}; |
|
80 |
- |
|
81 |
-#endif |
|
82 | 0 |
\ No newline at end of file |
... | ... |
@@ -1,529 +1,93 @@ |
1 | 1 |
#include "GibbsSampler.h" |
2 |
-#include "Algorithms.h" |
|
3 | 2 |
|
4 |
-#include <Rcpp.h> |
|
5 |
- |
|
6 |
-static const float EPSILON = 1.e-10; |
|
7 |
- |
|
8 |
-GibbsSampler::GibbsSampler(const Rcpp::NumericMatrix &D, |
|
3 |
+/* |
|
4 |
+AmplitudeGibbsSampler::AmplitudeGibbsSampler(const Rcpp::NumericMatrix &D, |
|
9 | 5 |
const Rcpp::NumericMatrix &S, unsigned nFactor) |
10 | 6 |
: |
11 |
-mDMatrix(D), mSMatrix(S), mAPMatrix(D.nrow(), D.ncol()), |
|
12 |
-mAMatrix(D.nrow(), nFactor), mPMatrix(nFactor, D.ncol()), |
|
13 |
-mADomain('A', D.nrow(), nFactor), mPDomain('P', nFactor, D.ncol()), |
|
14 |
-mAMeanMatrix(D.nrow(), nFactor), mAStdMatrix(D.nrow(), nFactor), |
|
15 |
-mPMeanMatrix(nFactor, D.ncol()), mPStdMatrix(nFactor, D.ncol()), |
|
16 |
-mPumpMatrix(D.nrow(), nFactor) |
|
7 |
+mMatrix(D.nrow(), nFactor), mDMatrix(D), mSMatrix(S), mAPMatrix(D.nrow(), D.ncol()), |
|
8 |
+mNumRows(D.nrow()), mNumCols(nFactor) |
|
17 | 9 |
{} |
18 | 10 |
|
19 |
-GibbsSampler::GibbsSampler(const Rcpp::NumericMatrix &D, |
|
20 |
-const Rcpp::NumericMatrix &S, unsigned nFactor, float alphaA, float alphaP, |
|
21 |
-float maxGibbmassA, float maxGibbmassP, bool singleCellRNASeq, |
|
22 |
-char whichMatrixFixed, const Rcpp::NumericMatrix &FP, PumpThreshold pumpThreshold) |
|
11 |
+AmplitudeGibbsSampler::AmplitudeGibbsSampler(const Rcpp::NumericMatrix &D, |
|
12 |
+const Rcpp::NumericMatrix &S, unsigned nFactor, float alpha, float maxGibbsmass) |
|
23 | 13 |
: |
24 |
-mDMatrix(D), mSMatrix(S), mAPMatrix(D.nrow(), D.ncol()), |
|
25 |
-mAMatrix(D.nrow(), nFactor), mPMatrix(nFactor, D.ncol()), |
|
26 |
-mADomain('A', D.nrow(), nFactor), mPDomain('P', nFactor, D.ncol()), |
|
27 |
-mAMeanMatrix(D.nrow(), nFactor), mAStdMatrix(D.nrow(), nFactor), |
|
28 |
-mPMeanMatrix(nFactor, D.ncol()), mPStdMatrix(nFactor, D.ncol()), |
|
29 |
-mPumpMatrix(D.nrow(), nFactor), mPumpThreshold(pumpThreshold), mStatUpdates(0), |
|
30 |
-mPumpStatUpdates(0), mMaxGibbsMassA(maxGibbmassA), mMaxGibbsMassP(maxGibbmassP), |
|
31 |
-mAnnealingTemp(1.0), mSingleCellRNASeq(singleCellRNASeq), mNumFixedPatterns(0), |
|
32 |
-mFixedMat(whichMatrixFixed) |
|
33 |
-{ |
|
34 |
- float meanD = mSingleCellRNASeq ? gaps::algo::nonZeroMean(mDMatrix) |
|
35 |
- : gaps::algo::mean(mDMatrix); |
|
36 |
- |
|
37 |
- mADomain.setAlpha(alphaA); |
|
38 |
- mADomain.setLambda(alphaA * std::sqrt(nFactor / meanD)); |
|
39 |
- mPDomain.setAlpha(alphaP); |
|
40 |
- mPDomain.setLambda(alphaP * std::sqrt(nFactor / meanD)); |
|
41 |
- |
|
42 |
- mMaxGibbsMassA /= mADomain.lambda(); |
|
43 |
- mMaxGibbsMassP /= mPDomain.lambda(); |
|
44 |
- |
|
45 |
- if (mFixedMat == 'A') |
|
46 |
- { |
|
47 |
- mNumFixedPatterns = FP.ncol(); |
|
48 |
- ColMatrix temp(FP); |
|
49 |
- for (unsigned j = 0; j < mNumFixedPatterns; ++j) |
|
50 |
- { |
|
51 |
- mAMatrix.getCol(j) = temp.getCol(j) / gaps::algo::sum(temp.getCol(j)); |
|
52 |
- } |
|
53 |
- } |
|
54 |
- else if (mFixedMat == 'P') |
|
55 |
- { |
|
56 |
- mNumFixedPatterns = FP.nrow(); |
|
57 |
- RowMatrix temp(FP); |
|
58 |
- for (unsigned i = 0; i < mNumFixedPatterns; ++i) |
|
59 |
- { |
|
60 |
- mPMatrix.getRow(i) = temp.getRow(i) / gaps::algo::sum(temp.getRow(i)); |
|
61 |
- } |
|
62 |
- } |
|
63 |
- gaps::algo::matrixMultiplication(mAPMatrix, mAMatrix, mPMatrix); |
|
64 |
-} |
|
65 |
- |
|
66 |
-float GibbsSampler::getGibbsMass(const MatrixChange &change) |
|
67 |
-{ |
|
68 |
- // check if this change is death (only called in birth/death) |
|
69 |
- bool death = change.delta1 < 0; |
|
70 |
- |
|
71 |
- // get s and su |
|
72 |
- AlphaParameters alphaParam = gaps::algo::alphaParameters(change, mDMatrix, |
|
73 |
- mSMatrix, mAMatrix, mPMatrix, mAPMatrix); |
|
74 |
- |
|
75 |
- // calculate mean and standard deviation |
|
76 |
- alphaParam.s *= mAnnealingTemp / 2.0; |
|
77 |
- alphaParam.su *= mAnnealingTemp / 2.0; |
|
78 |
- float lambda = change.label == 'A' ? mADomain.lambda() : mPDomain.lambda(); |
|
79 |
- float mean = (2.0 * alphaParam.su - lambda) / (2.0 * alphaParam.s); |
|
80 |
- float sd = 1.0 / std::sqrt(2.0 * alphaParam.s); |
|
81 |
- |
|
82 |
- // note: is bounded below by zero so have to use inverse sampling! |
|
83 |
- // based upon algorithm in DistScalarRmath.cc (scalarRandomSample) |
|
84 |
- float plower = gaps::random::p_norm(0.f, mean, sd); |
|
85 |
- |
|
86 |
- // if the likelihood is flat and nonzero, sample strictly from the prior |
|
87 |
- float newMass = 0.f; |
|
88 |
- if (plower == 1.f || alphaParam.s < 0.00001f) |
|
89 |
- { |
|
90 |
- newMass = death ? std::abs(change.delta1) : 0.f; |
|
91 |
- } |
|
92 |
- else if (plower >= 0.99f) // what is this? |
|
93 |
- { |
|
94 |
- float tmp1 = gaps::random::d_norm(0.f, mean, sd); |
|
95 |
- float tmp2 = gaps::random::d_norm(10.f * lambda, mean, sd); |
|
96 |
- if (tmp1 > EPSILON && std::abs(tmp1 - tmp2) < EPSILON) |
|
97 |
- { |
|
98 |
- return death ? 0.0 : change.delta1; |
|
99 |
- } |
|
100 |
- } |
|
101 |
- else |
|
102 |
- { |
|
103 |
- newMass = gaps::random::inverseNormSample(plower, 1.f, mean, sd); |
|
104 |
- } |
|
105 |
- |
|
106 |
- newMass = (change.label == 'A' ? std::min(newMass, mMaxGibbsMassA) |
|
107 |
- : std::min(newMass, mMaxGibbsMassP)); |
|
108 |
- |
|
109 |
- return std::max(newMass, 0.f); |
|
110 |
-} |
|
111 |
- |
|
112 |
-float GibbsSampler::computeDeltaLL(const MatrixChange &change) |
|
113 |
-{ |
|
114 |
- return gaps::algo::deltaLL(change, mDMatrix, mSMatrix, mAMatrix, |
|
115 |
- mPMatrix, mAPMatrix); |
|
116 |
-} |
|
117 |
- |
|
118 |
-void GibbsSampler::update(char matrixLabel, unsigned nUpdates) |
|
119 |
-{ |
|
120 |
- AtomicSupport &domain(matrixLabel == 'A' ? mADomain : mPDomain); |
|
121 |
- for (unsigned i = 0; i < nUpdates; ++i) |
|
122 |
- { |
|
123 |
- assert(nUpdates - i - domain.size() >= 0); |
|
124 |
- domain.populateQueue(nUpdates - i - domain.size()); |
|
125 |
- AtomicProposal proposal = domain.popQueue(); |
|
126 |
- switch (proposal.type) |
|
127 |
- { |
|
128 |
- case 'D': death(domain, proposal); break; |
|
129 |
- case 'B': birth(domain, proposal); break; |
|
130 |
- case 'M': move(domain, proposal); break; |
|
131 |
- case 'E': exchange(domain, proposal); break; |
|
132 |
- } |
|
133 |
- } |
|
134 |
-} |
|
14 |
+mMatrix(D.nrow(), nFactor), mDMatrix(D), mSMatrix(S), mAPMatrix(D.nrow(), D.ncol()), |
|
15 |
+mQueue(alpha, D.ncol(), D.nrow() * nFactor), mMaxGibbsMass(maxGibbsmass), |
|
16 |
+mAnnealingTemp(0), mNumRows(D.nrow()), mNumCols(nFactor) |
|
17 |
+{} |
|
18 |
+*/ |
|
135 | 19 |
|
136 |
-uint64_t GibbsSampler::totalNumAtoms(char matrixLabel) const |
|
20 |
+unsigned AmplitudeGibbsSampler::getRow(uint64_t pos) const |
|
137 | 21 |
{ |
138 |
- return matrixLabel == 'A' ? mADomain.numAtoms() : mPDomain.numAtoms(); |
|
22 |
+ return pos / (mBinSize * mNumCols); |
|
139 | 23 |
} |
140 | 24 |
|
141 |
-float GibbsSampler::chi2() const |
|
25 |
+unsigned AmplitudeGibbsSampler::getCol(uint64_t pos) const |
|
142 | 26 |
{ |
143 |
- return 2.f * gaps::algo::loglikelihood(mDMatrix, mSMatrix, mAPMatrix); |
|
27 |
+ return (pos / mBinSize) % mNumCols; |
|
144 | 28 |
} |
145 | 29 |
|
146 |
-void GibbsSampler::setAnnealingTemp(float temp) |
|
30 |
+bool AmplitudeGibbsSampler::canUseGibbs(unsigned row, unsigned col) const |
|
147 | 31 |
{ |
148 |
- mAnnealingTemp = temp; |
|
32 |
+ return !gaps::algo::isRowZero(*mOtherMatrix, col); |
|
149 | 33 |
} |
150 | 34 |
|
151 |
-void GibbsSampler::evaluateChange(AtomicSupport &domain, |
|
152 |
-const AtomicProposal &proposal, MatrixChange &change, float threshold, |
|
153 |
-bool accept) |
|
35 |
+bool AmplitudeGibbsSampler::canUseGibbs(unsigned r1, unsigned c1, unsigned r2, unsigned c2) const |
|
154 | 36 |
{ |
155 |
- float delLL = accept ? 0.f : computeDeltaLL(change); |
|
156 |
- if (accept || delLL * mAnnealingTemp >= threshold) |
|
157 |
- { |
|
158 |
- change = domain.acceptProposal(proposal, change); |
|
159 |
- change.label == 'A' ? mAMatrix.update(change) : mPMatrix.update(change); |
|
160 |
- updateAPMatrix(change); |
|
161 |
- } |
|
37 |
+ return !gaps::algo::isRowZero(*mOtherMatrix, c1) |
|
38 |
+ && !gaps::algo::isRowZero(*mOtherMatrix, c2); |
|
162 | 39 |
} |
163 | 40 |
|
164 |
-// simd? |
|
165 |
-void GibbsSampler::updateAPMatrix_A(unsigned row, unsigned col, float delta) |
|
41 |
+void AmplitudeGibbsSampler::updateAPMatrix(unsigned row, unsigned col, float delta) |
|
166 | 42 |
{ |
167 |
- const Vector &APvec = mAPMatrix.getRow(row); |
|
168 |
- const Vector &Pvec = mPMatrix.getRow(col); |
|
169 | 43 |
for (unsigned j = 0; j < mAPMatrix.nCol(); ++j) |
170 | 44 |
{ |
171 |
- mAPMatrix.set(row, j, APvec[j] + delta * Pvec[j]); |
|
172 |
- } |
|
173 |
-} |
|
174 |
- |
|
175 |
-// simd? |
|
176 |
-void GibbsSampler::updateAPMatrix_P(unsigned row, unsigned col, float delta) |
|
177 |
-{ |
|
178 |
- const Vector &APvec = mAPMatrix.getCol(col); |
|
179 |
- const Vector &Avec = mAMatrix.getCol(row); |
|
180 |
- for (unsigned i = 0; i < mAPMatrix.nRow(); ++i) |
|
181 |
- { |
|
182 |
- mAPMatrix.set(i, col, APvec[i] + delta * Avec[i]); |
|
183 |
- } |
|
184 |
-} |
|
185 |
- |
|
186 |
-void GibbsSampler::updateAPMatrix(const MatrixChange &change) |
|
187 |
-{ |
|
188 |
- if (change.label == 'A') |
|
189 |
- { |
|
190 |
- updateAPMatrix_A(change.row1, change.col1, change.delta1); |
|
191 |
- if (change.nChanges > 1) |
|
192 |
- { |
|
193 |
- updateAPMatrix_A(change.row2, change.col2, change.delta2); |
|
194 |
- } |
|
195 |
- } |
|
196 |
- else |
|
197 |
- { |
|
198 |
- updateAPMatrix_P(change.row1, change.col1, change.delta1); |
|
199 |
- if (change.nChanges > 1) |
|
200 |
- { |
|
201 |
- updateAPMatrix_P(change.row2, change.col2, change.delta2); |
|
202 |
- } |
|
203 |
- } |
|
204 |
-} |
|
205 |
- |
|
206 |
-bool GibbsSampler::canUseGibbs(const MatrixChange &ch) |
|
207 |
-{ |
|
208 |
- bool check1 = (ch.label == 'A' && gaps::algo::isRowZero(mPMatrix, ch.col1)) |
|
209 |
- || (ch.label == 'P' && gaps::algo::isColZero(mAMatrix, ch.row1)); |
|
210 |
- |
|
211 |
- if (ch.nChanges > 1) |
|
212 |
- { |
|
213 |
- bool check2 = (ch.label == 'A' && gaps::algo::isRowZero(mPMatrix, ch.col2)) |
|
214 |
- || (ch.label == 'P' && gaps::algo::isColZero(mAMatrix, ch.row2)); |
|
215 |
- return !(check1 && check2); |
|
216 |
- } |
|
217 |
- return !check1; |
|
218 |
-} |
|
219 |
- |
|
220 |
-// accept automatically, try to rebirth |
|
221 |
-// TODO consolidate to single proposal |
|
222 |
-void GibbsSampler::death(AtomicSupport &domain, AtomicProposal &prop) |
|
223 |
-{ |
|
224 |
- // automaticallly accept death |
|
225 |
- MatrixChange change(prop.label, domain.getRow(prop.pos1), |
|
226 |
- domain.getCol(prop.pos1), prop.delta1); |
|
227 |
- evaluateChange(domain, prop, change, 0.f, true); |
|
228 |
- |
|
229 |
- // rebirth, label as birth |
|
230 |
- float newMass = canUseGibbs(change) ? getGibbsMass(change) : 0.f; |
|
231 |
- prop.delta1 = newMass < EPSILON ? -prop.delta1 : newMass; |
|
232 |
- change.delta1 = prop.delta1; |
|
233 |
- |
|
234 |
- // attempt to accept rebirth |
|
235 |
- evaluateChange(domain, prop, change, std::log(gaps::random::uniform())); |
|
236 |
-} |
|
237 |
- |
|
238 |
-void GibbsSampler::birth(AtomicSupport &domain, AtomicProposal &prop) |
|
239 |
-{ |
|
240 |
- // attempt gibbs |
|
241 |
- MatrixChange change(prop.label, domain.getRow(prop.pos1), |
|
242 |
- domain.getCol(prop.pos1), prop.delta1); |
|
243 |
- prop.delta1 = canUseGibbs(change) ? getGibbsMass(change) : prop.delta1; |
|
244 |
- change.delta1 = prop.delta1; |
|
245 |
- |
|
246 |
- // accept birth |
|
247 |
- evaluateChange(domain, prop, change, 0.f, true); |
|
248 |
-} |
|
249 |
- |
|
250 |
-void GibbsSampler::move(AtomicSupport &domain, AtomicProposal &prop) |
|
251 |
-{ |
|
252 |
- MatrixChange change(prop.label, domain.getRow(prop.pos1), |
|
253 |
- domain.getCol(prop.pos1), prop.delta1, domain.getRow(prop.pos2), |
|
254 |
- domain.getCol(prop.pos2), prop.delta2); |
|
255 |
- if (change.row1 != change.row2 || change.col1 != change.col2) |
|
256 |
- { |
|
257 |
- evaluateChange(domain, prop, change, std::log(gaps::random::uniform())); |
|
258 |
- } |
|
259 |
-} |
|
260 |
- |
|
261 |
-void GibbsSampler::exchange(AtomicSupport &domain, AtomicProposal &prop) |
|
262 |
-{ |
|
263 |
- MatrixChange change(prop.label, domain.getRow(prop.pos1), |
|
264 |
- domain.getCol(prop.pos1), prop.delta1, domain.getRow(prop.pos2), |
|
265 |
- domain.getCol(prop.pos2), prop.delta2); |
|
266 |
- if (change.row1 == change.row2 && change.col1 == change.col2) |
|
267 |
- { |
|
268 |
- return; |
|
269 |
- } |
|
270 |
- |
|
271 |
- float mass1 = domain.at(prop.pos1); |
|
272 |
- float mass2 = domain.at(prop.pos2); |
|
273 |
- float newMass1 = mass1 + prop.delta1; |
|
274 |
- float newMass2 = mass2 + prop.delta2; |
|
275 |
- |
|
276 |
- if (canUseGibbs(change)) |
|
277 |
- { |
|
278 |
- AlphaParameters alphaParam = gaps::algo::alphaParameters(change, |
|
279 |
- mDMatrix, mSMatrix, mAMatrix, mPMatrix, mAPMatrix); |
|
280 |
- alphaParam.s *= mAnnealingTemp; |
|