Browse code

use free functions instead of class for gaps runner

Tom Sherman authored on 23/10/2018 21:37:09
Showing 49 changed files

... ...
@@ -8,9 +8,7 @@
8 8
 // this file contains the blueprint for creating a wrapper around the C++
9 9
 // interface used for running CoGAPS. It exposes some functions to R, has a
10 10
 // method for converting the R parameters to the standard GapsParameters
11
-// struct, and creates a GapsRunner object. The GapsRunner class manages
12
-// all information from a CoGAPS run and is used to set off the run
13
-// and get return data.
11
+// and calls gaps::run
14 12
 
15 13
 ////////////////// functions for converting matrix types ///////////////////////
16 14
 
... ...
@@ -53,26 +51,20 @@ const Rcpp::Nullable<Rcpp::IntegerVector> &indices)
53 51
 {
54 52
     // check if subsetting data
55 53
     const Rcpp::S4 &gapsParams(allParams["gaps"]);
56
-    bool subsetData = false;
57
-    bool printThreadUsage = true;
58 54
     bool subsetGenes = false;
59
-    char whichFixedMatrix = 'N';
60 55
     std::vector<unsigned> subset;
61 56
     if (indices.isNotNull())
62 57
     {
63
-        subsetData = true;
64
-        printThreadUsage = false;
65 58
         std::string d(Rcpp::as<std::string>(gapsParams.slot("distributed")));
66 59
         subsetGenes = (d == "genome-wide");
67
-        whichFixedMatrix = (d == "genome-wide") ? 'P' : 'A';
68 60
         subset = Rcpp::as< std::vector<unsigned> >(Rcpp::IntegerVector(indices));
69 61
     }
70 62
 
71 63
     // create standard CoGAPS parameters struct
72
-    GapsParameters params(data, allParams["transposeData"], subsetData,
64
+    GapsParameters params(data, allParams["transposeData"], indices.isNotNull(),
73 65
         subsetGenes, subset);
74
-    params.printThreadUsage = printThreadUsage;
75
-    params.whichFixedMatrix = whichFixedMatrix;
66
+    params.printThreadUsage = !indices.isNotNull();
67
+    params.whichFixedMatrix = indices.isNotNull() ? (subsetGenes ? 'P' : 'A') : 'N';
76 68
 
77 69
     // get configuration parameters
78 70
     params.maxThreads = allParams["nThreads"];
... ...
@@ -104,13 +96,12 @@ const Rcpp::Nullable<Rcpp::IntegerVector> &indices)
104 96
     {
105 97
         params.checkpointFile = Rcpp::as<std::string>(allParams["checkpointInFile"]);
106 98
         params.useCheckPoint = true;
107
-        params.peekCheckpoint(params.checkpointFile);
108 99
     }
109 100
 
110 101
     return params;
111 102
 }
112 103
 
113
-////////// main function that creates a GapsRunner and runs CoGAPS /////////////
104
+////////////////////// main function that runs CoGAPS //////////////////////////
114 105
 
115 106
 // note uncertainty matrix gets special treatment since it's the same size as
116 107
 // the data (potentially large), so we want to avoid copying it into the 
... ...
@@ -122,23 +113,12 @@ const DataType &uncertainty, const Rcpp::Nullable<Rcpp::IntegerVector> &indices,
122 113
 const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix, bool isMaster)
123 114
 {
124 115
     // convert R parameters to GapsParameters struct
125
-    GapsParameters gapsParams(getGapsParameters(data, allParams, isMaster,
116
+    GapsParameters params(getGapsParameters(data, allParams, isMaster,
126 117
         fixedMatrix, indices));
127 118
 
128 119
     // create GapsRunner, note we must first initialize the random generator
129
-    GapsRng::setSeed(gapsParams.seed);
130
-    gaps_printf("Loading Data...");
131
-    GapsRunner runner(data, gapsParams);
132
-
133
-    // set uncertainty
134
-    if (!uncertainty.empty())
135
-    {
136
-        runner.setUncertainty(uncertainty, gapsParams);
137
-    }
138
-    gaps_printf("Done!\n");
139
-    
140
-    // run cogaps
141
-    GapsResult result(runner.run());
120
+    GapsRandomState randState(params.seed);
121
+    GapsResult result(gaps::run(data, params, uncertainty, &randState));
142 122
 
143 123
     // write result to file if requested
144 124
     if (allParams["outputToFile"] != R_NilValue)
... ...
@@ -152,7 +132,7 @@ const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix, bool isMaster)
152 132
         Rcpp::Named("Pmean") = createRMatrix(result.Pmean),
153 133
         Rcpp::Named("Asd") = createRMatrix(result.Asd),
154 134
         Rcpp::Named("Psd") = createRMatrix(result.Psd),
155
-        Rcpp::Named("seed") = gapsParams.seed,
135
+        Rcpp::Named("seed") = params.seed,
156 136
         Rcpp::Named("meanChiSq") = result.meanChiSq,
157 137
         Rcpp::Named("geneNames") = allParams["geneNames"],
158 138
         Rcpp::Named("sampleNames") = allParams["sampleNames"],
... ...
@@ -1,9 +1,19 @@
1 1
 #include "GapsParameters.h"
2 2
 
3
-void GapsParameters::peekCheckpoint(const std::string &file)
3
+Archive& operator<<(Archive &ar, const GapsParameters &p)
4 4
 {
5
-    Archive ar(file, ARCHIVE_READ);
6
-    ar >> nPatterns >> seed >> nIterations >> whichFixedMatrix;
5
+    ar << p.seed << p.nGenes << p.nSamples << p.nPatterns << p.nIterations
6
+        << p.alphaA << p.alphaP << p.maxGibbsMassA << p.maxGibbsMassP
7
+        << p.singleCell << p.useSparseOptimization;
8
+    return ar;
9
+}
10
+
11
+Archive& operator>>(Archive &ar, GapsParameters &p)
12
+{
13
+    ar >> p.seed >> p.nGenes >> p.nSamples >> p.nPatterns >> p.nIterations
14
+        >> p.alphaA >> p.alphaP >> p.maxGibbsMassA >> p.maxGibbsMassP
15
+        >> p.singleCell >> p.useSparseOptimization;
16
+    return ar;
7 17
 }
8 18
     
9 19
 void GapsParameters::calculateDataDimensions(const std::string &file)
... ...
@@ -16,8 +16,6 @@ public:
16 16
         bool t_subsetData=false, bool t_subsetGenes=false,
17 17
         const std::vector<unsigned> t_dataIndicesSubset=std::vector<unsigned>());
18 18
 
19
-    void peekCheckpoint(const std::string &file);
20
-
21 19
     Matrix fixedMatrix;
22 20
 
23 21
     std::vector<unsigned> dataIndicesSubset;
... ...
@@ -58,6 +56,9 @@ private:
58 56
     void calculateDataDimensions(const Matrix &mat);
59 57
 };
60 58
 
59
+Archive& operator<<(Archive &ar, const GapsParameters &p);
60
+Archive& operator>>(Archive &ar, GapsParameters &p);
61
+
61 62
 template <class DataType>
