Browse code

basic checkpoint system working

sherman5 authored on 08/01/2018 22:48:46
Showing17 changed files

... ...
@@ -34,7 +34,7 @@ Imports:
34 34
 Suggests:
35 35
     testthat,
36 36
     lintr
37
-LinkingTo: Rcpp
37
+LinkingTo: Rcpp, BH
38 38
 License: GPL (==2)
39 39
 biocViews: GeneExpression, Transcription, GeneSetEnrichment,
40 40
     DifferentialExpression, Bayesian, Clustering, TimeCourse, RNASeq, Microarray,
... ...
@@ -9,6 +9,7 @@ export(calcZ)
9 9
 export(createGWCoGAPSSets)
10 10
 export(gapsMapRun)
11 11
 export(gapsRun)
12
+export(gapsRunFromCheckpoint)
12 13
 export(generateSeeds)
13 14
 export(patternMarkers)
14 15
 export(patternMatch4Parallel)
... ...
@@ -5,6 +5,10 @@ cogaps <- function(DMatrix, SMatrix, nFactor, alphaA, alphaP, nEquil, nEquilCool
5 5
     .Call('_CoGAPS_cogaps', PACKAGE = 'CoGAPS', DMatrix, SMatrix, nFactor, alphaA, alphaP, nEquil, nEquilCool, nSample, maxGibbsMassA, maxGibbsMassP, fixedPatterns, whichMatrixFixed, seed, messages, singleCellRNASeq, numOutputs, numSnapshots)
6 6
 }
7 7
 
8
+cogapsFromCheckpoint <- function(fileName) {
9
+    .Call('_CoGAPS_cogapsFromCheckpoint', PACKAGE = 'CoGAPS', fileName)
10
+}
11
+
8 12
 run_catch_unit_tests <- function() {
9 13
     .Call('_CoGAPS_run_catch_unit_tests', PACKAGE = 'CoGAPS')
10 14
 }
... ...
@@ -138,3 +138,56 @@ gapsRun <- function(D, S, ABins = data.frame(), PBins = data.frame(),
138 138
     message(paste("Chi-Squared of Mean:", calcChiSq))
139 139
     return(cogapResult);
140 140
 }
141
+
142
+#' @export
143
+gapsRunFromCheckpoint <- function(D, S, path)
144
+{
145
+    # call to C++ Rcpp code
146
+    cogapResult <- cogapsFromCheckpoint(path)
147
+
148
+    # convert returned files to matrices to simplify visualization and processing
149
+    cogapResult$Amean <- as.matrix(cogapResult$Amean);
150
+    cogapResult$Asd <- as.matrix(cogapResult$Asd);
151
+    cogapResult$Pmean <- as.matrix(cogapResult$Pmean);
152
+    cogapResult$Psd <- as.matrix(cogapResult$Psd);
153
+
154
+    geneNames <- rownames(D);
155
+    sampleNames <- colnames(D);
156
+
157
+    # label patterns as Patt N
158
+    patternNames <- c("0");
159
+    for(i in 1:ncol(cogapResult$Amean))
160
+    {
161
+        patternNames[i] <- paste('Patt', i);
162
+    }
163
+
164
+    ## label matrices
165
+    colnames(cogapResult$Amean) <- patternNames;
166
+    rownames(cogapResult$Amean) <- geneNames;
167
+    colnames(cogapResult$Asd) <- patternNames;
168
+    rownames(cogapResult$Asd) <- geneNames;
169
+    colnames(cogapResult$Pmean) <- sampleNames;
170
+    rownames(cogapResult$Pmean) <- patternNames;
171
+    colnames(cogapResult$Psd) <- sampleNames;
172
+    rownames(cogapResult$Psd) <- patternNames;
173
+
174
+    ## calculate chi-squared of mean, this should be smaller than individual
175
+    ## chi-squared sample values if sampling is good
176
+    calcChiSq <- c(0);
177
+    MMatrix <- (cogapResult$Amean %*% cogapResult$Pmean);
178
+
179
+    for(i in 1:(nrow(MMatrix)))
180
+    {
181
+        for(j in 1:(ncol(MMatrix)))
182
+        {
183
+            calcChiSq <- calcChiSq + ((D[i,j] - MMatrix[i,j])/S[i,j])^2;
184
+        }
185
+    }
186
+
187
+    cogapResult = c(cogapResult, calcChiSq);
188
+    
189
+    # names(cogapResult)[13] <- "meanChi2";
190
+
191
+    message(paste("Chi-Squared of Mean:", calcChiSq))
192
+    return(cogapResult);
193
+}
141 194
\ No newline at end of file
142 195
new file mode 100644
... ...
@@ -0,0 +1,43 @@
1
+#ifndef __COGAPS_ARCHIVE_H__
2
+#define __COGAPS_ARCHIVE_H__
3
+
4
+#include <boost/random/mersenne_twister.hpp>
5
+#include <fstream>
6
+
7
+#define ARCHIVE_READ  std::ios::in
8
+#define ARCHIVE_WRITE std::ios::out | std::ios::trunc
9
+
10
+class Archive
11
+{
12
+private:
13
+
14
+    std::fstream mStream;
15
+
16
+public:
17
+
18
+    Archive(const std::string &path, std::ios_base::openmode flags)
19
+        : mStream(path.c_str(), std::ios::binary | flags)
20
+    {}
21
+
22
+    void close() {mStream.close();}
23
+
24
+    template<typename T>
25
+    friend void operator<<(Archive &ar, T val);
26
+
27
+    template<typename T>
28
+    friend void operator>>(Archive &ar, T &val);
29
+};
30
+
31
+template<typename T>
32
+void operator<<(Archive &ar, T val)
33
+{
34
+    ar.mStream.write(reinterpret_cast<char*>(&val), sizeof(T));
35
+}
36
+
37
+template<typename T>
38
+void operator>>(Archive &ar, T &val)
39
+{
40
+    ar.mStream.read(reinterpret_cast<char*>(&val), sizeof(T));
41
+}
42
+
43
+#endif
0 44
\ No newline at end of file
... ...
@@ -192,4 +192,52 @@ MatrixChange AtomicSupport::getMatrixChange(const AtomicProposal &prop) const
192 192
     {   
193 193
         return MatrixChange(prop.label, getRow(prop.pos1), getCol(prop.pos1), prop.delta1);
194 194
     }
