Browse code

gibbs sampler interface

sherman5 authored on 19/03/2018 23:23:12
Showing16 changed files

... ...
@@ -5,8 +5,8 @@ cogaps_cpp <- function(D, S, nFactor, nEquil, nEquilCool, nSample, nOutputs, nSn
5 5
     .Call('_CoGAPS_cogaps_cpp', PACKAGE = 'CoGAPS', D, S, nFactor, nEquil, nEquilCool, nSample, nOutputs, nSnapshots, alphaA, alphaP, maxGibbmassA, maxGibbmassP, seed, messages, singleCellRNASeq, whichMatrixFixed, FP, checkpointInterval, cptFile, pumpThreshold, nPumpSamples)
6 6
 }
7 7
 
8
-cogapsFromCheckpoint_cpp <- function(D, S, fileName, cptFile) {
9
-    .Call('_CoGAPS_cogapsFromCheckpoint_cpp', PACKAGE = 'CoGAPS', D, S, fileName, cptFile)
8
+cogapsFromCheckpoint_cpp <- function(D, S, nFactor, nEquil, nSample, fileName, cptFile) {
9
+    .Call('_CoGAPS_cogapsFromCheckpoint_cpp', PACKAGE = 'CoGAPS', D, S, nFactor, nEquil, nSample, fileName, cptFile)
10 10
 }
11 11
 
12 12
 displayBuildReport_cpp <- function() {
... ...
@@ -1,208 +1,14 @@
1
-#include "GapsAssert.h"
2
-#include "GibbsSampler.h"
3
-#include "Matrix.h"
4 1
 #include "Archive.h"
5
-#include "InternalState.h"
2
+#include "GapsRunner.h"
3
+#include "Random.h"
6 4
 #include "SIMD.h"
7 5
 
8 6
 #include <Rcpp.h>
9
-#include <ctime>
10
-#include <fstream>
11
-#include <cstdio>
12
-#include <boost/date_time/posix_time/posix_time.hpp>
13
-#include <boost/archive/text_oarchive.hpp>
14
-#include <boost/archive/text_iarchive.hpp>
15 7
 
16 8
 // used to convert defined macro values into strings
17 9
 #define STR_HELPER(x) #x
18 10
 #define STR(x) STR_HELPER(x)
19 11
 
20
-// boost time helpers
21
-namespace bpt = boost::posix_time;
22
-#define bpt_now() bpt::microsec_clock::local_time()
23
-
24
-// keeps track of when checkpoints are made
25
-static bpt::ptime lastCheckpoint; 
26
-static std::string checkpointFile;
27
-
28
-// save the current internal state to a file
29
-static void createCheckpoint(GapsInternalState &state)
30
-{
31
-    // create backup file
32
-    std::rename(checkpointFile.c_str(), (checkpointFile + ".backup").c_str());
33
-
34
-    // record starting time
35
-    bpt::ptime start = bpt_now();
36
-
37
-    // save state to file, write magic number at beginning
38
-    std::string fname(checkpointFile);
39
-    Archive ar(fname, ARCHIVE_WRITE);
40
-    gaps::random::save(ar);
41
-    ar << state.nFactor << state.nEquil << state.nSample << state;
42
-    ar.close();
43
-
44
-    // display time it took to create checkpoint
45
-    bpt::time_duration diff = bpt_now() - start;
46
-    double elapsed = diff.total_milliseconds() / 1000.;
47
-    Rprintf("created checkpoint in %.3f seconds\n", elapsed);
48
-
49
-    // delete backup file
50
-    std::remove((checkpointFile + ".backup").c_str());
51
-}
52
-
53
-static void updateSampler(GapsInternalState &state)
54
-{
55
-    state.nUpdatesA += state.nIterA;
56
-    update(state.ASampler, state.nIterA);
57
-    state.PSampler.sync(state.ASampler);
58
-
59
-    state.nUpdatesP += state.nIterP;
60
-    update(state.PSampler, state.nIterP);
61
-    state.ASampler.sync(state.PSampler);
62
-}
63
-
64
-static void makeCheckpointIfNeeded(GapsInternalState &state)
65
-{
66
-    bpt::time_duration diff = bpt_now() - lastCheckpoint;
67
-    int diff_sec = diff.total_milliseconds() / 1000;
68
-    if (diff_sec > state.checkpointInterval && state.checkpointInterval > 0)
69
-    {
70
-        createCheckpoint(state);
71
-        lastCheckpoint = bpt_now();
72
-    }
73
-}
74
-
75
-static void storeSamplerInfo(GapsInternalState &state, Vector &atomsA,
76
-Vector &atomsP, Vector &chi2)
77
-{
78
-    chi2[state.iter] = state.ASampler.chi2();
79
-    atomsA[state.iter] = state.ASampler.nAtoms();
80
-    atomsP[state.iter] = state.PSampler.nAtoms();
81
-    state.nIterA = gaps::random::poisson(std::max(atomsA[state.iter], 10.f));
82
-    state.nIterP = gaps::random::poisson(std::max(atomsP[state.iter], 10.f));
83
-}
84
-
85
-static void displayStatus(GapsInternalState &state, const std::string &type,
86
-unsigned nIterTotal)
87
-{
88
-    if ((state.iter + 1) % state.nOutputs == 0 && state.messages)
89
-    {
90
-        Rprintf("%s %d of %d, Atoms:%d(%d) Chi2 = %.2f\n", type.c_str(),
91
-            state.iter + 1, nIterTotal, state.ASampler.nAtoms(),
92
-            state.PSampler.nAtoms(), state.ASampler.chi2());
93
-    }
94
-}
95
-
96
-static void takeSnapshots(GapsInternalState &state)
97
-{
98
-    /*if (state.nSnapshots && !((state.iter+1)%(state.nSample/state.nSnapshots)))
99
-    {
100
-        state.snapshotsA.push_back(state.sampler.normedAMatrix().rMatrix());
101
-        state.snapshotsP.push_back(state.sampler.normedPMatrix().rMatrix());
102
-    }*/   
103
-}
104
-
105
-static void runBurnPhase(GapsInternalState &state)
106
-{
107
-    for (; state.iter < state.nEquil; ++state.iter)
108
-    {
109
-        makeCheckpointIfNeeded(state);
110
-        float temp = ((float)state.iter + 2.f) / ((float)state.nEquil / 2.f);
111
-        state.ASampler.setAnnealingTemp(std::min(1.f,temp));
112
-        state.PSampler.setAnnealingTemp(std::min(1.f,temp));
113
-        updateSampler(state);
114
-        displayStatus(state, "Equil: ", state.nEquil);
115
-        storeSamplerInfo(state, state.nAtomsAEquil, state.nAtomsPEquil,
116
-            state.chi2VecEquil);
117
-    }
118
-}
119
-
120
-static void runCoolPhase(GapsInternalState &state)
121
-{
122
-    for (; state.iter < state.nEquilCool; ++state.iter)
123
-    {
124
-        makeCheckpointIfNeeded(state);
125
-        updateSampler(state);
126
-    }
127
-}
128
-
129
-static void runSampPhase(GapsInternalState &state)
130
-{
131
-    for (; state.iter < state.nSample; ++state.iter)
132
-    {
133
-        makeCheckpointIfNeeded(state);
134
-        updateSampler(state);
135
-        //state.sampler.updateStatistics();
136
-        //if (state.nPumpSamples && !((state.iter + 1) % (state.nSample / state.nPumpSamples)))
137
-        //{
138
-        //    state.sampler.updatePumpStatistics();
139
-       // }
140
-        takeSnapshots(state);
141
-        displayStatus(state, "Samp: ", state.nSample);
142
-        storeSamplerInfo(state, state.nAtomsASample, state.nAtomsPSample,
143
-            state.chi2VecSample);
144
-    }
145
-}
146
-
147
-// execute the steps of the algorithm, return list to R
148
-static Rcpp::List runCogaps(GapsInternalState &state)
149
-{
150
-    // reset the checkpoint timer
151
-    lastCheckpoint = bpt_now();
152
-
153
-    // cascade down the various phases of the algorithm
154
-    // this allows for starting in the middle of the algorithm
155
-    switch (state.phase)
156
-    {
157
-        case GAPS_BURN:
158
-            runBurnPhase(state);
159
-            state.iter = 0;
160
-            state.phase = GAPS_COOL;
161
-
162
-        case GAPS_COOL:
163
-            runCoolPhase(state);
164
-            state.iter = 0;
165
-            state.phase = GAPS_SAMP;
166
-
167
-        case GAPS_SAMP:
168
-            runSampPhase(state);
169
-    }
170
-
171
-    // combine chi2 vectors
172
-    Vector chi2Vec(state.chi2VecEquil);
173
-    chi2Vec.concat(state.chi2VecSample);
174
-
175
-    // print final chi-sq value
176
-    /*float meanChiSq = state.sampler.meanChiSq();
177
-    if (state.messages)
178
-    {
179
-        Rprintf("Chi-Squared of Mean: %.2f\n", meanChiSq);
180
-    }*/
181
-
182
-    return Rcpp::List::create(
183
-        //Rcpp::Named("Amean") = state.sampler.AMeanRMatrix(),
184
-        //Rcpp::Named("Asd") = state.sampler.AStdRMatrix(),
185
-        //Rcpp::Named("Pmean") = state.sampler.PMeanRMatrix(),
186
-        //Rcpp::Named("Psd") = state.sampler.PStdRMatrix(),
187
-        Rcpp::Named("Amean") = Rcpp::NumericMatrix(193,9),
188
-        Rcpp::Named("Asd") = Rcpp::NumericMatrix(193,9),
189
-        Rcpp::Named("Pmean") = Rcpp::NumericMatrix(9,193),
190
-        Rcpp::Named("Psd") = Rcpp::NumericMatrix(9,193),
191
-        Rcpp::Named("ASnapshots") = Rcpp::wrap(state.snapshotsA),
192
-        Rcpp::Named("PSnapshots") = Rcpp::wrap(state.snapshotsP),
193
-        Rcpp::Named("atomsAEquil") = state.nAtomsAEquil.rVec(),
194
-        Rcpp::Named("atomsASamp") = state.nAtomsASample.rVec(),
195
-        Rcpp::Named("atomsPEquil") = state.nAtomsPEquil.rVec(),
196
-        Rcpp::Named("atomsPSamp") = state.nAtomsPSample.rVec(),
197
-        Rcpp::Named("chiSqValues") = chi2Vec.rVec(),
198
-        Rcpp::Named("randSeed") = state.seed,
199
-        Rcpp::Named("numUpdates") = state.nUpdatesA + state.nUpdatesP
200
-        //Rcpp::Named("meanChi2") = meanChiSq,
201
-        //Rcpp::Named("pumpStats") = state.sampler.pumpMatrix(),
202
-        //Rcpp::Named("meanPatternAssignment") = state.sampler.meanPattern()
203
-    );
204
-}
205
-
206 12
 // [[Rcpp::export]]
207 13
 Rcpp::List cogaps_cpp(const Rcpp::NumericMatrix &D,
208 14
 const Rcpp::NumericMatrix &S, unsigned nFactor, unsigned nEquil,
... ...
@@ -212,7 +18,7 @@ bool messages, bool singleCellRNASeq, char whichMatrixFixed,
212 18
 const Rcpp::NumericMatrix &FP, unsigned checkpointInterval,
213 19
 const std::string &cptFile, unsigned pumpThreshold, unsigned nPumpSamples)
214 20
 {
215
-    // set seed
21
+    // get seed, TODO do this on R side, multiple benefits (same seed in R, C++)
216 22
     uint32_t seedUsed = static_cast<uint32_t>(seed);
217 23
     if (seed < 0)
218 24
     {
... ...
@@ -220,37 +26,23 @@ const std::string &cptFile, unsigned pumpThreshold, unsigned nPumpSamples)
220 26
         bpt::time_duration diff = bpt_now() - epoch;
221 27
         seedUsed = static_cast<uint32_t>(diff.total_milliseconds() % 1000);
222 28
     }
223
-    gaps::random::setSeed(seedUsed);
224 29
 
225 30
     // create internal state from parameters and run from there
226
-    GapsInternalState state(D, S, nFactor, nEquil, nEquilCool, nSample,
31
+    GapsRunner runner(D, S, nFactor, nEquil, nEquilCool, nSample,
227 32
         nOutputs, nSnapshots, alphaA, alphaP, maxGibbmassA, maxGibbmassP, seed,
228
-        messages, singleCellRNASeq, whichMatrixFixed, FP, checkpointInterval);
229
-        //static_cast<PumpThreshold>(pumpThreshold), nPumpSamples);
230
-    checkpointFile = cptFile;
231
-    return runCogaps(state);
33
+        messages, singleCellRNASeq,  checkpointInterval, cptFile,
34
+        whichMatrixFixed, FP);
35
+    return runner.run();
232 36
 }
233 37
 
234 38
 // TODO add checksum to verify D,S matrices
235 39
 // [[Rcpp::export]]
236 40
 Rcpp::List cogapsFromCheckpoint_cpp(const Rcpp::NumericMatrix &D,
237
-const Rcpp::NumericMatrix &S, const std::string &fileName,
238
-const std::string &cptFile)
41
+const Rcpp::NumericMatrix &S, unsigned nFactor, unsigned nEquil,
42
+unsigned nSample, const std::string &fileName, const std::string &cptFile)
239 43
 {   
240
-    Archive ar(fileName, ARCHIVE_READ);
241
-    gaps::random::load(ar);
242
-
243
-    // read parameters needed to calculate the size of the internal state
244
-    unsigned nFactor = 0, nEquil = 0, nSample = 0;
245
-    ar >> nFactor >> nEquil >> nSample;
246
-    
247
-    // construct empty state of the correct size, populate from file
248
-    GapsInternalState state(D, S, nFactor, nEquil, nSample);
249
-    ar >> state;
250
-
251
-    // run cogaps from this internal state
252
-    checkpointFile = cptFile;
253
-    return runCogaps(state);
44
+    GapsRunner runner(D, S, nFactor, nEquil, nSample, cptFile);
45
+    return runner.run();
254 46
 }
