... | ... |
@@ -5,6 +5,10 @@ cogaps <- function(DMatrix, SMatrix, nFactor, alphaA, alphaP, nEquil, nEquilCool |
5 | 5 |
.Call('_CoGAPS_cogaps', PACKAGE = 'CoGAPS', DMatrix, SMatrix, nFactor, alphaA, alphaP, nEquil, nEquilCool, nSample, maxGibbsMassA, maxGibbsMassP, fixedPatterns, whichMatrixFixed, seed, messages, singleCellRNASeq, numOutputs, numSnapshots) |
6 | 6 |
} |
7 | 7 |
|
8 |
+cogapsFromCheckpoint <- function(fileName) { |
|
9 |
+ .Call('_CoGAPS_cogapsFromCheckpoint', PACKAGE = 'CoGAPS', fileName) |
|
10 |
+} |
|
11 |
+ |
|
8 | 12 |
run_catch_unit_tests <- function() { |
9 | 13 |
.Call('_CoGAPS_run_catch_unit_tests', PACKAGE = 'CoGAPS') |
10 | 14 |
} |
... | ... |
@@ -138,3 +138,56 @@ gapsRun <- function(D, S, ABins = data.frame(), PBins = data.frame(), |
138 | 138 |
message(paste("Chi-Squared of Mean:", calcChiSq)) |
139 | 139 |
return(cogapResult); |
140 | 140 |
} |
141 |
+ |
|
142 |
+#' @export |
|
143 |
+gapsRunFromCheckpoint <- function(D, S, path) |
|
144 |
+{ |
|
145 |
+ # call to C++ Rcpp code |
|
146 |
+ cogapResult <- cogapsFromCheckpoint(path) |
|
147 |
+ |
|
148 |
+ # convert returned files to matrices to simplify visualization and processing |
|
149 |
+ cogapResult$Amean <- as.matrix(cogapResult$Amean); |
|
150 |
+ cogapResult$Asd <- as.matrix(cogapResult$Asd); |
|
151 |
+ cogapResult$Pmean <- as.matrix(cogapResult$Pmean); |
|
152 |
+ cogapResult$Psd <- as.matrix(cogapResult$Psd); |
|
153 |
+ |
|
154 |
+ geneNames <- rownames(D); |
|
155 |
+ sampleNames <- colnames(D); |
|
156 |
+ |
|
157 |
+ # label patterns as Patt N |
|
158 |
+ patternNames <- c("0"); |
|
159 |
+ for(i in 1:ncol(cogapResult$Amean)) |
|
160 |
+ { |
|
161 |
+ patternNames[i] <- paste('Patt', i); |
|
162 |
+ } |
|
163 |
+ |
|
164 |
+ ## label matrices |
|
165 |
+ colnames(cogapResult$Amean) <- patternNames; |
|
166 |
+ rownames(cogapResult$Amean) <- geneNames; |
|
167 |
+ colnames(cogapResult$Asd) <- patternNames; |
|
168 |
+ rownames(cogapResult$Asd) <- geneNames; |
|
169 |
+ colnames(cogapResult$Pmean) <- sampleNames; |
|
170 |
+ rownames(cogapResult$Pmean) <- patternNames; |
|
171 |
+ colnames(cogapResult$Psd) <- sampleNames; |
|
172 |
+ rownames(cogapResult$Psd) <- patternNames; |
|
173 |
+ |
|
174 |
+ ## calculate chi-squared of mean, this should be smaller than individual |
|
175 |
+ ## chi-squared sample values if sampling is good |
|
176 |
+ calcChiSq <- c(0); |
|
177 |
+ MMatrix <- (cogapResult$Amean %*% cogapResult$Pmean); |
|
178 |
+ |
|
179 |
+ for(i in 1:(nrow(MMatrix))) |
|
180 |
+ { |
|
181 |
+ for(j in 1:(ncol(MMatrix))) |
|
182 |
+ { |
|
183 |
+ calcChiSq <- calcChiSq + ((D[i,j] - MMatrix[i,j])/S[i,j])^2; |
|
184 |
+ } |
|
185 |
+ } |
|
186 |
+ |
|
187 |
+ cogapResult = c(cogapResult, calcChiSq); |
|
188 |
+ |
|
189 |
+ # names(cogapResult)[13] <- "meanChi2"; |
|
190 |
+ |
|
191 |
+ message(paste("Chi-Squared of Mean:", calcChiSq)) |
|
192 |
+ return(cogapResult); |
|
193 |
+} |
|
141 | 194 |
\ No newline at end of file |
142 | 195 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,43 @@ |
1 |
+#ifndef __COGAPS_ARCHIVE_H__ |
|
2 |
+#define __COGAPS_ARCHIVE_H__ |
|
3 |
+ |
|
4 |
+#include <boost/random/mersenne_twister.hpp> |
|
5 |
+#include <fstream> |
|
6 |
+ |
|
7 |
+#define ARCHIVE_READ std::ios::in |
|
8 |
+#define ARCHIVE_WRITE std::ios::out | std::ios::trunc |
|
9 |
+ |
|
10 |
+class Archive |
|
11 |
+{ |
|
12 |
+private: |
|
13 |
+ |
|
14 |
+ std::fstream mStream; |
|
15 |
+ |
|
16 |
+public: |
|
17 |
+ |
|
18 |
+ Archive(const std::string &path, std::ios_base::openmode flags) |
|
19 |
+ : mStream(path.c_str(), std::ios::binary | flags) |
|
20 |
+ {} |
|
21 |
+ |
|
22 |
+ void close() {mStream.close();} |
|
23 |
+ |
|
24 |
+ template<typename T> |
|
25 |
+ friend void operator<<(Archive &ar, T val); |
|
26 |
+ |
|
27 |
+ template<typename T> |
|
28 |
+ friend void operator>>(Archive &ar, T &val); |
|
29 |
+}; |
|
30 |
+ |
|
31 |
+template<typename T> |
|
32 |
+void operator<<(Archive &ar, T val) |
|
33 |
+{ |
|
34 |
+ ar.mStream.write(reinterpret_cast<char*>(&val), sizeof(T)); |
|
35 |
+} |
|
36 |
+ |
|
37 |
+template<typename T> |
|
38 |
+void operator>>(Archive &ar, T &val) |
|
39 |
+{ |
|
40 |
+ ar.mStream.read(reinterpret_cast<char*>(&val), sizeof(T)); |
|
41 |
+} |
|
42 |
+ |
|
43 |
+#endif |
|
0 | 44 |
\ No newline at end of file |
... | ... |
@@ -192,4 +192,52 @@ MatrixChange AtomicSupport::getMatrixChange(const AtomicProposal &prop) const |
192 | 192 |
{ |
193 | 193 |
return MatrixChange(prop.label, getRow(prop.pos1), getCol(prop.pos1), prop.delta1); |
194 | 194 |
} |
195 |
+} |
|
196 |
+ |
|
197 |
+void operator<<(Archive &ar, AtomicSupport &domain) |
|
198 |
+{ |
|
199 |
+ ar << domain.mLabel; |
|
200 |
+ ar << domain.mNumAtoms; |
|
201 |
+ ar << domain.mMaxNumAtoms; |
|
202 |
+ ar << domain.mTotalMass; |
|
203 |
+ ar << domain.mNumRows; |
|
204 |
+ ar << domain.mNumCols; |
|
205 |
+ ar << domain.mNumBins; |
|
206 |
+ ar << domain.mBinSize; |
|
207 |
+ ar << domain.mAlpha; |
|
208 |
+ ar << domain.mLambda; |
|
209 |
+ |
|
210 |
+ std::map<uint64_t, double>::iterator iter = domain.mAtomicDomain.begin(); |
|
211 |
+ for (; iter != domain.mAtomicDomain.end(); ++iter) |
|
212 |
+ { |
|
213 |
+ uint64_t pos = iter->first; |
|
214 |
+ double mass = iter->second; |
|
215 |
+ ar << pos; |
|
216 |
+ ar << mass; |
|
217 |
+ //ar << iter->first; |
|
218 |
+ //ar << iter->second; |
|
219 |
+ } |
|
220 |
+} |
|
221 |
+ |
|
222 |
+void operator>>(Archive &ar, AtomicSupport &domain) |
|
223 |
+{ |
|
224 |
+ ar >> domain.mLabel; |
|
225 |
+ ar >> domain.mNumAtoms; |
|
226 |
+ ar >> domain.mMaxNumAtoms; |
|
227 |
+ ar >> domain.mTotalMass; |
|
228 |
+ ar >> domain.mNumRows; |
|
229 |
+ ar >> domain.mNumCols; |
|
230 |
+ ar >> domain.mNumBins; |
|
231 |
+ ar >> domain.mBinSize; |
|
232 |
+ ar >> domain.mAlpha; |
|
233 |
+ ar >> domain.mLambda; |
|
234 |
+ |
|
235 |
+ uint64_t pos = 0; |
|
236 |
+ double mass = 0.0; |
|
237 |
+ for (unsigned i = 0; i < domain.mNumAtoms; ++i) |
|
238 |
+ { |
|
239 |
+ ar >> pos; |
|
240 |
+ ar >> mass; |
|
241 |
+ domain.mAtomicDomain.insert(std::pair<uint64_t, double>(pos, mass)); |
|
242 |
+ } |
|
195 | 243 |
} |
196 | 244 |
\ No newline at end of file |
... | ... |
@@ -62,23 +62,6 @@ public: |
62 | 62 |
// expected magnitude of each atom (must be > 0) |
63 | 63 |
double mLambda; |
64 | 64 |
|
65 |
- /*friend class boost::serialization::access; |
|
66 |
- template<class Archive> |
|
67 |
- void serialize(Archive &ar) |
|
68 |
- { |
|
69 |
- ar & mLabel; |
|
70 |
- ar & mAtomicDomain; |
|
71 |
- ar & mNumAtoms; |
|
72 |
- ar & mMaxNumAtoms; |
|
73 |
- ar & mTotalMass; |
|
74 |
- ar & mNumRows; |
|
75 |
- ar & mNumCols; |
|
76 |
- ar & mNumBins; |
|
77 |
- ar & mBinSize; |
|
78 |
- ar & mAlpha; |
|
79 |
- ar & mLambda; |
|
80 |
- }*/ |
|
81 |
- |
|
82 | 65 |
// convert atomic position to row/col of the matrix |
83 | 66 |
uint64_t getRow(uint64_t pos) const; |
84 | 67 |
uint64_t getCol(uint64_t pos) const; |
... | ... |
@@ -99,7 +82,6 @@ public: |
99 | 82 |
public: |
100 | 83 |
|
101 | 84 |
// constructors |
102 |
- AtomicSupport(std::ifstream &file); |
|
103 | 85 |
AtomicSupport(char label, uint64_t nrow, uint64_t ncol, double alpha=1.0, |
104 | 86 |
double lambda=1.0); |
105 | 87 |
|
... | ... |
@@ -122,6 +104,9 @@ public: |
122 | 104 |
void setLambda(double lambda) {mLambda = lambda;} |
123 | 105 |
|
124 | 106 |
void serializeAndWrite(const std::ofstream &file); |
107 |
+ |
|
108 |
+ friend void operator<<(Archive &ar, AtomicSupport &sampler); |
|
109 |
+ friend void operator>>(Archive &ar, AtomicSupport &sampler); |
|
125 | 110 |
}; |
126 | 111 |
|
127 | 112 |
#endif |
... | ... |
@@ -1,5 +1,6 @@ |
1 | 1 |
#include "GibbsSampler.h" |
2 | 2 |
#include "Matrix.h" |
3 |
+#include "Archive.h" |
|
3 | 4 |
|
4 | 5 |
#include <Rcpp.h> |
5 | 6 |
#include <ctime> |
... | ... |
@@ -8,6 +9,8 @@ |
8 | 9 |
#include <boost/archive/text_oarchive.hpp> |
9 | 10 |
#include <boost/archive/text_iarchive.hpp> |
10 | 11 |
|
12 |
+#define ARCHIVE_MAGIC_NUM 0xCE45D32A |
|
13 |
+ |
|
11 | 14 |
typedef std::vector<Rcpp::NumericMatrix> SnapshotList; |
12 | 15 |
|
13 | 16 |
enum GapsPhase |
... | ... |
@@ -63,32 +66,59 @@ struct GapsInternalState |
63 | 66 |
maxGibbsMassA, maxGibbsMassP, singleCellRNASeq, fixedPatterns, |
64 | 67 |
whichMatrixFixed) |
65 | 68 |
{} |
69 |
+ |
|
70 |
+ GapsInternalState(unsigned nE, unsigned nS, unsigned nRow, unsigned nCol, |
|
71 |
+ unsigned nFactor) |
|
72 |
+ : |
|
73 |
+ chi2VecEquil(nE), nAtomsAEquil(nE), nAtomsPEquil(nE), |
|
74 |
+ chi2VecSample(nS), nAtomsASample(nS), nAtomsPSample(nS), |
|
75 |
+ sampler(nRow, nCol, nFactor) |
|
76 |
+ {} |
|
66 | 77 |
}; |
67 | 78 |
|
68 |
-/*template<class Archive> |
|
69 |
-void boost::serialization::serialize(Archive &ar, GapsInternalState &state) |
|
79 |
+void operator<<(Archive &ar, GapsInternalState &state) |
|
80 |
+{ |
|
81 |
+ ar << state.chi2VecEquil; |
|
82 |
+ ar << state.nAtomsAEquil; |
|
83 |
+ ar << state.nAtomsPEquil; |
|
84 |
+ ar << state.chi2VecSample; |
|
85 |
+ ar << state.nAtomsASample; |
|
86 |
+ ar << state.nAtomsPSample; |
|
87 |
+ ar << state.nIterA; |
|
88 |
+ ar << state.nIterP; |
|
89 |
+ ar << state.nEquil; |
|
90 |
+ ar << state.nEquilCool; |
|
91 |
+ ar << state.nSample; |
|
92 |
+ ar << state.nSnapshots; |
|
93 |
+ ar << state.nOutputs; |
|
94 |
+ ar << state.messages; |
|
95 |
+ ar << state.iter; |
|
96 |
+ ar << state.phase; |
|
97 |
+ ar << state.seed; |
|
98 |
+ ar << state.sampler; |
|
99 |
+} |
|
100 |
+ |
|
101 |
+void operator>>(Archive &ar, GapsInternalState &state) |
|
70 | 102 |
{ |
71 |
- ar & state.chi2VecEquil; |
|
72 |
- ar & state.nAtomsAEquil; |
|
73 |
- ar & state.nAtomsPEquil; |
|
74 |
- ar & state.chi2VecSample; |
|
75 |
- ar & state.nAtomsASample; |
|
76 |
- ar & state.nAtomsPSample; |
|
77 |
- ar & state.nIterA; |
|
78 |
- ar & state.nIterP; |
|
79 |
- ar & state.nEquil; |
|
80 |
- ar & state.nEquilCool; |
|
81 |
- ar & state.nSample; |
|
82 |
- ar & state.nSnapshots; |
|
83 |
- ar & state.nOutputs; |
|
84 |
- ar & state.messages; |
|
85 |
- ar & state.iter; |
|
86 |
- ar & state.phase; |
|
87 |
- ar & state.seed; |
|
88 |
- ar & state.sampler; |
|
89 |
- //ar & state.snapshotsA; |
|
90 |
- //ar & state.snapshotsP; |
|
91 |
-}*/ |
|
103 |
+ ar >> state.chi2VecEquil; |
|
104 |
+ ar >> state.nAtomsAEquil; |
|
105 |
+ ar >> state.nAtomsPEquil; |
|
106 |
+ ar >> state.chi2VecSample; |
|
107 |
+ ar >> state.nAtomsASample; |
|
108 |
+ ar >> state.nAtomsPSample; |
|
109 |
+ ar >> state.nIterA; |
|
110 |
+ ar >> state.nIterP; |
|
111 |
+ ar >> state.nEquil; |
|
112 |
+ ar >> state.nEquilCool; |
|
113 |
+ ar >> state.nSample; |
|
114 |
+ ar >> state.nSnapshots; |
|
115 |
+ ar >> state.nOutputs; |
|
116 |
+ ar >> state.messages; |
|
117 |
+ ar >> state.iter; |
|
118 |
+ ar >> state.phase; |
|
119 |
+ ar >> state.seed; |
|
120 |
+ ar >> state.sampler; |
|
121 |
+} |
|
92 | 122 |
|
93 | 123 |
static void runGibbsSampler(GapsInternalState &state, unsigned nIterTotal, |
94 | 124 |
Vector &chi2Vec, Vector &aAtomVec, Vector &pAtomVec) |
... | ... |
@@ -152,12 +182,16 @@ static Rcpp::List runCogaps(GapsInternalState &state) |
152 | 182 |
state.phase = GAPS_COOLING; |
153 | 183 |
} |
154 | 184 |
|
155 |
- std::ofstream ofs("gaps_checkpoint.out"); |
|
156 |
- { |
|
157 |
- boost::archive::text_oarchive oa(ofs); |
|
158 |
- //oa << state; |
|
159 |
- } |
|
160 |
- ofs.close(); |
|
185 |
+ Archive ar("gaps_checkpoint.out", ARCHIVE_WRITE); |
|
186 |
+ ar << ARCHIVE_MAGIC_NUM; |
|
187 |
+ gaps::random::save(ar); |
|
188 |
+ ar << state.nEquil; |
|
189 |
+ ar << state.nSample; |
|
190 |
+ ar << state.sampler.nRow(); |
|
191 |
+ ar << state.sampler.nCol(); |
|
192 |
+ ar << state.sampler.nFactor(); |
|
193 |
+ ar << state; |
|
194 |
+ ar.close(); |
|
161 | 195 |
|
162 | 196 |
if (state.phase == GAPS_COOLING) |
163 | 197 |
{ |
... | ... |
@@ -222,26 +256,37 @@ unsigned numOutputs, unsigned numSnapshots) |
222 | 256 |
return runCogaps(state); |
223 | 257 |
} |
224 | 258 |
|
225 |
-/* |
|
259 |
+// [[Rcpp::export]] |
|
226 | 260 |
Rcpp::List cogapsFromCheckpoint(const std::string &fileName) |
227 | 261 |
{ |
228 | 262 |
// open file |
229 |
- std::ifstream file(fileName); |
|
263 |
+ Archive ar(fileName, ARCHIVE_READ); |
|
230 | 264 |
|
231 | 265 |
// verify magic number |
232 | 266 |
uint32_t magicNum = 0; |
233 |
- file.read(reinterpret_cast<char*>(&magicNum), sizeof(uint32_t)); |
|
234 |
- if (magicNum != 0xCE45D32A) |
|
267 |
+ ar >> magicNum; |
|
268 |
+ if (magicNum != ARCHIVE_MAGIC_NUM) |
|
235 | 269 |
{ |
236 | 270 |
std::cout << "invalid checkpoint file" << std::endl; |
237 | 271 |
return Rcpp::List::create(); |
238 | 272 |
} |
239 | 273 |
|
240 | 274 |
// seed random number generator and create internal state |
241 |
- gaps::random::load(file); |
|
242 |
- GapsInternalState state(file); |
|
275 |
+ gaps::random::load(ar); |
|
276 |
+ |
|
277 |
+ // read needed parameters |
|
278 |
+ unsigned nE = 0, nS = 0, nRow = 0, nCol = 0, nFactor = 0; |
|
279 |
+ ar >> nE; |
|
280 |
+ ar >> nS; |
|
281 |
+ ar >> nRow; |
|
282 |
+ ar >> nCol; |
|
283 |
+ ar >> nFactor; |
|
284 |
+ |
|
285 |
+ // construct empty state of the correct size, populate from file |
|
286 |
+ GapsInternalState state(nE, nS, nRow, nCol, nFactor); |
|
287 |
+ ar >> state; |
|
243 | 288 |
|
244 | 289 |
// run cogaps from this internal state |
245 | 290 |
return runCogaps(state); |
246 | 291 |
} |
247 |
-*/ |
|
292 |
+ |
... | ... |
@@ -5,6 +5,15 @@ |
5 | 5 |
|
6 | 6 |
static const double EPSILON = 1.e-10; |
7 | 7 |
|
8 |
+GibbsSampler::GibbsSampler(unsigned nRow, unsigned nCol, unsigned nFactor) |
|
9 |
+ : |
|
10 |
+mDMatrix(nRow, nCol), mSMatrix(nRow, nCol), mAPMatrix(nRow, nCol), |
|
11 |
+mAMatrix(nRow, nFactor), mPMatrix(nFactor, nCol), mADomain('A', nRow, nFactor), |
|
12 |
+mPDomain('P', nFactor, nCol), mAMeanMatrix(nRow, nFactor), |
|
13 |
+mAStdMatrix(nRow, nFactor), mPMeanMatrix(nFactor, nCol), |
|
14 |
+mPStdMatrix(nFactor, nCol) |
|
15 |
+{} |
|
16 |
+ |
|
8 | 17 |
GibbsSampler::GibbsSampler(Rcpp::NumericMatrix D, Rcpp::NumericMatrix S, |
9 | 18 |
unsigned int nFactor, double alphaA, double alphaP, double maxGibbsMassA, |
10 | 19 |
double maxGibbsMassP, bool singleCellRNASeq, Rcpp::NumericMatrix fixedPat, |
... | ... |
@@ -437,4 +446,50 @@ void GibbsSampler::updateStatistics() |
437 | 446 |
mPStdMatrix.getRow(r) += gaps::algo::squaredScalarDivision(mPMatrix.getRow(r), |
438 | 447 |
normVec(r)); |
439 | 448 |
} |
449 |
+} |
|
450 |
+ |
|
451 |
+void operator<<(Archive &ar, GibbsSampler &sampler) |
|
452 |
+{ |
|
453 |
+ ar << sampler.mDMatrix; |
|
454 |
+ ar << sampler.mSMatrix; |
|
455 |
+ ar << sampler.mAPMatrix; |
|
456 |
+ ar << sampler.mAMatrix; |
|
457 |
+ ar << sampler.mPMatrix; |
|
458 |
+ ar << sampler.mADomain; |
|
459 |
+ ar << sampler.mPDomain; |
|
460 |
+ ar << sampler.mAMeanMatrix; |
|
461 |
+ ar << sampler.mAStdMatrix; |
|
462 |
+ ar << sampler.mPMeanMatrix; |
|
463 |
+ ar << sampler.mPStdMatrix; |
|
464 |
+ ar << sampler.mStatUpdates; |
|
465 |
+ ar << sampler.mMaxGibbsMassA; |
|
466 |
+ ar << sampler.mMaxGibbsMassP; |
|
467 |
+ ar << sampler.mAnnealingTemp; |
|
468 |
+ ar << sampler.mChi2; |
|
469 |
+ ar << sampler.mSingleCellRNASeq; |
|
470 |
+ ar << sampler.mNumFixedPatterns; |
|
471 |
+ ar << sampler.mFixedMat; |
|
472 |
+} |
|
473 |
+ |
|
474 |
+void operator>>(Archive &ar, GibbsSampler &sampler) |
|
475 |
+{ |
|
476 |
+ ar >> sampler.mDMatrix; |
|
477 |
+ ar >> sampler.mSMatrix; |
|
478 |
+ ar >> sampler.mAPMatrix; |
|
479 |
+ ar >> sampler.mAMatrix; |
|
480 |
+ ar >> sampler.mPMatrix; |
|
481 |
+ ar >> sampler.mADomain; |
|
482 |
+ ar >> sampler.mPDomain; |
|
483 |
+ ar >> sampler.mAMeanMatrix; |
|
484 |
+ ar >> sampler.mAStdMatrix; |
|
485 |
+ ar >> sampler.mPMeanMatrix; |
|
486 |
+ ar >> sampler.mPStdMatrix; |
|
487 |
+ ar >> sampler.mStatUpdates; |
|
488 |
+ ar >> sampler.mMaxGibbsMassA; |
|
489 |
+ ar >> sampler.mMaxGibbsMassP; |
|
490 |
+ ar >> sampler.mAnnealingTemp; |
|
491 |
+ ar >> sampler.mChi2; |
|
492 |
+ ar >> sampler.mSingleCellRNASeq; |
|
493 |
+ ar >> sampler.mNumFixedPatterns; |
|
494 |
+ ar >> sampler.mFixedMat; |
|
440 | 495 |
} |
441 | 496 |
\ No newline at end of file |
... | ... |
@@ -36,31 +36,6 @@ public: |
36 | 36 |
unsigned mNumFixedPatterns; |
37 | 37 |
char mFixedMat; |
38 | 38 |
|
39 |
- /*friend class boost::serialization::access; |
|
40 |
- template<class Archive> |
|
41 |
- void serialize(Archive &ar) |
|
42 |
- { |
|
43 |
- ar & mDMatrix; |
|
44 |
- ar & mSMatrix; |
|
45 |
- ar & mAPMatrix; |
|
46 |
- ar & mAMatrix; |
|
47 |
- ar & mPMatrix; |
|
48 |
- ar & mADomain; |
|
49 |
- ar & mPDomain; |
|
50 |
- ar & mAMeanMatrix; |
|
51 |
- ar & mAStdMatrix; |
|
52 |
- ar & mPMeanMatrix; |
|
53 |
- ar & mPStdMatrix; |
|
54 |
- ar & mStatUpdates; |
|
55 |
- ar & mMaxGibbsMassA; |
|
56 |
- ar & mMaxGibbsMassP; |
|
57 |
- ar & mAnnealingTemp; |
|
58 |
- ar & mChi2; |
|
59 |
- ar & mSingleCellRNASeq; |
|
60 |
- ar & mNumFixedPatterns; |
|
61 |
- ar & mFixedMat; |
|
62 |
- }*/ |
|
63 |
- |
|
64 | 39 |
bool death(AtomicSupport &domain, AtomicProposal &proposal); |
65 | 40 |
bool birth(AtomicSupport &domain, AtomicProposal &proposal); |
66 | 41 |
bool move(AtomicSupport &domain, AtomicProposal &proposal); |
... | ... |
@@ -82,7 +57,7 @@ public: |
82 | 57 |
|
83 | 58 |
public: |
84 | 59 |
|
85 |
- GibbsSampler(std::ifstream &file); |
|
60 |
+ GibbsSampler(unsigned nRow, unsigned nCol, unsigned nFactor); |
|
86 | 61 |
GibbsSampler(Rcpp::NumericMatrix D, Rcpp::NumericMatrix S, unsigned nFactor, |
87 | 62 |
double alphaA, double alphaP, double maxGibbsMassA, double maxGibbsMassP, |
88 | 63 |
bool singleCellRNASeq, Rcpp::NumericMatrix fixedPat, char whichMat); |
... | ... |
@@ -102,7 +77,12 @@ public: |
102 | 77 |
|
103 | 78 |
Rcpp::NumericMatrix getNormedMatrix(char mat); |
104 | 79 |
|
105 |
- void serializeAndWrite(const std::ofstream &file); |
|
80 |
+ unsigned nRow() const {return mDMatrix.nRow();} |
|
81 |
+ unsigned nCol() const {return mDMatrix.nCol();} |
|
82 |
+ unsigned nFactor() const {return mAMatrix.nCol();} |
|
83 |
+ |
|
84 |
+ friend void operator<<(Archive &ar, GibbsSampler &sampler); |
|
85 |
+ friend void operator>>(Archive &ar, GibbsSampler &sampler); |
|
106 | 86 |
}; |
107 | 87 |
|
108 | 88 |
#endif |
... | ... |
@@ -50,6 +50,22 @@ void Vector::operator+=(const Vector &vec) |
50 | 50 |
} |
51 | 51 |
} |
52 | 52 |
|
53 |
+void operator<<(Archive &ar, Vector &vec) |
|
54 |
+{ |
|
55 |
+ for (unsigned i = 0; i < vec.size(); ++i) |
|
56 |
+ { |
|
57 |
+ ar << vec[i]; |
|
58 |
+ } |
|
59 |
+} |
|
60 |
+ |
|
61 |
+void operator>>(Archive &ar, Vector &vec) |
|
62 |
+{ |
|
63 |
+ for (unsigned i = 0; i < vec.size(); ++i) |
|
64 |
+ { |
|
65 |
+ ar >> vec.mValues[i]; |
|
66 |
+ } |
|
67 |
+} |
|
68 |
+ |
|
53 | 69 |
/****************************** ROW MATRIX *****************************/ |
54 | 70 |
|
55 | 71 |
RowMatrix::RowMatrix(unsigned nrow, unsigned ncol) |
... | ... |
@@ -93,6 +109,22 @@ Rcpp::NumericMatrix RowMatrix::rMatrix() const |
93 | 109 |
return convertToRMatrix<RowMatrix>(*this); |
94 | 110 |
} |
95 | 111 |
|
112 |
+void operator<<(Archive &ar, RowMatrix &mat) |
|
113 |
+{ |
|
114 |
+ for (unsigned i = 0; i < mat.nRow(); ++i) |
|
115 |
+ { |
|
116 |
+ ar << mat.mRows[i]; |
|
117 |
+ } |
|
118 |
+} |
|
119 |
+ |
|
120 |
+void operator>>(Archive &ar, RowMatrix &mat) |
|
121 |
+{ |
|
122 |
+ for (unsigned i = 0; i < mat.nRow(); ++i) |
|
123 |
+ { |
|
124 |
+ ar >> mat.mRows[i]; |
|
125 |
+ } |
|
126 |
+} |
|
127 |
+ |
|
96 | 128 |
/**************************** COLUMN MATRIX ****************************/ |
97 | 129 |
|
98 | 130 |
ColMatrix::ColMatrix(unsigned nrow, unsigned ncol) |
... | ... |
@@ -135,3 +167,33 @@ Rcpp::NumericMatrix ColMatrix::rMatrix() const |
135 | 167 |
{ |
136 | 168 |
return convertToRMatrix<ColMatrix>(*this); |
137 | 169 |
} |
170 |
+ |
|
171 |
+void operator<<(Archive &ar, ColMatrix &mat) |
|
172 |
+{ |
|
173 |
+ for (unsigned j = 0; j < mat.nCol(); ++j) |
|
174 |
+ { |
|
175 |
+ ar << mat.mCols[j]; |
|
176 |
+ } |
|
177 |
+} |
|
178 |
+ |
|
179 |
+void operator>>(Archive &ar, ColMatrix &mat) |
|
180 |
+{ |
|
181 |
+ for (unsigned j = 0; j < mat.nCol(); ++j) |
|
182 |
+ { |
|
183 |
+ ar >> mat.mCols[j]; |
|
184 |
+ } |
|
185 |
+} |
|
186 |
+ |
|
187 |
+/**************************** TWO-WAY MATRIX ***************************/ |
|
188 |
+ |
|
189 |
+void operator<<(Archive &ar, TwoWayMatrix &mat) |
|
190 |
+{ |
|
191 |
+ ar << mat.mRowMatrix; |
|
192 |
+ ar << mat.mColMatrix; |
|
193 |
+} |
|
194 |
+ |
|
195 |
+void operator>>(Archive &ar, TwoWayMatrix &mat) |
|
196 |
+{ |
|
197 |
+ ar >> mat.mRowMatrix; |
|
198 |
+ ar >> mat.mColMatrix; |
|
199 |
+} |
|
138 | 200 |
\ No newline at end of file |
... | ... |
@@ -1,9 +1,10 @@ |
1 | 1 |
#ifndef __COGAPS_MATRIX_H__ |
2 | 2 |
#define __COGAPS_MATRIX_H__ |
3 | 3 |
|
4 |
+#include "Archive.h" |
|
5 |
+ |
|
4 | 6 |
#include <Rcpp.h> |
5 | 7 |
#include <vector> |
6 |
-//#include <boost/serialization/vector.hpp> |
|
7 | 8 |
|
8 | 9 |
// temporary: used for testing performance of float vs double |
9 | 10 |
typedef double matrix_data_t; |
... | ... |
@@ -33,22 +34,12 @@ struct MatrixChange |
33 | 34 |
{} |
34 | 35 |
}; |
35 | 36 |
|
36 |
-// no polymorphism to prevent virtual function overhead, not really |
|
37 |
-// needed anyways since few functions are used on all types of matrices |
|
38 |
- |
|
39 | 37 |
class Vector |
40 | 38 |
{ |
41 | 39 |
private: |
42 | 40 |
|
43 | 41 |
std::vector<matrix_data_t> mValues; |
44 | 42 |
|
45 |
- /*friend class boost::serialization::access; |
|
46 |
- template<class Archive> |
|
47 |
- void serialize(Archive &ar) |
|
48 |
- { |
|
49 |
- ar & mValues; |
|
50 |
- }*/ |
|
51 |
- |
|
52 | 43 |
public: |
53 | 44 |
|
54 | 45 |
Vector(unsigned size) : mValues(std::vector<matrix_data_t>(size, 0.0)) {} |
... | ... |
@@ -64,6 +55,9 @@ public: |
64 | 55 |
Rcpp::NumericVector rVec() const {return Rcpp::wrap(mValues);} |
65 | 56 |
void concat(const Vector& vec); |
66 | 57 |
void operator+=(const Vector &vec); |
58 |
+ |
|
59 |
+ friend void operator<<(Archive &ar, Vector &vec); |
|
60 |
+ friend void operator>>(Archive &ar, Vector &vec); |
|
67 | 61 |
}; |
68 | 62 |
|
69 | 63 |
class RowMatrix |
... | ... |
@@ -73,15 +67,6 @@ private: |
73 | 67 |
std::vector<Vector> mRows; |
74 | 68 |
unsigned mNumRows, mNumCols; |
75 | 69 |
|
76 |
- /*friend class boost::serialization::access; |
|
77 |
- template<class Archive> |
|
78 |
- void serialize(Archive &ar) |
|
79 |
- { |
|
80 |
- ar & mRows; |
|
81 |
- ar & mNumRows; |
|
82 |
- ar & mNumCols; |
|
83 |
- }*/ |
|
84 |
- |
|
85 | 70 |
public: |
86 | 71 |
|
87 | 72 |
RowMatrix(unsigned nrow, unsigned ncol); |
... | ... |
@@ -99,6 +84,9 @@ public: |
99 | 84 |
|
100 | 85 |
void update(const MatrixChange &change); |
101 | 86 |
Rcpp::NumericMatrix rMatrix() const; |
87 |
+ |
|
88 |
+ friend void operator<<(Archive &ar, RowMatrix &mat); |
|
89 |
+ friend void operator>>(Archive &ar, RowMatrix &mat); |
|
102 | 90 |
}; |
103 | 91 |
|
104 | 92 |
class ColMatrix |
... | ... |
@@ -108,15 +96,6 @@ private: |
108 | 96 |
std::vector<Vector> mCols; |
109 | 97 |
unsigned mNumRows, mNumCols; |
110 | 98 |
|
111 |
- /*friend class boost::serialization::access; |
|
112 |
- template<class Archive> |
|
113 |
- void serialize(Archive &ar) |
|
114 |
- { |
|
115 |
- ar & mCols; |
|
116 |
- ar & mNumRows; |
|
117 |
- ar & mNumCols; |
|
118 |
- }*/ |
|
119 |
- |
|
120 | 99 |
public: |
121 | 100 |
|
122 | 101 |
ColMatrix(unsigned nrow, unsigned ncol); |
... | ... |
@@ -134,6 +113,9 @@ public: |
134 | 113 |
|
135 | 114 |
void update(const MatrixChange &change); |
136 | 115 |
Rcpp::NumericMatrix rMatrix() const; |
116 |
+ |
|
117 |
+ friend void operator<<(Archive &ar, ColMatrix &mat); |
|
118 |
+ friend void operator>>(Archive &ar, ColMatrix &mat); |
|
137 | 119 |
}; |
138 | 120 |
|
139 | 121 |
// gain performance at the expense of memory |
... | ... |
@@ -144,14 +126,6 @@ private: |
144 | 126 |
RowMatrix mRowMatrix; |
145 | 127 |
ColMatrix mColMatrix; |
146 | 128 |
|
147 |
- /*friend class boost::serialization::access; |
|
148 |
- template<class Archive> |
|
149 |
- void serialize(Archive &ar) |
|
150 |
- { |
|
151 |
- ar & mRowMatrix; |
|
152 |
- ar & mColMatrix; |
|
153 |
- }*/ |
|
154 |
- |
|
155 | 129 |
public: |
156 | 130 |
|
157 | 131 |
TwoWayMatrix(unsigned nrow, unsigned ncol) |
... | ... |
@@ -165,12 +139,6 @@ public: |
165 | 139 |
unsigned nRow() const {return mRowMatrix.nRow();} |
166 | 140 |
unsigned nCol() const {return mRowMatrix.nCol();} |
167 | 141 |
|
168 |
- // TODO remove since accessing this way defeats the purpose |
|
169 |
- /*matrix_data_t operator()(unsigned r, unsigned c) const |
|
170 |
- { |
|
171 |
- return mRowMatrix(r,c); |
|
172 |
- }*/ |
|
173 |
- |
|
174 | 142 |
const Vector& getRow(unsigned row) const {return mRowMatrix.getRow(row);} |
175 | 143 |
const Vector& getCol(unsigned col) const {return mColMatrix.getCol(col);} |
176 | 144 |
|
... | ... |
@@ -184,6 +152,9 @@ public: |
184 | 152 |
{ |
185 | 153 |
return mRowMatrix.rMatrix(); |
186 | 154 |
} |
155 |
+ |
|
156 |
+ friend void operator<<(Archive &ar, TwoWayMatrix &mat); |
|
157 |
+ friend void operator>>(Archive &ar, TwoWayMatrix &mat); |
|
187 | 158 |
}; |
188 | 159 |
|
189 | 160 |
#endif |
190 | 161 |
\ No newline at end of file |
... | ... |
@@ -10,21 +10,26 @@ |
10 | 10 |
#include <boost/math/distributions/normal.hpp> |
11 | 11 |
#include <boost/math/distributions/exponential.hpp> |
12 | 12 |
|
13 |
+#include <boost/random/mersenne_twister.hpp> |
|
14 |
+ |
|
13 | 15 |
#include <stdint.h> |
14 | 16 |
|
15 | 17 |
#define Q_GAMMA_THRESHOLD 1E-6 |
16 | 18 |
#define Q_GAMMA_MIN_VALUE 0.0 |
17 | 19 |
|
20 |
+typedef boost::random::mt19937 RNGType; |
|
21 |
+//typedef boost::random::mt11213b RNGType; // should be faster |
|
22 |
+ |
|
18 | 23 |
static RNGType rng; |
19 | 24 |
|
20 |
-RNGType gaps::random::getGenerator() |
|
25 |
+void gaps::random::save(Archive &ar) |
|
21 | 26 |
{ |
22 |
- return rng; |
|
27 |
+ ar << rng; |
|
23 | 28 |
} |
24 | 29 |
|
25 |
-void gaps::random::setGenerator(RNGType temp) |
|
30 |
+void gaps::random::load(Archive &ar) |
|
26 | 31 |
{ |
27 |
- rng = temp; |
|
32 |
+ ar >> rng; |
|
28 | 33 |
} |
29 | 34 |
|
30 | 35 |
void gaps::random::setSeed(uint32_t seed) |
... | ... |
@@ -1,13 +1,11 @@ |
1 | 1 |
#ifndef __COGAPS_RANDOM_H__ |
2 | 2 |
#define __COGAPS_RANDOM_H__ |
3 | 3 |
|
4 |
+#include "Archive.h" |
|
5 |
+ |
|
4 | 6 |
#include <stdint.h> |
5 | 7 |
#include <vector> |
6 | 8 |
#include <fstream> |
7 |
-#include <boost/random/mersenne_twister.hpp> |
|
8 |
- |
|
9 |
-typedef boost::random::mt19937 RNGType; |
|
10 |
-//typedef boost::random::mt11213b RNGType; // should be faster |
|
11 | 9 |
|
12 | 10 |
namespace gaps |
13 | 11 |
{ |
... | ... |
@@ -33,8 +31,8 @@ namespace random |
33 | 31 |
double q_norm(double q, double mean, double sd); |
34 | 32 |
double p_norm(double p, double mean, double sd); |
35 | 33 |
|
36 |
- RNGType getGenerator(); |
|
37 |
- void setGenerator(RNGType rng); |
|
34 |
+ void save(Archive &ar); |
|
35 |
+ void load(Archive &ar); |
|
38 | 36 |
} |
39 | 37 |
|
40 | 38 |
} |
... | ... |
@@ -32,6 +32,17 @@ BEGIN_RCPP |
32 | 32 |
return rcpp_result_gen; |
33 | 33 |
END_RCPP |
34 | 34 |
} |
35 |
+// cogapsFromCheckpoint |
|
36 |
+Rcpp::List cogapsFromCheckpoint(const std::string& fileName); |
|
37 |
+RcppExport SEXP _CoGAPS_cogapsFromCheckpoint(SEXP fileNameSEXP) { |
|
38 |
+BEGIN_RCPP |
|
39 |
+ Rcpp::RObject rcpp_result_gen; |
|
40 |
+ Rcpp::RNGScope rcpp_rngScope_gen; |
|
41 |
+ Rcpp::traits::input_parameter< const std::string& >::type fileName(fileNameSEXP); |
|
42 |
+ rcpp_result_gen = Rcpp::wrap(cogapsFromCheckpoint(fileName)); |
|
43 |
+ return rcpp_result_gen; |
|
44 |
+END_RCPP |
|
45 |
+} |
|
35 | 46 |
// run_catch_unit_tests |
36 | 47 |
int run_catch_unit_tests(); |
37 | 48 |
RcppExport SEXP _CoGAPS_run_catch_unit_tests() { |
... | ... |
@@ -45,6 +56,7 @@ END_RCPP |
45 | 56 |
|
46 | 57 |
static const R_CallMethodDef CallEntries[] = { |
47 | 58 |
{"_CoGAPS_cogaps", (DL_FUNC) &_CoGAPS_cogaps, 17}, |
59 |
+ {"_CoGAPS_cogapsFromCheckpoint", (DL_FUNC) &_CoGAPS_cogapsFromCheckpoint, 1}, |
|
48 | 60 |
{"_CoGAPS_run_catch_unit_tests", (DL_FUNC) &_CoGAPS_run_catch_unit_tests, 0}, |
49 | 61 |
{NULL, NULL, 0} |
50 | 62 |
}; |