Browse code

polymorphic structure

Tom Sherman authored on 02/10/2018 20:54:21
Showing13 changed files

... ...
@@ -1,159 +1,137 @@
1 1
 #include "GapsRunner.h"
2
-#include "math/SIMD.h"
3 2
 #include "utils/GlobalConfig.h"
4 3
 
5 4
 #include <Rcpp.h>
6 5
 #include <string>
7 6
 #include <sstream>
8 7
 
9
-// these are helper functions for converting matrix/vector types
10
-// to and from R objects
8
+// this file contains the blueprint for creating a wrapper around the C++
9
+// interface used for running CoGAPS. It exposes some functions to R, has a
10
+// method for converting the R parameters to the standard GapsParameters
11
+// struct, and creates a GapsRunner object. The GapsRunner class manages
12
+// all information from a CoGAPS run and is used to set off the run
13
+// and get return data.
11 14
 
12
-static Matrix convertRMatrix(const Rcpp::NumericMatrix &rmat, bool transpose=false)
15
+////////////////// functions for converting matrix types ///////////////////////
16
+
17
+// convert R to C++ data type
18
+static Matrix convertRMatrix(const Rcpp::NumericMatrix &rmat)
13 19
 {
14
-    unsigned nr = transpose ? rmat.ncol() : rmat.nrow();
15
-    unsigned nc = transpose ? rmat.nrow() : rmat.ncol();
16
-    Matrix mat(nr, nc);
17
-    for (unsigned i = 0; i < nr; ++i)
20
+    Matrix mat(rmat.nrow(), rmat.ncol());
21
+    for (unsigned i = 0; i < rmat.nrow(); ++i)
18 22
     {
19
-        for (unsigned j = 0; j < nc; ++j)
23
+        for (unsigned j = 0; j < rmat.ncol(); ++j)
20 24
         {
21
-            mat(i,j) = transpose ? rmat(j,i) : rmat(i,j);
25
+            mat(i,j) = rmat(i,j);
22 26
         }
23 27
     }
24 28
     return mat;
25 29
 }
26 30
 
27
-template <class Matrix>
28
-static Rcpp::NumericMatrix createRMatrix(const Matrix &mat, bool transpose=false)
31
+// convert C++ to R data type
32
+template <class GenericMatrix>
33
+static Rcpp::NumericMatrix createRMatrix(const GenericMatrix &mat)
29 34
 {
30
-    unsigned nr = transpose ? mat.nCol() : mat.nRow();
31
-    unsigned nc = transpose ? mat.nRow() : mat.nCol();
32
-    Rcpp::NumericMatrix rmat(nr, nc);
33
-    for (unsigned i = 0; i < nr; ++i)
35
+    Rcpp::NumericMatrix rmat(mat.nRow(), mat.nCol());
36
+    for (unsigned i = 0; i < mat.nRow(); ++i)
34 37
     {
35
-        for (unsigned j = 0; j < nc; ++j)
38
+        for (unsigned j = 0; j < mat.nCol(); ++j)
36 39
         {
37
-            rmat(i,j) = transpose ? mat(j,i) : mat(i,j);
40
+            rmat(i,j) = mat(i,j);
38 41
         }
39 42
     }
40 43
     return rmat;
41 44
 }
42 45
 
43
-// these helper functions provide an abtracted way for communicating which
44
-// parameters are null between R and C++
46
+////////// converts R parameters to single GapsParameters struct ///////////////
45 47
 
46
-static bool isNull(const std::string &file)
47
-{
48
-    return file.empty();
49
-}
50
-
51
-static bool isNull(const Matrix &mat)
52
-{
53
-    return mat.nRow() == 1 && mat.nCol() == 1;
54
-}
55
-
56
-// needed to create proper size of GapsRunner
57
-unsigned getNumPatterns(const Rcpp::List &allParams)
48
+GapsParameters getGapsParameters(const Rpp::List &allParams, bool isMaster,
49
+const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix,
50
+const Rcpp::Nullable<Rcpp::IntegerVector> &indices)
58 51
 {
52
+    // Standard CoGAPS parameters struct
53
+    GapsParameters params;
54
+
55
+    // get configuration parameters
56
+    params.maxThreads = allParams["nThreads"];
57
+    params.printMessages = allParams["messages"] && isMaster;
58
+    params.transposeData = allParams["transposeData"];
59
+    params.outputFrequency = allParams["outputFrequency"];
60
+    params.checkpointOutFile = allParams["checkpointOutFile"];
61
+    params.checkpointInterval = allParams["checkpointInterval"];
62
+
63
+    // extract model specific parameters from list
59 64
     const Rcpp::S4 &gapsParams(allParams["gaps"]);
60
-    unsigned nPatterns = gapsParams.slot("nPatterns");
61
-    if (!Rf_isNull(allParams["checkpointInFile"]))
65
+    params.seed = gapsParams.slot("seed");
66
+    params.nPatterns = gapsParams.slot("nPatterns");
67
+    params.nIterations = gapsParams.slot("nIterations");
68
+    params.alphaA = gapsParams.slot("alphaA");
69
+    params.alphaP = gapsParams.slot("alphaP");
70
+    params.maxGibbsMassA = gapsParams.slot("maxGibbsMassA");
71
+    params.maxGibbsMassP = gapsParams.slot("maxGibbsMassP");
72
+    params.singleCell = gapsParams.slot("singleCell");
73
+
74
+    // check if using fixed matrix
75
+    if (fixedMatrix.isNotNull())
62 76
     {
63
-        std::string file(Rcpp::as<std::string>(allParams["checkpointInFile"]));
64
-        Archive ar(file, ARCHIVE_READ);
65
-        GapsRng::load(ar);
66
-        ar >> nPatterns;
67
-        ar.close();
77
+        params.useFixedMatrix = true;
78
+        params.fixedMatrix = convertRMatrix(Rcpp::NumericMatrix(fixedMatrix));
68 79
     }
69
-    return nPatterns;
70
-}
71 80
 
72
-std::vector<unsigned> getSubsetIndices(const Rcpp::Nullable<Rcpp::IntegerVector> &indices)
73
-{
81
+    // check if subsetting data
74 82
     if (indices.isNotNull())
75 83
     {
76
-        return Rcpp::as< std::vector<unsigned> >(Rcpp::IntegerVector(indices));
84
+        params.subsetData = true;
85
+        params.printThreadUsage = false;
86
+
87
+        std::string d(Rcpp::as<std::string>(gapsParams.slot("distributed")));
88
+        params.subsetGenes = (d == "genome-wide");
89
+        params.whichFixedMatrix = (d == "genome-wide") ? 'P' : 'A';
90
+
91
+        params.dataIndicesSubset =
92
+            Rcpp::as< std::vector<unsigned> >(Rcpp::IntegerVector(indices));
77 93
     }
78
-    return std::vector<unsigned>(1); // interpreted as null, i.e. will be ignored
79
-}
80 94
 
81
-// return if running distributed, and if so, are we partitioning rows/cols
82
-std::pair<bool, bool> processDistributedParameters(const Rcpp::List &allParams)
83
-{
84
-    const Rcpp::S4 &gapsParams(allParams["gaps"]);
85
-    if (!Rf_isNull(gapsParams.slot("distributed")))
95
+    // check if using checkpoint file, peek at the saved parameters
96
+    if (!Rf_isNull(allParams["checkpointInFile"]))
86 97
     {
87
-        std::string d(Rcpp::as<std::string>(gapsParams.slot("distributed")));
88
-        GAPS_ASSERT(d == "genome-wide" || d == "single-cell");
89
-        return std::pair<bool, bool>(true, d == "genome-wide");
98
+        params.checkpointFile = Rcpp::as<std::string>(allParams["checkpointInFile"]);
99
+        params.useCheckPoint = true;
100
+        params.peekCheckpoint(params.checkpointFile);
90 101
     }
91
-    return std::pair<bool, bool>(false, false);
102
+
103
+    return params;
92 104
 }