255 47
 
256 48
 // [[Rcpp::export]]
257 49
new file mode 100644
... ...
@@ -0,0 +1,197 @@
1
+#include "GapsRunner.h"
2
+
3
+GapsRunner::GapsRunner(const Rcpp::NumericMatrix &D, const Rcpp::NumericMatrix &S,
4
+unsigned nFactor, unsigned nEquil, unsigned nCool, unsigned nSample,
5
+unsigned nOutputs, unsigned nSnapshots, float alphaA, float alphaP,
6
+float maxGibbsMassA, float maxGibbsMassP, uint32_t seed, bool messages,
7
+bool singleCellRNASeq, unsigned cptInterval, const std::string &cptFile,
8
+char whichMatrixFixed, const Rcpp::NumericMatrix &FP)
9
+    :
10
+mChiSqEquil(nEquil), mNumAAtomsEquil(nEquil), mNumPAtomsEquil(nEquil),
11
+mChiSqSample(nSample), mNumAAtomsSample(nSample), mNumPAtomsSample(nSample),
12
+mIterA(10), mIterP(10), mEquilIter(nEquil), mCoolIter(nCool),
13
+mSampleIter(nSample), mNumPatterns(nFactor), mNumOutputs(nOutputs),
14
+mPrintMessages(messages), mCurrentIter(0), mPhase(GAPS_BURN), mSeed(seed),
15
+mCheckpointInterval(cptInterval), mCheckpointFile(cptFile),
16
+mNumUpdatesA(0), mNumUpdatesP(0),
17
+mASampler(D, S, nFactor, alphaA, maxGibbsMassA),
18
+mPSampler(D, S, nFactor, alphaP, maxGibbsMassP),
19
+mStatistics(D.nrow(), D.ncol(), nFactor)
20
+{
21
+    mASampler.sync(mPSampler);
22
+    mPSampler.sync(mASampler);
23
+    gaps::random::setSeed(seed);
24
+}
25
+
26
+GapsRunner::GapsRunner(const Rcpp::NumericMatrix &D, const Rcpp::NumericMatrix &S,
27
+unsigned nFactor, unsigned nEquil, unsigned nSample, const std::string &cptFile)
28
+    :
29
+mChiSqEquil(nEquil), mNumAAtomsEquil(nEquil), mNumPAtomsEquil(nEquil),
30
+mChiSqSample(nSample), mNumAAtomsSample(nSample), mNumPAtomsSample(nSample),
31
+mASampler(D, S, nFactor), mPSampler(D, S, nFactor),
32
+mStatistics(D.nrow(), D.ncol(), nFactor)
33
+{
34
+    Archive ar(cptFile, ARCHIVE_READ);
35
+    gaps::random::load(ar);
36
+
37
+   //ar >> mChiSqEquil >> mNumAAtomsEquil >> mNumPAtomsEquil >> mChiSqSample
38
+   //    >> mNumAAtomsSample >> mNumPAtomsSample >> mIterA >> mIterP
39
+   //    >> mEquilIter >> mCoolIter >> mSampleIter >> mNumPatterns >> mNumOutputs
40
+   //    >> mPrintMessages >> mCurrentIter >> mPhase >> mSeed
41
+   //    >> mCheckpointInterval >> mCheckpointFile >> mNumUpdatesA
42
+   //    >> mNumUpdatesP >> mASampler >> mPSampler >> mStatistics;
43
+}
44
+
45
+// execute the steps of the algorithm, return list to R
46
+Rcpp::List GapsRunner::run()
47
+{
48
+    // reset the checkpoint timer
49
+    mLastCheckpoint = bpt_now();
50
+
51
+    // cascade down the various phases of the algorithm
52
+    // this allows for starting in the middle of the algorithm
53
+    switch (mPhase)
54
+    {
55
+        case GAPS_BURN:
56
+            runBurnPhase();
57
+            mCurrentIter = 0;
58
+            mPhase = GAPS_COOL;
59
+        case GAPS_COOL:
60
+            runCoolPhase();
61
+            mCurrentIter = 0;
62
+            mPhase = GAPS_SAMP;
63
+        case GAPS_SAMP:
64
+            runSampPhase();
65
+    }
66
+
67
+    // combine chi2 vectors
68
+    Vector chi2Vec(mChiSqEquil);
69
+    chi2Vec.concat(mChiSqSample);
70
+
71
+    // print final chi-sq value
72
+    float meanChiSq = mStatistics.meanChiSq();
73
+    if (mPrintMessages)
74
+    {
75
+        Rprintf("Chi-Squared of Mean: %.2f\n", meanChiSq);
76
+    }
77
+
78
+    return Rcpp::List::create(
79
+        Rcpp::Named("Amean") = mStatistics.AMean(),
80
+        Rcpp::Named("Asd") = mStatistics.AStd(),
81
+        Rcpp::Named("Pmean") = mStatistics.PMean(),
82
+        Rcpp::Named("Psd") = mStatistics.PStd(),
83
+        Rcpp::Named("atomsAEquil") = mNumAAtomsEquil.rVec(),
84
+        Rcpp::Named("atomsASamp") = mNumAAtomsSample.rVec(),
85
+        Rcpp::Named("atomsPEquil") = mNumPAtomsEquil.rVec(),
86
+        Rcpp::Named("atomsPSamp") = mNumPAtomsSample.rVec(),
87
+        Rcpp::Named("chiSqValues") = chi2Vec.rVec(),
88
+        Rcpp::Named("randSeed") = mSeed,
89
+        Rcpp::Named("numUpdates") = mNumUpdatesA + mNumUpdatesP,
90
+        Rcpp::Named("meanChi2") = meanChiSq
91
+    );
92
+}
93
+
94
+void GapsRunner::runBurnPhase()
95
+{
96
+    for (; mCurrentIter < mEquilIter; ++mCurrentIter)
97
+    {
98
+        makeCheckpointIfNeeded();
99
+        float temp = ((float)mCurrentIter + 2.f) / ((float)mEquilIter / 2.f);
100
+        mASampler.setAnnealingTemp(std::min(1.f,temp));
101
+        mPSampler.setAnnealingTemp(std::min(1.f,temp));
102
+        updateSampler();
103
+        displayStatus("Equil: ", mEquilIter);
104
+        storeSamplerInfo(mNumAAtomsEquil, mNumPAtomsEquil, mChiSqEquil);
105
+    }
106
+}
107
+
108
+void GapsRunner::runCoolPhase()
109
+{
110
+    for (; mCurrentIter < mCoolIter; ++mCurrentIter)
111
+    {
112
+        makeCheckpointIfNeeded();
113
+        updateSampler();
114
+    }
115
+}
116
+
117
+void GapsRunner::runSampPhase()
118
+{
119
+    for (; mCurrentIter < mSampleIter; ++mCurrentIter)
120
+    {
121
+        makeCheckpointIfNeeded();
122
+        updateSampler();
123
+        mStatistics.update(mASampler, mPSampler);
124
+        displayStatus("Samp: ", mSampleIter);
125
+        storeSamplerInfo(mNumAAtomsSample, mNumPAtomsSample, mChiSqSample);
126
+    }
127
+}
128
+
129
+void GapsRunner::updateSampler()
130
+{
131
+    mNumUpdatesA += mIterA;
132
+    mASampler.update(mIterA);
133
+    mPSampler.sync(mASampler);
134
+
135
+    mNumUpdatesA += mIterP;
136
+    mPSampler.update(mIterP);
137
+    mASampler.sync(mPSampler);
138
+}
139
+
140
+void GapsRunner::storeSamplerInfo(Vector &atomsA, Vector &atomsP, Vector &chi2)
141
+{
142
+    chi2[mCurrentIter] = mASampler.chi2();
143
+    atomsA[mCurrentIter] = mASampler.nAtoms();
144
+    atomsP[mCurrentIter] = mPSampler.nAtoms();
145
+    mIterA = gaps::random::poisson(std::max(atomsA[mCurrentIter], 10.f));
146
+    mIterP = gaps::random::poisson(std::max(atomsP[mCurrentIter], 10.f));
147
+}
148
+
149
+void GapsRunner::displayStatus(const std::string &type, unsigned nIterTotal)
150
+{
151
+    if ((mCurrentIter + 1) % mNumOutputs == 0 && mPrintMessages)
152
+    {
153
+        Rprintf("%s %d of %d, Atoms:%d(%d) Chi2 = %.2f\n", type.c_str(),
154
+            mCurrentIter + 1, nIterTotal, mASampler.nAtoms(),
155
+            mPSampler.nAtoms(), mASampler.chi2());
156
+    }
157
+}
158
+
159
+// save the current internal state to a file
160
+void GapsRunner::createCheckpoint()
161
+{
162
+    // create backup file
163
+    std::rename(mCheckpointFile.c_str(), (mCheckpointFile + ".backup").c_str());
164
+
165
+    // record starting time
166
+    bpt::ptime start = bpt_now();
167
+
168
+    // save state to file, write magic number at beginning
169
+    Archive ar(mCheckpointFile, ARCHIVE_WRITE);
170
+    gaps::random::save(ar);
171
+    //ar << mChiSqEquil << mNumAAtomsEquil << mNumPAtomsEquil << mChiSqSample
172
+    //    << mNumAAtomsSample << mNumPAtomsSample << mIterA << mIterP
173
+    //    << mEquilIter << mCoolIter << mSampleIter << mNumPatterns << mNumOutputs
174
+    //    << mPrintMessages << mCurrentIter << mPhase << mSeed
175
+    //    << mCheckpointInterval << mNumUpdatesA << mNumUpdatesP << mASampler
176
+    //    << mPSampler << mStatistics;
177
+    ar.close();
178
+
179
+    // display time it took to create checkpoint
180
+    bpt::time_duration diff = bpt_now() - start;
181
+    double elapsed = diff.total_milliseconds() / 1000.;
182
+    Rprintf("created checkpoint in %.3f seconds\n", elapsed);
183
+
184
+    // delete backup file
185
+    std::remove((mCheckpointFile + ".backup").c_str());
186
+}
187
+
188
+void GapsRunner::makeCheckpointIfNeeded()
189
+{
190
+    bpt::time_duration diff = bpt_now() - mLastCheckpoint;
191
+    long sec = diff.total_milliseconds() / 1000;
192
+    if (sec > mCheckpointInterval && mCheckpointInterval > 0)
193
+    {
194
+        createCheckpoint();
195
+        mLastCheckpoint = bpt_now();
196
+    }
197
+}
0 198
\ No newline at end of file
1 199
new file mode 100644
... ...
@@ -0,0 +1,86 @@
1
+#ifndef __GAPS_GAPS_RUNNER_H__
2
+#define __GAPS_GAPS_RUNNER_H__
3
+
4
+#include "Archive.h"
5
+#include "Matrix.h"
6
+#include "GibbsSampler.h"
7
+#include "GapsStatistics.h"
8
+
9
+#include <Rcpp.h>
10
+
11
+// boost time helpers
12
+#include <boost/date_time/posix_time/posix_time.hpp>
13
+namespace bpt = boost::posix_time;
14
+#define bpt_now() bpt::microsec_clock::local_time()
15
+
16
+enum GapsPhase
17
+{
18
+    GAPS_BURN,
19
+    GAPS_COOL,
20
+    GAPS_SAMP
21
+};
22
+
23
+class GapsRunner
24
+{
25
+private:
26
+
27
+    Vector mChiSqEquil;
28
+    Vector mNumAAtomsEquil;
29
+    Vector mNumPAtomsEquil;
30
+
31
+    Vector mChiSqSample;
32
+    Vector mNumAAtomsSample;
33
+    Vector mNumPAtomsSample;
34
+
35
+    unsigned mIterA;
36
+    unsigned mIterP;
37
+    
38
+    unsigned mEquilIter;
39
+    unsigned mCoolIter;
40
+    unsigned mSampleIter;
41
+
42
+    unsigned mNumPatterns;
43
+    unsigned mNumOutputs;
44
+    bool mPrintMessages;
45
+
46
+    unsigned mCurrentIter;
47
+    GapsPhase mPhase;
48
+    uint32_t mSeed;
49
+
50
+    bpt::ptime mLastCheckpoint;
51
+    long mCheckpointInterval;
52
+    std::string mCheckpointFile;
53
+
54
+    unsigned mNumUpdatesA;
55
+    unsigned mNumUpdatesP;
56
+    
57
+    AmplitudeGibbsSampler mASampler;
58
+    PatternGibbsSampler mPSampler;
59
+    GapsStatistics mStatistics;
60
+
61
+    void createCheckpoint();
62
+    void makeCheckpointIfNeeded();
63
+    void displayStatus(const std::string &type, unsigned nIterTotal);
64
+    void storeSamplerInfo(Vector &atomA, Vector &atomsP, Vector &chi2);
65
+    void updateSampler();
66
+    void runBurnPhase();
67
+    void runCoolPhase();
68
+    void runSampPhase();
69
+
70
+public:
71
+
72
+    GapsRunner(const Rcpp::NumericMatrix &D, const Rcpp::NumericMatrix &S,
73
+        unsigned nFactor, unsigned nEquil, unsigned nCool, unsigned nSample,
74
+        unsigned nOutputs, unsigned nSnapshots, float alphaA, float alphaP,
75
+        float maxGibbsMassA, float maxGibbsMassP, uint32_t seed, bool messages,
76
+        bool singleCellRNASeq, unsigned cptInterval, const std::string &cptFile,
77
+        char whichMatrixFixed, const Rcpp::NumericMatrix &FP);
78
+
79
+    GapsRunner(const Rcpp::NumericMatrix &D, const Rcpp::NumericMatrix &S,
80
+        unsigned nFactor, unsigned nEquil, unsigned nSample,
81
+        const std::string &cptFile);
82
+
83
+    Rcpp::List run();
84
+};
85
+
86
+#endif
0 87
\ No newline at end of file
... ...
@@ -1,8 +1,10 @@
1 1
 #include "GapsStatistics.h"
