Browse code

unit tests for checkpoints

sherman5 authored on 10/01/2018 17:37:42
Showing10 changed files

... ...
@@ -31,4 +31,5 @@ update_remotes.sh
31 31
 inst/profiling/callgrind.out.*
32 32
 inst/profiling/vgcore.*
33 33
 inst/profiling/valgrind_out.txt
34
->>>>>>> develop
34
+
35
+*.temp
... ...
@@ -66,7 +66,6 @@ gapsRun <- function(D, S, ABins = data.frame(), PBins = data.frame(),
66 66
                     alphaP = 0.01, nMaxP = 100000, max_gibbmass_paraP = 100.0,
67 67
                     seed=-1, messages=TRUE, singleCellRNASeq=FALSE,
68 68
                     fixedPatterns = matrix(0), whichMatrixFixed = 'N',
69
-                    checkpoint_file_name = "gaps_checkpoint.out",
70 69
                     checkpoint_interval = 0)
71 70
 {
72 71
     # Floor the parameters that are integers to prevent allowing doubles.
... ...
@@ -13,8 +13,7 @@ gapsRun(D, S, ABins = data.frame(), PBins = data.frame(), nFactor = 7,
13 13
   max_gibbmass_paraA = 100, alphaP = 0.01, nMaxP = 1e+05,
14 14
   max_gibbmass_paraP = 100, seed = -1, messages = TRUE,
15 15
   singleCellRNASeq = FALSE, fixedPatterns = matrix(0),
16
-  whichMatrixFixed = "N", checkpoint_file_name = "gaps_checkpoint.out",
17
-  checkpoint_interval = 0)
16
+  whichMatrixFixed = "N", checkpoint_interval = 0)
18 17
 }
19 18
 \arguments{
20 19
 \item{D}{data matrix}
... ...
@@ -71,9 +70,9 @@ domain for relative probabilities}
71 70
 \item{whichMatrixFixed}{character to indicate whether A or P matrix
72 71
 contains the fixed patterns}
73 72
 
74
-\item{checkpoint_file_name}{name of file to store checkpoint}
75
-
76 73
 \item{checkpoint_interval}{time (in seconds) between cogaps checkpoints}
74
+
75
+\item{checkpoint_file_name}{name of file to store checkpoint}
77 76
 }
78 77
 \description{
79 78
 \code{gapsRun} calls the C++ MCMC code and performs Bayesian
... ...
@@ -98,13 +98,12 @@ public:
98 98
     double totalMass() const {return mTotalMass;}
99 99
     uint64_t numAtoms() const {return mNumAtoms;}
100 100
     double at(uint64_t loc) const {return mAtomicDomain.at(loc);}
101
+    double test_at(uint64_t loc) const {return mAtomicDomain.count(loc) ? mAtomicDomain.at(loc) : 0.0;}
101 102
 
102 103
     // setters
103 104
     void setAlpha(double alpha) {mAlpha = alpha;}
104 105
     void setLambda(double lambda) {mLambda = lambda;}
105 106
 
106
-    void serializeAndWrite(const std::ofstream &file);
107
-
108 107
     friend void operator<<(Archive &ar, AtomicSupport &sampler);
109 108
     friend void operator>>(Archive &ar, AtomicSupport &sampler);
110 109
 };
... ...
@@ -10,6 +10,11 @@
10 10
 #include <boost/archive/text_oarchive.hpp>
11 11
 #include <boost/archive/text_iarchive.hpp>
12 12
 
13
+// no C++11 std::to_string
14
+#include <sstream>
15
+#define SSTR( x ) static_cast< std::ostringstream & >( \
16
+        ( std::ostringstream() << std::dec << x ) ).str()
17
+
13 18
 #define ARCHIVE_MAGIC_NUM 0xCE45D32A
14 19
 
15 20
 namespace bpt = boost::posix_time;
... ...
@@ -18,9 +23,11 @@ static bpt::ptime lastCheckpoint;
18 23
 
19 24
 static void createCheckpoint(GapsInternalState &state)