62 63
 GapsParameters::GapsParameters(const DataType &data, bool t_transposeData,
63 64
 bool t_subsetData, bool t_subsetGenes,
... ...
@@ -1,5 +1,7 @@
1 1
 #include "GapsRunner.h"
2 2
 
3
+#include "utils/Archive.h"
4
+
3 5
 #ifdef __GAPS_R_BUILD__
4 6
 #include <Rcpp.h>
5 7
 #endif
... ...
@@ -8,117 +10,61 @@
8 10
 #include <omp.h>
9 11
 #endif
10 12
 
11
-///////////////////////////// RAII wrapper /////////////////////////////////////
13
+// boost time helpers
14
+#include <boost/date_time/posix_time/posix_time.hpp>
15
+namespace bpt = boost::posix_time;
16
+#define bpt_now() bpt::microsec_clock::local_time()
12 17
 
13
-GapsRunner::~GapsRunner()
14
-{
15
-    delete mRunner;
16
-}
18
+// forward declaration
19
+template <class Sampler, class DataType>
20
+static GapsResult runCoGAPSAlgorithm(const DataType &data, GapsParameters &params,
21
+    const DataType &uncertainty, GapsRandomState *randState);
17 22
 
18
-GapsResult GapsRunner::run()
19
-{
20
-    return mRunner->run();
21
-}
23
+////////////////////////////////////////////////////////////////////////////////
22 24
 
23
-///////////////////////// Abstract Interface ///////////////////////////////////
24
-
25
-AbstractGapsRunner::AbstractGapsRunner(const GapsParameters &params)
26
-    :
27
-mStatistics(params.nGenes, params.nSamples, params.nPatterns),
28
-mCheckpointOutFile(params.checkpointOutFile),
29
-mCurrentIteration(0),
30
-mMaxIterations(params.nIterations),
31
-mMaxThreads(params.maxThreads),
32
-mOutputFrequency(params.outputFrequency),
33
-mCheckpointInterval(params.checkpointInterval),
34
-mNumPatterns(params.nPatterns),
35
-mNumUpdatesA(0),
36
-mNumUpdatesP(0),
37
-mSeed(params.seed),
38
-mPrintMessages(params.printMessages),
39
-mPrintThreadUsage(params.printThreadUsage),
40
-mPhase('C'),
41
-mFixedMatrix(params.whichFixedMatrix)
42
-{}
43
-
44
-GapsResult AbstractGapsRunner::run()
25
+// helper function, this dispatches the correct run function depending
26
+// on the type of GibbsSampler needed for the given parameters
27
+template <class DataType>
28
+static GapsResult run_helper(const DataType &data, GapsParameters &params,
29
+const DataType &uncertainty, GapsRandomState *randState)
45 30
 {
46
-    GAPS_ASSERT(mPhase == 'C' || mPhase == 'S');
47
-
48
-    mStartTime = bpt_now();
49
-
50
-    // check if running in debug mode
51
-    #ifdef GAPS_DEBUG
52
-    gaps_printf("Running in debug mode\n");
53
-    #endif
54
-
55
-    // calculate appropiate number of threads if compiled with openmp
56
-    #ifdef __GAPS_OPENMP__
57
-    if (mPrintMessages && mPrintThreadUsage)
31
+    // fetch parameters from checkpoint - some are used in initialization
32
+    if (params.useCheckPoint)
58 33
     {
59
-        unsigned availableThreads = omp_get_max_threads();
60
-        mMaxThreads = gaps::min(availableThreads, mMaxThreads);
61
-        gaps_printf("Running on %d out of %d available threads\n",
62
-            mMaxThreads, availableThreads);
34
+        Archive ar(params.checkpointFile, ARCHIVE_READ);
35
+        ar >> params;
36
+        ar >> *randState;
63 37
     }
64
-    #endif
65 38
 
66
-    // cascade through phases, allows algorithm to be resumed in either phase
67
-    switch (mPhase)
39
+    if (params.useSparseOptimization)
68 40
     {
69
-        case 'C':
70
-            if (mPrintMessages)
71
-            {
72
-                gaps_printf("-- Calibration Phase --\n");
73
-            }
74
-            runOnePhase();
75
-            mPhase = 'S';
76
-            mCurrentIteration = 0;
77
-
78
-        case 'S':
79
-            if (mPrintMessages)
80
-            {
81
-                gaps_printf("-- Sampling Phase --\n");
82
-            }
83
-            runOnePhase();
84
-            break;
41
+        return runCoGAPSAlgorithm<SparseGibbsSampler>(data, params,
42
+            uncertainty, randState);
43
+    }
44
+    else
45
+    {
46
+        return runCoGAPSAlgorithm<DenseGibbsSampler>(data, params,
47
+            uncertainty, randState);
85 48
     }
86
-    GapsResult result(mStatistics);
87
-    result.meanChiSq = meanChiSq();
88
-    return result;    
89 49
 }
90 50
 
91
-void AbstractGapsRunner::runOnePhase()
92
-{
93
-    for (; mCurrentIteration < mMaxIterations; ++mCurrentIteration)
94
-    {
95
-        createCheckpoint();
51
+// these two functions are the top-level functions exposed to the C++
52
+// code that is being wrapped by any given language
96 53
 
97
-        #ifdef __GAPS_R_BUILD__
98
-        Rcpp::checkUserInterrupt();
99
-        #endif
100
-        
101
-        // set annealing temperature in calibration phase
102
-        if (mPhase == 'C')
103
-        {        
104
-            float temp = static_cast<float>(2 * mCurrentIteration)
105
-                / static_cast<float>(mMaxIterations);
106
-            setAnnealingTemp(gaps::min(1.f, temp));
107
-        }
108
-    
109
-        // number of updates per iteration is poisson 
110
-        unsigned nA = mRng.poisson(gaps::max(nAtoms('A'), 10));
111
-        unsigned nP = mRng.poisson(gaps::max(nAtoms('P'), 10));
112
-        updateSampler(nA, nP);
54
+GapsResult gaps::run(const Matrix &data, GapsParameters &params,
55
+const Matrix &uncertainty, GapsRandomState *randState)
56
+{
57
+    return run_helper(data, params, uncertainty, randState);
58
+}
113 59
 
114
-        if (mPhase == 'S')
115
-        {
116
-            updateStatistics();
117
-        }
118
-        displayStatus();
119
-    }
60
+GapsResult gaps::run(const std::string &data, GapsParameters &params,
61
+const std::string &uncertainty, GapsRandomState *randState)
62
+{
63
+    return run_helper(data, params, uncertainty, randState);
120 64
 }
121 65
 
66
+////////////////////////////////////////////////////////////////////////////////
67
+
122 68
 // sum coef * log(i) for i = 1 to total, fit coef from number of atoms
123 69
 // approximates sum of number of atoms (stirling approx to factorial)
124 70
 // this should be proportional to total number of updates
... ...
@@ -129,19 +75,21 @@ static double estimatedNumUpdates(double current, double total, float nAtoms)
129 75
         total * coef * std::log(total) - total * coef;
130 76
 }
131 77
 
132
-
133
-double AbstractGapsRunner::estimatedPercentComplete() const
78
+template <class Sampler>
79
+static double estimatedPercentComplete(const GapsParameters &params,
80
+const Sampler &ASampler, const Sampler &PSampler, bpt::ptime startTime,
81
+char phase, unsigned iter)
134 82
 {
135
-    double nIter = static_cast<double>(mCurrentIteration);
136
-    double nAtomsA = static_cast<double>(nAtoms('A'));
137
-    double nAtomsP = static_cast<double>(nAtoms('P'));
83
+    double nIter = static_cast<double>(iter);
84
+    double nAtomsA = static_cast<double>(ASampler.nAtoms());
85
+    double nAtomsP = static_cast<double>(PSampler.nAtoms());
138 86
     
139
-    if (mPhase == 'S')
87
+    if (phase == 'S')
140 88
     {
141
-        nIter += mMaxIterations;
89
+        nIter += params.nIterations;
142 90
     }
143 91
 
144
-    double totalIter = 2.0 * static_cast<double>(mMaxIterations);
92
+    double totalIter = 2.0 * static_cast<double>(params.nIterations);
145 93
 
146 94
     double estimatedCompleted = estimatedNumUpdates(nIter, nIter, nAtomsA) + 
147 95
         estimatedNumUpdates(nIter, nIter, nAtomsP);
... ...
@@ -152,13 +100,19 @@ double AbstractGapsRunner::estimatedPercentComplete() const
152 100
     return estimatedCompleted / estimatedTotal;
153 101
 }
154 102
 
155
-void AbstractGapsRunner::displayStatus()
103
+template <class Sampler>
104
+static void displayStatus(const GapsParameters &params,
105
+const Sampler &ASampler, const Sampler &PSampler, bpt::ptime startTime,
106
+char phase, unsigned iter)
156 107
 {
157
-    if (mPrintMessages && mOutputFrequency > 0 && ((mCurrentIteration + 1) % mOutputFrequency) == 0)
108
+    if (params.printMessages && params.outputFrequency > 0
109
+    && ((iter + 1) % params.outputFrequency) == 0)
158 110
     {
159
-        bpt::time_duration diff = bpt_now() - mStartTime;
111
+        bpt::time_duration diff = bpt_now() - startTime;
112
+        double perComplete = estimatedPercentComplete(params, ASampler,
113
+            PSampler, startTime, phase, iter);
160 114
         double nSecondsCurrent = diff.total_seconds();
161
-        double nSecondsTotal = nSecondsCurrent / estimatedPercentComplete();
115
+        double nSecondsTotal = nSecondsCurrent / perComplete;
162 116
 
163 117
         unsigned elapsedSeconds = static_cast<unsigned>(nSecondsCurrent);
164 118
         unsigned totalSeconds = static_cast<unsigned>(nSecondsTotal);
... ...
@@ -174,180 +128,201 @@ void AbstractGapsRunner::displayStatus()
174 128
         totalSeconds -= totalMinutes * 60;
175 129
 
176 130
         gaps_printf("%d of %d, Atoms: %lu(%lu), ChiSq: %.0f, Time: %02d:%02d:%02d / %02d:%02d:%02d\n",
177
-            mCurrentIteration + 1, mMaxIterations, nAtoms('A'),
178
-            nAtoms('P'), chiSq(), elapsedHours, elapsedMinutes,
131
+            iter + 1, params.nIterations, ASampler.nAtoms(),
132
+            PSampler.nAtoms(), PSampler.chiSq(), elapsedHours, elapsedMinutes,
179 133
             elapsedSeconds, totalHours, totalMinutes, totalSeconds);
180 134
         gaps_flush();
181 135
     }
182 136
 }
183 137
 
184
-void AbstractGapsRunner::createCheckpoint()
138
+template <class Sampler>
139
+static void updateSampler(const GapsParameters &params, Sampler &ASampler,
140
+Sampler &PSampler, unsigned nA, unsigned nP)
185 141
 {
186
-    if (mCheckpointInterval > 0 && ((mCurrentIteration + 1) % mCheckpointInterval) == 0)
142
+    if (params.whichFixedMatrix != 'A')
187 143
     {
188
-        // create backup file
189
-        std::rename(mCheckpointOutFile.c_str(), (mCheckpointOutFile + ".backup").c_str());
190
-    
191
-        // create checkpoint file
192
-        Archive ar(mCheckpointOutFile, ARCHIVE_WRITE);
193
-        ar << mNumPatterns << mSeed << mMaxIterations << mFixedMatrix << mPhase
194
-            << mCurrentIteration << mNumUpdatesA << mNumUpdatesP << mRng;
195
-        writeSamplers(ar);
196
-        GapsRng::save(ar);
197
-
198
-        // delete backup file
199
-        std::remove((mCheckpointOutFile + ".backup").c_str());
200
-    }
201
-}
202
-
203
-///////////////////// DenseGapsRunner Implementation ///////////////////////////
204
-
205
-float DenseGapsRunner::chiSq() const
206
-{
207
-    // doesn't matter which sampler is called
208
-    return mPSampler.chiSq();
209
-}
210
-
211
-float DenseGapsRunner::meanChiSq() const
212
-{
213
-    // need to pass P sampler (due to configuration of internal data)
214
-    return mStatistics.meanChiSq(mPSampler);
215
-}
216
-
217
-unsigned DenseGapsRunner::nAtoms(char which) const
218
-{
219
-    return which == 'A' ? mASampler.nAtoms() : mPSampler.nAtoms();
220
-}
221
-
222
-void DenseGapsRunner::setAnnealingTemp(float temp)
223
-{
224
-    mASampler.setAnnealingTemp(temp);
225
-    mPSampler.setAnnealingTemp(temp);
226
-}
227
-
228
-void DenseGapsRunner::updateStatistics()
229
-{
230
-    mStatistics.update(mASampler, mPSampler);
231
-}
232
-
233
-Archive& DenseGapsRunner::readSamplers(Archive &ar)
234
-{
235
-    ar >> mASampler >> mPSampler;
236
-    return ar;
237
-}
238
-
239
-Archive& DenseGapsRunner::writeSamplers(Archive &ar)
240
-{
241
-    ar << mASampler << mPSampler;
242
-    return ar;
243
-}
244
-
245
-void DenseGapsRunner::updateSampler(unsigned nA, unsigned nP)
246
-{
247
-    if (mFixedMatrix != 'A')
248
-    {
249
-        mNumUpdatesA += nA;
250
-        mASampler.update(nA, mMaxThreads);
251
-        if (mFixedMatrix != 'P')
144
+        ASampler.update(nA, params.maxThreads);
145
+        if (params.whichFixedMatrix != 'P')
252 146
         {
253
-            mPSampler.sync(mASampler, mMaxThreads);
147
+            PSampler.sync(ASampler, params.maxThreads);
254 148
         }
255 149
     }
256 150
 
257
-    if (mFixedMatrix != 'P')
151
+    if (params.whichFixedMatrix != 'P')
258 152
     {
259
-        mNumUpdatesP += nP;
260
-        mPSampler.update(nP, mMaxThreads);
261
-        if (mFixedMatrix != 'A')
153
+        PSampler.update(nP, params.maxThreads);
154
+        if (params.whichFixedMatrix != 'A')
262 155
         {
263
-            mASampler.sync(mPSampler, mMaxThreads);
156
+            ASampler.sync(PSampler, params.maxThreads);
264 157
         }
265 158
     }
266 159
 }
267 160
 
268
-void DenseGapsRunner::setUncertainty(const Matrix &unc, const GapsParameters &params)
161
+template <class Sampler>
162
+static void createCheckpoint(const GapsParameters &params,
163
+const Sampler &ASampler, const Sampler &PSampler, const GapsRandomState *randState,
164
+const GapsRng &rng, char phase, unsigned iter)
269 165
 {
270
-    mASampler.setUncertainty(unc, !params.transposeData, !params.subsetGenes, params);
271
-    mPSampler.setUncertainty(unc, params.transposeData, params.subsetGenes, params);
166
+    if (params.checkpointInterval > 0
167
+    && ((iter + 1) % params.checkpointInterval) == 0)
168
+    {
169
+        // create backup file
170
+        std::rename(params.checkpointOutFile.c_str(),
171
+            (params.checkpointOutFile + ".backup").c_str());
172
+    
173
+        // create checkpoint file
174
+        Archive ar(params.checkpointOutFile, ARCHIVE_WRITE);
175
+        ar << params;
176
+        ar << *randState;
177
+        ar << ASampler << PSampler << phase << iter << rng;
178
+        
179
+        // delete backup file
180
+        std::remove((params.checkpointOutFile + ".backup").c_str());
181
+    }
272 182
 }
273 183
 
274
-void DenseGapsRunner::setUncertainty(const std::string &unc, const GapsParameters &params)
184
+template <class Sampler>
185
+static void runOnePhase(const GapsParameters &params, Sampler &ASampler,
186
+Sampler &PSampler, GapsStatistics &stats, const GapsRandomState *randState,
187
+GapsRng &rng, bpt::ptime startTime, char phase, unsigned &currentIter)
275 188
 {
276
-    mASampler.setUncertainty(unc, !params.transposeData, !params.subsetGenes, params);
277
-    mPSampler.setUncertainty(unc, params.transposeData, params.subsetGenes, params);
278
-}
189
+    for (; currentIter < params.nIterations; ++currentIter)
190
+    {
191
+        #ifdef __GAPS_R_BUILD__
192
+        Rcpp::checkUserInterrupt();
193
+        #endif
279 194
 
280
-///////////////////// SparseGapsRunner Implementation //////////////////////////
195
+        createCheckpoint(params, ASampler, PSampler, randState, rng, phase,
196
+            currentIter);
197
+        
198
+        // set annealing temperature in calibration phase
199
+        if (phase == 'C')
200
+        {        
201
+            float temp = static_cast<float>(2 * currentIter)
202
+                / static_cast<float>(params.nIterations);
203
+            ASampler.setAnnealingTemp(gaps::min(1.f, temp));
204
+            PSampler.setAnnealingTemp(gaps::min(1.f, temp));
205
+        }
206
+    
207
+        // number of updates per iteration is poisson 
208
+        unsigned nA = rng.poisson(gaps::max(ASampler.nAtoms(), 10));
209
+        unsigned nP = rng.poisson(gaps::max(PSampler.nAtoms(), 10));
210
+        updateSampler(params, ASampler, PSampler, nA, nP);
281 211
 
282
-float SparseGapsRunner::chiSq() const
283
-{
284
-    // doesn't matter which sampler is called
285
-    return mPSampler.chiSq();
212
+        if (phase == 'S')
213
+        {
214
+            stats.update(ASampler, PSampler);
215
+        }
216
+        displayStatus(params, ASampler, PSampler, startTime, phase, currentIter);
217
+    }
286 218
 }
287 219
 
288
-float SparseGapsRunner::meanChiSq() const
220
+// here is the CoGAPS algorithm
221
+template <class Sampler, class DataType>
222
+static GapsResult runCoGAPSAlgorithm(const DataType &data, GapsParameters &params,
223
+const DataType &uncertainty, GapsRandomState *randState)
289 224
 {
290
-    // need to pass P sampler (due to configuration of internal data)
291
-    return mStatistics.meanChiSq(mPSampler);
292
-}
225
+    // check if running in debug mode
226
+    #ifdef GAPS_DEBUG
227
+    gaps_printf("Running in debug mode\n");
228
+    #endif
293 229
 
294
-unsigned SparseGapsRunner::nAtoms(char which) const
295
-{
296
-    return which == 'A' ? mASampler.nAtoms() : mPSampler.nAtoms();
297
-}
230
+    // load data into gibbs samplers
231
+    // we transpose the data in the A sampler so that the update step
232
+    // is symmetrical for each sampler, this simplifies the code 
233
+    // within the sampler, note the subsetting genes/samples flag must be
234
+    // flipped if we are flipping the transpose flag
235
+    gaps_printf("Loading Data...");
236
+    Sampler ASampler(data, !params.transposeData, !params.subsetGenes,
237
+        params.alphaA, params.maxGibbsMassA, params, randState);
238
+    Sampler PSampler(data, params.transposeData, params.subsetGenes,
239
+        params.alphaA, params.maxGibbsMassA, params, randState);
240
+
241
+    // read in the uncertainty matrix if one is provided
242
+    if (!uncertainty.empty())
243
+    {
244
+        ASampler.setUncertainty(uncertainty, !params.transposeData,
245
+            !params.subsetGenes, params);
246
+        PSampler.setUncertainty(uncertainty, params.transposeData,
247
+            params.subsetGenes, params);
248
+    }
249
+    gaps_printf("Done!\n");
298 250
 
299
-void SparseGapsRunner::setAnnealingTemp(float temp)
300
-{
301
-    mASampler.setAnnealingTemp(temp);
302
-    mPSampler.setAnnealingTemp(temp);
303
-}
251
+    // these variables will get overwritten by checkpoint if provided
252
+    GapsRng rng(randState);
253
+    char phase = 'C';
254
+    unsigned currentIter = 0;
304 255
 
305
-void SparseGapsRunner::updateStatistics()
306
-{
307
-    mStatistics.update(mASampler, mPSampler);
308
-}
256
+    // check if we're fixing a matrix
257
+    switch (params.whichFixedMatrix)
258
+    {
259
+        case 'A' : ASampler.setMatrix(params.fixedMatrix); break;
260
+        case 'P' : PSampler.setMatrix(params.fixedMatrix); break;
261
+        default: break; // 'N' for none
262
+    }
263
+     
264
+    // check if running from checkpoint, get all saved data
265
+    if (params.useCheckPoint)
266
+    {
267
+        Archive ar(params.checkpointFile, ARCHIVE_READ);
268
+        ar >> params;
269
+        ar >> *randState;
270
+        ar >> ASampler >> PSampler >> phase >> currentIter >> rng;
271
+    }
309 272
 
310
-Archive& SparseGapsRunner::readSamplers(Archive &ar)
311
-{
312
-    ar >> mASampler >> mPSampler;
313
-    return ar;
314
-}
273
+    // sync samplers, second parameter indicates this should be a full sync
274
+    ASampler.sync(PSampler);
275
+    PSampler.sync(ASampler);
315 276
 
316
-Archive& SparseGapsRunner::writeSamplers(Archive &ar)
317
-{
318
-    ar << mASampler << mPSampler;
319
-    return ar;
320
-}
277
+    // sampler may need to run additional initialization after parameters set
278
+    ASampler.extraInitialization();
279
+    PSampler.extraInitialization();
321 280
 
322
-void SparseGapsRunner::updateSampler(unsigned nA, unsigned nP)
323
-{
324
-    if (mFixedMatrix != 'A')
281
+    // calculate appropiate number of threads if compiled with openmp
282
+    #ifdef __GAPS_OPENMP__
283
+    if (params.printMessages && params.printThreadUsage)
325 284
     {
326
-        mNumUpdatesA += nA;
327
-        mASampler.update(nA, mMaxThreads);
328
-        if (mFixedMatrix != 'P')
329
-        {
330
-            mPSampler.sync(mASampler, mMaxThreads);
331
-        }
285
+        unsigned availableThreads = omp_get_max_threads();
286
+        params.maxThreads = gaps::min(availableThreads, params.maxThreads);
287
+        gaps_printf("Running on %d out of %d available threads\n",
288
+            params.maxThreads, availableThreads);
332 289
     }
290
+    #endif
291
+
292
+    // record start time
293
+    bpt::ptime startTime = bpt_now();
333 294
 
334
-    if (mFixedMatrix != 'P')
295
+    // cascade through phases, allows algorithm to be resumed in either phase
296
+    GapsStatistics stats(params.nGenes, params.nSamples, params.nPatterns);
297
+    GAPS_ASSERT(phase == 'C' || phase == 'S');
298
+    switch (phase)
335 299
     {
336
-        mNumUpdatesP += nP;
337
-        mPSampler.update(nP, mMaxThreads);
338
-        if (mFixedMatrix != 'A')
339
-        {
340
-            mASampler.sync(mPSampler, mMaxThreads);
341
-        }
342
-    }
343
-}
300
+        case 'C':
301
+            if (params.printMessages)
302
+            {
303
+                gaps_printf("-- Calibration Phase --\n");
304
+            }
305
+            runOnePhase(params, ASampler, PSampler, stats, randState, rng,
306
+                startTime, phase, currentIter);
307
+            phase = 'S';
308
+            currentIter = 0;
344 309
 
345
-void SparseGapsRunner::setUncertainty(const Matrix &unc, const GapsParameters &params)
346
-{
347
-    // nothing happens - SparseGibbsSampler assumes default uncertainty always
348
-}
349 310
 
350
-void SparseGapsRunner::setUncertainty(const std::string &unc, const GapsParameters &params)
351
-{
352
-    // nothing happens - SparseGibbsSampler assumes default uncertainty always
311
+
312
+        case 'S':
313
+            if (params.printMessages)
314
+            {
315
+                gaps_printf("-- Sampling Phase --\n");
316
+            }
317
+            runOnePhase(params, ASampler, PSampler, stats, randState, rng,
318
+                startTime, phase, currentIter);
319
+
320
+        default: break;
321
+    }
322
+    
323
+    // get result
324
+    GapsResult result(stats);
325
+    result.meanChiSq = stats.meanChiSq(PSampler);
326
+    return result;
353 327
 }