195
+}
196
+
197
+void operator<<(Archive &ar, AtomicSupport &domain)
198
+{
199
+    ar << domain.mLabel;
200
+    ar << domain.mNumAtoms;
201
+    ar << domain.mMaxNumAtoms;
202
+    ar << domain.mTotalMass;
203
+    ar << domain.mNumRows;
204
+    ar << domain.mNumCols;
205
+    ar << domain.mNumBins;
206
+    ar << domain.mBinSize;
207
+    ar << domain.mAlpha;
208
+    ar << domain.mLambda;
209
+    
210
+    std::map<uint64_t, double>::iterator iter = domain.mAtomicDomain.begin();
211
+    for (; iter != domain.mAtomicDomain.end(); ++iter)
212
+    {
213
+        uint64_t pos = iter->first;
214
+        double mass = iter->second;
215
+        ar << pos;
216
+        ar << mass;
217
+        //ar << iter->first;
218
+        //ar << iter->second;
219
+    }
220
+}
221
+
222
+void operator>>(Archive &ar, AtomicSupport &domain)
223
+{
224
+    ar >> domain.mLabel;
225
+    ar >> domain.mNumAtoms;
226
+    ar >> domain.mMaxNumAtoms;
227
+    ar >> domain.mTotalMass;
228
+    ar >> domain.mNumRows;
229
+    ar >> domain.mNumCols;
230
+    ar >> domain.mNumBins;
231
+    ar >> domain.mBinSize;
232
+    ar >> domain.mAlpha;
233
+    ar >> domain.mLambda;
234
+
235
+    uint64_t pos = 0;
236
+    double mass = 0.0;
237
+    for (unsigned i = 0; i < domain.mNumAtoms; ++i)
238
+    {
239
+        ar >> pos;
240
+        ar >> mass;
241
+        domain.mAtomicDomain.insert(std::pair<uint64_t, double>(pos, mass));
242
+    }    
195 243
 }
196 244
\ No newline at end of file
... ...
@@ -62,23 +62,6 @@ public:
62 62
     // expected magnitude of each atom (must be > 0)
63 63
     double mLambda;
64 64
 
65
-    /*friend class boost::serialization::access;
66
-    template<class Archive>
67
-    void serialize(Archive &ar)
68
-    {
69
-        ar & mLabel;
70
-        ar & mAtomicDomain;
71
-        ar & mNumAtoms;
72
-        ar & mMaxNumAtoms;
73
-        ar & mTotalMass;
74
-        ar & mNumRows;
75
-        ar & mNumCols;
76
-        ar & mNumBins;
77
-        ar & mBinSize;
78
-        ar & mAlpha;
79
-        ar & mLambda;
80
-    }*/
81
-
82 65
     // convert atomic position to row/col of the matrix
