Browse code

original cogaps working with checlpoint framework

sherman5 authored on 04/01/2018 18:31:37
Showing12 changed files

... ...
@@ -256,7 +256,6 @@ double gaps::algo::deltaLL(const MatrixChange &ch, const TwoWayMatrix &D,
256 256
 const TwoWayMatrix &S, const ColMatrix &A, const RowMatrix &P,
257 257
 const TwoWayMatrix &AP)
258 258
 {
259
-    // change in A matrix
260 259
     if (ch.label == 'A' && ch.nChanges == 2 && ch.row1 != ch.row2)
261 260
     {
262 261
         return deltaLL_A(D, S, P, AP, ch.row1, ch.col1, ch.delta1)
... ...
@@ -267,14 +266,12 @@ const TwoWayMatrix &AP)
267 266
         return deltaLL_A(D, S, P, AP, ch.row1, ch.col1, ch.delta1, ch.col2,
268 267
             ch.delta2, ch.nChanges == 2);
269 268
     }
270
-
271
-    // change in P matrix
272
-    if (ch.label == 'P' && ch.nChanges == 2 && ch.col1 != ch.col2)
269
+    else if (ch.label == 'P' && ch.nChanges == 2 && ch.col1 != ch.col2)
273 270
     {
274 271
         return deltaLL_P(D, S, A, AP, ch.col1, ch.row1, ch.delta1)
275 272
             + deltaLL_P(D, S, A, AP, ch.col2, ch.row2, ch.delta2);
276 273
     }
277
-    else if (ch.label == 'P')
274
+    else
278 275
     {
279 276
         return deltaLL_P(D, S, A, AP, ch.col1, ch.row1, ch.delta1, ch.row2,
280 277
             ch.delta2, ch.nChanges == 2);
... ...
@@ -5,10 +5,11 @@ static const double EPSILON = 1.e-10;
5 5
 AtomicSupport::AtomicSupport(char label, uint64_t nrow, uint64_t ncol,
6 6
 double alpha, double lambda)
7 7
     :
8
-mNumRows(nrow), mNumCols(ncol), mNumBins(nrow * ncol),
9
-mNumAtoms(0), mTotalMass(0.0), mLabel(label), mAlpha(alpha), 
10
-mMaxNumAtoms(std::numeric_limits<uint64_t>::max()), mLambda(lambda),
11
-mBinSize(std::numeric_limits<uint64_t>::max() / (nrow * ncol))
8
+mLabel(label), mNumAtoms(0),
9
+mMaxNumAtoms(std::numeric_limits<uint64_t>::max()),
10
+mTotalMass(0.0), mNumRows(nrow), mNumCols(ncol), mNumBins(nrow * ncol),
11
+mBinSize(std::numeric_limits<uint64_t>::max() / (nrow * ncol)),
12
+mAlpha(alpha), mLambda(lambda)
12 13
 {}
13 14
 
14 15
 uint64_t AtomicSupport::getRow(uint64_t pos) const
... ...
@@ -113,7 +114,7 @@ AtomicProposal AtomicSupport::proposeExchange() const
113 114
     return AtomicProposal(mLabel, 'E', pos1, delta1, pos2, delta2);
114 115
 }
115 116
 