2
+#include "Algorithms.h"
2 3
 
3 4
 GapsStatistics::GapsStatistics(unsigned nRow, unsigned nCol, unsigned nFactor)
4 5
     : mAMeanMatrix(nRow, nFactor), mAStdMatrix(nRow, nFactor),
5
-        mPMeanMatrix(nFactor, nCol), mPStdMatrix(nFactor, nCol), mStatUpdates(0)
6
+        mPMeanMatrix(nFactor, nCol), mPStdMatrix(nFactor, nCol),
7
+        mNumPatterns(nFactor), mStatUpdates(0)
6 8
 {}
7 9
 
8 10
 void GapsStatistics::update(const AmplitudeGibbsSampler &ASampler,
... ...
@@ -10,9 +12,20 @@ const PatternGibbsSampler &PSampler)
10 12
 {
11 13
     mStatUpdates++;
12 14
 
13
-    mAMeanMatrix += ASampler.mAMatrix;
14
-    mPMeanMatrix += PSampler.mPMatrix;
15
-    
15
+    Vector normVec(mNumPatterns);
16
+    for (unsigned j = 0; j < mNumPatterns; ++j)
17
+    {
18
+        //normVec[j] = gaps::algo::sum(PSampler.mPMatrix.getRow(j));
19
+        //normVec[j] = normVec[j] == 0 ? 1.f : normVec[j];
20
+
21
+        //Vector quot(PSampler.mPMatrix.getRow(j) / normVec[j]);
22
+        //mPMeanMatrix.getRow(j) += quot;
23
+        //mPStdMatrix.getRow(j) += gaps::algo::elementSq(quot);
24
+
25
+        //Vector prod(ASampler.mAMatrix.getCol(j) * normVec[j]);
26
+        //mAMeanMatrix.getCol(j) += prod;
27
+        //mAStdMatrix.getCol(j) += gaps::algo::elementSq(prod); 
28
+    }
16 29
 }