93 105
 
94
-// this is the main function that creates a GapsRunner and runs CoGAPS
106
+
107
+////////// main function that creates a GapsRunner and runs CoGAPS /////////////
108
+
109
+// note uncertainty matrix gets special treatment since it's the same size as
110
+// the data (potentially large), so we want to avoid copying it into the 
111
+// GapsParameters struct temporarily
95 112
 
96 113
 template <class DataType>
97 114
 static Rcpp::List cogapsRun(const DataType &data, const Rcpp::List &allParams,
98 115
 const DataType &uncertainty, const Rcpp::Nullable<Rcpp::IntegerVector> &indices,
99 116
 const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix, bool isMaster)
100 117
 {
101
-    // calculate essential parameters needed for constructing GapsRunner
102
-    const Rcpp::S4 &gapsParams(allParams["gaps"]);
103
-    GapsRng::setSeed(gapsParams.slot("seed"));
104
-    unsigned nPatterns = getNumPatterns(allParams); // TODO clarify this sets the checkpoint seed as well
105
-    bool printThreads = !processDistributedParameters(allParams).first;
106
-    bool partitionRows = processDistributedParameters(allParams).second;
107
-    std::vector<unsigned> cIndices(getSubsetIndices(indices));
118
+    // convert R parameters to GapsParameters struct
119
+    GapsParameters gapsParams(getGapsParameters(allParams, isMaster,
120
+        fixedMatrix, indices));
108 121
 
109
-    // construct GapsRunner
110
-    GapsRunner runner(data, allParams["transposeData"], nPatterns,
111
-        partitionRows, cIndices);
122
+    // create GapsRunner, note we must first initialize the random generator
123
+    GapsRng::setSeed(params.seed);
124
+    GapsRunner runner(data, gapsParams);
112 125
 
113 126
     // set uncertainty
114
-    if (!isNull(uncertainty))
127
+    if (!uncertainty.empty())
115 128
     {
116
-        runner.setUncertainty(uncertainty, allParams["transposeData"],
117
-            partitionRows, cIndices);
129
+        runner.setUncertainty(uncertainty, gapsParams);
118 130
     }
119 131
     
120
-    // populate GapsRunner from checkpoint file
121
-    if (!Rf_isNull(allParams["checkpointInFile"]))
122
-    {
123
-        std::string file(Rcpp::as<std::string>(allParams["checkpointInFile"]));
124
-        Archive ar(file, ARCHIVE_READ);
125
-        GapsRng::load(ar);
126
-        ar >> runner;
127
-        ar.close();
128
-    }
129
-    else // no checkpoint, populate from given parameters
130
-    {
131
-        // set fixed matrix
132
-        if (fixedMatrix.isNotNull())
133
-        {
134
-            GAPS_ASSERT(!Rf_isNull(allParams["whichMatrixFixed"]));
135
-            std::string which = Rcpp::as<std::string>(allParams["whichMatrixFixed"]);
136
-            runner.setFixedMatrix(which[0], convertRMatrix(Rcpp::NumericMatrix(fixedMatrix), which[0]=='P'));
137
-        }
138
-
139
-        // set parameters that would be saved in the checkpoint
140
-        runner.recordSeed(gapsParams.slot("seed"));
141
-        runner.setMaxIterations(gapsParams.slot("nIterations"));
142
-        runner.setSparsity(gapsParams.slot("alphaA"), gapsParams.slot("alphaP"),
143
-            gapsParams.slot("maxGibbsMassA"), gapsParams.slot("maxGibbsMassP"),
144
-            gapsParams.slot("singleCell"));
145
-    }
146
-
147
-    // set parameters that aren't saved in the checkpoint
148
-    runner.setMaxThreads(allParams["nThreads"]);
149
-    runner.setPrintMessages(allParams["messages"] && isMaster);
150
-    runner.setOutputFrequency(allParams["outputFrequency"]);
151
-    runner.setCheckpointOutFile(allParams["checkpointOutFile"]);
152
-    runner.setCheckpointInterval(allParams["checkpointInterval"]);
153
-
154 132
     // run cogaps
155
-    GapsResult result(runner.run(printThreads));
156
-    
133
+    GapsResult result(runner.run());
134
+
157 135
     // write result to file if requested
158 136
     if (allParams["outputToFile"] != R_NilValue)
159 137
     {
... ...
@@ -166,7 +144,7 @@ const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix, bool isMaster)
166 144
         Rcpp::Named("Pmean") = createRMatrix(result.Pmean),
167 145
         Rcpp::Named("Asd") = createRMatrix(result.Asd),
168 146
         Rcpp::Named("Psd") = createRMatrix(result.Psd),
169
-        Rcpp::Named("seed") = runner.getSeed(),
147
+        Rcpp::Named("seed") = params.seed,
170 148
         Rcpp::Named("meanChiSq") = result.meanChiSq,
171 149
         Rcpp::Named("geneNames") = allParams["geneNames"],
172 150
         Rcpp::Named("sampleNames") = allParams["sampleNames"],
... ...
@@ -174,7 +152,7 @@ const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix, bool isMaster)
174 152
     );
175 153
 }
176 154
 
177
-// these are the functions exposed to the R package
155
+/////////////////// functions exposed to the R package /////////////////////////
178 156
 
179 157
 // [[Rcpp::export]]
180 158
 Rcpp::List cogaps_cpp_from_file(const Rcpp::CharacterVector &data,
... ...
@@ -184,11 +162,12 @@ const Rcpp::Nullable<Rcpp::IntegerVector> &indices=R_NilValue,
184 162
 const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix=R_NilValue,
185 163
 bool isMaster=true)
186 164
 {
187
-    std::string unc = ""; // interpreted as null, i.e. will be ignored
165
+    std::string unc;
188 166
     if (uncertainty.isNotNull())
189 167
     {
190 168
         unc = Rcpp::as<std::string>(Rcpp::CharacterVector(uncertainty));
191 169
     }
170
+
192 171
     return cogapsRun(Rcpp::as<std::string>(data), allParams, unc, indices,
193 172
         fixedMatrix, isMaster);
194 173
 }
... ...
@@ -201,11 +180,12 @@ const Rcpp::Nullable<Rcpp::IntegerVector> &indices=R_NilValue,
201 180
 const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix=R_NilValue,
202 181
 bool isMaster=true)
203 182
 {
204
-    Matrix unc(1,1); // interpreted as null, i.e. will be ignored
183
+    Matrix unc;
205 184
     if (uncertainty.isNotNull())
206 185
     {
207 186
         unc = convertRMatrix(Rcpp::NumericMatrix(uncertainty));
208 187
     }
188
+
209 189
     return cogapsRun(convertRMatrix(data), allParams, unc, indices,
210 190
         fixedMatrix, isMaster);
211 191
 }