116
-double AtomicSupport::updateAtomMass(char type, uint64_t pos, double delta)
117
+double AtomicSupport::updateAtomMass(uint64_t pos, double delta)
117 118
 {
118 119
     if (mAtomicDomain.count(pos)) // update atom if it exists
119 120
     {
... ...
@@ -172,10 +173,10 @@ AtomicProposal AtomicSupport::makeProposal() const
172 173
 MatrixChange AtomicSupport::acceptProposal(const AtomicProposal &prop)
173 174
 {
174 175
     MatrixChange change = getMatrixChange(prop);
175
-    change.delta1 = updateAtomMass(prop.type, prop.pos1, prop.delta1);
176
+    change.delta1 = updateAtomMass(prop.pos1, prop.delta1);
176 177
     if (prop.nChanges > 1)
177 178
     {
178
-        change.delta2 = updateAtomMass(prop.type, prop.pos2, prop.delta2);
179
+        change.delta2 = updateAtomMass(prop.pos2, prop.delta2);
179 180
     }
180 181
     return change;
181 182
 }
... ...
@@ -76,14 +76,14 @@ public:
76 76
     AtomicProposal proposeExchange() const;
77 77
 
78 78
     // update the mass of an atom, return the total amount changed
79
-    double updateAtomMass(char type, uint64_t pos, double delta);
79
+    double updateAtomMass(uint64_t pos, double delta);
80 80
 
81 81
 public:
82 82
 
83
-    // constructor
83
+    // constructors
84
+    AtomicSupport(std::ifstream &file);
84 85
     AtomicSupport(char label, uint64_t nrow, uint64_t ncol, double alpha=1.0,
85 86
         double lambda=1.0);
86
-    AtomicSupport(const std::ifstream &file);
87 87
 
88 88
     // create and accept a proposal
89 89
     AtomicProposal makeProposal() const;
... ...
@@ -103,7 +103,6 @@ public:
103 103
     void setAlpha(double alpha) {mAlpha = alpha;}
104 104
     void setLambda(double lambda) {mLambda = lambda;}
105 105
 
106
-    // serialize and write to file
107 106
     void serializeAndWrite(const std::ofstream &file);
108 107
 };
109 108
 
... ...
@@ -1,4 +1,5 @@
1 1
 #include "GibbsSampler.h"
2
+#include "Matrix.h"
2 3
 
3 4
 #include <Rcpp.h>
4 5
 #include <ctime>
... ...
@@ -8,7 +9,7 @@ typedef std::vector<Rcpp::NumericMatrix> SnapshotList;
8 9
 
9 10
 enum GapsPhase
10 11
 {
11
-    GAPS_CALIBRATION=,
12
+    GAPS_CALIBRATION,
12 13
     GAPS_COOLING,
13 14
     GAPS_SAMPLING
14 15
 };
... ...
@@ -17,8 +18,6 @@ enum GapsPhase
17 18
 // initialization depends on it
18 19
 struct GapsInternalState
19 20
 {
20
-    GibbsSampler sampler;
21
-    
22 21
     Vector chi2VecEquil;
23 22
     Vector nAtomsAEquil;
24 23
     Vector nAtomsPEquil;
... ...
@@ -34,35 +33,35 @@ struct GapsInternalState
34 33
     unsigned nEquilCool;
35 34
     unsigned nSample;
36 35
 
36
+    unsigned nSnapshots;
37
+    unsigned nOutputs;
38
+    bool messages;
39
+
37 40
     unsigned iteration;
38 41
     GapsPhase phase;
42
+    uint32_t seed;
43
+
44
+    GibbsSampler sampler;
39 45
 
40 46
     SnapshotList snapshotsA;
41 47
     SnapshotList snapshotsP;
42 48
 
43
-    uint32_t seed;
44
-
45 49
     GapsInternalState(Rcpp::NumericMatrix DMatrix, Rcpp::NumericMatrix SMatrix,
46 50
         unsigned nFactor, double alphaA, double alphaP, unsigned nE,
47 51
         unsigned nEC, unsigned nS, double maxGibbsMassA,
48 52
         double maxGibbsMassP, Rcpp::NumericMatrix fixedPatterns,
49
-        char whichMatrixFixed, bool messages, bool singleCellRNASeq,
50
-        unsigned numOutputs, unsigned numSnapshots)
53
+        char whichMatrixFixed, bool msg, bool singleCellRNASeq,
54
+        unsigned numOutputs, unsigned numSnapshots, uint32_t in_seed)
51 55
             :
52
-        chi2VecEquil(nEquil), nAtomsAEquil(nEquil), nAtomsPEquil(nEquil),
53
-        chi2VecSample(nSample), nAtomsASample(nSample), nAtomsPSample(nSample),
56
+        chi2VecEquil(nE), nAtomsAEquil(nE), nAtomsPEquil(nE),
57
+        chi2VecSample(nS), nAtomsASample(nS), nAtomsPSample(nS),
54 58
         nIterA(10), nIterP(10), nEquil(nE), nEquilCool(nEC), nSample(nS),
55
-        iteration(0), phase(GAPS_CALIBRATION),
59
+        nSnapshots(numSnapshots), nOutputs(numOutputs), messages(msg),
60
+        iteration(0), phase(GAPS_CALIBRATION), seed(in_seed),
56 61
         sampler(DMatrix, SMatrix, nFactor, alphaA, alphaP,
57 62
             maxGibbsMassA, maxGibbsMassP, singleCellRNASeq, fixedPatterns,
58 63
             whichMatrixFixed)
59 64
     {}
60
-
61
-    GapsInternalState(std::ifstream &file)
62
-        :
63
-    {}
64
-
65
-
66 65
 };