17 30
 
18 31
 Rcpp::NumericMatrix GapsStatistics::AMean() const
... ...
@@ -36,3 +49,12 @@ Rcpp::NumericMatrix GapsStatistics::PStd() const
36 49
     return gaps::algo::computeStdDev(mPStdMatrix, mPMeanMatrix,
37 50
         mStatUpdates).rMatrix();
38 51
 }
52
+
53
+float GapsStatistics::meanChiSq() const
54
+{
55
+    //ColMatrix A = mAMeanMatrix / mStatUpdates;
56
+    //RowMatrix P = mPMeanMatrix / mStatUpdates;
57
+    //RowMatrix M(A * P);
58
+    return 0.f;
59
+}
60
+
... ...
@@ -13,16 +13,19 @@ private:
13 13
     RowMatrix mPMeanMatrix;
14 14
     RowMatrix mPStdMatrix;
15 15
     
16
-    unsigned mStatUpdates;    
16
+    unsigned mStatUpdates;
17
+    unsigned mNumPatterns;
17 18
 
18 19
 public:
19 20
 
20 21
     GapsStatistics(unsigned nRow, unsigned nCol, unsigned nFactor);
21 22
 
22
-    Rcpp::NumericMatrix AMean();
23
-    Rcpp::NumericMatrix AStd();
24
-    Rcpp::NumericMatrix PMean();
25
-    Rcpp::NumericMatrix PStd();
23
+    Rcpp::NumericMatrix AMean() const;
24
+    Rcpp::NumericMatrix AStd() const;
25
+    Rcpp::NumericMatrix PMean() const;
26
+    Rcpp::NumericMatrix PStd() const;
27
+
28
+    float meanChiSq() const;
26 29
 
27 30
     void update(const AmplitudeGibbsSampler &ASampler,
28 31
         const PatternGibbsSampler &PSampler);
... ...
@@ -2,15 +2,8 @@
2 2
 
3 3
 AmplitudeGibbsSampler::AmplitudeGibbsSampler(const Rcpp::NumericMatrix &D,
4 4
 const Rcpp::NumericMatrix &S, unsigned nFactor, float alpha, float maxGibbsMass)