212 192
new file mode 100644
... ...
@@ -0,0 +1,66 @@
1
+#ifndef __COGAPS_GAPS_PARAMETERS_H__
2
+#define __COGAPS_GAPS_PARAMETERS_H__
3
+
4
+struct GapsParameters
5
+{
6
+    Matrix fixedMatrix;
7
+
8
+    std::vector<unsigned> dataIndicesSubset;
9
+
10
+    std::string checkpointFile;
11
+    std::string checkpointOutFile;
12
+
13
+    uint32_t seed;
14
+
15
+    unsigned nPatterns;
16
+    unsigned nIterations;
17
+    unsigned maxThreads;
18
+    unsigned outputFrequency;
19
+    unsigned checkpointInterval;
20
+
21
+    float alphaA;
22
+    float alphaP;
23
+    float maxGibbsMassA;
24
+    float maxGibbsMassP;
25
+
26
+    bool useFixedMatrix;
27
+    bool subsetData;
28
+    bool useCheckPoint;
29
+    bool transposeData;
30
+    bool singleCell;
31
+    bool printMessages;
32
+    bool subsetGenes;
33
+    bool printThreadUsage;
34
+
35
+    char whichFixedMatrix;
36
+
37
+    GapsParameters() :
38
+        checkpointOutFile("gaps_checkpoint.out"),
39
+        seed(0),
40
+        nIterations(1000),
41
+        maxThreads(1),
42
+        outputFrequency(500),
43
+        checkpointInterval(250),
44
+        alphaA(0.01f),
45
+        alphaP(0.01f),
46
+        maxGibbsMassA(100.f),
47
+        maxGibbsMassP(100.f),
48
+        useFixedMatrix(false),
49
+        subsetData(false),
50
+        useCheckpoint(false),
51
+        transposeData(false),
52
+        singleCell(false),
53
+        printMessages(true),
54
+        subsetGenes(false),
55
+        printThreadUsage(true),
56
+        whichFixedMatrix('N')
57
+    {}
58
+
59
+    void peekCheckpoint(const std::string &file)
60
+    {
61
+        Archive ar(file, ARCHIVE_READ);
62
+        ar >> nPatterns >> seed >> nIterations >> whichFixedMatrix;
63
+    }
64
+};
65
+
66
+#endif
0 67
\ No newline at end of file
... ...
@@ -14,73 +14,13 @@
14 14
 #include <omp.h>
15 15
 #endif
16 16
 
17
-void GapsRunner::setFixedMatrix(char which, const Matrix &mat)
18
-{
19
-    mFixedMatrix = which;
20
-    if (which == 'A')
21
-    {
22
-        mASampler.setMatrix(mat);
23
-    }
24
-    else if (which == 'P')
25
-    {
26
-        mPSampler.setMatrix(mat);
27
-    }
28
-}
29
-
30
-void GapsRunner::recordSeed(uint32_t seed)
31
-{
32
-    mSeed = seed;
33
-}
34
-
35
-uint32_t GapsRunner::getSeed() const
36
-{
37
-    return mSeed;
38
-}
39
-
40
-void GapsRunner::setMaxIterations(unsigned nIterations)
41
-{
42
-    mMaxIterations = nIterations;
43
-}
44
-
45
-void GapsRunner::setSparsity(float alphaA, float alphaP, float maxA, float maxP,
46
-bool singleCell)
47
-{
48
-    mASampler.setSparsity(alphaA, maxA, singleCell);
49
-    mPSampler.setSparsity(alphaP, maxP, singleCell);
50
-}
51
-
52
-void GapsRunner::setMaxThreads(unsigned nThreads)
53
-{
54
-    mMaxThreads = nThreads;
55
-}
56
-
57
-void GapsRunner::setPrintMessages(bool print)
58
-{
59
-    mPrintMessages = print;
60
-}
61
-
62
-void GapsRunner::setOutputFrequency(unsigned n)
63
-{
64
-    mOutputFrequency = n;
65
-}
66
-
67
-void GapsRunner::setCheckpointOutFile(const std::string &file)
68
-{
69
-    mCheckpointOutFile = file;
70
-}
71
-
72
-void GapsRunner::setCheckpointInterval(unsigned interval)
73
-{
74
-    mCheckpointInterval = interval;
75
-}
76
-
77
-GapsResult GapsRunner::run(bool printThreads)
17
+GapsResult GapsRunner::run()
78 18
 {
79 19
     mStartTime = bpt_now();
80 20
 
81 21
     // calculate appropiate number of threads if compiled with openmp
82 22
     #ifdef __GAPS_OPENMP__
83
-    if (mPrintMessages && printThreads)
23
+    if (mPrintMessages && mPrintThreadUsage)
84 24
     {
85 25
         unsigned availableThreads = omp_get_max_threads();
86 26
         mMaxThreads = gaps::min(availableThreads, mMaxThreads);
... ...
@@ -129,13 +69,13 @@ void GapsRunner::runOnePhase()
129 69
         {        
130 70
             float temp = static_cast<float>(2 * mCurrentIteration)
131 71
                 / static_cast<float>(mMaxIterations);
132
-            mASampler.setAnnealingTemp(gaps::min(1.f, temp));
133
-            mPSampler.setAnnealingTemp(gaps::min(1.f, temp));
72
+            mASampler->setAnnealingTemp(gaps::min(1.f, temp));
73
+            mPSampler->setAnnealingTemp(gaps::min(1.f, temp));
134 74
         }
135 75
     
136 76
         // number of updates per iteration is poisson 
137
-        unsigned nA = mRng.poisson(gaps::max(mASampler.nAtoms(), 10));
138
-        unsigned nP = mRng.poisson(gaps::max(mPSampler.nAtoms(), 10));
77
+        unsigned nA = mRng.poisson(gaps::max(mASampler->nAtoms(), 10));
78
+        unsigned nP = mRng.poisson(gaps::max(mPSampler->nAtoms(), 10));
139 79
         updateSampler(nA, nP);
140 80
 
141 81
         if (mPhase == 'S')
... ...
@@ -152,23 +92,23 @@ void GapsRunner::updateSampler(unsigned nA, unsigned nP)
152 92
     if (mFixedMatrix != 'A')
153 93
     {
154 94
         mNumUpdatesA += nA;
155
-        mASampler.update(nA, mMaxThreads);
95
+        mASampler->update(nA, mMaxThreads);
156 96
         if (mFixedMatrix != 'P')
157 97
         {
158
-            mPSampler.sync(mASampler, mMaxThreads);
98
+            mPSampler->sync(mASampler, mMaxThreads);
159 99
         }
160
-        GAPS_ASSERT(mASampler.internallyConsistent());
100
+        GAPS_ASSERT(mASampler->internallyConsistent());
161 101
     }
162 102
 
163 103
     if (mFixedMatrix != 'P')
164 104
     {
165 105
         mNumUpdatesP += nP;
166
-        mPSampler.update(nP, mMaxThreads);
106
+        mPSampler->update(nP, mMaxThreads);
167 107
         if (mFixedMatrix != 'A')
168 108
         {
169
-            mASampler.sync(mPSampler, mMaxThreads);
109
+            mASampler->sync(mPSampler, mMaxThreads);
170 110
         }
171
-        GAPS_ASSERT(mPSampler.internallyConsistent());
111
+        GAPS_ASSERT(mPSampler->internallyConsistent());
172 112
     }
173 113
 }
174 114
 
... ...
@@ -185,8 +125,8 @@ static double estimatedNumUpdates(double current, double total, float nAtoms)
185 125
 double GapsRunner::estimatedPercentComplete() const
186 126
 {
187 127
     double nIter = static_cast<double>(mCurrentIteration);
188
-    double nAtomsA = static_cast<double>(mASampler.nAtoms());
189
-    double nAtomsP = static_cast<double>(mPSampler.nAtoms());
128
+    double nAtomsA = static_cast<double>(mASampler->nAtoms());
129
+    double nAtomsP = static_cast<double>(mPSampler->nAtoms());
190 130
     
191 131
     if (mPhase == 'S')
192 132
     {
... ...
@@ -226,8 +166,8 @@ void GapsRunner::displayStatus()
226 166
         totalSeconds -= totalMinutes * 60;
227 167
 
228 168
         gaps_printf("%d of %d, Atoms: %lu(%lu), ChiSq: %.0f, Time: %02d:%02d:%02d / %02d:%02d:%02d\n",
229
-            mCurrentIteration + 1, mMaxIterations, mASampler.nAtoms(),
230
-            mPSampler.nAtoms(), mPSampler.chi2(), elapsedHours, elapsedMinutes,
169
+            mCurrentIteration + 1, mMaxIterations, mASampler->nAtoms(),
170
+            mPSampler->nAtoms(), mPSampler->chi2(), elapsedHours, elapsedMinutes,
231 171
             elapsedSeconds, totalHours, totalMinutes, totalSeconds);
232 172
         gaps_flush();
233 173
     }
... ...
@@ -242,30 +182,12 @@ void GapsRunner::createCheckpoint()
242 182
     
243 183
         // create checkpoint file
244 184
         Archive ar(mCheckpointOutFile, ARCHIVE_WRITE);
185
+        ar << mNumPatterns << mSeed << mMaxIterations << mFixedMatrix << mPhase
186
+            << mCurrentIteration << mNumUpdatesA << mNumUpdatesP << mRng
187
+            << *mASampler << *mPSampler;
245 188
         GapsRng::save(ar);
246
-        ar << mNumPatterns << mSeed << mASampler << mPSampler << mStatistics
247
-            << mFixedMatrix << mMaxIterations << mPhase << mCurrentIteration
248
-            << mNumUpdatesA << mNumUpdatesP << mRng;
249
-        ar.close();
250 189
 
251 190
         // delete backup file
252 191
         std::remove((mCheckpointOutFile + ".backup").c_str());
253 192
     }
254 193
 }