67 66
 
68 67
 static void runGibbsSampler(GapsInternalState &state, unsigned nIterTotal,
... ...
@@ -89,31 +88,29 @@ Vector &chi2Vec, Vector &aAtomVec, Vector &pAtomVec)
89 88
         if (state.phase == GAPS_SAMPLING)
90 89
         {
91 90
             state.sampler.updateStatistics();
92
-            if (state.numSnapshots > 0 && (i + 1) % (nIterTotal / state.numSnapshots) == 0)
91
+            if (state.nSnapshots > 0 && (i + 1) % (nIterTotal / state.nSnapshots) == 0)
93 92
             {
94 93
                 state.snapshotsA.push_back(state.sampler.getNormedMatrix('A'));
95 94
                 state.snapshotsP.push_back(state.sampler.getNormedMatrix('P'));
96 95
             }
97 96
         }
98 97
 
99
-        if (phase != GAPS_COOLING)
98
+        if (state.phase != GAPS_COOLING)
100 99
         {
101
-            double numAtomsA = state.sampler.totalNumAtoms('A');
102
-            double numAtomsP = state.sampler.totalNumAtoms('P');
103
-            aAtomVec(i) = numAtomsA;
104
-            pAtomVec(i) = numAtomsP;
100
+            aAtomVec(i) = state.sampler.totalNumAtoms('A');
101
+            pAtomVec(i) = state.sampler.totalNumAtoms('P');
105 102
             chi2Vec(i) = state.sampler.chi2();
103
+            state.nIterA = gaps::random::poisson(std::max(aAtomVec(i), 10.0));
104
+            state.nIterP = gaps::random::poisson(std::max(pAtomVec(i), 10.0));
106 105
 
107
-            if ((i + 1) % state.numOutputs == 0 && state.messages)
106
+            if ((i + 1) % state.nOutputs == 0 && state.messages)
108 107
             {
109 108
                 std::string temp = state.phase == GAPS_CALIBRATION ? "Equil: "
110 109
                     : "Samp: ";
111 110
                 std::cout << temp << i + 1 << " of " << nIterTotal
112
-                    << ", Atoms:" << numAtomsA << "(" << numAtomsP
111
+                    << ", Atoms:" << aAtomVec(i) << "(" << pAtomVec(i)
113 112
                     << ") Chi2 = " << state.sampler.chi2() << '\n';
114 113
             }
115
-            state.nIterA = gaps::random::poisson(std::max(numAtomsA, 10.0));
116
-            state.nIterP = gaps::random::poisson(std::max(numAtomsP, 10.0));
117 114
         }
118 115
     }
119 116
 }