83 66
     uint64_t getRow(uint64_t pos) const;
84 67
     uint64_t getCol(uint64_t pos) const;
... ...
@@ -99,7 +82,6 @@ public:
99 82
 public:
100 83
 
101 84
     // constructors
102
-    AtomicSupport(std::ifstream &file);
103 85
     AtomicSupport(char label, uint64_t nrow, uint64_t ncol, double alpha=1.0,
104 86
         double lambda=1.0);
105 87
 
... ...
@@ -122,6 +104,9 @@ public:
122 104
     void setLambda(double lambda) {mLambda = lambda;}
123 105
 
124 106
     void serializeAndWrite(const std::ofstream &file);
107
+
108
+    friend void operator<<(Archive &ar, AtomicSupport &sampler);
109
+    friend void operator>>(Archive &ar, AtomicSupport &sampler);
125 110
 };
126 111
 
127 112
 #endif
... ...
@@ -1,5 +1,6 @@
1 1
 #include "GibbsSampler.h"
2 2
 #include "Matrix.h"
3
+#include "Archive.h"
3 4
 
4 5
 #include <Rcpp.h>
5 6
 #include <ctime>
... ...
@@ -8,6 +9,8 @@
8 9
 #include <boost/archive/text_oarchive.hpp>
9 10
 #include <boost/archive/text_iarchive.hpp>
10 11
 
12
+#define ARCHIVE_MAGIC_NUM 0xCE45D32A
13
+
11 14
 typedef std::vector<Rcpp::NumericMatrix> SnapshotList;
12 15
 
13 16
 enum GapsPhase
... ...
@@ -63,32 +66,59 @@ struct GapsInternalState
63 66
             maxGibbsMassA, maxGibbsMassP, singleCellRNASeq, fixedPatterns,
64 67
             whichMatrixFixed)
65 68
     {}
69
+
70
+    GapsInternalState(unsigned nE, unsigned nS, unsigned nRow, unsigned nCol,
71
+    unsigned nFactor)
72
+            :
73
+        chi2VecEquil(nE), nAtomsAEquil(nE), nAtomsPEquil(nE),
74
+        chi2VecSample(nS), nAtomsASample(nS), nAtomsPSample(nS),
75
+        sampler(nRow, nCol, nFactor)
76
+    {}
66 77
 };
67 78
 
68
-/*template<class Archive>
69
-void boost::serialization::serialize(Archive &ar, GapsInternalState &state)
79
+void operator<<(Archive &ar, GapsInternalState &state)
80
+{
81
+    ar << state.chi2VecEquil;
82
+    ar << state.nAtomsAEquil;
83
+    ar << state.nAtomsPEquil;
84
+    ar << state.chi2VecSample;
85
+    ar << state.nAtomsASample;
86
+    ar << state.nAtomsPSample;
87
+    ar << state.nIterA;
88
+    ar << state.nIterP;
89
+    ar << state.nEquil;
90
+    ar << state.nEquilCool;
91
+    ar << state.nSample;
92
+    ar << state.nSnapshots;
93
+    ar << state.nOutputs;
94
+    ar << state.messages;
95
+    ar << state.iter;
96
+    ar << state.phase;
97
+    ar << state.seed;
98
+    ar << state.sampler;
99
+}
100
+
101
+void operator>>(Archive &ar, GapsInternalState &state)
70 102
 {
71
-    ar & state.chi2VecEquil;
72
-    ar & state.nAtomsAEquil;
73
-    ar & state.nAtomsPEquil;
74
-    ar & state.chi2VecSample;
75
-    ar & state.nAtomsASample;
76
-    ar & state.nAtomsPSample;
77
-    ar & state.nIterA;
78
-    ar & state.nIterP;
79
-    ar & state.nEquil;
80
-    ar & state.nEquilCool;
81
-    ar & state.nSample;
82
-    ar & state.nSnapshots;
83
-    ar & state.nOutputs;
84
-    ar & state.messages;
85
-    ar & state.iter;
86
-    ar & state.phase;
87
-    ar & state.seed;
88
-    ar & state.sampler;
89
-    //ar & state.snapshotsA;
90
-    //ar & state.snapshotsP;
91
-}*/
103
+    ar >> state.chi2VecEquil;
104
+    ar >> state.nAtomsAEquil;
105
+    ar >> state.nAtomsPEquil;
106
+    ar >> state.chi2VecSample;
107
+    ar >> state.nAtomsASample;
108
+    ar >> state.nAtomsPSample;
109
+    ar >> state.nIterA;
110
+    ar >> state.nIterP;
111
+    ar >> state.nEquil;
112
+    ar >> state.nEquilCool;
113
+    ar >> state.nSample;
114
+    ar >> state.nSnapshots;
115
+    ar >> state.nOutputs;
116
+    ar >> state.messages;
117
+    ar >> state.iter;
118
+    ar >> state.phase;
119
+    ar >> state.seed;
120
+    ar >> state.sampler;
121
+}
92 122
 