255
-
256
-// assume random state has been loaded and nPatterns and seed have been read
257
-Archive& operator>>(Archive &ar, GapsRunner &gr)
258
-{
259
-    ar >> gr.mNumPatterns >> gr.mSeed >> gr.mASampler >> gr.mPSampler
260
-        >> gr.mStatistics >> gr.mFixedMatrix >> gr.mMaxIterations >> gr.mPhase
261
-        >> gr.mCurrentIteration >> gr.mNumUpdatesA >> gr.mNumUpdatesP
262
-        >> gr.mRng;
263
-
264
-    gr.mASampler.sync(gr.mPSampler);
265
-    gr.mPSampler.sync(gr.mASampler);
266
-
267
-    gr.mASampler.recalculateAPMatrix();
268
-    gr.mPSampler.recalculateAPMatrix();
269
-
270
-    return ar;
271
-}
... ...
@@ -1,12 +1,11 @@
1 1
 #ifndef __COGAPS_GAPS_RUNNER_H__
2 2
 #define __COGAPS_GAPS_RUNNER_H__
3 3
 
4
+#include "GapsParameters.h"
4 5
 #include "GapsResult.h"
5 6
 #include "GapsStatistics.h"
6 7
 #include "GibbsSampler.h"
7 8
 
8
-#include "data_structures/Matrix.h"
9
-
10 9
 // boost time helpers
11 10
 #include <boost/date_time/posix_time/posix_time.hpp>
12 11
 namespace bpt = boost::posix_time;
... ...
@@ -14,94 +13,104 @@ namespace bpt = boost::posix_time;
14 13
 
15 14
 class GapsRunner
16 15
 {
16
+public:
17
+
18
+    template <class DataType>
19
+    GapsRunner(const DataType &data, const GapsParameters &params);
20
+
21
+    template <class DataType>
22
+    void setUncertainty(const DataType &unc, const GapsParameters &params);
23
+
24
+    GapsResult run();
25
+
17 26
 private:
18 27
     
19
-    GibbsSampler mASampler;
20
-    GibbsSampler mPSampler;
28
+    GibbsSampler *mASampler;
29
+    GibbsSampler *mPSampler;
21 30
     GapsStatistics mStatistics;
22 31
 
23
-    char mFixedMatrix;
24
-    unsigned mMaxIterations;
25
-    
26
-    unsigned mMaxThreads;
27
-    bool mPrintMessages;
28
-    unsigned mOutputFrequency;
32
+    mutable GapsRng mRng;
33
+
29 34
     std::string mCheckpointOutFile;
30
-    unsigned mCheckpointInterval;
31 35
 
32 36
     bpt::ptime mStartTime;
33
-    char mPhase;
34
-    unsigned mCurrentIteration;
35 37
 
36
-    // only kept since they need to be written to the start of every checkpoint
38
+    unsigned mCurrentIteration;
39
+    unsigned mMaxIterations;
40
+    unsigned mMaxThreads;
41
+    unsigned mOutputFrequency;
42
+    unsigned mCheckpointInterval;
37 43
     unsigned mNumPatterns;
38
-    uint32_t mSeed;
39
-
40 44
     unsigned mNumUpdatesA;
41 45
     unsigned mNumUpdatesP;
46
+    uint32_t mSeed;
42 47
 
43
-    mutable GapsRng mRng;
48
+    bool mPrintMessages;
49
+    bool mPrintThreadUsage;
50
+
51
+    char mPhase;
52
+    char mFixedMatrix;
44 53
         
45 54
     void runOnePhase();
46 55
     void updateSampler(unsigned nA, unsigned nP);
47 56
     double estimatedPercentComplete() const;
48 57
     void displayStatus();
49 58
     void createCheckpoint();
50
-
51
-public:
52
-
53
-    template <class DataType>
54
-    GapsRunner(const DataType &data, bool transposeData, unsigned nPatterns,
55
-        bool partitionRows, const std::vector<unsigned> &indices);
56
-
57
-    template <class DataType>
58
-    void setUncertainty(const DataType &unc, bool transposeData,
59
-        bool partitionRows, const std::vector<unsigned> &indices);
60
-
61
-    void setFixedMatrix(char which, const Matrix &mat);
62
-
63
-    void recordSeed(uint32_t seed);
64
-    uint32_t getSeed() const;
65
-
66
-    void setMaxIterations(unsigned nIterations);
67
-    void setSparsity(float alphaA, float alphaP, float maxA, float maxP,
68
-        bool singleCell);
69
-    
70
-    void setMaxThreads(unsigned nThreads);
71
-    void setPrintMessages(bool print);
72
-    void setOutputFrequency(unsigned n);
73
-    void setCheckpointOutFile(const std::string &outFile);
74
-    void setCheckpointInterval(unsigned interval);
75
-
76
-    GapsResult run(bool printThreads=true);
77
-
78
-    // serialization
79
-    friend Archive& operator>>(Archive &ar, GapsRunner &runner);
80 59
 };
81 60
 
82
-// problem with passing file parser - need to read it twice
83 61
 template <class DataType>
84
-GapsRunner::GapsRunner(const DataType &data, bool transposeData,
85
-unsigned nPatterns, bool partitionRows, const std::vector<unsigned> &indices)
62
+GapsRunner::GapsRunner(const DataType &data, const GapsParameters &params)
86 63
     :