... ...
@@ -123,7 +120,7 @@ static Rcpp::List runCogaps(GapsInternalState &state)
123 120
     if (state.phase == GAPS_CALIBRATION)
124 121
     {
125 122
         runGibbsSampler(state, state.nEquil, state.chi2VecEquil,
126
-            state.nAtomsAEquil, state.nAtomsPEquil)
123
+            state.nAtomsAEquil, state.nAtomsPEquil);
127 124
         state.phase = GAPS_COOLING;
128 125
     }
129 126
 
... ...
@@ -137,7 +134,7 @@ static Rcpp::List runCogaps(GapsInternalState &state)
137 134
     if (state.phase == GAPS_SAMPLING)
138 135
     {
139 136
         runGibbsSampler(state, state.nSample, state.chi2VecSample,
140
-            state.nAtomsASample, state.nAtomsPSample)
137
+            state.nAtomsASample, state.nAtomsPSample);
141 138
     }
142 139
 
143 140
     // combine chi2 vectors
... ...
@@ -155,34 +152,11 @@ static Rcpp::List runCogaps(GapsInternalState &state)
155 152
         Rcpp::Named("atomsASamp") = state.nAtomsASample.rVec(),
156 153
         Rcpp::Named("atomsPEquil") = state.nAtomsPEquil.rVec(),
157 154
         Rcpp::Named("atomsPSamp") = state.nAtomsPSample.rVec(),
158
-        Rcpp::Named("chiSqValues") = state.chi2Vec.rVec(),
155
+        Rcpp::Named("chiSqValues") = chi2Vec.rVec(),
159 156
         Rcpp::Named("randSeed") = state.seed
160 157
     );
161 158
 }
162 159
 
163
-// [[Rcpp::export]]
164
-Rcpp::List cogapsFromCheckpoint(const std::string &fileName)
165
-{   
166
-    // open file
167
-    std::ifstream file(fileName);
168
-
169
-    // verify magic number
170
-    uint32_t magicNum = 0;
171
-    file.read(reinterpret_cast<char*>(&magicNum), sizeof(uint32_t));
172
-    if (magicNum != 0xCE45D32A)
173
-    {
174
-        std::cout << "invalid checkpoint file" << std::endl;
175
-        return Rcpp::List::create();
176
-    }
177
-    
178
-    // seed random number generator and create internal state
179
-    gaps::random::load(file);
180
-    GapsInternalState state(file);
181
-
182
-    // run cogaps from this internal state
183
-    return runCogaps(state);
184
-}
185
-
186 160
 // [[Rcpp::export]]
187 161
 Rcpp::List cogaps(Rcpp::NumericMatrix DMatrix, Rcpp::NumericMatrix SMatrix,
188 162
 unsigned nFactor, double alphaA, double alphaP, unsigned nEquil,
... ...
@@ -206,9 +180,32 @@ unsigned numOutputs, unsigned numSnapshots)
206 180
     GapsInternalState state(DMatrix, SMatrix, nFactor, alphaA, alphaP,
207 181
         nEquil, nEquilCool, nSample, maxGibbsMassA, maxGibbsMassP,
208 182
         fixedPatterns, whichMatrixFixed, messages, singleCellRNASeq,
209
-        numOutputs, numSnapshots);
183
+        numOutputs, numSnapshots, seedUsed);
210 184
 
211 185
     // run cogaps from this internal state
212 186
     return runCogaps(state);
213 187
 }
214 188
 