328
+
... ...
@@ -1,265 +1,23 @@
1 1
 #ifndef __COGAPS_GAPS_RUNNER_H__
2 2
 #define __COGAPS_GAPS_RUNNER_H__
3 3
 
4
-#include "GapsParameters.h"
5 4
 #include "GapsResult.h"
6
-#include "GapsStatistics.h"
7
-#include "gibbs_sampler/GibbsSampler.h"
8
-#include "gibbs_sampler/DenseGibbsSampler.h"
9
-#include "gibbs_sampler/SparseGibbsSampler.h"
10
-
11
-#include <string>
12
-
13
-// boost time helpers
14
-#include <boost/date_time/posix_time/posix_time.hpp>
15
-namespace bpt = boost::posix_time;
16
-#define bpt_now() bpt::microsec_clock::local_time()
17
-
18
-// forward declarations
19
-class AbstractGapsRunner;
20
-
21
-///////////////////////////// RAII wrapper /////////////////////////////////////
22
-
23
-// This is the class that is exposed to the top-level CoGAPS routine - all 
24
-// aspects of CoGAPS can be managed through this class. The class itself is
25
-// just a lightweight wrapper around an abstract interface, which allows for
26
-// multiple types of GapsRunner to be declared. Which implementation is used
27
-// depends on the parameters passed to the GapsRunner constructor.
28
-class GapsRunner
29
-{
30
-public:
31
-
32
-    template <class DataType>
33
-    GapsRunner(const DataType &data, const GapsParameters &params);
34
-
35
-    ~GapsRunner();
36
-
37
-    template <class DataType>
38
-    void setUncertainty(const DataType &unc, const GapsParameters &params);
39
-
40
-    GapsResult run();
41
-
42
-private:
43
-
44
-    AbstractGapsRunner *mRunner;
45
-
46
-    GapsRunner(const GapsRunner &p); // don't allow copies
47
-    GapsRunner& operator=(const GapsRunner &p); // don't allow copies    
48
-};
49
-
50
-///////////////////////// Abstract Interface ///////////////////////////////////
51
-
52
-// This class is the abstract interface that any implementation of GapsRunner
53
-// must satisfy. It provides a factory method that will create the appropiate
54
-// derived class depending on the parameters passed in.
55
-class AbstractGapsRunner
56
-{
57
-public:
58
-
59
-    AbstractGapsRunner(const GapsParameters &params);
60
-    virtual ~AbstractGapsRunner() {}
61
-
62
-    template <class DataType>
63
-    static AbstractGapsRunner* create(const DataType &data, const GapsParameters &params);
64
-
65
-    // can't use template with virtual function
66
-    virtual void setUncertainty(const Matrix &unc, const GapsParameters &params) = 0;
67
-    virtual void setUncertainty(const std::string &unc, const GapsParameters &params) = 0;
68
-
69
-    GapsResult run();
70
-
71
-protected:
72
-
73
-    GapsStatistics mStatistics;
74
-
75
-    mutable GapsRng mRng;
76
-
77
-    std::string mCheckpointOutFile;
78
-
79
-    bpt::ptime mStartTime;
80
-
81
-    unsigned mCurrentIteration;
82
-    unsigned mMaxIterations;
83
-    unsigned mMaxThreads;
84
-    unsigned mOutputFrequency;
85
-    unsigned mCheckpointInterval;
86
-    unsigned mNumPatterns;
87
-    unsigned mNumUpdatesA;
88
-    unsigned mNumUpdatesP;
89
-    uint32_t mSeed;
90
-
91
-    bool mPrintMessages;
92
-    bool mPrintThreadUsage;
93
-
94
-    char mPhase;
95
-    char mFixedMatrix;
96
-        
97
-    void runOnePhase();
98
-    double estimatedPercentComplete() const;
99
-    void displayStatus();
100
-    void createCheckpoint();
101
-
102
-    virtual float chiSq() const = 0;
103
-    virtual float meanChiSq() const = 0;
104
-    virtual unsigned nAtoms(char which) const = 0;
105
-    virtual void setAnnealingTemp(float temp) = 0;
106
-    virtual void updateStatistics() = 0;
107
-    virtual Archive& readSamplers(Archive &ar) = 0;
108
-    virtual Archive& writeSamplers(Archive &ar) = 0;
109
-    virtual void updateSampler(unsigned nA, unsigned nP) = 0;
110
-};
111
-
112
-///////////////////// GapsRunner Implementations ///////////////////////////////
113
-
114
-// This implementation uses a DenseGibbsSampler internally
115
-class DenseGapsRunner : public AbstractGapsRunner
116
-{
117
-public:
118
-
119
-    ~DenseGapsRunner() {}
120
-
121
-    template <class DataType>
122
-    DenseGapsRunner(const DataType &data, const GapsParameters &params);
123
-
124
-    void setUncertainty(const Matrix &unc, const GapsParameters &params);
125
-    void setUncertainty(const std::string &unc, const GapsParameters &params);
126
-
127
-private:
128
-
129
-    DenseGibbsSampler mASampler;
130
-    DenseGibbsSampler mPSampler;
5
+#include "GapsParameters.h"
6
+#include "data_structures/Matrix.h"
7
+#include "math/Random.h"
131 8
 
132
-    float chiSq() const;
133
-    float meanChiSq() const;
134
-    unsigned nAtoms(char which) const;
135
-    void setAnnealingTemp(float temp);
136
-    void updateStatistics();
137
-    Archive& readSamplers(Archive &ar);
138
-    Archive& writeSamplers(Archive &ar);
139
-    void updateSampler(unsigned nA, unsigned nP);
140
-};
9
+// these two functions are the top-level functions exposed to the C++
10
+// code that is being wrapped by any given language
141 11
 