93 123
 static void runGibbsSampler(GapsInternalState &state, unsigned nIterTotal,
94 124
 Vector &chi2Vec, Vector &aAtomVec, Vector &pAtomVec)
... ...
@@ -152,12 +182,16 @@ static Rcpp::List runCogaps(GapsInternalState &state)
152 182
         state.phase = GAPS_COOLING;
153 183
     }
154 184
 
155
-    std::ofstream ofs("gaps_checkpoint.out");
156
-    {
157
-        boost::archive::text_oarchive oa(ofs);
158
-        //oa << state;
159
-    }
160
-    ofs.close();
185
+    Archive ar("gaps_checkpoint.out", ARCHIVE_WRITE);
186
+    ar << ARCHIVE_MAGIC_NUM;
187
+    gaps::random::save(ar);
188
+    ar << state.nEquil;
189
+    ar << state.nSample;
190
+    ar << state.sampler.nRow();
191
+    ar << state.sampler.nCol();
192
+    ar << state.sampler.nFactor();
193
+    ar << state;
194
+    ar.close();
161 195
 
162 196
     if (state.phase == GAPS_COOLING)
163 197
     {
... ...
@@ -222,26 +256,37 @@ unsigned numOutputs, unsigned numSnapshots)
222 256
     return runCogaps(state);
223 257
 }
224 258
 
225
-/*
259
+// [[Rcpp::export]]
226 260
 Rcpp::List cogapsFromCheckpoint(const std::string &fileName)
227 261
 {   
228 262
     // open file
229
-    std::ifstream file(fileName);
263
+    Archive ar(fileName, ARCHIVE_READ);
230 264
 
231 265
     // verify magic number
232 266
     uint32_t magicNum = 0;
233
-    file.read(reinterpret_cast<char*>(&magicNum), sizeof(uint32_t));
234
-    if (magicNum != 0xCE45D32A)
267
+    ar >> magicNum;
268
+    if (magicNum != ARCHIVE_MAGIC_NUM)
235 269
     {
236 270
         std::cout << "invalid checkpoint file" << std::endl;
237 271
         return Rcpp::List::create();
238 272
     }
239 273
     
240 274
     // seed random number generator and create internal state
241
-    gaps::random::load(file);
242
-    GapsInternalState state(file);
275
+    gaps::random::load(ar);
276
+
277
+    // read needed parameters
278
+    unsigned nE = 0, nS = 0, nRow = 0, nCol = 0, nFactor = 0;
279
+    ar >> nE;
280
+    ar >> nS;
281
+    ar >> nRow;
282
+    ar >> nCol;
283
+    ar >> nFactor;
284
+    
285
+    // construct empty state of the correct size, populate from file
286
+    GapsInternalState state(nE, nS, nRow, nCol, nFactor);
287
+    ar >> state;
243 288
 
244 289
     // run cogaps from this internal state
245 290
     return runCogaps(state);
246 291
 }
247
-*/
292
+
... ...
@@ -5,6 +5,15 @@
5 5
 
6 6
 static const double EPSILON = 1.e-10;
7 7
 
8
+GibbsSampler::GibbsSampler(unsigned nRow, unsigned nCol, unsigned nFactor)
9
+    :
10
+mDMatrix(nRow, nCol), mSMatrix(nRow, nCol), mAPMatrix(nRow, nCol),
11
+mAMatrix(nRow, nFactor), mPMatrix(nFactor, nCol), mADomain('A', nRow, nFactor),
12
+mPDomain('P', nFactor, nCol), mAMeanMatrix(nRow, nFactor),
13
+mAStdMatrix(nRow, nFactor), mPMeanMatrix(nFactor, nCol),
14
+mPStdMatrix(nFactor, nCol)
15
+{}
16
+
8 17
 GibbsSampler::GibbsSampler(Rcpp::NumericMatrix D, Rcpp::NumericMatrix S,
9 18
 unsigned int nFactor, double alphaA, double alphaP, double maxGibbsMassA,
10 19
 double maxGibbsMassP, bool singleCellRNASeq, Rcpp::NumericMatrix fixedPat,
... ...
@@ -437,4 +446,50 @@ void GibbsSampler::updateStatistics()
437 446
         mPStdMatrix.getRow(r) += gaps::algo::squaredScalarDivision(mPMatrix.getRow(r),
438 447
             normVec(r));
439 448
     }