5
-    :
6
-mAMatrix(D.nrow(), nFactor), mDMatrix(D), mSMatrix(S),
7
-mAPMatrix(D.nrow(), D.ncol()), mQueue(D.nrow() * nFactor, alpha),
8
-mAnnealingTemp(0.f), mNumRows(D.nrow()), mNumCols(nFactor)
5
+    : GibbsSampler(D, S, D.nrow(), nFactor, alpha)
9 6
 {
10
-    mBinSize = std::numeric_limits<uint64_t>::max() / (mNumRows * mNumCols);
11
-    uint64_t remain = std::numeric_limits<uint64_t>::max() % (mNumRows * mNumCols);
12
-    mQueue.setDomainSize(std::numeric_limits<uint64_t>::max() - remain);
13
-
14 7
     float meanD = gaps::algo::mean(mDMatrix);
15 8
     mLambda = alpha * std::sqrt(nFactor / meanD);
16 9
     mMaxGibbsMass = maxGibbsMass / mLambda;
... ...
@@ -18,65 +11,47 @@ mAnnealingTemp(0.f), mNumRows(D.nrow()), mNumCols(nFactor)
18 11
 
19 12
 unsigned AmplitudeGibbsSampler::getRow(uint64_t pos) const
20 13
 {
21
-    return pos / (mBinSize * mNumCols);
14
+    unsigned row = pos / (mBinSize * mNumCols);
15
+    GAPS_ASSERT(row >= 0 && row < mNumRows);
16
+    return row;
22 17
 }
23 18
 
24 19
 unsigned AmplitudeGibbsSampler::getCol(uint64_t pos) const
25 20
 {
26
-    return (pos / mBinSize) % mNumCols;
21
+    unsigned col = (pos / mBinSize) % mNumCols;
22
+    GAPS_ASSERT(col >= 0 && col < mNumCols);
23
+    return col;
27 24
 }
28 25
 
29 26
 bool AmplitudeGibbsSampler::canUseGibbs(unsigned row, unsigned col) const
30 27
 {
31
-    return !gaps::algo::isRowZero(*mPMatrix, col);
28
+    return !gaps::algo::isRowZero(*mOtherMatrix, col);
32 29
 }
33 30
 
34 31
 bool AmplitudeGibbsSampler::canUseGibbs(unsigned r1, unsigned c1, unsigned r2, unsigned c2) const
35 32
 {
36
-    return !gaps::algo::isRowZero(*mPMatrix, c1)
37
-        && !gaps::algo::isRowZero(*mPMatrix, c2);
38
-}
39
-
40
-void AmplitudeGibbsSampler::updateAPMatrix(unsigned row, unsigned col, float delta)
41
-{
42
-    for (unsigned j = 0; j < mAPMatrix.nCol(); ++j)
43
-    {
44
-        mAPMatrix(row,j) += delta * (*mPMatrix)(j,col);
45
-    }
33
+    return !gaps::algo::isRowZero(*mOtherMatrix, c1)
34
+        && !gaps::algo::isRowZero(*mOtherMatrix, c2);
46 35
 }
47 36
 
48 37
 void AmplitudeGibbsSampler::sync(PatternGibbsSampler &sampler)
49 38
 {
50
-    mPMatrix = &(sampler.mPMatrix);
39
+    mOtherMatrix = &sampler.mMatrix;
51 40
     mAPMatrix = sampler.mAPMatrix;
52 41
 }
53 42
 
54
-float AmplitudeGibbsSampler::nAtoms() const
55
-{
56
-    return mDomain.size();
57
-}
58
-
59
-void AmplitudeGibbsSampler::setAnnealingTemp(float temp)
60
-{
61
-    mAnnealingTemp = temp;
62
-}
63
-
64
-float AmplitudeGibbsSampler::chi2() const
43
+void AmplitudeGibbsSampler::updateAPMatrix(unsigned row, unsigned col, float delta)
65 44
 {
66
-    return 2.f * gaps::algo::loglikelihood(mDMatrix, mSMatrix, mAPMatrix);   
45
+    for (unsigned j = 0; j < mAPMatrix.nCol(); ++j)
46
+    {
47
+        mAPMatrix(row,j) += delta * (*mOtherMatrix)(j,col);
48
+    }
67 49
 }
68 50
 
69 51
 PatternGibbsSampler::PatternGibbsSampler(const Rcpp::NumericMatrix &D,
70 52
 const Rcpp::NumericMatrix &S, unsigned nFactor, float alpha, float maxGibbsMass)
71
-    :
72
-mPMatrix(nFactor, D.ncol()), mDMatrix(D), mSMatrix(S),
73
-mAPMatrix(D.nrow(), D.ncol()), mQueue(nFactor * D.ncol(), alpha),
74
-mAnnealingTemp(0.f), mNumRows(nFactor), mNumCols(D.nrow())
53
+    : GibbsSampler(D, S, nFactor, D.ncol(), alpha)
75 54
 {
76
-    mBinSize = std::numeric_limits<uint64_t>::max() / (mNumRows * mNumCols);
77
-    uint64_t remain = std::numeric_limits<uint64_t>::max() % (mNumRows * mNumCols);
78
-    mQueue.setDomainSize(std::numeric_limits<uint64_t>::max() - remain);
79
-
80 55
     float meanD = gaps::algo::mean(mDMatrix);
81 56
     mLambda = alpha * std::sqrt(nFactor / meanD);
82 57
     mMaxGibbsMass = maxGibbsMass / mLambda;
... ...
@@ -84,50 +59,39 @@ mAnnealingTemp(0.f), mNumRows(nFactor), mNumCols(D.nrow())
84 59
 
85 60
 unsigned PatternGibbsSampler::getRow(uint64_t pos) const
86 61
 {
87
-    return (pos / mBinSize) % mNumRows;
62
+    unsigned row = (pos / mBinSize) % mNumRows;
63
+    GAPS_ASSERT(row >= 0 && row < mNumRows);
64
+    return row;
88 65
 }
89 66
 
90 67
 unsigned PatternGibbsSampler::getCol(uint64_t pos) const
91 68
 {
92
-    return pos / (mBinSize * mNumRows);
69
+    unsigned col = pos / (mBinSize * mNumRows);
70
+    GAPS_ASSERT(col >= 0 && col < mNumCols);
71
+    return col;
93 72
 }
94 73
 
95 74
 bool PatternGibbsSampler::canUseGibbs(unsigned row, unsigned col) const
96 75
 {
97
-    return !gaps::algo::isColZero(*mAMatrix, row);
76
+    return !gaps::algo::isColZero(*mOtherMatrix, row);
98 77
 }
99 78
 
100 79
 bool PatternGibbsSampler::canUseGibbs(unsigned r1, unsigned c1, unsigned r2, unsigned c2) const
101 80
 {
102
-    return !gaps::algo::isColZero(*mAMatrix, r1)
103
-        && !gaps::algo::isColZero(*mAMatrix, r2);
104
-}
105
-
106
-void PatternGibbsSampler::updateAPMatrix(unsigned row, unsigned col, float delta)
107
-{
108
-    for (unsigned i = 0; i < mAPMatrix.nRow(); ++i)
109
-    {
110
-        mAPMatrix(i,col) += delta * (*mAMatrix)(i,row);
111
-    }
81
+    return !gaps::algo::isColZero(*mOtherMatrix, r1)
82
+        && !gaps::algo::isColZero(*mOtherMatrix, r2);
112 83
 }
113 84
 
114 85
 void PatternGibbsSampler::sync(AmplitudeGibbsSampler &sampler)
115 86
 {
116
-    mAMatrix = &(sampler.mAMatrix);
87
+    mOtherMatrix = &sampler.mMatrix;
117 88
     mAPMatrix = sampler.mAPMatrix;
118 89
 }
119 90
 
120
-float PatternGibbsSampler::nAtoms() const
121
-{
122
-    return mDomain.size();
123
-}
124
-
125
-void PatternGibbsSampler::setAnnealingTemp(float temp)
126
-{
127
-    mAnnealingTemp = temp;
128
-}
129
-
130
-float PatternGibbsSampler::chi2() const
91
+void PatternGibbsSampler::updateAPMatrix(unsigned row, unsigned col, float delta)
131 92
 {
132
-    return 2.f * gaps::algo::loglikelihood(mDMatrix, mSMatrix, mAPMatrix);   
93
+    for (unsigned i = 0; i < mAPMatrix.nRow(); ++i)
94
+    {
95
+        mAPMatrix(i,col) += delta * (*mOtherMatrix)(i,row);
96
+    }
133 97
 }
134 98
\ No newline at end of file
... ...
@@ -1,8 +1,7 @@
1
-#ifndef __GAPS_GIBBS_SAMPLER_H__
2
-#define __GAPS_GIBBS_SAMPLER_H__
1
+#ifndef __COGAPS_GIBBS_SAMPLER_H__
2
+#define __COGAPS_GIBBS_SAMPLER_H__
3 3
 
4 4
 #include "GapsAssert.h"
5
-#include "GapsStatistics.h"
6 5
 #include "Archive.h"
7 6
 #include "Matrix.h"
8 7
 #include "Random.h"
... ...
@@ -12,28 +11,30 @@
12 11
 
13 12
 #include <Rcpp.h>
14 13
 
15
-// stats should be friend class
16
-// CRTP not really needed here, free friend functions are clear enough,
17
-// it would reduce some repeated code - but only a few lines and probably
18
-// not worth the headache
19
-// maybe convert to CRTP once everything is working cleanly
20
-
14
+// forward declarations needed for friend classes
21 15
 class AmplitudeGibbsSampler;
22 16
 class PatternGibbsSampler;
23 17
 class GapsStatistics;
24 18
 
25
-class AmplitudeGibbsSampler
19
+/************************** GIBBS SAMPLER INTERFACE **************************/
20
+
21
+template <class T, class MatA, class MatB>
22
+class GibbsSampler
26 23
 {
27 24
 private:
28 25
 
29
-    friend PatternGibbsSampler;
30
-    friend GapsStatistics;
26
+    friend T; // prevent incorrect inheritance - only T can construct
27
+
28
+    GibbsSampler(const Rcpp::NumericMatrix &D, const Rcpp::NumericMatrix &S,
29
+        unsigned nrow, unsigned ncol, float alpha);
31 30
 
32
-    ColMatrix mAMatrix;
33
-    RowMatrix *mPMatrix;
34
-    RowMatrix mDMatrix;
35
-    RowMatrix mSMatrix;
36
-    RowMatrix mAPMatrix;
31
+protected:
32
+
33
+    MatA mMatrix;
34
+    MatB* mOtherMatrix;
35
+    MatB mDMatrix;
36
+    MatB mSMatrix;
37
+    MatB mAPMatrix;
37 38
 
38 39
     ProposalQueue mQueue;
39 40
     AtomicDomain mDomain;
... ...
@@ -44,7 +45,36 @@ private:
44 45
     
45 46
     unsigned mNumRows;
46 47
     unsigned mNumCols;
47
-    unsigned mBinSize;
48
+    uint64_t mBinSize;
49
+
50
+    T* impl();
51
+
52
+    void processProposal(const AtomicProposal &prop);
53
+    void birth(uint64_t pos);
54
+    void death(uint64_t pos, float mass);
55
+    void move(uint64_t src, float mass, uint64_t dest);
56
+    void exchange(uint64_t p1, float mass1, uint64_t p2, float mass2);
57
+    float gibbsMass(unsigned row, unsigned col);
58
+
59
+public:
60
+
61
+    void update(unsigned nSteps);
62
+    void setAnnealingTemp(float temp);
63
+    
64
+    float chi2() const;
65
+    float nAtoms() const;
66
+
67
+    // serialization
68
+    //friend Archive& operator<<(Archive &ar, GibbsSampler &sampler);
69
+    //friend Archive& operator>>(Archive &ar, GibbsSampler &sampler);
70
+};
71
+
72
+class AmplitudeGibbsSampler : public GibbsSampler<AmplitudeGibbsSampler, ColMatrix, RowMatrix>
73
+{
74
+private:
75
+
76
+    friend GibbsSampler;
77
+    friend PatternGibbsSampler;
48 78
 
49 79
     unsigned getRow(uint64_t pos) const;
50 80
     unsigned getCol(uint64_t pos) const;
... ...
@@ -59,49 +89,14 @@ public:
59 89
         float maxGibbsMass=0.f);
60 90
 
61 91
     void sync(PatternGibbsSampler &sampler);
62
-    float nAtoms() const;
63
-    void setAnnealingTemp(float temp);
64
-    float chi2() const;
65
-
66
-    template <class Sampler>
67
-    friend void update(Sampler&, unsigned);
68
-
69
-    template <class Sampler>
70
-    friend void processProposal(Sampler &sampler, const AtomicProposal &prop);
71
-
72
-    template <class Sampler>
73
-    friend void birth(Sampler &sampler, uint64_t pos);
74
-
75
-    template <class Sampler>
76
-    friend void death(Sampler &sampler, uint64_t pos, float mass);
77
-
78
-    template <class Sampler>
79
-    friend void exchange(Sampler &sampler, uint64_t p1, float mass1, uint64_t p2, float mass2);
80 92
 };
81 93
 
82
-class PatternGibbsSampler
94
+class PatternGibbsSampler : public GibbsSampler<PatternGibbsSampler, RowMatrix, ColMatrix>
83 95
 {
84 96
 private:
85 97
 
98
+    friend GibbsSampler;
86 99
     friend AmplitudeGibbsSampler;
87
-    friend GapsStatistics;
88
-
89
-    RowMatrix mPMatrix;
90
-    ColMatrix *mAMatrix;
91
-    ColMatrix mDMatrix;
92
-    ColMatrix mSMatrix;
93
-    ColMatrix mAPMatrix;
94
-
95
-    ProposalQueue mQueue;
96
-    AtomicDomain mDomain;
97
-
98
-    float mLambda;
99
-    float mMaxGibbsMass;
100
-    float mAnnealingTemp;
101
-    
102
-    unsigned mNumRows;
103
-    unsigned mNumCols;
104
-    unsigned mBinSize;
105 100
 
106 101
     unsigned getRow(uint64_t pos) const;
107 102
     unsigned getCol(uint64_t pos) const;
... ...
@@ -116,128 +111,196 @@ public:
116 111
         float maxGibbsMass=0.f);
117 112
 
118 113
     void sync(AmplitudeGibbsSampler &sampler);
119
-    float nAtoms() const;
120
-    void setAnnealingTemp(float temp);
121
-    float chi2() const;
122
-
123
-    template <class Sampler>
124
-    friend void update(Sampler&, unsigned);
125
-
126
-    template <class Sampler>
127
-    friend void processProposal(Sampler &sampler, const AtomicProposal &prop);
114
+};
128 115
 
129
-    template <class Sampler>
130
-    friend void birth(Sampler &sampler, uint64_t pos);
116
+/******************* IMPLEMENTATION OF TEMPLATED FUNCTIONS *******************/
131 117
 
132
-    template <class Sampler>
133
-    friend void death(Sampler &sampler, uint64_t pos, float mass);
118
+template <class T, class MatA, class MatB>
119
+GibbsSampler<T, MatA, MatB>::GibbsSampler(const Rcpp::NumericMatrix &D,
120
+const Rcpp::NumericMatrix &S, unsigned nrow, unsigned ncol, float alpha)
121
+    :
122
+mMatrix(nrow, ncol), mOtherMatrix(NULL), mDMatrix(D), mSMatrix(S),
123
+mAPMatrix(D.nrow(), D.ncol()), mQueue(nrow * ncol, alpha),
124
+mAnnealingTemp(0.f), mNumRows(nrow), mNumCols(ncol)
125
+{
126
+    mBinSize = std::numeric_limits<uint64_t>::max() / static_cast<uint64_t>(mNumRows * mNumCols);
127
+    uint64_t remain = std::numeric_limits<uint64_t>::max() % (mNumRows * mNumCols);
128
+    mQueue.setDomainSize(std::numeric_limits<uint64_t>::max() - remain);
129
+}
134 130
 