142
-// This implementation uses a SparseGibbsSampler internally
143
-class SparseGapsRunner : public AbstractGapsRunner
12
+namespace gaps
144 13
 {
145
-public:
146
-
147
-    ~SparseGapsRunner() {}
148
-
149
-    template <class DataType>
150
-    SparseGapsRunner(const DataType &data, const GapsParameters &params);
151
-
152
-    void setUncertainty(const Matrix &unc, const GapsParameters &params);
153
-    void setUncertainty(const std::string &unc, const GapsParameters &params);
14
+    // data stored in matrix
15
+    GapsResult run(const Matrix &data, GapsParameters &params,
16
+        const Matrix &uncertainty, GapsRandomState *randState);
154 17
 
155
-private:
156
-
157
-    SparseGibbsSampler mASampler;
158
-    SparseGibbsSampler mPSampler;
159
-
160
-    float chiSq() const;
161
-    float meanChiSq() const;
162
-    unsigned nAtoms(char which) const;
163
-    void setAnnealingTemp(float temp);
164
-    void updateStatistics();
165
-    Archive& readSamplers(Archive &ar);
166
-    Archive& writeSamplers(Archive &ar);
167
-    void updateSampler(unsigned nA, unsigned nP);
18
+    // data stored in file
19
+    GapsResult run(const std::string &data, GapsParameters &params,
20
+        const std::string &uncertainty, GapsRandomState *randState);
168 21
 };