87
-mASampler(data, !transposeData, nPatterns, !partitionRows, indices),
88
-mPSampler(data, transposeData, nPatterns, partitionRows, indices),
89
-mStatistics(mPSampler.dataRows(), mPSampler.dataCols(), nPatterns),
90
-mFixedMatrix('N'), mMaxIterations(1000), mMaxThreads(1), mPrintMessages(true),
91
-mOutputFrequency(500), mCheckpointOutFile("gaps_checkpoint.out"),
92
-mCheckpointInterval(0), mPhase('C'), mCurrentIteration(0),
93
-mNumPatterns(nPatterns), mSeed(0), mNumUpdatesA(0), mNumUpdatesP(0)
64
+mASampler(new DenseGibbsSampler(data, !params.transposeData, params.nPatterns, params.subsetGenes, params.dataIndicesSubset)),
65
+mPSampler(new DenseGibbsSampler(data, params.transposeData, params.nPatterns, params.subsetGenes, params.dataIndicesSubset)),
66
+mStatistics(mPSampler->dataRows(), mPSampler->dataCols(), params.nPatterns),
67
+mCheckpointOutFile(params.checkpointOutFile),
68
+mMaxIterations(params.nIterations),
69
+mMaxThreads(params.mMaxThreads),
70
+mOutputFrequency(params.mOutputFrequency),
71
+mCheckpointInterval(params.mCheckpointInterval),
72
+mNumPatterns(params.nPatterns),
73
+mNumUpdatesA(0),
74
+mNumUpdatesP(0),
75
+mSeed(params.seed),
76
+mPrintMessages(params.printMessages),
77
+mPrintThreadUsage(params.printThreadUsage),
78
+mPhase('C'),
79
+mFixedMatrix(params.whichFixedMatrix)
94 80
 {
95
-    mASampler.sync(mPSampler);
96
-    mPSampler.sync(mASampler);
81
+    mASampler->setSparsity(params.alphaA, params.maxGibbsMassA, params.singleCell);
82
+    mPSampler->setSparsity(params.alphaP, params.maxGibbsMassP, params.singleCell);
83
+
84
+    switch (mFixedMatrix)
85
+    {
86
+        case 'A' : mASampler->setMatrix(params.fixedMatrix); break;
87
+        case 'P' : mPSampler->setMatrix(params.fixedMatrix); break;
88
+        default: break;
89
+    }
90
+
91
+    // overwrite with info from checkpoint file
92
+    if (params.useCheckPoint)
93
+    {
94
+        Archive ar(params.checkpointFile, ARCHIVE_READ);
95
+        ar >> mNumPatterns >> mSeed >> mMaxIterations >> mFixedMatrix >> mPhase
96
+            >> mCurrentIteration >> mNumUpdatesA >> mNumUpdatesP >> mRng
97
+            >> *mASampler >> *mPSampler;
98
+        GapsRng::load(ar);
99
+    }
100
+
101
+    mASampler->sync(mPSampler);
102
+    mPSampler->sync(mASampler);
103
+    mASampler->recalculateAPMatrix();
104
+    mPSampler->recalculateAPMatrix();
97 105
 }
98 106
 
99 107
 template <class DataType>
100
-void GapsRunner::setUncertainty(const DataType &unc, bool transposeData,
101
-bool partitionRows, const std::vector<unsigned> &indices)
108
+void GapsRunner::setUncertainty(const DataType &unc, const GapsParameters &params)
102 109
 {
103
-    mASampler.setUncertainty(unc, !transposeData, !partitionRows, indices);
104
-    mPSampler.setUncertainty(unc, transposeData, partitionRows, indices);
110
+    mASampler->setUncertainty(unc, !params.transposeData, params.nPatterns,
111
+        params.subsetGenes, params.dataIndicesSubset);
112
+    mPSampler->setUncertainty(unc, params.transposeData, params.nPatterns,
113
+        params.subsetGenes, params.dataIndicesSubset);
105 114
 }
106 115
 
107 116
 #endif // __COGAPS_GAPS_RUNNER_H__