135
-    template <class Sampler>
136
-    friend void exchange(Sampler &sampler, uint64_t p1, float mass1, uint64_t p2, float mass2);
137
-};
131
+template <class T, class MatA, class MatB>
132
+T* GibbsSampler<T, MatA, MatB>::impl()
133
+{
134
+    return static_cast<T*>(this);
135
+}
138 136
 
139
-template <class Sampler>
140
-void update(Sampler &sampler, unsigned nSteps)
137
+template <class T, class MatA, class MatB>
138
+void GibbsSampler<T, MatA, MatB>::update(unsigned nSteps)
141 139
 {
142 140
     unsigned n = 0;
143 141
     while (n < nSteps)
144 142
     {
145 143
         /*
146
-        assert(nSteps - (queue.size() + n) >= 0);
144
+        GAPS_ASSERT(nSteps - (queue.size() + n) >= 0);
147 145
         mQueue.populate(mDomain, nSteps - (mQueue.size() + n))
148 146
 
149 147
         // would making this a mulitple of nCores be better?
150 148
         unsigned nJobs = mQueue.size();
151 149
         for (unsigned i = 0; i < nJobs; ++i) // can be run in parallel
152 150
         {
153
-            processProposal(mDomain, mQueue[i]);
151
+            processProposal(mQueue[i]);
154 152
         }
155 153
         mQueue.clear();
156 154
         n += nJobs;
157
-        assert(n <= nSteps);
155
+        GAPS_ASSERT(n <= nSteps);
158 156
         */
159
-        sampler.mQueue.populate(sampler.mDomain, 1);
160
-        GAPS_ASSERT(sampler.mQueue.size() == 1);
161
-        processProposal(sampler, sampler.mQueue[0]);
162
-        sampler.mQueue.clear();
157
+        mQueue.populate(mDomain, 1);
158
+        GAPS_ASSERT(mQueue.size() == 1);
159
+        processProposal(mQueue[0]);
160
+        mQueue.clear(1);
161
+        GAPS_ASSERT(mQueue.size() == 0);
163 162
         n++;
164 163
     }
165 164
 }
166 165
 
167
-template <class Sampler>
168
-void processProposal(Sampler &sampler, const AtomicProposal &prop)
166
+template <class T, class MatA, class MatB>
167
+void GibbsSampler<T, MatA, MatB>::processProposal(const AtomicProposal &prop)
169 168
 {
170 169
     GAPS_ASSERT(prop.type == 'B' || prop.type == 'D' || prop.type == 'M' || prop.type == 'E');
171 170
     switch (prop.type)
172 171
     {
173
-        case 'B': birth(sampler, prop.pos1); break;
174
-        case 'D': death(sampler, prop.pos1, prop.mass1); break;
175
-        //case 'M': move(prop.pos1, prop.mass1, prop.pos2); break;
176
-        case 'E': exchange(sampler, prop.pos1, prop.mass1, prop.pos2, prop.mass2); break;
172
+        case 'B': birth(prop.pos1); break;
173
+        case 'D': death(prop.pos1, prop.mass1); break;
174
+        case 'M': move(prop.pos1, prop.mass1, prop.pos2); break;
175
+        case 'E': exchange(prop.pos1, prop.mass1, prop.pos2, prop.mass2); break;
177 176
     }
178 177
 }
179 178
 
180
-template <class Sampler>
181
-void birth(Sampler &sampler, uint64_t pos)
179
+template <class T, class MatA, class MatB>
180
+void GibbsSampler<T, MatA, MatB>::birth(uint64_t pos)
182 181
 {
183
-    unsigned row = sampler.getRow(pos);
184
-    unsigned col = sampler.getCol(pos);
185
-    float mass = gaps::random::exponential(sampler.mLambda);
186
-    sampler.mDomain.insert(pos, mass);
182
+    unsigned row = impl()->getRow(pos);
183
+    unsigned col = impl()->getCol(pos);
184
+    float mass = impl()->canUseGibbs(row, col) ? gibbsMass(row, col)
185
+        : gaps::random::exponential(mLambda);
186
+
187
+    mDomain.insert(pos, mass);
188
+    mMatrix(row, col) += mass;
189
+    //impl()->updateAPMatrix(row, col, mass);
187 190
 }
188 191
 
189
-template <class Sampler>
190
-void death(Sampler &sampler, uint64_t pos, float mass)
192
+template <class T, class MatA, class MatB>
193
+void GibbsSampler<T, MatA, MatB>::death(uint64_t pos, float mass)
191 194
 {
192
-    sampler.mQueue.rejectDeath();
195
+    mQueue.rejectDeath();
193 196
 }
194 197
 
195
-template <class Sampler>
196
-void exchange(Sampler &sampler, uint64_t p1, float mass1, uint64_t p2, float mass2)
198
+template <class T, class MatA, class MatB>
199
+void GibbsSampler<T, MatA, MatB>::move(uint64_t src, float mass, uint64_t dest)
197 200
 {
198
-    sampler.mQueue.rejectDeath();
201
+    unsigned r1 = impl()->getRow(src);
202
+    unsigned c1 = impl()->getCol(src);
203
+    unsigned r2 = impl()->getRow(dest);
204
+    unsigned c2 = impl()->getCol(dest);
205
+    if (r1 == r2 && c1 == c2)
206
+    {
207
+        mDomain.erase(src);
208
+        mDomain.insert(dest, mass);
209
+    }
210
+    else
211
+    {
212
+/*
213
+        if (deltaLL * mAnnealingTemp >= std::log(gaps::random::uniform()))
214
+        {
215
+            mDomain.deleteAtom(p1);
216
+            mDomain.addAtom(p2, mass);
217
+            mMatrix(r1, c1) += -mass;
218
+            mMatrix(r2, c2) += mass;
219
+            impl()->updateAPMatrix(r1, c1, -mass);
220
+            impl()->updateAPMatrix(r2, c2, mass);
221
+        }
222
+*/
223
+    }
199 224
 }
200 225
 
201
-//  template <class T, class MatA, class MatB>
202
-//  void GibbsSampler<T, MatA, MatB>::processProposal(const AtomicProposal &prop)
203
-//  {
204
-//      GAPS_ASSERT(prop.type == 'B' || prop.type == 'D' || prop.type == 'M' || prop.type == 'E');
205
-//      switch (prop.type)
206
-//      {
207
-//          case 'B': birth(prop.pos1); break;
208
-//          //case 'D': death(prop.pos1, prop.mass1); break;
209
-//          //case 'M': move(prop.pos1, prop.mass1, prop.pos2); break;
210
-//          //case 'E': exchange(prop.pos1, prop.mass1, prop.pos2, prop.mass2); break;
211
-//      }
212
-//  }
213
-//  
214
-//  template <class T, class MatA, class MatB>
215
-//  void GibbsSampler<T, MatA, MatB>::birth(uint64_t pos)
216
-//  {
217
-//      //GAPS_ASSERT(impl());
218
-//      //unsigned row = impl()->getRow(pos);
219
-//      //unsigned col = impl()->getCol(pos);
220
-//      //float mass = impl()->canUseGibbs(row, col) ? gibbsMass(row, col)
221
-//      //    : gaps::random::exponential(mLambda);
222
-//      /*float mass = gaps::random::exponential(mLambda);
223
-//  
224
-//      mDomain.insert(pos, mass);
225
-//      mMatrix(row, col) += mass;
226
-//      impl()->updateAPMatrix(row, col, mass);*/
227
-//  }
228
-//  
229
-//  template <class T, class MatA, class MatB>
230
-//  void GibbsSampler<T, MatA, MatB>::death(uint64_t pos, float mass)
231
-//  {
232
-//      /*unsigned row = impl()->getRow(pos);
233
-//      unsigned col = impl()->getCol(pos);
234
-//      mMatrix(row, col) += -mass;
235
-//      impl()->updateAPMatrix(row, col, -mass);
236
-//      mDomain.erase(pos);
237
-//      mQueue.acceptDeath();
238
-//  
239
-//      float newMass = impl()->canUseGibbs(row, col) ? gibbsMass(row, col)
240
-//      : mass;
226
+template <class T, class MatA, class MatB>
227
+void GibbsSampler<T, MatA, MatB>::exchange(uint64_t p1, float mass1, uint64_t p2, float mass2)
228
+{
229
+    mQueue.rejectDeath();
230
+}
231
+
232
+template <class T, class MatA, class MatB>
233
+float GibbsSampler<T, MatA, MatB>::gibbsMass(unsigned row, unsigned col)
234
+{        
235
+    //AlphaParameters alpha = impl()->alphaParameters(row, col);
236
+    AlphaParameters alpha(10.f, 10.f);
237
+    alpha.s *= mAnnealingTemp / 2.f;
238
+    alpha.su *= mAnnealingTemp / 2.f;
239
+    float mean  = (2.f * alpha.su - mLambda) / (2.f * alpha.s);
240
+    float sd = 1.f / std::sqrt(2.f * alpha.s);
241
+
242
+    float plower = gaps::random::p_norm(0.f, mean, sd);
243
+
244
+    //float newMass = death ? mass : 0.f;
245
+    float newMass = 0.f;
246
+    if (plower < 1.f && alpha.s > 0.00001f)
247
+    {
248
+        newMass = gaps::random::inverseNormSample(plower, 1.f, mean, sd);
249
+    }
250
+    return std::max(0.f, std::min(newMass, mMaxGibbsMass));
251
+}
252
+
253
+template <class T, class MatA, class MatB>
254
+void GibbsSampler<T, MatA, MatB>::setAnnealingTemp(float temp)
255
+{
256
+    mAnnealingTemp = temp;
257
+}
258
+  
259
+template <class T, class MatA, class MatB>
260
+float GibbsSampler<T, MatA, MatB>::chi2() const
261
+{   
262
+    return 2.f * gaps::algo::loglikelihood(mDMatrix, mSMatrix, mAPMatrix);
263
+}
264
+  
265
+template <class T, class MatA, class MatB>
266
+float GibbsSampler<T, MatA, MatB>::nAtoms() const
267
+{   
268
+    return mDomain.size();
269
+}
270
+
271
+
272
+//
273
+
274
+//
275
+
276
+//
277
+////  template <class T, class MatA, class MatB>
278
+////  void GibbsSampler<T, MatA, MatB>::processProposal(const AtomicProposal &prop)
279
+////  {
280
+////      GAPS_ASSERT(prop.type == 'B' || prop.type == 'D' || prop.type == 'M' || prop.type == 'E');
281
+////      switch (prop.type)
282
+////      {
283
+////          case 'B': birth(prop.pos1); break;
284
+////          //case 'D': death(prop.pos1, prop.mass1); break;
285
+////          //case 'M': move(prop.pos1, prop.mass1, prop.pos2); break;
286
+////          //case 'E': exchange(prop.pos1, prop.mass1, prop.pos2, prop.mass2); break;
287
+////      }
288
+////  }
289
+////  
290
+//
291
+////  
292
+////  template <class T, class MatA, class MatB>
293
+////  void GibbsSampler<T, MatA, MatB>::death(uint64_t pos, float mass)
294
+////  {
295
+////      /*unsigned row = impl()->getRow(pos);
296
+////      unsigned col = impl()->getCol(pos);
297
+////      mMatrix(row, col) += -mass;
298
+////      impl()->updateAPMatrix(row, col, -mass);
299
+////      mDomain.erase(pos);
300
+////      mQueue.acceptDeath();
301
+////  
302
+////      float newMass = impl()->canUseGibbs(row, col) ? gibbsMass(row, col)
303
+////      : mass;
241 304
 //      float deltaLL = impl()->computeDeltaLL(row, col, newMass);