189
+/*
190
+Rcpp::List cogapsFromCheckpoint(const std::string &fileName)
191
+{   
192
+    // open file
193
+    std::ifstream file(fileName);
194
+
195
+    // verify magic number
196
+    uint32_t magicNum = 0;
197
+    file.read(reinterpret_cast<char*>(&magicNum), sizeof(uint32_t));
198
+    if (magicNum != 0xCE45D32A)
199
+    {
200
+        std::cout << "invalid checkpoint file" << std::endl;
201
+        return Rcpp::List::create();
202
+    }
203
+    
204
+    // seed random number generator and create internal state
205
+    gaps::random::load(file);
206
+    GapsInternalState state(file);
207
+
208
+    // run cogaps from this internal state
209
+    return runCogaps(state);
210
+}
211
+*/
... ...
@@ -10,14 +10,14 @@ unsigned int nFactor, double alphaA, double alphaP, double maxGibbsMassA,
10 10
 double maxGibbsMassP, bool singleCellRNASeq, Rcpp::NumericMatrix fixedPat,
11 11
 char whichMat)
12 12
     :
13
-mDMatrix(D), mSMatrix(S), mAMatrix(D.nrow(), nFactor),
14
-mPMatrix(nFactor, D.ncol()), mAPMatrix(D.nrow(), D.ncol()),
13
+mDMatrix(D), mSMatrix(S), mAPMatrix(D.nrow(), D.ncol()),
14
+mAMatrix(D.nrow(), nFactor), mPMatrix(nFactor, D.ncol()), 
15 15
 mADomain('A', D.nrow(), nFactor), mPDomain('P', nFactor, D.ncol()),
16
-mMaxGibbsMassA(maxGibbsMassA), mMaxGibbsMassP(maxGibbsMassP),
17
-mAnnealingTemp(1.0), mChi2(0.0), mSingleCellRNASeq(singleCellRNASeq),
18 16
 mAMeanMatrix(D.nrow(), nFactor), mAStdMatrix(D.nrow(), nFactor),
19 17
 mPMeanMatrix(nFactor, D.ncol()), mPStdMatrix(nFactor, D.ncol()),
20
-mStatUpdates(0), mFixedMat(whichMat)
18
+mStatUpdates(0), mMaxGibbsMassA(maxGibbsMassA), mMaxGibbsMassP(maxGibbsMassP),
19
+mAnnealingTemp(1.0), mChi2(0.0), mSingleCellRNASeq(singleCellRNASeq),
20
+mFixedMat(whichMat)
21 21
 {
22 22
     double meanD = mSingleCellRNASeq ? gaps::algo::nonZeroMean(mDMatrix)
23 23
         : gaps::algo::mean(mDMatrix);
... ...
@@ -13,14 +13,16 @@ private:
13 13
 public:
14 14
 #endif
15 15
 
16
-    AtomicSupport mADomain, mPDomain;
17
-
18 16
     TwoWayMatrix mDMatrix, mSMatrix, mAPMatrix;
19
-    RowMatrix mPMatrix;
17
+
20 18
     ColMatrix mAMatrix;
19
+    RowMatrix mPMatrix;
20
+
21
+    AtomicSupport mADomain, mPDomain;
21 22
 
22
-    RowMatrix mPMeanMatrix, mPStdMatrix;
23 23
     ColMatrix mAMeanMatrix, mAStdMatrix;
24
+    RowMatrix mPMeanMatrix, mPStdMatrix;
25
+
24 26
     unsigned mStatUpdates;
25 27
 
26 28
     double mMaxGibbsMassA;
... ...
@@ -55,10 +57,10 @@ public:
55 57
 
56 58
 public:
57 59
 
60
+    GibbsSampler(std::ifstream &file);
58 61
     GibbsSampler(Rcpp::NumericMatrix D, Rcpp::NumericMatrix S, unsigned nFactor,
59 62
         double alphaA, double alphaP, double maxGibbsMassA, double maxGibbsMassP,
60 63
         bool singleCellRNASeq, Rcpp::NumericMatrix fixedPat, char whichMat);
61
-    GibbsSampler(const std::ifstream &file);
62 64
 
63 65
     void update(char matrixLabel);
64 66
 
... ...
@@ -75,7 +77,6 @@ public:
75 77
 
76 78
     Rcpp::NumericMatrix getNormedMatrix(char mat);
77 79
 
78
-    // serialize and write to file
79 80
     void serializeAndWrite(const std::ofstream &file);
80 81
 };