108 117
\ No newline at end of file
... ...
@@ -8,10 +8,10 @@ class GapsStatistics
8 8
 {
9 9
 private:
10 10
 
11
-    ColMatrix mAMeanMatrix;
12
-    ColMatrix mAStdMatrix;
13
-    ColMatrix mPMeanMatrix;
14
-    ColMatrix mPStdMatrix;
11
+    Matrix mAMeanMatrix;
12
+    Matrix mAStdMatrix;
13
+    Matrix mPMeanMatrix;
14
+    Matrix mPStdMatrix;
15 15
     
16 16
     unsigned mStatUpdates;
17 17
     unsigned mNumPatterns;
... ...
@@ -20,14 +20,14 @@ public:
20 20
 
21 21
     GapsStatistics(unsigned nRow, unsigned nCol, unsigned nPatterns);
22 22
 
23
-    void update(const GibbsSampler &ASampler, const GibbsSampler &PSampler);
23
+    void update(const GibbsSampler *ASampler, const GibbsSampler *PSampler);
24 24
 
25
-    ColMatrix Amean() const;
26
-    ColMatrix Pmean() const;
27
-    ColMatrix Asd() const;
28
-    ColMatrix Psd() const;
25
+    Matrix Amean() const;
26
+    Matrix Pmean() const;
27
+    Matrix Asd() const;
28
+    Matrix Psd() const;
29 29
 
30
-    float meanChiSq(const GibbsSampler &PSampler) const;
30
+    float meanChiSq(const GibbsSampler *PSampler) const;
31 31
 
32 32
     // serialization
33 33
     friend Archive& operator<<(Archive &ar, GapsStatistics &stat);
34 34
deleted file mode 100644
... ...
@@ -1,128 +0,0 @@
1
-#ifndef __COGAPS_GIBBS_SAMPLER_H__
2
-#define __COGAPS_GIBBS_SAMPLER_H__
3
-
4
-#define DEFAULT_ALPHA           0.01f
5
-#define DEFAULT_MAX_GIBBS_MASS  100.f
6
-
7
-#include "AtomicDomain.h"
8
-#include "ProposalQueue.h"
9
-#include "data_structures/Matrix.h"
10
-#include "math/Algorithms.h"
11
-#include "math/Random.h"
12
-
13
-#include <vector>
14
-
15
-class GibbsSampler
16
-{
17
-public:
18
-
19
-    template <class DataType>
20
-    GibbsSampler(const DataType &data, bool transpose, unsigned nPatterns,
21
-        bool partitionRows, const std::vector<unsigned> &indices);
22
-
23
-    template <class DataType>
24
-    void setUncertainty(const DataType &unc, bool transposeData,
25
-        bool partitionRows, const std::vector<unsigned> &indices);
26
-
27
-    unsigned dataRows() const;
28
-    unsigned dataCols() const;
29
-    
30
-    void setSparsity(float alpha, float maxGibbsMass, bool singleCell);
31
-    void setAnnealingTemp(float temp);
32
-    void setMatrix(const Matrix &mat);
33
-
34
-    float chi2() const;
35
-    uint64_t nAtoms() const;
36
-
37
-    void recalculateAPMatrix();
38
-    void sync(const GibbsSampler &sampler, unsigned nThreads=1);
39
-    void update(unsigned nSteps, unsigned nCores);
40
-
41
-    // serialization
42
-    friend Archive& operator<<(Archive &ar, GibbsSampler &sampler);
43
-    friend Archive& operator>>(Archive &ar, GibbsSampler &sampler);
44
-
45
-    #ifdef GAPS_DEBUG
46
-    bool internallyConsistent();
47
-    #endif
48
-
49
-private:
50
-
51
-    friend class GapsStatistics;
52
-
53
-    ColMatrix mDMatrix; // samples by genes for A, genes by samples for P
54
-    ColMatrix mSMatrix; // same configuration as D
55
-    ColMatrix mAPMatrix; // cached product of A and P, same configuration as D
56
-
57
-    ColMatrix mMatrix; // genes by patterns for A, samples by patterns for P
58
-    const ColMatrix *mOtherMatrix; // pointer to P if this is A, and vice versa
59
-
60
-    AtomicDomain mDomain; // data structure providing access to atoms
61
-    ProposalQueue mQueue; // creates queue of proposals that get evaluated by sampler
62
-
63
-    float mAlpha;
64
-    float mLambda;
65
-    float mMaxGibbsMass;
66
-    float mAnnealingTemp;
67
-    
68
-    unsigned mNumPatterns;
69
-    uint64_t mNumBins;
70
-    uint64_t mBinLength;
71
-    uint64_t mDomainLength;
72
-
73
-    void processProposal(const AtomicProposal &prop);
74
-    float deathProb(uint64_t nAtoms) const;
75
-
76
-    void birth(const AtomicProposal &prop);
77
-    void death(const AtomicProposal &prop);
78
-    void move(const AtomicProposal &prop);
79
-    void exchange(const AtomicProposal &prop);
80
-    void exchangeUsingMetropolisHastings(const AtomicProposal &prop,
81
-        AlphaParameters alpha);
82
-    void acceptExchange(const AtomicProposal &prop, float delta);
83
-    bool updateAtomMass(Atom *atom, float delta);
84
-
85
-    void changeMatrix(unsigned row, unsigned col, float delta);
86
-    void safelyChangeMatrix(unsigned row, unsigned col, float delta);
87
-    void updateAPMatrix(unsigned row, unsigned col, float delta);
88
-
89
-    bool canUseGibbs(unsigned col) const;
90
-    bool canUseGibbs(unsigned c1, unsigned c2) const;
91
-
92
-    AlphaParameters alphaParameters(unsigned row, unsigned col);
93
-    AlphaParameters alphaParameters(unsigned r1, unsigned c1, unsigned r2, unsigned c2);
94
-    AlphaParameters alphaParametersWithChange(unsigned row, unsigned col, float ch);
95
-};
96
-
97
-template <class DataType>
98
-GibbsSampler::GibbsSampler(const DataType &data,
99
-bool transposeData, unsigned nPatterns, bool partitionRows,
100
-const std::vector<unsigned> &indices)
101
-    :
102
-mDMatrix(data, transposeData, partitionRows, indices),
103
-mSMatrix(gaps::algo::pmax(mDMatrix, 0.1f)),
104
-mAPMatrix(mDMatrix.nRow(), mDMatrix.nCol()),
105
-mMatrix(mDMatrix.nCol(), nPatterns),
106
-mOtherMatrix(NULL),
107
-mDomain(mMatrix.nRow() * mMatrix.nCol()),
108
-mQueue(mMatrix.nRow(), mMatrix.nCol()),
109
-mLambda(0.f),
110
-mMaxGibbsMass(0.f),
111
-mAnnealingTemp(1.f),
112
-mNumPatterns(mMatrix.nCol()),
113
-mNumBins(mMatrix.nRow() * mMatrix.nCol()),
114
-mBinLength(std::numeric_limits<uint64_t>::max() / mNumBins),
115
-mDomainLength(mBinLength * mNumBins)
116
-{
117
-    // default sparsity parameters
118
-    setSparsity(DEFAULT_ALPHA, DEFAULT_MAX_GIBBS_MASS, false);
119
-}
120
-
121
-template <class DataType>
122
-void GibbsSampler::setUncertainty(const DataType &unc,
123
-bool transpose, bool partitionRows, const std::vector<unsigned> &indices)
124
-{
125
-    mSMatrix = ColMatrix(unc, transpose, partitionRows, indices);
126
-}
127
-
128
-#endif // __COGAPS_GIBBS_SAMPLER_H__
129 0
\ No newline at end of file
130 1
similarity index 100%
131 2
rename from src/GibbsSampler.cpp
132 3
rename to src/GibbsSamplerImplementation.cpp
133 4
new file mode 100644
... ...
@@ -0,0 +1,200 @@
1
+#ifndef __COGAPS_GIBBS_SAMPLER_H__
2
+#define __COGAPS_GIBBS_SAMPLER_H__
3
+
4
+#define DEFAULT_ALPHA           0.01f
5
+#define DEFAULT_MAX_GIBBS_MASS  100.f
6
+
7
+#include "AtomicDomain.h"
8
+#include "ProposalQueue.h"
9
+#include "data_structures/Matrix.h"
10
+#include "math/Algorithms.h"
11
+#include "math/Random.h"
12
+
13
+#include <vector>
14
+
15
+////////////////////////////// CLASS DEFINITIONS ///////////////////////////////
16
+
17
+// These classes provide the various implementations of a GibbsSampler. Compile
18
+// time polymorphism is used to reduce code duplication.
19
+
20
+// can't be constructed, interface is avaiable through derived classes
21
+template <class T>
22
+class GibbsSamplerImplementation
23
+{
24
+private:
25
+
26
+    friend class T;         
27
+
28
+    GibbsSamplerImplementation();
29
+
30
+    void setAnnealingTemp(float temp);
31
+
32
+    void update(unsigned nSteps, unsigned nThreads);
33
+    void processProposal(const AtomicProposal &prop);
34
+    void birth(const AtomicProposal &prop);
35
+    void death(const AtomicProposal &prop);
36
+    void move(const AtomicProposal &prop);
37
+    void exchange(const AtomicProposal &prop);
38
+    void exchangeUsingMetropolisHastings(const AtomicProposal &prop,
39
+        AlphaParameters alpha);
40
+    void acceptExchange(const AtomicProposal &prop, float delta);
41
+    bool updateAtomMass(Atom *atom, float delta);
42
+
43
+    AtomicDomain mDomain; // data structure providing access to atoms
44
+    ProposalQueue mQueue; // creates queue of proposals that get evaluated by sampler
45
+
46
+    float mAlpha;
47
+    float mLambda;
48
+    float mMaxGibbsMass;
49
+    float mAnnealingTemp;
50
+    
51
+    unsigned mNumPatterns;
52
+    uint64_t mBinLength;
53
+};
54
+
55
+class DenseGibbsSamplerImplementation : public GibbsSamplerImplementation<DenseGibbsSamplerImplementation>
56
+{
57
+private:
58
+
59
+    friend class DenseGibbsSampler;
60
+
61
+    DenseColMatrix mDMatrix; // samples by genes for A, genes by samples for P
62
+    DenseColMatrix mSMatrix; // same configuration as D
63
+    DenseColMatrix mAPMatrix; // cached product of A and P, same configuration as D
64
+    DenseColMatrix mMatrix; // genes by patterns for A, samples by patterns for P
65
+    const DenseColMatrix *mOtherMatrix; // pointer to P if this is A, and vice versa
66
+    
67
+    // private constructor allows only DenseGibbsSampler to construct this
68
+    DenseGibbsSamplerImplementation(const DataType &data, bool transpose,
69
+        unsigned nPatterns, bool partitionRows,
70
+        const std::vector<unsigned> &indices);
71
+
72
+    template <class DataType>
73
+    void setUncertainty(const DataType &data, bool transpose, unsigned nPatterns,
74
+        bool partitionRows, const std::vector<unsigned> &indices);
75
+
76
+    unsigned dataRows() const;
77
+    unsigned dataCols() const;
78
+
79
+    void setSparsity(float alpha, float maxGibbsMass, bool singleCell);
80
+    void setMatrix(const Matrix &mat);
81
+
82
+    void sync(const GibbsSampler &sampler)
83
+
84
+    void changeMatrix(unsigned row, unsigned col, float delta);
85
+    void safelyChangeMatrix(unsigned row, unsigned col, float delta);
86
+    void updateAPMatrix(unsigned row, unsigned col, float delta);
87
+
88
+    bool canUseGibbs(unsigned col) const;
89
+    bool canUseGibbs(unsigned c1, unsigned c2) const;
90
+    AlphaParameters alphaParameters(unsigned row, unsigned col);
91
+    AlphaParameters alphaParameters(unsigned r1, unsigned c1, unsigned r2, unsigned c2);
92
+    AlphaParameters alphaParametersWithChange(unsigned row, unsigned col, float ch);
93
+};
94
+
95
+class SparseGibbsSamplerImplementation : public GibbsSamplerImplementation<SparseGibbsSamplerImplementation>
96
+{
97
+private :
98
+
99
+    friend class SparseGibbsSampler;
100
+
101
+    SparseColMatrix mDMatrix;
102
+    HybridColMatrix mSparseMatrix;
103
+    DenseRowMatrix mDenseMatrix;
104
+    const HybridColMatrix *mOtherSparseMatrix;
105
+    const DenseRowMatrix *mOtherDenseMatrix;
106
+
107
+    // private constructor allows only SparseGibbsSampler to construct this
108
+    SparseGibbsSamplerImplementation(const DataType &data, bool transpose,
109
+        unsigned nPatterns, bool partitionRows,
110
+        const std::vector<unsigned> &indices);
111
+
112
+    unsigned dataRows() const;
113
+    unsigned dataCols() const;
114
+
115
+    void setSparsity(float alpha, float maxGibbsMass, bool singleCell);
116
+    void setMatrix(const Matrix &mat);
117
+
118
+    float chiSq() const;
119
+
120
+    bool canUseGibbs(unsigned col) const;
121
+    bool canUseGibbs(unsigned c1, unsigned c2) const;
122
+    AlphaParameters alphaParameters(unsigned row, unsigned col);
123
+    AlphaParameters alphaParameters(unsigned r1, unsigned c1, unsigned r2, unsigned c2);
124
+    AlphaParameters alphaParametersWithChange(unsigned row, unsigned col, float ch);
125
+};
126
+
127
+//////////////////////IMPLEMENTATION OF TEMPLATED FUNCTIONS ////////////////////
128
+
129
+template <class T>
130
+void GibbsSamplerImplementation<T>::update(unsigned nSteps, unsigned nThreads)
131
+{
132
+    unsigned n = 0;
133
+    while (n < nSteps)
134
+    {
135
+        // populate queue, prepare domain for this queue
136
+        mQueue.populate(mDomain, nSteps - n);
137
+        n += mQueue.size();
138
+        
139
+        // process all proposed updates
140
+        #pragma omp parallel for num_threads(nThreads)
141
+        for (unsigned i = 0; i < mQueue.size(); ++i)
142
+        {
143
+            processProposal(mQueue[i]);
144
+        }
145
+        mQueue.clear();
146
+    }
147
+
148
+    GAPS_ASSERT(internallyConsistent());
149
+    GAPS_ASSERT(mDomain.isSorted());
150
+}
151
+
152
+template <class T>
153
+void GibbsSamplerImplementation<T>::processProposal(const AtomicProposal &prop)
154
+{
155
+    switch (prop.type)
156
+    {
157
+        case 'B':
158
+            birth(prop);
159
+            break;
160
+        case 'D':
161
+            death(prop);
162
+            break;
163
+        case 'M':
164
+            move(prop);
165
+            break;
166
+        case 'E':
167
+            exchange(prop);
168
+            break;
169
+    }
170
+}
171
+
172
+// add an atom at a random position, calculate mass either with an
173
+// exponential distribution or with the gibbs mass distribution
174
+void GibbsSampler::birth(const AtomicProposal &prop)
175
+{
176
+    // calculate proposed mass
177
+    float mass = canUseGibbs(prop.c1)
178
+        ? gibbsMass(alphaParameters(prop.r1, prop.c1), &(prop.rng)).value()
179
+        : prop.rng.exponential(mLambda);
180
+
181
+    // accept mass as long as it's non-zero
182
+    if (mass >= gaps::epsilon)
183
+    {
184
+        mQueue.acceptBirth();
185
+        prop.atom1->mass = mass;
186
+        changeMatrix(prop.r1, prop.c1, mass);
187
+    }
188
+    else
189
+    {
190
+        mQueue.rejectBirth();
191
+        mDomain.erase(prop.atom1->pos);
192
+    }
193
+}
194
+
195
+
196
+
197
+
198
+
199
+
200
+#endif // __COGAPS_GIBBS_SAMPLER_H__
0 201
\ No newline at end of file
1 202
new file mode 100644
... ...
@@ -0,0 +1,116 @@
1
+#ifndef __COGAPS_GIBBS_SAMPLER_H__
2
+#define __COGAPS_GIBBS_SAMPLER_H__
3
+
4
+#define DEFAULT_ALPHA           0.01f
5
+#define DEFAULT_MAX_GIBBS_MASS  100.f
6
+
7
+#include "AtomicDomain.h"
8
+#include "ProposalQueue.h"
9
+#include "data_structures/Matrix.h"
10
+#include "math/Algorithms.h"
11
+#include "math/Random.h"
12
+
13
+#include <vector>
14
+
15
+// This is a polymorphic wrapper to an underlying implementation. The purpose of
16
+// this is to move the virtual functions to the highest level possible where
17
+// they are called infrequently, rather than pay the performance cost of
18
+// dynamic dispatch of frequently called, inexpensive functions.
19
+
20
+// interface
21
+class GibbsSampler
22
+{
23
+public:
24
+
25
+    virtual unsigned dataRows() const = 0;
26
+    virtual unsigned dataCols() const = 0;
27
+    
28
+    virtual void setSparsity(float alpha, float maxGibbsMass, bool singleCell) = 0;
29
+    virtual void setAnnealingTemp(float temp) = 0;
30
+    virtual void setMatrix(const Matrix &mat) = 0;
31
+
32
+    virtual float chi2() const = 0;
33
+    virtual uint64_t nAtoms() const = 0;
34
+
35
+    virtual void recalculateAPMatrix() = 0;
36
+    virtual void sync(const GibbsSampler *sampler, unsigned nThreads=1) = 0;
37
+    virtual void update(unsigned nSteps, unsigned nCores) = 0;
38
+};
39
+
40
+// wrapper for a dense GibbsSampler implementation - all data in this class
41
+// is stored as a dense matrix
42
+class DenseGibbsSampler : public GibbsSampler
43
+{
44
+public:
45
+
46
+    template <class DataType>
47
+    DenseGibbsSampler(const DataType &data, bool transpose, unsigned nPatterns,
48
+    bool partitionRows, const std::vector<unsigned> &indices)
49
+        : mImplementation(data, transpose, nPatterns, partitionRows, indices);
50
+    {}
51
+
52
+    template <class DataType>
53
+    void setUncertainty(const DataType &data, bool transpose, unsigned nPatterns,
54
+    bool partitionRows, const std::vector<unsigned> &indices)
55
+    {
56
+        mImplementation.setUncertainty(data, transpose, nPatterns, partitionRows, indices);
57
+    }
58
+
59
+    unsigned dataRows() const { return mImplementation.dataRows(); }
60
+    unsigned dataCols() const { return mImplementation.dataCols(); }
61
+
62
+    void setSparsity(float alpha, float maxGibbsMass, bool singleCell) { mImplementation.setSparsity(alpha, maxGibbsMass, singleCell); }
63
+    void setAnnealingTemp(float temp) { mImplementation.setAnnealingTemp(temp); }
64
+    void setMatrix(const Matrix &mat) { mImplementation.setMatrix(mat); }
65
+
66
+    float chi2() const { return mImplementation.chi2(); }
67
+    uint64_t nAtoms() const { return mImplementation.nAtoms(); }
68
+
69
+    void recalculateAPMatrix() { mImplementation.recalculateAPMatrix(); }
70
+    void sync(const GibbsSampler *sampler, unsigned nThreads=1) { mImplementation.sync(sampler->mImplementation, nThreads); }
71
+    void update(unsigned nSteps, unsigned nCores) { mImplementation.update(nSteps, nCores); }
72
+
73
+private:
74
+
75
+    DenseGibbsSamplerImplementation mImplementation;
76
+};
77
+
78
+// wrapper for a sparse GibbsSampler implementation - all data in this class
79
+// is stored as a sparse matrix
80
+class SparseGibbsSampler : public GibbsSampler
81
+{
82
+public:
83
+
84
+    template <class DataType>
85
+    SparseGibbsSampler(const DataType &data, bool transpose, unsigned nPatterns,
86
+    bool partitionRows, const std::vector<unsigned> &indices)
87
+        : mImplementation(data, transpose, nPatterns, partitionRows, indices);
88
+    {}
89
+
90
+    template <class DataType>
91
+    void setUncertainty(const DataType &data, bool transpose, unsigned nPatterns,
92
+    bool partitionRows, const std::vector<unsigned> &indices)
93
+    {
94
+        GAPS_ASSERT(false); // should never reach
95
+    }
96
+
97
+    unsigned dataRows() const { return mImplementation.dataRows(); }
98
+    unsigned dataCols() const { return mImplementation.dataCols(); }
99
+
100
+    void setSparsity(float alpha, float maxGibbsMass, bool singleCell) { mImplementation.setSparsity(alpha, maxGibbsMass, singleCell); }
101
+    void setAnnealingTemp(float temp) { mImplementation.setAnnealingTemp(temp); }
102
+    void setMatrix(const Matrix &mat) { mImplementation.setMatrix(mat); }
103
+
104
+    float chi2() const { return mImplementation.chi2(); }
105
+    uint64_t nAtoms() const { return mImplementation.nAtoms(); }
106
+
107
+    void recalculateAPMatrix() { GAPS_ASSERT(false); /* should never reach */}
108
+    void sync(const GibbsSampler &sampler, unsigned nThreads=1) { mImplementation.sync(sampler); }
109
+    void update(unsigned nSteps, unsigned nCores) { mImplementation.update(nSteps, nCores); }
110
+
111
+private:
112
+    
113
+    SparseGibbsSamplerImplementation mImplementation;
114
+};
115
+
116
+#endif // __COGAPS_GIBBS_SAMPLER_H__
0 117
\ No newline at end of file
... ...
@@ -6,8 +6,12 @@
6 6
 
7 7
 #include <vector>
8 8
 
9
-class ColMatrix;
10
-typedef ColMatrix Matrix; // when we don't care about row/col major order
9
+class DenseColMatrix;
10
+class DenseRowMatrix;
11
+class SparseColMatrix;
12
+class HybridColMatrix;
13
+
14
+typedef DenseColMatrix Matrix; // when we don't care about storage
11 15
 
12 16
 class ColMatrix
13 17
 {
14 18
new file mode 100644
... ...
@@ -0,0 +1,70 @@
1
+// can access without iterator, can set elements with accesor
2
+class Vector
3
+{
4
+public:
5
+
6
+    explicit Vector(unsigned size);
7
+    explicit Vector(const std::vector<float> &v);
8
+
9
+    float& operator()(unsigned i, unsigned j); // set value
10
+    const float* ptr() const; // access without iterator
11
+
12
+    friend Archive& operator<<(Archive &ar, Vector &vec);
13
+    friend Archive& operator>>(Archive &ar, Vector &vec);
14
+
15
+private:
16
+
17
+    aligned_vector mData;
18
+};
19
+
20
+// can only access through iterator, all data is const
21
+class SparseVector
22
+{
23
+public:
24
+
25
+    explicit SparseVector(unsigned size);
26
+    explicit SparseVector(const std::vector<float> &v);
27
+
28
+    friend Archive& operator<<(Archive &ar, Vector &vec);
29
+    friend Archive& operator>>(Archive &ar, Vector &vec);
30
+
31
+private:
32
+    
33
+    std::vector<uint64_t> mIndexBitFlags;
34
+    std::vector<float> mData;
35
+};
36
+
37
+// stored as a dense vector (efficient setting of values) but maintains
38
+// index bit flags of non-zeros so it can be used with SparseIterator
39
+class HybridVector
40
+{
41
+public:
42
+
43
+    explicit HybridVector(unsigned size);
44
+    explicit HybridVector(const std::vector<float> &v);
45
+
46
+    void change(unsigned i, unsigned j, float v);
47
+
48
+    friend Archive& operator<<(Archive &ar, Vector &vec);
49
+    friend Archive& operator>>(Archive &ar, Vector &vec);
50
+
51
+private:
52
+
53
+    std::vector<uint64_t> mIndexBitFlags;
54
+    std::vector<float> mData;
55
+};
56
+
57
+class SparseIterator
58
+{
59
+public:
60
+
61
+    SparseIterator(const HybridVector &A, const SparseVector &B);
62
+
63
+    bool atEnd() const;
64
+    void next();
65
+    float firstValue();
66
+    float secondValue();
67
+    unsigned firstIndex();
68
+    unsigned secondIndex();
69
+};
70
+
... ...
@@ -237,6 +237,33 @@ unsigned nUpdates)
237 237
     return stdMat;
238 238
 }