449
+}
450
+
451
+void operator<<(Archive &ar, GibbsSampler &sampler)
452
+{
453
+    ar << sampler.mDMatrix;
454
+    ar << sampler.mSMatrix;
455
+    ar << sampler.mAPMatrix;
456
+    ar << sampler.mAMatrix;
457
+    ar << sampler.mPMatrix;
458
+    ar << sampler.mADomain;
459
+    ar << sampler.mPDomain;
460
+    ar << sampler.mAMeanMatrix;
461
+    ar << sampler.mAStdMatrix;
462
+    ar << sampler.mPMeanMatrix;
463
+    ar << sampler.mPStdMatrix;
464
+    ar << sampler.mStatUpdates;
465
+    ar << sampler.mMaxGibbsMassA;
466
+    ar << sampler.mMaxGibbsMassP;
467
+    ar << sampler.mAnnealingTemp;
468
+    ar << sampler.mChi2;
469
+    ar << sampler.mSingleCellRNASeq;
470
+    ar << sampler.mNumFixedPatterns;
471
+    ar << sampler.mFixedMat;
472
+}
473
+
474
+void operator>>(Archive &ar, GibbsSampler &sampler)
475
+{
476
+    ar >> sampler.mDMatrix;
477
+    ar >> sampler.mSMatrix;
478
+    ar >> sampler.mAPMatrix;
479
+    ar >> sampler.mAMatrix;
480
+    ar >> sampler.mPMatrix;
481
+    ar >> sampler.mADomain;
482
+    ar >> sampler.mPDomain;
483
+    ar >> sampler.mAMeanMatrix;
484
+    ar >> sampler.mAStdMatrix;
485
+    ar >> sampler.mPMeanMatrix;
486
+    ar >> sampler.mPStdMatrix;
487
+    ar >> sampler.mStatUpdates;
488
+    ar >> sampler.mMaxGibbsMassA;
489
+    ar >> sampler.mMaxGibbsMassP;
490
+    ar >> sampler.mAnnealingTemp;
491
+    ar >> sampler.mChi2;
492
+    ar >> sampler.mSingleCellRNASeq;
493
+    ar >> sampler.mNumFixedPatterns;
494
+    ar >> sampler.mFixedMat;
440 495
 }
441 496
\ No newline at end of file
... ...
@@ -36,31 +36,6 @@ public:
36 36
     unsigned mNumFixedPatterns;
37 37
     char mFixedMat;
38 38
 
39
-    /*friend class boost::serialization::access;    
40
-    template<class Archive>
41
-    void serialize(Archive &ar)
42
-    {
43
-        ar & mDMatrix;
44
-        ar & mSMatrix;
45
-        ar & mAPMatrix;
46
-        ar & mAMatrix;
47
-        ar & mPMatrix;
48
-        ar & mADomain;  
49
-        ar & mPDomain;
50
-        ar & mAMeanMatrix;
51
-        ar & mAStdMatrix;
52
-        ar & mPMeanMatrix;
53
-        ar & mPStdMatrix;
54
-        ar & mStatUpdates;
55
-        ar & mMaxGibbsMassA;
56
-        ar & mMaxGibbsMassP;
57
-        ar & mAnnealingTemp;
58
-        ar & mChi2;
59
-        ar & mSingleCellRNASeq;
60
-        ar & mNumFixedPatterns;
61
-        ar & mFixedMat;
62
-    }*/
63
-
64 39
     bool death(AtomicSupport &domain, AtomicProposal &proposal);
65 40
     bool birth(AtomicSupport &domain, AtomicProposal &proposal);
66 41
     bool move(AtomicSupport &domain, AtomicProposal &proposal);
... ...
@@ -82,7 +57,7 @@ public:
82 57
 
83 58
 public:
84 59
 