81 82
 
... ...
@@ -1,4 +1,4 @@
1
-PKG_CPPFLAGS = 
1
+PKG_CPPFLAGS = -Wall -Wextra
2 2
 OBJECTS =   Algorithms.o \
3 3
             AtomicSupport.o \
4 4
             Cogaps.o \
... ...
@@ -50,6 +50,7 @@ public:
50 50
 
51 51
     Vector(unsigned size) : mValues(std::vector<matrix_data_t>(size, 0.0)) {}
52 52
     Vector(const std::vector<matrix_data_t>& v) : mValues(v) {}
53
+    Vector(const Vector &vec) : mValues(vec.mValues) {}
53 54
 
54 55
     matrix_data_t& operator()(unsigned i) {return mValues[i];}
55 56
     matrix_data_t operator()(unsigned i) const {return mValues[i];}
... ...
@@ -57,7 +57,7 @@ double gaps::random::uniform(double a, double b)
57 57
     {
58 58
         return a;
59 59
     }
60
-    else if (a < b)
60
+    else
61 61
     {
62 62
         boost::random::uniform_real_distribution<> dist(a,b);
63 63
         return dist(rng);
... ...
@@ -77,7 +77,7 @@ uint64_t gaps::random::uniform64(uint64_t a, uint64_t b)
77 77
     {
78 78
         return a;
79 79
     }
80
-    else if (a < b)
80
+    else
81 81
     {
82 82
         boost::random::uniform_int_distribution<uint64_t> dist(a,b);
83 83
         return dist(rng);
... ...
@@ -3,6 +3,7 @@
3 3
 
4 4
 #include <stdint.h>
5 5
 #include <vector>
6
+#include <fstream>
6 7
 
7 8
 namespace gaps
8 9
 {
... ...
@@ -28,8 +29,8 @@ namespace random
28 29
     double q_norm(double q, double mean, double sd);
29 30
     double p_norm(double p, double mean, double sd);
30 31
 
31
-    void save(ofstream &file);
32
-    void load(ofstream &file);
32
+    void save(std::ofstream &file);
33
+    void load(std::ifstream &file);
33 34
 }
34 35
 
35 36
 }
... ...
@@ -11537,7 +11537,7 @@ int main (int argc, char * const argv[]) {
11537 11537
 #endif
11538 11538
 
11539 11539
 #define INFO( msg ) INTERNAL_CATCH_INFO( "INFO", msg )
11540
-#define WARN( msg ) INTERNAL_CATCH_MSG( "WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg )
11540
+//#define WARN( msg ) INTERNAL_CATCH_MSG( "WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg )
11541 11541
 #define SCOPED_INFO( msg ) INTERNAL_CATCH_INFO( "INFO", msg )
11542 11542
 #define CAPTURE( msg ) INTERNAL_CATCH_INFO( "CAPTURE", #msg " := " << Catch::toString(msg) )
11543 11543
 #define SCOPED_CAPTURE( msg ) INTERNAL_CATCH_INFO( "CAPTURE", #msg " := " << Catch::toString(msg) )
... ...
@@ -42,10 +42,8 @@ TEST_CASE("Test AtomicSupport.h")
42 42
             
43 43
             REQUIRE(change.label == 'A');
44 44
             REQUIRE(change.nChanges == prop.nChanges);
45
-            cond = change.row1 >= 0 && change.row2 < nrow;
46
-            REQUIRE(cond);
47
-            cond = change.col1 >= 0 && change.col2 < ncol;
48
-            REQUIRE(cond);
45
+            REQUIRE(change.row2 < nrow);
46
+            REQUIRE(change.col2 < ncol);
49 47
             REQUIRE(change.delta1 == prop.delta1);
50 48
             REQUIRE(change.delta2 == prop.delta2);
51 49