20 25
 {
26
+    state.numCheckpoints++;
21 27
     std::cout << "creating gaps checkpoint...";
22 28
     bpt::ptime start = bpt::microsec_clock::local_time();
23
-    Archive ar("gaps_checkpoint.out", ARCHIVE_WRITE);
29
+    Archive ar("gaps_checkpoint_" + SSTR(state.numCheckpoints) + ".out",
30
+        ARCHIVE_WRITE);
24 31
     ar << ARCHIVE_MAGIC_NUM;
25 32
     gaps::random::save(ar);
26 33
     ar << state.nEquil;
... ...
@@ -41,6 +41,7 @@ struct GapsInternalState
41 41
     uint32_t seed;
42 42
 
43 43
     long checkpointInterval;
44
+    unsigned numCheckpoints;
44 45
 
45 46
     GibbsSampler sampler;
46 47
 
... ...
@@ -60,7 +61,7 @@ struct GapsInternalState
60 61
         nIterA(10), nIterP(10), nEquil(nE), nEquilCool(nEC), nSample(nS),
61 62
         nSnapshots(numSnapshots), nOutputs(numOutputs), messages(msg),
62 63
         iter(0), phase(GAPS_CALIBRATION), seed(in_seed),
63
-        checkpointInterval(cptInterval),
64
+        checkpointInterval(cptInterval), numCheckpoints(0),
64 65
         sampler(DMatrix, SMatrix, nFactor, alphaA, alphaP,
65 66
             maxGibbsMassA, maxGibbsMassP, singleCellRNASeq, fixedPatterns,
66 67
             whichMatrixFixed)
... ...
@@ -75,7 +76,7 @@ struct GapsInternalState
75 76
     {}
76 77
 };
77 78
 
78
-void operator<<(Archive &ar, GapsInternalState &state)
79
+inline void operator<<(Archive &ar, GapsInternalState &state)
79 80
 {
80 81
     ar << state.chi2VecEquil;
81 82
     ar << state.nAtomsAEquil;
... ...
@@ -95,10 +96,11 @@ void operator<<(Archive &ar, GapsInternalState &state)
95 96
     ar << state.phase;
96 97
     ar << state.seed;
97 98
     ar << state.checkpointInterval;
99
+    ar << state.numCheckpoints;
98 100
     ar << state.sampler;
99 101
 }
100 102
 
101
-void operator>>(Archive &ar, GapsInternalState &state)
103
+inline void operator>>(Archive &ar, GapsInternalState &state)
102 104
 {
103 105
     ar >> state.chi2VecEquil;
104 106
     ar >> state.nAtomsAEquil;
... ...
@@ -117,7 +119,8 @@ void operator>>(Archive &ar, GapsInternalState &state)
117 119
     ar >> state.iter;
118 120
     ar >> state.phase;
119 121
     ar >> state.seed;
120
-    ar << state.checkpointInterval;
122
+    ar >> state.checkpointInterval;
123
+    ar >> state.numCheckpoints;
121 124
     ar >> state.sampler;
122 125
 }
123 126
 
... ...
@@ -11,5 +11,6 @@ OBJECTS =   Algorithms.o \
11 11
             cpp_tests/testAtomicSupport.o \
12 12
             cpp_tests/testGibbsSampler.o \
13 13
             cpp_tests/testMatrix.o \
14
-            cpp_tests/testRandom.o
14
+            cpp_tests/testRandom.o \
15
+            cpp_tests/testSerialization.o
15 16
 
... ...
@@ -126,14 +126,10 @@ TEST_CASE("Internal AtomicSupport Tests")
126 126
     {
127 127
         for (unsigned i = 0; i < 10000; ++i)
128 128
         {
129
-            REQUIRE(Adomain.getRow(gaps::random::uniform64()) >= 0);
130 129
             REQUIRE(Adomain.getRow(gaps::random::uniform64()) < nrow);
131
-            REQUIRE(Adomain.getCol(gaps::random::uniform64()) >= 0);
132 130
             REQUIRE(Adomain.getCol(gaps::random::uniform64()) < ncol);
133 131
 
134
-            REQUIRE(Pdomain.getRow(gaps::random::uniform64()) >= 0);
135 132
             REQUIRE(Pdomain.getRow(gaps::random::uniform64()) < nrow);
136
-            REQUIRE(Pdomain.getCol(gaps::random::uniform64()) >= 0);
137 133
             REQUIRE(Pdomain.getCol(gaps::random::uniform64()) < ncol);            
138 134
         }
139 135
     }
... ...
@@ -30,9 +30,6 @@ TEST_CASE("Test Random.h - Random Number Generation")
30 30
 
31 31
     SECTION("Test uniform distribution over general interval")