85
-    GibbsSampler(std::ifstream &file);
60
+    GibbsSampler(unsigned nRow, unsigned nCol, unsigned nFactor);
86 61
     GibbsSampler(Rcpp::NumericMatrix D, Rcpp::NumericMatrix S, unsigned nFactor,
87 62
         double alphaA, double alphaP, double maxGibbsMassA, double maxGibbsMassP,
88 63
         bool singleCellRNASeq, Rcpp::NumericMatrix fixedPat, char whichMat);
... ...
@@ -102,7 +77,12 @@ public:
102 77
 
103 78
     Rcpp::NumericMatrix getNormedMatrix(char mat);
104 79
 
105
-    void serializeAndWrite(const std::ofstream &file);
80
+    unsigned nRow() const {return mDMatrix.nRow();}
81
+    unsigned nCol() const {return mDMatrix.nCol();}
82
+    unsigned nFactor() const {return mAMatrix.nCol();}
83
+
84
+    friend void operator<<(Archive &ar, GibbsSampler &sampler);
85
+    friend void operator>>(Archive &ar, GibbsSampler &sampler);
106 86
 };
107 87
 
108 88
 #endif
... ...
@@ -1,5 +1,4 @@
1
-PKG_CPPFLAGS = -Wall -Wextra -O2 -I.
2
-PKG_LIBS = `$(R_HOME)/bin/Rscript -e "Rcpp:::LdFlags()"` -lboost_serialization
1
+PKG_CPPFLAGS = -Wall -Wextra -O2
3 2
 OBJECTS =   Algorithms.o \
4 3
             AtomicSupport.o \
5 4
             Cogaps.o \
... ...
@@ -50,6 +50,22 @@ void Vector::operator+=(const Vector &vec)
50 50
     }
51 51
 }
52 52
 
53
+void operator<<(Archive &ar, Vector &vec)
54
+{
55
+    for (unsigned i = 0; i < vec.size(); ++i)
56
+    {
57
+        ar << vec[i];
58
+    }
59
+}
60
+
61
+void operator>>(Archive &ar, Vector &vec)
62
+{
63
+    for (unsigned i = 0; i < vec.size(); ++i)
64
+    {
65
+        ar >> vec.mValues[i];
66
+    }
67
+}
68
+
53 69
 /****************************** ROW MATRIX *****************************/
54 70
 
55 71
 RowMatrix::RowMatrix(unsigned nrow, unsigned ncol)
... ...
@@ -93,6 +109,22 @@ Rcpp::NumericMatrix RowMatrix::rMatrix() const
93 109
     return convertToRMatrix<RowMatrix>(*this);
94 110
 }
95 111
 
112
+void operator<<(Archive &ar, RowMatrix &mat)
113
+{
114
+    for (unsigned i = 0; i < mat.nRow(); ++i)
115
+    {
116
+        ar << mat.mRows[i];
117
+    }
118
+}
119
+
120
+void operator>>(Archive &ar, RowMatrix &mat)
121
+{
122
+    for (unsigned i = 0; i < mat.nRow(); ++i)
123
+    {
124
+        ar >> mat.mRows[i];
125
+    }
126
+}
127
+
96 128
 /**************************** COLUMN MATRIX ****************************/
97 129
 
98 130
 ColMatrix::ColMatrix(unsigned nrow, unsigned ncol)
... ...
@@ -135,3 +167,33 @@ Rcpp::NumericMatrix ColMatrix::rMatrix() const
135 167
 {
136 168
     return convertToRMatrix<ColMatrix>(*this);
137 169
 }
170
+
171
+void operator<<(Archive &ar, ColMatrix &mat)
172
+{
173
+    for (unsigned j = 0; j < mat.nCol(); ++j)
174
+    {
175
+        ar << mat.mCols[j];
176
+    }
177
+}
178
+
179
+void operator>>(Archive &ar, ColMatrix &mat)
180
+{
181
+    for (unsigned j = 0; j < mat.nCol(); ++j)
182
+    {
183
+        ar >> mat.mCols[j];
184
+    }
185
+}
186
+
187
+/**************************** TWO-WAY MATRIX ***************************/
188
+
189
+void operator<<(Archive &ar, TwoWayMatrix &mat)
190
+{
191
+    ar << mat.mRowMatrix;
192
+    ar << mat.mColMatrix;
193
+}
194
+
195
+void operator>>(Archive &ar, TwoWayMatrix &mat)
196
+{
197
+    ar >> mat.mRowMatrix;
198
+    ar >> mat.mColMatrix;
199
+}
138 200
\ No newline at end of file
... ...
@@ -1,9 +1,10 @@
1 1
 #ifndef __COGAPS_MATRIX_H__