242 305
 //  
243 306
 //      if (deltaLL * mAnnealingTemp >= std::log(gaps::random::uniform()))
... ...
@@ -327,37 +390,5 @@ void exchange(Sampler &sampler, uint64_t p1, float mass1, uint64_t p2, float mas
327 390
 //      return std::max(0.f, std::min(newMax, mMaxGibbsMass));
328 391
 //  */
329 392
 //  }
330
-//  
331
-//  template <class T, class MatA, class MatB>
332
-//  void GibbsSampler<T, MatA, MatB>::syncAP(const MatA &otherAP)
333
-//  {   
334
-//      mAPMatrix = otherAP;
335
-//  }
336
-//  
337
-//  template <class T, class MatA, class MatB>
338
-//  const MatB& GibbsSampler<T, MatA, MatB>::APMatrix() const
339
-//  {   
340
-//      return mAPMatrix;
341
-//  }
342
-//  
343
-//  template <class T, class MatA, class MatB>
344
-//  void GibbsSampler<T, MatA, MatB>::setAnnealingTemp(float temp)
345
-//  {
346
-//      mAnnealingTemp = temp;
347
-//  }
348
-//  
349
-//  template <class Sampler>
350
-//  float chi2(const Sampler &sampler)
351
-//  {   
352
-//      //return 2.f * gaps::algo::loglikelihood(mDMatrix, mSMatrix, mAPMatrix);
353
-//      return 0.f;
354
-//  }
355
-//  
356
-//  template <class T, class MatA, class MatB>
357
-//  float GibbsSampler<T, MatA, MatB>::nAtoms() const
358
-//  {   
359
-//      //return 2.f * gaps::algo::loglikelihood(mDMatrix, mSMatrix, mAPMatrix);
360
-//      return 0.f;
361
-//  }
362 393
 
363 394
 #endif
364 395
\ No newline at end of file
365 396
deleted file mode 100644
... ...
@@ -1,114 +0,0 @@
1
-#ifndef __COGAPS_INTERNAL_STATE_H__
2
-#define __COGAPS_INTERNAL_STATE_H__
3
-
4
-#include "Archive.h"
5
-#include "Matrix.h"
6
-#include "GibbsSampler.h"
7
-
8
-#include <Rcpp.h>
9
-
10
-typedef std::vector<Rcpp::NumericMatrix> SnapshotList;
11
-
12
-enum GapsPhase
13
-{
14
-    GAPS_BURN,
15
-    GAPS_COOL,
16
-    GAPS_SAMP
17
-};
18
-
19
-// maybe cleaner if this was a class with member functions? GapsRunner
20
-struct GapsInternalState
21
-{
22
-    Vector chi2VecEquil;
23
-    Vector nAtomsAEquil;
24
-    Vector nAtomsPEquil;
25
-
26
-    Vector chi2VecSample;
27
-    Vector nAtomsASample;
28
-    Vector nAtomsPSample;
29
-
30
-    unsigned nIterA;
31
-    unsigned nIterP;
32
-    
33
-    unsigned nEquil;
34
-    unsigned nEquilCool;
35
-    unsigned nSample;
36
-    unsigned nFactor;
37
-
38
-    unsigned nSnapshots;
39
-    unsigned nOutputs;
40
-    bool messages;
41
-
42
-    unsigned iter;
43
-    GapsPhase phase;
44
-    uint32_t seed;
45
-
46
-    long checkpointInterval;
47
-
48
-    unsigned nUpdatesA;
49
-    unsigned nUpdatesP;
50
-    
51
-    unsigned nPumpSamples;
52
-
53
-    AmplitudeGibbsSampler ASampler;
54
-    PatternGibbsSampler PSampler;
55
-    
56
-    SnapshotList snapshotsA;
57
-    SnapshotList snapshotsP;
58
-
59
-    GapsInternalState(const Rcpp::NumericMatrix &D,
60
-        const Rcpp::NumericMatrix &S, unsigned nF, unsigned nE, unsigned nEC,
61
-        unsigned nS, unsigned nOut, unsigned nSnap, float alphaA, float alphaP,
62
-        float maxGibbsMassA, float maxGibbsMassP, int sd, bool msgs,
63
-        bool singleCellRNASeq, char whichMatrixFixed,
64
-        const Rcpp::NumericMatrix &FP, unsigned cptInterval)
65
-        //PumpThreshold pumpThreshold, unsigned numPumpSamples)
66
-            :
67
-        chi2VecEquil(nE), nAtomsAEquil(nE), nAtomsPEquil(nE),
68
-        chi2VecSample(nS), nAtomsASample(nS), nAtomsPSample(nS),
69
-        nIterA(10), nIterP(10), nEquil(nE), nEquilCool(nEC), nSample(nS),
70
-        nSnapshots(nSnap), nOutputs(nOut), messages(msgs), iter(0),
71
-        phase(GAPS_BURN), seed(sd), checkpointInterval(cptInterval),
72
-        nUpdatesA(0), nUpdatesP(0), //nPumpSamples(numPumpSamples),
73
-        ASampler(D, S, nF, alphaA, maxGibbsMassA),
74
-        PSampler(D, S, nF, alphaP, maxGibbsMassP)
75
-    {}
76
-
77
-    GapsInternalState(const Rcpp::NumericMatrix &D,
78
-        const Rcpp::NumericMatrix &S, unsigned nF, unsigned nE, unsigned nS)
79
-            :
80
-        chi2VecEquil(nE), nAtomsAEquil(nE), nAtomsPEquil(nE),
81
-        chi2VecSample(nS), nAtomsASample(nS), nAtomsPSample(nS),
82
-        ASampler(D, S, nF), PSampler(D, S, nF)
83
-    {}
84
-};
85
-
86
-inline Archive& operator<<(Archive &ar, GapsInternalState &state)
87
-{
88
-/*
89
-    ar << state.chi2VecEquil << state.nAtomsAEquil << state.nAtomsPEquil
90
-        << state.chi2VecSample << state.nAtomsASample << state.nAtomsPSample
91
-        << state.nIterA << state.nIterP << state.nEquil << state.nEquilCool
92
-        << state.nSample << state.nSnapshots << state.nOutputs << state.messages
93
-        << state.iter << state.phase << state.seed << state.checkpointInterval
94
-        << state.nUpdatesA << state.nUpdatesP << state.nPumpSamples;
95
-        //<< state.sampler;
96
-*/
97
-    return ar;
98
-}
99
-
100
-inline Archive& operator>>(Archive &ar, GapsInternalState &state)
101
-{
102
-/*  
103
-    ar >> state.chi2VecEquil >> state.nAtomsAEquil >> state.nAtomsPEquil
104
-        >> state.chi2VecSample >> state.nAtomsASample >> state.nAtomsPSample
105
-        >> state.nIterA >> state.nIterP >> state.nEquil >> state.nEquilCool
106
-        >> state.nSample >> state.nSnapshots >> state.nOutputs >> state.messages
107
-        >> state.iter >> state.phase >> state.seed >> state.checkpointInterval
108
-        >> state.nUpdatesA >> state.nUpdatesP >> state.nPumpSamples;
109
-        //>> state.sampler;
110
-*/
111
-    return ar;
112
-}
113
-
114
-#endif
115 0
\ No newline at end of file
... ...
@@ -2,6 +2,8 @@ PKG_CPPFLAGS = -DGAPS_DEBUG -DBOOST_MATH_PROMOTE_DOUBLE_POLICY=0 -DSIMD -msse4.1
2 2
 OBJECTS =   Algorithms.o \
3 3
             AtomicDomain.o \
4 4
             Cogaps.o \
5
+            GapsRunner.o \
6
+            GapsStatistics.o \
5 7
             GibbsSampler.o \
6 8
             Matrix.o \
7 9
             ProposalQueue.o \
... ...
@@ -81,6 +81,8 @@ private:
81 81
     std::vector<Vector> mRows;
82 82
     unsigned mNumRows, mNumCols;
83 83
 
84
+    RowMatrix() {}
85
+
84 86
 public:
85 87
 
86 88
     RowMatrix(unsigned nrow, unsigned ncol);
... ...
@@ -117,6 +119,8 @@ private:
117 119
     std::vector<Vector> mCols;
118 120
     unsigned mNumRows, mNumCols;
119 121
 
122
+    ColMatrix() {}
123
+
120 124
 public:
121 125
 
122 126
     ColMatrix(unsigned nrow, unsigned ncol);
... ...
@@ -7,10 +7,8 @@ const double atomicSize = static_cast<double>(atomicEnd);
7 7
 
8 8
 void ProposalQueue::populate(const AtomicDomain &domain, unsigned limit)
9 9
 {
10
-    //unsigned nIter = 0;
11
-    //while (nIter++ < limit && makeProposal(domain));
12
-    GAPS_ASSERT(makeProposal(domain));
13
-    //Rprintf("%c\n", mQueue[0].type);
10
+    unsigned nIter = 0;
11
+    while (nIter++ < limit && makeProposal(domain));
14 12
 }
15 13
 
16 14
 void ProposalQueue::setNumBins(unsigned nBins)
... ...
@@ -33,14 +31,12 @@ void ProposalQueue::setAlpha(float alpha)
33 31
     mAlpha = alpha;
34 32
 }