32 32
     {
33
-        // invalid bounds
34
-        REQUIRE_THROWS(gaps::random::uniform(2.0, 1.4));
35
-
36 33
         // bounds equal
37 34
         REQUIRE(gaps::random::uniform(4.3,4.3) == 4.3);
38 35
 
... ...
@@ -0,0 +1,219 @@
1
+#include "catch.h"
2
+#include "../Archive.h"
3
+#include "../Matrix.h"
4
+#include "../AtomicSupport.h"
5
+#include "../GibbsSampler.h"
6
+#include "../InternalState.h"
7
+#include "../Random.h"
8
+
9
+TEST_CASE("Test Archive.h")
10
+{
11
+    SECTION("Reading/Writing to an Archive")
12
+    {
13
+        Archive ar1("test_ar.temp", ARCHIVE_WRITE);
14
+        ar1 << 3;
15
+        ar1.close();
16
+
17
+        Archive ar2("test_ar.temp", ARCHIVE_READ);
18
+        unsigned i = 0;
19
+        ar2 >> i;
20
+        REQUIRE(i == 3);
21
+        ar2.close();
22
+    }
23
+
24
+    SECTION("Serialization of primitive types")
25
+    {
26
+        // test values
27
+        unsigned u_read = 0, u_write = 456;
28
+        uint32_t u32_read = 0, u32_write = 512;
29
+        uint64_t u64_read = 0, u64_write = 0xAABBCCDDEE;
30
+        float f_read = 0.f, f_write = 0.123542f;
31
+        double d_read = 0., d_write = 0.54362;
32
+        bool b_read = false, b_write = true;
33
+
34
+        // write to archive
35
+        Archive arWrite("test_ar.temp", ARCHIVE_WRITE);
36
+        arWrite << u_write;
37
+        arWrite << u32_write;
38
+        arWrite << u64_write;
39
+        arWrite << f_write;
40
+        arWrite << d_write;
41
+        arWrite << b_write;
42
+        arWrite.close();
43
+
44
+        // read from archive
45
+        Archive arRead("test_ar.temp", ARCHIVE_READ);
46
+        arRead >> u_read;
47
+        arRead >> u32_read;
48
+        arRead >> u64_read;
49
+        arRead >> f_read;
50
+        arRead >> d_read;
51
+        arRead >> b_read;
52
+        arRead.close();
53
+
54
+        // test that values are the same
55
+        REQUIRE(u_read == u_write);
56
+        REQUIRE(u32_read == u32_write);
57
+        REQUIRE(u64_read == u64_write);
58
+        REQUIRE(f_read == f_write);
59
+        REQUIRE(d_read == d_write);
60
+        REQUIRE(b_read == b_write);
61
+    }
62
+    
63
+    SECTION("Vector Serialization")
64
+    {
65
+        Vector vec_read(100), vec_write(100);
66
+        for (unsigned i = 0; i < 100; ++i)
67
+        {
68
+            vec_write[i] = gaps::random::normal(0.0, 2.0);
69
+        }
70
+
71
+        Archive arWrite("test_ar.temp", ARCHIVE_WRITE);
72
+        arWrite << vec_write;
73
+        arWrite.close();
74
+
75
+        Archive arRead("test_ar.temp", ARCHIVE_READ);
76
+        arRead >> vec_read;
77
+        arRead.close();
78
+
79
+        REQUIRE(vec_read.size() == vec_write.size());
80
+
81
+        for (unsigned i = 0; i < 100; ++i)
82
+        {
83
+            REQUIRE(vec_read[i] == vec_write[i]);
84
+        }
85
+    }
86
+
87
+    SECTION("Matrix Serialization")
88
+    {
89
+        RowMatrix rMat_read(100,100), rMat_write(100,100);
90
+        ColMatrix cMat_read(100,100), cMat_write(100,100);
91
+        TwoWayMatrix twMat_read(100,100), twMat_write(100,100);
92
+
93
+        for (unsigned i = 0; i < 100; ++i)
94
+        {
95
+            for (unsigned j = 0; j < 100; ++j)
96
+            {
97
+                rMat_write(i,j) = gaps::random::normal(0.0, 2.0);
98
+                cMat_write(i,j) = gaps::random::normal(0.0, 2.0);
99
+                twMat_write.set(i,j,gaps::random::normal(0.0, 2.0));
100
+            }
101
+        }
102
+
103
+        Archive arWrite("test_ar.temp", ARCHIVE_WRITE);
104
+        arWrite << rMat_write;
105
+        arWrite << cMat_write;
106
+        arWrite << twMat_write;
107
+        arWrite.close();
108
+
109
+        Archive arRead("test_ar.temp", ARCHIVE_READ);
110
+        arRead >> rMat_read;
111
+        arRead >> cMat_read;
112
+        arRead >> twMat_read;
113
+        arRead.close();
114
+
115
+        REQUIRE(rMat_read.nRow() == rMat_write.nRow());
116
+        REQUIRE(rMat_read.nCol() == rMat_write.nCol());
117
+        REQUIRE(cMat_read.nRow() == cMat_write.nRow());
118
+        REQUIRE(cMat_read.nCol() == cMat_write.nCol());
119
+        REQUIRE(twMat_read.nRow() == twMat_write.nRow());
120
+        REQUIRE(twMat_read.nCol() == twMat_write.nCol());
121
+    
122
+        for (unsigned i = 0; i < 100; ++i)
123
+        {
124
+            for (unsigned j = 0; j < 100; ++j)
125
+            {
126
+                REQUIRE(rMat_read(i,j) == rMat_write(i,j));
127
+                REQUIRE(cMat_read(i,j) == cMat_write(i,j));
128
+                REQUIRE(twMat_read.getRow(i)[j] == twMat_write.getRow(i)[j]);
129
+            }
130
+        }
131
+    }
132
+    
133
+    SECTION("Atomic Serialization")
134
+    {
135
+        AtomicSupport domain_read('A',100,100), domain_write('A',100,100);
136
+        std::vector<uint64_t> locations;
137
+        for (unsigned i = 0; i < 10000; ++i)
138
+        {
139
+            AtomicProposal prop = domain_write.makeProposal();
140
+            locations.push_back(prop.pos1);
141
+            locations.push_back(prop.pos2);
142
+            domain_write.acceptProposal(prop);
143
+        }
144
+
145
+        Archive arWrite("test_ar.temp", ARCHIVE_WRITE);
146
+        arWrite << domain_write;
147
+        arWrite.close();
148
+    
149
+        Archive arRead("test_ar.temp", ARCHIVE_READ);
150
+        arRead >> domain_read;
151
+        arRead.close();
152
+
153
+        REQUIRE(domain_read.alpha() == domain_write.alpha());
154
+        REQUIRE(domain_read.lambda() == domain_write.lambda());
155
+        REQUIRE(domain_read.totalMass() == domain_write.totalMass());
156
+        REQUIRE(domain_read.numAtoms() == domain_write.numAtoms());
157
+
158
+#ifdef GAPS_INTERNAL_TESTS
159
+        std::map<uint64_t, double>::iterator readIter, writeIter;
160
+        readIter = domain_read.mAtomicDomain.begin();
161
+        writeIter = domain_write.mAtomicDomain.begin();        
162
+        while (readIter != domain_read.mAtomicDomain.end())
163
+        {
164
+            REQUIRE(readIter->first == writeIter->first);   
165
+            REQUIRE(readIter->second == writeIter->second);
166
+            ++readIter;
167
+            ++writeIter;
168
+        }
169
+#endif
170
+    }
171
+
172
+    SECTION("GibbsSampler Serialization")
173
+    {
174
+
175
+    }
176
+
177
+    SECTION("GapsInternalState Serialization")
178
+    {
179
+
180
+    }
181
+
182
+    SECTION("Random Generator Serialization")
183
+    {
184
+        std::vector<double> randSequence;
185
+
186
+        gaps::random::setSeed(123);
187
+        volatile double burn_in = 0.0;
188
+        for (unsigned i = 0; i < 1000; ++i)
189
+        {
190
+            burn_in = gaps::random::uniform(0,1);
191
+        }
192
+        REQUIRE(burn_in < 1);
193
+
194
+        Archive arWrite("test_ar.temp", ARCHIVE_WRITE);
195
+        gaps::random::save(arWrite);
196
+        arWrite.close();
197
+
198
+        for (unsigned i = 0; i < 1000; ++i)
199
+        {
200
+            randSequence.push_back(gaps::random::uniform());
201
+            randSequence.push_back(gaps::random::uniform(0.3, 6.4));
202
+            randSequence.push_back(gaps::random::normal(0.0, 2.0));
203
+            randSequence.push_back(gaps::random::exponential(5.5));
204
+        }
205
+    
206
+        gaps::random::setSeed(11);
207
+        Archive arRead("test_ar.temp", ARCHIVE_READ);
208
+        gaps::random::load(arRead);
209
+        arRead.close();
210
+
211
+        for (unsigned i = 0; i < 1000; ++i)
212
+        {
213
+            REQUIRE(gaps::random::uniform() == randSequence[i++]);
214
+            REQUIRE(gaps::random::uniform(0.3, 6.4) == randSequence[i++]);
215
+            REQUIRE(gaps::random::normal(0.0, 2.0) == randSequence[i++]);
216
+            REQUIRE(gaps::random::exponential(5.5) == randSequence[i]);
217
+        }
218
+    }
219
+}
0 220
\ No newline at end of file