2 2
 #define __COGAPS_MATRIX_H__
3 3
 
4
+#include "Archive.h"
5
+
4 6
 #include <Rcpp.h>
5 7
 #include <vector>
6
-//#include <boost/serialization/vector.hpp>
7 8
 
8 9
 // temporary: used for testing performance of float vs double
9 10
 typedef double matrix_data_t;
... ...
@@ -33,22 +34,12 @@ struct MatrixChange
33 34
     {}
34 35
 };
35 36
 
36
-// no polymorphism to prevent virtual function overhead, not really
37
-// needed anyways since few functions are used on all types of matrices
38
-
39 37
 class Vector
40 38
 {
41 39
 private:
42 40
 
43 41
     std::vector<matrix_data_t> mValues;
44 42
 
45
-    /*friend class boost::serialization::access;
46
-    template<class Archive>
47
-    void serialize(Archive &ar)
48
-    {
49
-        ar & mValues;
50
-    }*/
51
-
52 43
 public:
53 44
 
54 45
     Vector(unsigned size) : mValues(std::vector<matrix_data_t>(size, 0.0)) {}
... ...
@@ -64,6 +55,9 @@ public:
64 55
     Rcpp::NumericVector rVec() const {return Rcpp::wrap(mValues);}
65 56
     void concat(const Vector& vec);
66 57
     void operator+=(const Vector &vec);
58
+
59
+    friend void operator<<(Archive &ar, Vector &vec);
60
+    friend void operator>>(Archive &ar, Vector &vec);
67 61
 };
68 62
 
69 63
 class RowMatrix
... ...
@@ -73,15 +67,6 @@ private:
73 67
     std::vector<Vector> mRows;
74 68
     unsigned mNumRows, mNumCols;
75 69
 
76
-    /*friend class boost::serialization::access;
77
-    template<class Archive>
78
-    void serialize(Archive &ar)
79
-    {
80
-        ar & mRows;
81
-        ar & mNumRows;
82
-        ar & mNumCols;
83
-    }*/
84
-
85 70
 public:
86 71
 
87 72
     RowMatrix(unsigned nrow, unsigned ncol);
... ...
@@ -99,6 +84,9 @@ public:
99 84
 
100 85
     void update(const MatrixChange &change);
101 86
     Rcpp::NumericMatrix rMatrix() const;
87
+
88
+    friend void operator<<(Archive &ar, RowMatrix &mat);
89
+    friend void operator>>(Archive &ar, RowMatrix &mat);
102 90
 };
103 91
 
104 92
 class ColMatrix
... ...
@@ -108,15 +96,6 @@ private:
108 96
     std::vector<Vector> mCols;
109 97
     unsigned mNumRows, mNumCols;
110 98
 
111
-    /*friend class boost::serialization::access;
112
-    template<class Archive>
113
-    void serialize(Archive &ar)
114
-    {
115
-        ar & mCols;
116
-        ar & mNumRows;
117
-        ar & mNumCols;
118
-    }*/
119
-
120 99
 public:
121 100
 
122 101
     ColMatrix(unsigned nrow, unsigned ncol);
... ...
@@ -134,6 +113,9 @@ public:
134 113
 
135 114
     void update(const MatrixChange &change);
136 115
     Rcpp::NumericMatrix rMatrix() const;
116
+
117
+    friend void operator<<(Archive &ar, ColMatrix &mat);
118
+    friend void operator>>(Archive &ar, ColMatrix &mat);
137 119
 };
138 120
 
139 121
 // gain performance at the expense of memory
... ...
@@ -144,14 +126,6 @@ private:
144 126
     RowMatrix mRowMatrix;
145 127
     ColMatrix mColMatrix;
146 128
 
147
-    /*friend class boost::serialization::access;
148
-    template<class Archive>
149
-    void serialize(Archive &ar)
150
-    {
151
-        ar & mRowMatrix;
152
-        ar & mColMatrix;
153
-    }*/
154
-
155 129
 public:
156 130
 
157 131
     TwoWayMatrix(unsigned nrow, unsigned ncol)
... ...
@@ -165,12 +139,6 @@ public:
165 139
     unsigned nRow() const {return mRowMatrix.nRow();}
166 140
     unsigned nCol() const {return mRowMatrix.nCol();}
167 141
     
168
-    // TODO remove since accessing this way defeats the purpose
169
-    /*matrix_data_t operator()(unsigned r, unsigned c) const
170
-    {
171
-        return mRowMatrix(r,c);
172
-    }*/
173
-
174 142
     const Vector& getRow(unsigned row) const {return mRowMatrix.getRow(row);}