35 33
 
36
-//void ProposalQueue::clear(unsigned n)
37
-void ProposalQueue::clear()
34
+void ProposalQueue::clear(unsigned n)
38 35
 {
39
-    //mQueue.erase(mQueue.end() - n, mQueue.end());
40
-    //assert(mMaxAtoms - mMinAtoms <= mQueue.size());
41
-    mQueue.clear();
42
-    mUsedIndices.clear();
43
-    mUsedPositions.clear();
36
+    mQueue.erase(mQueue.end() - n, mQueue.end());
37
+    mUsedIndices.erase(mUsedIndices.end() - n, mUsedIndices.end());
38
+    mUsedPositions.erase(mUsedPositions.end() - n, mUsedPositions.end());
39
+    GAPS_ASSERT(mMaxAtoms - mMinAtoms <= mQueue.size());
44 40
 }
45 41
 
46 42
 unsigned ProposalQueue::size() const
... ...
@@ -63,7 +63,7 @@ public:
63 63
 
64 64
     // modify/access queue
65 65
     void populate(const AtomicDomain &domain, unsigned limit);
66
-    void clear();
66
+    void clear(unsigned n);
67 67
     unsigned size() const;
68 68
     const AtomicProposal& operator[](int n) const;
69 69
 
... ...
@@ -22,6 +22,7 @@ typedef boost::random::mt19937 RNGType;
22 22
 //typedef boost::random::mt11213b RNGType; // should be faster
23 23
 
24 24
 static RNGType rng;
25
+static boost::random::uniform_01<RNGType&> u01_dist(rng);
25 26
 
26 27
 void gaps::random::save(Archive &ar)
27 28
 {
... ...
@@ -56,13 +57,12 @@ float gaps::random::exponential(float lambda)
56 57
     return dist(rng);
57 58
 }
58 59
 
60
+// open interval
59 61
 float gaps::random::uniform()
60 62
 {
61
-    boost::random::uniform_01<RNGType&> dist(rng); // could be moved out
62
-    return dist();
63
+    return u01_dist();
63 64
 }
64 65
 
65
-// open interval
66 66
 float gaps::random::uniform(float a, float b)
67 67
 {
68 68
     if (a == b)
... ...
@@ -72,12 +72,7 @@ float gaps::random::uniform(float a, float b)
72 72
     else
73 73
     {
74 74
         boost::random::uniform_real_distribution<> dist(a,b);
75
-        float result = dist(rng);
76
-        while (result == b)
77
-        {
78
-            result = dist(rng);
79
-        }
80
-        return result;
75
+        return dist(rng);
81 76
     }
82 77
 }
83 78
 
... ...
@@ -143,3 +138,23 @@ float gaps::random::p_norm(float p, float mean, float sd)
143 138
     boost::math::normal_distribution<> norm(mean, sd);
144 139
     return cdf(norm, p);
145 140
 }
141
+
142
+float gaps::random::inverseNormSample(float a, float b, float mean, float sd)
143
+{
144
+    float u = gaps::random::uniform(a, b);
145
+    while (u == 0.f || u == 1.f)
146
+    {
147
+        u = gaps::random::uniform(a, b);
148
+    }
149
+    return gaps::random::q_norm(u, mean, sd);
150
+}
151
+
152
+float gaps::random::inverseGammaSample(float a, float b, float mean, float sd)
153
+{
154
+    float u = gaps::random::uniform(a, b);
155
+    while (u == 0.f || u == 1.f)
156
+    {
157
+        u = gaps::random::uniform(a, b);
158
+    }
159
+    return gaps::random::q_gamma(u, mean, sd);
160
+}
146 161
\ No newline at end of file
... ...
@@ -10,7 +10,6 @@
10 10
 namespace gaps
11 11
 {
12 12
 
13
-// rename to math for dist functions
14 13
 namespace random
15 14
 {
16 15
     void setSeed(uint32_t seed);
... ...
@@ -31,10 +30,13 @@ namespace random
31 30
     float q_norm(float q, float mean, float sd);
32 31
     float p_norm(float p, float mean, float sd);
33 32
 
33
+    float inverseNormSample(float a, float b, float mean, float sd);
34
+    float inverseGammaSample(float a, float b, float mean, float sd);
35
+
34 36
     void save(Archive &ar);
35 37
     void load(Archive &ar);
36 38
 }
37 39
 
38 40
 }
39 41
 
40
-#endif
42
+#endif
41 43
\ No newline at end of file
... ...
@@ -37,16 +37,19 @@ BEGIN_RCPP
37 37
 END_RCPP
38 38
 }
39 39
 // cogapsFromCheckpoint_cpp
40
-Rcpp::List cogapsFromCheckpoint_cpp(const Rcpp::NumericMatrix& D, const Rcpp::NumericMatrix& S, const std::string& fileName, const std::string& cptFile);
41
-RcppExport SEXP _CoGAPS_cogapsFromCheckpoint_cpp(SEXP DSEXP, SEXP SSEXP, SEXP fileNameSEXP, SEXP cptFileSEXP) {
40
+Rcpp::List cogapsFromCheckpoint_cpp(const Rcpp::NumericMatrix& D, const Rcpp::NumericMatrix& S, unsigned nFactor, unsigned nEquil, unsigned nSample, const std::string& fileName, const std::string& cptFile);
41
+RcppExport SEXP _CoGAPS_cogapsFromCheckpoint_cpp(SEXP DSEXP, SEXP SSEXP, SEXP nFactorSEXP, SEXP nEquilSEXP, SEXP nSampleSEXP, SEXP fileNameSEXP, SEXP cptFileSEXP) {
42 42
 BEGIN_RCPP
43 43
     Rcpp::RObject rcpp_result_gen;
44 44
     Rcpp::RNGScope rcpp_rngScope_gen;
45 45
     Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type D(DSEXP);
46 46
     Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type S(SSEXP);
47
+    Rcpp::traits::input_parameter< unsigned >::type nFactor(nFactorSEXP);
48
+    Rcpp::traits::input_parameter< unsigned >::type nEquil(nEquilSEXP);
49
+    Rcpp::traits::input_parameter< unsigned >::type nSample(nSampleSEXP);
47 50
     Rcpp::traits::input_parameter< const std::string& >::type fileName(fileNameSEXP);
48 51
     Rcpp::traits::input_parameter< const std::string& >::type cptFile(cptFileSEXP);
49
-    rcpp_result_gen = Rcpp::wrap(cogapsFromCheckpoint_cpp(D, S, fileName, cptFile));
52
+    rcpp_result_gen = Rcpp::wrap(cogapsFromCheckpoint_cpp(D, S, nFactor, nEquil, nSample, fileName, cptFile));
50 53
     return rcpp_result_gen;
51 54
 END_RCPP
52 55
 }
... ...
@@ -72,7 +75,7 @@ END_RCPP
72 75
 
73 76
 static const R_CallMethodDef CallEntries[] = {
74 77
     {"_CoGAPS_cogaps_cpp", (DL_FUNC) &_CoGAPS_cogaps_cpp, 21},
75
-    {"_CoGAPS_cogapsFromCheckpoint_cpp", (DL_FUNC) &_CoGAPS_cogapsFromCheckpoint_cpp, 4},
78
+    {"_CoGAPS_cogapsFromCheckpoint_cpp", (DL_FUNC) &_CoGAPS_cogapsFromCheckpoint_cpp, 7},
76 79
     {"_CoGAPS_displayBuildReport_cpp", (DL_FUNC) &_CoGAPS_displayBuildReport_cpp, 0},
77 80
     {"_CoGAPS_run_catch_unit_tests", (DL_FUNC) &_CoGAPS_run_catch_unit_tests, 0},
78 81
     {NULL, NULL, 0}