239 239
 
240
+// vec is a column of either A or P
241
+AlphaParameters gaps::algo::alphaParameters(const Sparsevector &D,
242
+const Sparsevector &vec, const float *A, const float *P, unsigned size)
243
+{
244
+    // initialize
245
+    float s = -1.f * Z_1[column] * beta;
246
+    float su = 0.f;
247
+    for (unsigned i = 0; i < size; ++i)
248
+    {
249
+        su += A[i] * Z_2[column, i];
250
+    }
251
+    su *= -1.f * beta;
252
+
253
+    // iterate over common non-zero entries
254
+    Sparsevector it(D, vec);
255
+    while (!it.atEnd())
256
+    {
257
+        float term1 = it.firstValue() / it.secondValue();
258
+        float term2 = term1 * term1 + it.firstValue() * it.firstValue() * beta;
259
+        float term3 = beta * it.firstValue() - alpha * term1 / it.secondValue();
260
+        s += alpha * term2;
261
+        s_mu += alpha * term1 + term3 * gaps::algo::dotProduct(A, P, size);
262
+        it.next();
263
+    }
264
+    return AlphaParameters(s, s_mu);
265
+}
266
+
240 267
 AlphaParameters gaps::algo::alphaParameters(unsigned size, const float *D,
241 268
 const float *S, const float *AP, const float *mat)
242 269
 {
... ...
@@ -42,7 +42,7 @@ public:
42 42
         }
43 43
     }
44 44
 
45
-    void close()
45
+    ~Archive()
46 46
     {
47 47
         mStream.close();
48 48
     }