175 143
     const Vector& getCol(unsigned col) const {return mColMatrix.getCol(col);}
176 144
 
... ...
@@ -184,6 +152,9 @@ public:
184 152
     {
185 153
         return mRowMatrix.rMatrix();
186 154
     }
155
+
156
+    friend void operator<<(Archive &ar, TwoWayMatrix &mat);
157
+    friend void operator>>(Archive &ar, TwoWayMatrix &mat);
187 158
 };
188 159
 
189 160
 #endif
190 161
\ No newline at end of file
... ...
@@ -10,21 +10,26 @@
10 10
 #include <boost/math/distributions/normal.hpp>
11 11
 #include <boost/math/distributions/exponential.hpp>
12 12
 
13
+#include <boost/random/mersenne_twister.hpp>
14
+
13 15
 #include <stdint.h>
14 16
 
15 17
 #define Q_GAMMA_THRESHOLD 1E-6
16 18
 #define Q_GAMMA_MIN_VALUE 0.0
17 19
 
20
+typedef boost::random::mt19937 RNGType;
21
+//typedef boost::random::mt11213b RNGType; // should be faster
22
+
18 23
 static RNGType rng;
19 24
 
20
-RNGType gaps::random::getGenerator()
25
+void gaps::random::save(Archive &ar)
21 26
 {
22
-    return rng;
27
+    ar << rng;
23 28
 }
24 29
 
25
-void gaps::random::setGenerator(RNGType temp)
30
+void gaps::random::load(Archive &ar)
26 31
 {
27
-    rng = temp;
32
+    ar >> rng;
28 33
 }
29 34
 
30 35
 void gaps::random::setSeed(uint32_t seed)
... ...
@@ -1,13 +1,11 @@
1 1
 #ifndef __COGAPS_RANDOM_H__
2 2
 #define __COGAPS_RANDOM_H__
3 3
 
4
+#include "Archive.h"
5
+
4 6
 #include <stdint.h>
5 7
 #include <vector>
6 8
 #include <fstream>
7
-#include <boost/random/mersenne_twister.hpp>
8
-
9
-typedef boost::random::mt19937 RNGType;
10
-//typedef boost::random::mt11213b RNGType; // should be faster
11 9
 
12 10
 namespace gaps
13 11
 {
... ...
@@ -33,8 +31,8 @@ namespace random
33 31
     double q_norm(double q, double mean, double sd);
34 32
     double p_norm(double p, double mean, double sd);
35 33
 
36
-    RNGType getGenerator();
37
-    void setGenerator(RNGType rng);
34
+    void save(Archive &ar);
35
+    void load(Archive &ar);
38 36
 }
39 37
 
40 38
 }
... ...
@@ -32,6 +32,17 @@ BEGIN_RCPP
32 32
     return rcpp_result_gen;
33 33
 END_RCPP
34 34
 }
35
+// cogapsFromCheckpoint
36
+Rcpp::List cogapsFromCheckpoint(const std::string& fileName);
37
+RcppExport SEXP _CoGAPS_cogapsFromCheckpoint(SEXP fileNameSEXP) {
38
+BEGIN_RCPP
39
+    Rcpp::RObject rcpp_result_gen;
40
+    Rcpp::RNGScope rcpp_rngScope_gen;
41
+    Rcpp::traits::input_parameter< const std::string& >::type fileName(fileNameSEXP);
42
+    rcpp_result_gen = Rcpp::wrap(cogapsFromCheckpoint(fileName));
43
+    return rcpp_result_gen;
44
+END_RCPP
45
+}
35 46
 // run_catch_unit_tests
36 47
 int run_catch_unit_tests();
37 48
 RcppExport SEXP _CoGAPS_run_catch_unit_tests() {
... ...
@@ -45,6 +56,7 @@ END_RCPP
45 56
 
46 57
 static const R_CallMethodDef CallEntries[] = {
47 58
     {"_CoGAPS_cogaps", (DL_FUNC) &_CoGAPS_cogaps, 17},
59
+    {"_CoGAPS_cogapsFromCheckpoint", (DL_FUNC) &_CoGAPS_cogapsFromCheckpoint, 1},
48 60
     {"_CoGAPS_run_catch_unit_tests", (DL_FUNC) &_CoGAPS_run_catch_unit_tests, 0},
49 61
     {NULL, NULL, 0}
50 62
 };
51 63
new file mode 100644