169 22
 
170
-/////////////////////// GapsRunner - templated functions ///////////////////////
171
-
172
-template <class DataType>
173
-GapsRunner::GapsRunner(const DataType &data, const GapsParameters &params)
174
-    : mRunner(AbstractGapsRunner::create(data, params))
175
-{}
176
-
177
-template <class DataType>
178
-void GapsRunner::setUncertainty(const DataType &unc, const GapsParameters &params)
179
-{
180
-    mRunner->setUncertainty(unc, params);
181
-}
182
-
183
-/////////////////// AbstractGapsRunner - templated functions ///////////////////
184
-
185
-template <class DataType>
186
-AbstractGapsRunner* AbstractGapsRunner::create(const DataType &data,
187
-const GapsParameters &params)
188
-{
189
-    if (params.useSparseOptimization)
190
-    {
191
-        return new SparseGapsRunner(data, params);
192
-    }
193
-    return new DenseGapsRunner(data, params);
194
-}
195
-
196
-//////////////////// DenseGapsRunner - templated functions /////////////////////
197
-
198
-template <class DataType>
199
-DenseGapsRunner::DenseGapsRunner(const DataType &data,
200
-const GapsParameters &params)
201
-    :
202
-AbstractGapsRunner(params),
203
-mASampler(data, !params.transposeData, !params.subsetGenes, params.alphaA, params.maxGibbsMassA, params),
204
-mPSampler(data, params.transposeData, params.subsetGenes, params.alphaP, params.maxGibbsMassP, params)
205
-{
206
-    switch (mFixedMatrix)
207
-    {
208
-        case 'A' : mASampler.setMatrix(params.fixedMatrix); break;
209
-        case 'P' : mPSampler.setMatrix(params.fixedMatrix); break;
210
-        default: break; // 'N' for none
211
-    }
212
-
213
-    // overwrite with info from checkpoint file
214
-    if (params.useCheckPoint)
215
-    {
216
-        Archive ar(params.checkpointFile, ARCHIVE_READ);
217
-        ar >> mNumPatterns >> mSeed >> mMaxIterations >> mFixedMatrix >> mPhase
218
-            >> mCurrentIteration >> mNumUpdatesA >> mNumUpdatesP >> mRng;
219
-        readSamplers(ar);
220
-        GapsRng::load(ar);
221
-    }
222
-
223
-    mASampler.sync(mPSampler);
224
-    mPSampler.sync(mASampler);
225
-
226
-    // AP matrix not stored in checkpoint
227
-    if (params.useCheckPoint)
228
-    {
229
-        mASampler.recalculateAPMatrix();
230
-        mPSampler.recalculateAPMatrix();
231
-    }
232
-}
233
-
234
-//////////////////// SparseGapsRunner - templated functions ////////////////////
235
-
236
-template <class DataType>
237
-SparseGapsRunner::SparseGapsRunner(const DataType &data,
238
-const GapsParameters &params)
239
-    :
240
-AbstractGapsRunner(params),
241
-mASampler(data, !params.transposeData, !params.subsetGenes, params.alphaA, params.maxGibbsMassA, params),
242
-mPSampler(data, params.transposeData, params.subsetGenes, params.alphaP, params.maxGibbsMassP, params)
243
-{
244
-    switch (mFixedMatrix)
245
-    {
246
-        case 'A' : mASampler.setMatrix(params.fixedMatrix); break;
247
-        case 'P' : mPSampler.setMatrix(params.fixedMatrix); break;
248
-        default: break;
249
-    }
250
-
251
-    // overwrite with info from checkpoint file
252
-    if (params.useCheckPoint)
253
-    {
254
-        Archive ar(params.checkpointFile, ARCHIVE_READ);
255
-        ar >> mNumPatterns >> mSeed >> mMaxIterations >> mFixedMatrix >> mPhase
256
-            >> mCurrentIteration >> mNumUpdatesA >> mNumUpdatesP >> mRng;
257
-        readSamplers(ar);
258
-        GapsRng::load(ar);
259
-    }
260
-
261
-    mASampler.sync(mPSampler);
262
-    mPSampler.sync(mASampler);
263
-}
264
-
265
-#endif // __COGAPS_GAPS_RUNNER_H__
23
+#endif // __COGAPS_GAPS_RUNNER_H__
266 24
\ No newline at end of file
... ...
@@ -78,7 +78,7 @@ float GapsStatistics::meanChiSq(const SparseGibbsSampler &PSampler) const
78 78
     return 0.f; // TODO
79 79
 }
80 80
 
81
-Archive& operator<<(Archive &ar, GapsStatistics &stat)
81
+Archive& operator<<(Archive &ar, const GapsStatistics &stat)
82 82
 {
83 83
     ar << stat.mAMeanMatrix << stat.mAStdMatrix << stat.mPMeanMatrix
84 84
         << stat.mPStdMatrix << stat.mStatUpdates << stat.mNumPatterns;
... ...
@@ -38,7 +38,7 @@ public:
38 38
     float meanChiSq(const SparseGibbsSampler &PSampler) const;
39 39
 
40 40
     // serialization
41
-    friend Archive& operator<<(Archive &ar, GapsStatistics &stat);
41
+    friend Archive& operator<<(Archive &ar, const GapsStatistics &stat);
42 42
     friend Archive& operator>>(Archive &ar, GapsStatistics &stat);
43 43
 };
44 44
 
... ...
@@ -49,22 +49,29 @@ void GapsStatistics::update(const Sampler &ASampler, const Sampler &PSampler)
49 49
 
50 50
     // update     
51 51
     // precision loss? use double?
52
+    DEBUG_PING
53
+    GAPS_ASSERT(mNumPatterns == ASampler.mMatrix.nCol());
54
+    GAPS_ASSERT(mNumPatterns == PSampler.mMatrix.nCol());
55
+
52 56
     for (unsigned j = 0; j < mNumPatterns; ++j)
53 57
     {
54 58
         float norm = gaps::max(PSampler.mMatrix.getCol(j));
55 59
         norm = norm == 0.f ? 1.f : norm;
56 60
         GAPS_ASSERT(norm > 0.f);
57 61
 
62
+        DEBUG_PING
58 63
         Vector quot(PSampler.mMatrix.getCol(j) / norm);
59 64
         GAPS_ASSERT(gaps::min(quot) >= 0.f);
60 65
         mPMeanMatrix.getCol(j) += quot;
61 66
         mPStdMatrix.getCol(j) += gaps::elementSq(quot);
62 67
 
68
+        DEBUG_PING
63 69
         Vector prod(ASampler.mMatrix.getCol(j) * norm);
64 70
         GAPS_ASSERT(gaps::min(prod) >= 0.f);
65 71
         mAMeanMatrix.getCol(j) += prod;
66 72
         mAStdMatrix.getCol(j) += gaps::elementSq(prod);
67 73
     }
74
+    DEBUG_PING
68 75
 }
69 76
 
70 77
 #endif
71 78
\ No newline at end of file
72 79
new file mode 100644
... ...
@@ -0,0 +1,157 @@
1
+#include "GapsRunner.h"
2
+
3
+#ifdef __GAPS_R_BUILD__
4
+#include <Rcpp.h>
5
+#endif
6
+
7
+#ifdef __GAPS_OPENMP__
8
+#include <omp.h>
9
+#endif
10
+
11
+///////////////////////////// RAII wrapper /////////////////////////////////////
12
+
13
+GapsRunner::~GapsRunner()
14
+{
15
+    delete mRunner;
16
+}
17
+
18
+GapsResult GapsRunner::run()
19
+{
20
+    return mRunner->run();
21
+}
22
+
23
+///////////////////////// Abstract Interface ///////////////////////////////////
24
+
25
+AbstractGapsRunner::AbstractGapsRunner(const GapsParameters &params,
26
+GapsRandomState *randState)
27
+    :
28
+mStatistics(params.nGenes, params.nSamples, params.nPatterns),
29
+mRandState(randState),
30
+mRng(randState),
31
+mCheckpointOutFile(params.checkpointOutFile),
32
+mCurrentIteration(0),
33
+mMaxIterations(params.nIterations),
34
+mMaxThreads(params.maxThreads),
35
+mOutputFrequency(params.outputFrequency),
36
+mCheckpointInterval(params.checkpointInterval),
37
+mNumPatterns(params.nPatterns),
38
+mNumUpdatesA(0),
39
+mNumUpdatesP(0),
40
+mSeed(params.seed),
41
+mPrintMessages(params.printMessages),
42
+mPrintThreadUsage(params.printThreadUsage),
43
+mPhase('C'),
44
+mFixedMatrix(params.whichFixedMatrix)
45
+{}
46
+
47
+///////////////////// DenseGapsRunner Implementation ///////////////////////////
48
+
49
+void DenseGapsRunner::updateSampler(unsigned nA, unsigned nP)
50
+{
51
+    if (mFixedMatrix != 'A')
52
+    {
53
+        mNumUpdatesA += nA;
54
+        mASampler.update(nA, mMaxThreads);
55
+        if (mFixedMatrix != 'P')
56
+        {
57
+            mPSampler.sync(mASampler, mMaxThreads);
58
+        }
59
+    }
60
+
61
+    if (mFixedMatrix != 'P')
62
+    {
63
+        mNumUpdatesP += nP;
64
+        mPSampler.update(nP, mMaxThreads);
65
+        if (mFixedMatrix != 'A')
66
+        {
67
+            mASampler.sync(mPSampler, mMaxThreads);
68
+        }
69
+    }
70
+}
71
+
72
+void DenseGapsRunner::setUncertainty(const Matrix &unc, const GapsParameters &params)
73
+{
74
+    mASampler.setUncertainty(unc, !params.transposeData, !params.subsetGenes, params);
75
+    mPSampler.setUncertainty(unc, params.transposeData, params.subsetGenes, params);
76
+}
77
+
78
+void DenseGapsRunner::setUncertainty(const std::string &unc, const GapsParameters &params)
79
+{
80
+    mASampler.setUncertainty(unc, !params.transposeData, !params.subsetGenes, params);
81
+    mPSampler.setUncertainty(unc, params.transposeData, params.subsetGenes, params);
82
+}
83
+
84
+///////////////////// SparseGapsRunner Implementation //////////////////////////
85
+
86
+float SparseGapsRunner::chiSq() const
87
+{
88
+    // doesn't matter which sampler is called
89
+    return mPSampler.chiSq();
90
+}
91
+
92
+float SparseGapsRunner::meanChiSq() const
93
+{
94
+    // need to pass P sampler (due to configuration of internal data)
95
+    return mStatistics.meanChiSq(mPSampler);
96
+}
97
+
98
+unsigned SparseGapsRunner::nAtoms(char which) const
99
+{
100
+    return which == 'A' ? mASampler.nAtoms() : mPSampler.nAtoms();
101
+}
102
+
103
+void SparseGapsRunner::setAnnealingTemp(float temp)
104
+{
105
+    mASampler.setAnnealingTemp(temp);
106
+    mPSampler.setAnnealingTemp(temp);
107
+}
108
+
109
+void SparseGapsRunner::updateStatistics()
110
+{
111
+    mStatistics.update(mASampler, mPSampler);
112
+}
113
+
114
+Archive& SparseGapsRunner::readSamplers(Archive &ar)
115
+{
116
+    ar >> mASampler >> mPSampler;
117
+    return ar;
118
+}
119
+
120
+Archive& SparseGapsRunner::writeSamplers(Archive &ar)
121
+{
122
+    ar << mASampler << mPSampler;
123
+    return ar;
124
+}
125
+
126
+void SparseGapsRunner::updateSampler(unsigned nA, unsigned nP)
127
+{
128
+    if (mFixedMatrix != 'A')
129
+    {
130
+        mNumUpdatesA += nA;
131
+        mASampler.update(nA, mMaxThreads);
132
+        if (mFixedMatrix != 'P')
133
+        {
134
+            mPSampler.sync(mASampler, mMaxThreads);
135
+        }
136
+    }
137
+
138
+    if (mFixedMatrix != 'P')
139
+    {
140
+        mNumUpdatesP += nP;
141
+        mPSampler.update(nP, mMaxThreads);
142
+        if (mFixedMatrix != 'A')
143
+        {
144
+            mASampler.sync(mPSampler, mMaxThreads);
145
+        }
146
+    }
147
+}
148
+
149
+void SparseGapsRunner::setUncertainty(const Matrix &unc, const GapsParameters &params)
150
+{
151
+    // nothing happens - SparseGibbsSampler assumes default uncertainty always
152
+}
153
+
154
+void SparseGapsRunner::setUncertainty(const std::string &unc, const GapsParameters &params)
155
+{
156
+    // nothing happens - SparseGibbsSampler assumes default uncertainty always
157
+}
0 158
new file mode 100644
... ...
@@ -0,0 +1,275 @@
1
+#ifndef __COGAPS_GAPS_RUNNER_H__
2
+#define __COGAPS_GAPS_RUNNER_H__
3
+
4
+#include "GapsParameters.h"
5
+#include "GapsResult.h"
6
+#include "GapsStatistics.h"
7
+#include "gibbs_sampler/GibbsSampler.h"
8
+#include "gibbs_sampler/DenseGibbsSampler.h"
9
+#include "gibbs_sampler/SparseGibbsSampler.h"
10
+
11
+#include <string>
12
+
13
+// boost time helpers
14
+#include <boost/date_time/posix_time/posix_time.hpp>
15
+namespace bpt = boost::posix_time;
16
+#define bpt_now() bpt::microsec_clock::local_time()
17
+
18
+// forward declarations
19
+class AbstractGapsRunner;
20
+
21
+///////////////////////////// RAII wrapper /////////////////////////////////////
22
+
23
+// This is the class that is exposed to the top-level CoGAPS routine - all 
24
+// aspects of CoGAPS can be managed through this class. The class itself is
25
+// just a lightweight wrapper around an abstract interface, which allows for
26
+// multiple types of GapsRunner to be declared. Which implementation is used
27
+// depends on the parameters passed to the GapsRunner constructor.
28
+class GapsRunner
29
+{
30
+public:
31
+
32
+    template <class DataType>
33
+    GapsRunner(const DataType &data, const GapsParameters &params,
34
+        GapsRandomState *randState);
35
+
36
+    ~GapsRunner();
37
+
38
+    template <class DataType>
39
+    void setUncertainty(const DataType &unc, const GapsParameters &params);
40
+
41
+    GapsResult run();
42
+
43
+private:
44
+
45
+    AbstractGapsRunner *mRunner;
46
+
47
+    GapsRunner(const GapsRunner &p); // don't allow copies
48
+    GapsRunner& operator=(const GapsRunner &p); // don't allow copies    
49
+};
50
+
51
+///////////////////////// Abstract Interface ///////////////////////////////////
52
+
53
+// This class is the abstract interface that any implementation of GapsRunner
54
+// must satisfy. It provides a factory method that will create the appropiate
55
+// derived class depending on the parameters passed in.
56
+class AbstractGapsRunner
57
+{
58
+public:
59
+
60
+    AbstractGapsRunner(const GapsParameters &params, GapsRandomState *randState);
61
+    virtual ~AbstractGapsRunner() {}
62
+
63
+    template <class DataType>
64
+    static AbstractGapsRunner* create(const DataType &data,
65
+        const GapsParameters &params, GapsRandomState *randState);
66
+
67
+    // can't use template with virtual function
68
+    virtual void setUncertainty(const Matrix &unc, const GapsParameters &params) = 0;
69
+    virtual void setUncertainty(const std::string &unc, const GapsParameters &params) = 0;
70
+
71
+    GapsResult run();
72
+
73
+protected:
74
+
75
+    GapsStatistics mStatistics;
76
+
77
+    const GapsRandomState *mRandState; // used for writing state to checkpoint
78
+    mutable GapsRng mRng;
79
+
80
+    std::string mCheckpointOutFile;
81
+
82
+    bpt::ptime mStartTime;
83
+
84
+    unsigned mCurrentIteration;
85
+    unsigned mMaxIterations;
86
+    unsigned mMaxThreads;
87
+    unsigned mOutputFrequency;
88
+    unsigned mCheckpointInterval;
89
+    unsigned mNumPatterns;
90
+    unsigned mNumUpdatesA;
91
+    unsigned mNumUpdatesP;
92
+    uint32_t mSeed;
93
+
94
+    bool mPrintMessages;
95
+    bool mPrintThreadUsage;
96
+
97
+    char mPhase;
98
+    char mFixedMatrix;
99
+        
100
+    void runOnePhase();
101
+    double estimatedPercentComplete() const;
102
+    void displayStatus();
103
+    void createCheckpoint();
104
+
105
+    virtual float chiSq() const = 0;
106
+    virtual float meanChiSq() const = 0;
107
+    virtual unsigned nAtoms(char which) const = 0;
108
+    virtual void setAnnealingTemp(float temp) = 0;
109
+    virtual void updateStatistics() = 0;
110
+    virtual Archive& readSamplers(Archive &ar) = 0;
111
+    virtual Archive& writeSamplers(Archive &ar) = 0;
112
+    virtual void updateSampler(unsigned nA, unsigned nP) = 0;
113
+};
114
+
115
+///////////////////// GapsRunner Implementations ///////////////////////////////
116
+
117
+// This implementation uses a DenseGibbsSampler internally
118
+class DenseGapsRunner : public AbstractGapsRunner
119
+{
120
+public:
121
+
122
+    ~DenseGapsRunner() {}
123
+
124
+    template <class DataType>
125
+    DenseGapsRunner(const DataType &data, const GapsParameters &params,
126
+        GapsRandomState *randState);
127
+
128
+    void setUncertainty(const Matrix &unc, const GapsParameters &params);
129
+    void setUncertainty(const std::string &unc, const GapsParameters &params);
130
+
131
+private:
132
+
133
+    DenseGibbsSampler mASampler;
134
+    DenseGibbsSampler mPSampler;
135
+
136
+    float chiSq() const;
137
+    float meanChiSq() const;
138
+    unsigned nAtoms(char which) const;
139
+    void setAnnealingTemp(float temp);
140
+    void updateStatistics();
141
+    Archive& readSamplers(Archive &ar);
142
+    Archive& writeSamplers(Archive &ar);
143
+    void updateSampler(unsigned nA, unsigned nP);
144
+};
145
+
146
+// This implementation uses a SparseGibbsSampler internally
147
+class SparseGapsRunner : public AbstractGapsRunner
148
+{
149
+public:
150
+
151
+    ~SparseGapsRunner() {}
152
+
153
+    template <class DataType>
154
+    SparseGapsRunner(const DataType &data, const GapsParameters &params,
155
+        GapsRandomState *randState);
156
+
157
+    void setUncertainty(const Matrix &unc, const GapsParameters &params);
158
+    void setUncertainty(const std::string &unc, const GapsParameters &params);
159
+
160
+private:
161
+
162
+    SparseGibbsSampler mASampler;
163
+    SparseGibbsSampler mPSampler;
164
+
165
+    float chiSq() const;
166
+    float meanChiSq() const;
167
+    unsigned nAtoms(char which) const;
168
+    void setAnnealingTemp(float temp);
169
+    void updateStatistics();
170
+    Archive& readSamplers(Archive &ar);
171
+    Archive& writeSamplers(Archive &ar);
172
+    void updateSampler(unsigned nA, unsigned nP);
173
+};
174
+
175
+/////////////////////// GapsRunner - templated functions ///////////////////////
176
+
177
+template <class DataType>
178
+GapsRunner::GapsRunner(const DataType &data, const GapsParameters &params,
179
+GapsRandomState *randState)
180
+    : mRunner(AbstractGapsRunner::create(data, params, randState))
181
+{}
182
+
183
+template <class DataType>
184
+void GapsRunner::setUncertainty(const DataType &unc, const GapsParameters &params)
185
+{
186
+    mRunner->setUncertainty(unc, params);
187
+}
188
+
189
+/////////////////// AbstractGapsRunner - templated functions ///////////////////
190
+
191
+template <class DataType>
192
+AbstractGapsRunner* AbstractGapsRunner::create(const DataType &data,
193
+const GapsParameters &params, GapsRandomState *randState)
194
+{
195
+    if (params.useSparseOptimization)
196
+    {
197
+        return new SparseGapsRunner(data, params, randState);
198
+    }
199
+    return new DenseGapsRunner(data, params, randState);
200
+}
201
+
202
+//////////////////// DenseGapsRunner - templated functions /////////////////////
203
+
204
+template <class DataType>
205
+DenseGapsRunner::DenseGapsRunner(const DataType &data,
206
+const GapsParameters &params, GapsRandomState *randState)
207
+    :
208
+AbstractGapsRunner(params, randState),
209
+mASampler(data, !params.transposeData, !params.subsetGenes, params.alphaA,
210
+    params.maxGibbsMassA, params, randState),