... | ... |
@@ -24,6 +24,12 @@ static bool vecContains(const std::vector<Atom*> &vec, uint64_t pos) |
24 | 24 |
return std::binary_search(vec.begin(), vec.end(), &temp, compareAtom); |
25 | 25 |
} |
26 | 26 |
|
27 |
+// check if a position in contained in a vector of positions |
|
28 |
+static bool vecContains(const std::vector<uint64_t> &vec, uint64_t pos) |
|
29 |
+{ |
|
30 |
+ return std::binary_search(vec.begin(), vec.end(), pos); |
|
31 |
+} |
|
32 |
+ |
|
27 | 33 |
// used in debug mode to check if vector is always sorted |
28 | 34 |
static bool isSorted(const std::vector<Atom*> &vec) |
29 | 35 |
{ |
... | ... |
@@ -132,11 +138,16 @@ AtomNeighborhood AtomicDomain::randomAtomWithRightNeighbor(GapsRng *rng) |
132 | 138 |
return AtomNeighborhood(NULL, mAtoms[index], right); |
133 | 139 |
} |
134 | 140 |
|
135 |
-uint64_t AtomicDomain::randomFreePosition(GapsRng *rng) const |
|
141 |
+uint64_t AtomicDomain::randomFreePosition(GapsRng *rng, |
|
142 |
+const std::vector<uint64_t> &possibleDeaths) const |
|
136 | 143 |
{ |
137 | 144 |
uint64_t pos = rng->uniform64(1, mDomainLength); |
138 | 145 |
while (vecContains(mAtoms, pos)) |
139 | 146 |
{ |
147 |
+ if (vecContains(possibleDeaths, pos)) |
|
148 |
+ { |
|
149 |
+ return 0; // might actually be a free position |
|
150 |
+ } |
|
140 | 151 |
pos = rng->uniform64(1, mDomainLength); |
141 | 152 |
} |
142 | 153 |
return pos; |
... | ... |
@@ -163,6 +174,25 @@ void AtomicDomain::erase(uint64_t pos) |
163 | 174 |
delete a; |
164 | 175 |
} |
165 | 176 |
|
177 |
+void AtomicDomain::move(uint64_t src, uint64_t dest) |
|
178 |
+{ |
|
179 |
+ GAPS_ASSERT(size() > 0); |
|
180 |
+ GAPS_ASSERT(vecContains(mAtoms, src)); |
|
181 |
+ |
|
182 |
+ #pragma omp critical(AtomicInsertOrErase) |
|
183 |
+ { |
|
184 |
+ std::vector<Atom*>::iterator it = std::lower_bound(mAtoms.begin(), mAtoms.end(), src, compareAtomLower); |
|
185 |
+ unsigned ndx = std::distance(mAtoms.begin(), it); |
|
186 |
+ while (ndx + 1 < mAtoms.size() && dest > mAtoms[ndx + 1]->pos) |
|
187 |
+ { |
|
188 |
+ Atom* temp = mAtoms[ndx]; |
|
189 |
+ mAtoms[ndx] = mAtoms[ndx + 1]; |
|
190 |
+ mAtoms[ndx + 1] = temp; |
|
191 |
+ ++ndx; |
|
192 |
+ } |
|
193 |
+ } |
|
194 |
+} |
|
195 |
+ |
|
166 | 196 |
Atom* AtomicDomain::insert(uint64_t pos, float mass) |
167 | 197 |
{ |
168 | 198 |
Atom *newAtom = new Atom(pos, mass); |
... | ... |
@@ -33,6 +33,8 @@ struct AtomNeighborhood |
33 | 33 |
bool hasRight(); |
34 | 34 |
}; |
35 | 35 |
|
36 |
+class ProposalQueue; // needed for friend |
|
37 |
+ |
|
36 | 38 |
class AtomicDomain |
37 | 39 |
{ |
38 | 40 |
public: |
... | ... |
@@ -49,12 +51,13 @@ public: |
49 | 51 |
Atom* getLeftNeighbor(uint64_t pos); |
50 | 52 |
Atom* getRightNeighbor(uint64_t pos); |
51 | 53 |
|
52 |
- uint64_t randomFreePosition(GapsRng *rng) const; |
|
54 |
+ uint64_t randomFreePosition(GapsRng *rng, |
|
55 |
+ const std::vector<uint64_t> &possibleDeaths) const; |
|
53 | 56 |
uint64_t size() const; |
54 | 57 |
|
55 | 58 |
// these need to happen concurrently without invalidating pointers |
56 | 59 |
void erase(uint64_t pos); |
57 |
- Atom* insert(uint64_t pos, float mass); |
|
60 |
+ void move(uint64_t src, uint64_t dest); |
|
58 | 61 |
|
59 | 62 |
// serialization |
60 | 63 |
friend Archive& operator<<(Archive &ar, AtomicDomain &domain); |
... | ... |
@@ -64,6 +67,10 @@ public: |
64 | 67 |
private: |
65 | 68 |
#endif |
66 | 69 |
|
70 |
+ // only the proposal queue can insert |
|
71 |
+ friend class ProposalQueue; |
|
72 |
+ Atom* insert(uint64_t pos, float mass); |
|
73 |
+ |
|
67 | 74 |
// size of atomic domain to ensure all bins are equal length |
68 | 75 |
uint64_t mDomainLength; |
69 | 76 |
|
... | ... |
@@ -62,6 +62,7 @@ unsigned getNumPatterns(const Rcpp::List &allParams) |
62 | 62 |
{ |
63 | 63 |
std::string file(Rcpp::as<std::string>(allParams["checkpointInFile"])); |
64 | 64 |
Archive ar(file, ARCHIVE_READ); |
65 |
+ GapsRng::load(ar); |
|
65 | 66 |
ar >> nPatterns; |
66 | 67 |
ar.close(); |
67 | 68 |
} |
... | ... |
@@ -99,6 +100,7 @@ const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix, bool isMaster) |
99 | 100 |
{ |
100 | 101 |
// calculate essential parameters needed for constructing GapsRunner |
101 | 102 |
const Rcpp::S4 &gapsParams(allParams["gaps"]); |
103 |
+ GapsRng::setSeed(gapsParams.slot("seed")); |
|
102 | 104 |
unsigned nPatterns = getNumPatterns(allParams); // TODO clarify this sets the checkpoint seed as well |
103 | 105 |
bool printThreads = !processDistributedParameters(allParams).first; |
104 | 106 |
bool partitionRows = processDistributedParameters(allParams).second; |
... | ... |
@@ -106,7 +108,7 @@ const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix, bool isMaster) |
106 | 108 |
|
107 | 109 |
// construct GapsRunner |
108 | 110 |
GapsRunner runner(data, allParams["transposeData"], nPatterns, |
109 |
- partitionRows, cIndices, gapsParams.slot("seed")); |
|
111 |
+ partitionRows, cIndices); |
|
110 | 112 |
|
111 | 113 |
// set uncertainty |
112 | 114 |
if (!isNull(uncertainty)) |
... | ... |
@@ -120,6 +122,7 @@ const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix, bool isMaster) |
120 | 122 |
{ |
121 | 123 |
std::string file(Rcpp::as<std::string>(allParams["checkpointInFile"])); |
122 | 124 |
Archive ar(file, ARCHIVE_READ); |
125 |
+ GapsRng::load(ar); |
|
123 | 126 |
ar >> runner; |
124 | 127 |
ar.close(); |
125 | 128 |
} |
... | ... |
@@ -134,6 +137,7 @@ const Rcpp::Nullable<Rcpp::NumericMatrix> &fixedMatrix, bool isMaster) |
134 | 137 |
} |
135 | 138 |
|
136 | 139 |
// set parameters that would be saved in the checkpoint |
140 |
+ runner.recordSeed(gapsParams.slot("seed")); |
|
137 | 141 |
runner.setMaxIterations(gapsParams.slot("nIterations")); |
138 | 142 |
runner.setSparsity(gapsParams.slot("alphaA"), |
139 | 143 |
gapsParams.slot("alphaP"), gapsParams.slot("singleCell")); |
... | ... |
@@ -52,8 +52,7 @@ public: |
52 | 52 |
|
53 | 53 |
template <class DataType> |
54 | 54 |
GapsRunner(const DataType &data, bool transposeData, unsigned nPatterns, |
55 |
- bool partitionRows, const std::vector<unsigned> &indices, |
|
56 |
- uint32_t seed); |
|
55 |
+ bool partitionRows, const std::vector<unsigned> &indices); |
|
57 | 56 |
|
58 | 57 |
template <class DataType> |
59 | 58 |
void setUncertainty(const DataType &unc, bool transposeData, |
... | ... |
@@ -61,6 +60,7 @@ public: |
61 | 60 |
|
62 | 61 |
void setFixedMatrix(char which, const Matrix &mat); |
63 | 62 |
|
63 |
+ void recordSeed(uint32_t seed); |
|
64 | 64 |
uint32_t getSeed() const; |
65 | 65 |
|
66 | 66 |
void setMaxIterations(unsigned nIterations); |
... | ... |
@@ -82,8 +82,7 @@ public: |
82 | 82 |
// problem with passing file parser - need to read it twice |
83 | 83 |
template <class DataType> |
84 | 84 |
GapsRunner::GapsRunner(const DataType &data, bool transposeData, |
85 |
-unsigned nPatterns, bool partitionRows, const std::vector<unsigned> &indices, |
|
86 |
-uint32_t seed) |
|
85 |
+unsigned nPatterns, bool partitionRows, const std::vector<unsigned> &indices) |
|
87 | 86 |
: |
88 | 87 |
mASampler(data, !transposeData, nPatterns,!partitionRows, indices), |
89 | 88 |
mPSampler(data, transposeData, nPatterns, partitionRows, indices), |
... | ... |
@@ -91,14 +90,10 @@ mStatistics(mPSampler.dataRows(), mPSampler.dataCols(), nPatterns), |
91 | 90 |
mFixedMatrix('N'), mMaxIterations(1000), mMaxThreads(1), mPrintMessages(true), |
92 | 91 |
mOutputFrequency(500), mCheckpointOutFile("gaps_checkpoint.out"), |
93 | 92 |
mCheckpointInterval(0), mPhase('C'), mCurrentIteration(0), |
94 |
-mNumPatterns(nPatterns), mSeed(seed), mNumUpdatesA(0), mNumUpdatesP(0), |
|
95 |
-mRng(seed) |
|
93 |
+mNumPatterns(nPatterns), mSeed(0), mNumUpdatesA(0), mNumUpdatesP(0) |
|
96 | 94 |
{ |
97 | 95 |
mASampler.sync(mPSampler); |
98 | 96 |
mPSampler.sync(mASampler); |
99 |
- |
|
100 |
- mASampler.setSeed(mRng.uniform64()); |
|
101 |
- mPSampler.setSeed(mRng.uniform64()); |
|
102 | 97 |
} |
103 | 98 |
|
104 | 99 |
template <class DataType> |
... | ... |
@@ -44,11 +44,6 @@ void GibbsSampler::setMatrix(const Matrix &mat) |
44 | 44 |
mMatrix = mat; |
45 | 45 |
} |
46 | 46 |
|
47 |
-void GibbsSampler::setSeed(uint64_t seed) |
|
48 |
-{ |
|
49 |
- mSeeder.seed(seed); |
|
50 |
-} |
|
51 |
- |
|
52 | 47 |
float GibbsSampler::chi2() const |
53 | 48 |
{ |
54 | 49 |
return 2.f * gaps::algo::loglikelihood(mDMatrix, mSMatrix, mAPMatrix); |
... | ... |
@@ -79,26 +74,19 @@ void GibbsSampler::update(unsigned nSteps, unsigned nCores) |
79 | 74 |
mQueue.populate(mDomain, nSteps - n); |
80 | 75 |
n += mQueue.size(); |
81 | 76 |
|
82 |
- // update average queue count |
|
83 |
- #ifdef GAPS_DEBUG |
|
84 |
- mNumQueues += 1.f; |
|
85 |
- mAvgQueue *= (mNumQueues - 1.f) / mNumQueues; |
|
86 |
- mAvgQueue += mQueue.size() / mNumQueues; |
|
87 |
- #endif |
|
88 |
- |
|
89 | 77 |
// process all proposed updates |
90 | 78 |
#pragma omp parallel for num_threads(nCores) |
91 | 79 |
for (unsigned i = 0; i < mQueue.size(); ++i) |
92 | 80 |
{ |
93 |
- processProposal(&mQueue[i]); |
|
81 |
+ processProposal(mQueue[i]); |
|
94 | 82 |
} |
95 | 83 |
mQueue.clear(); |
96 | 84 |
} |
97 | 85 |
} |
98 | 86 |
|
99 |
-void GibbsSampler::processProposal(AtomicProposal *prop) |
|
87 |
+void GibbsSampler::processProposal(const AtomicProposal &prop) |
|
100 | 88 |
{ |
101 |
- switch (prop->type) |
|
89 |
+ switch (prop.type) |
|
102 | 90 |
{ |
103 | 91 |
case 'B': |
104 | 92 |
birth(prop); |
... | ... |
@@ -117,45 +105,40 @@ void GibbsSampler::processProposal(AtomicProposal *prop) |
117 | 105 |
|
118 | 106 |
// add an atom at a random position, calculate mass either with an |
119 | 107 |
// exponential distribution or with the gibbs mass distribution |
120 |
-void GibbsSampler::birth(AtomicProposal *prop) |
|
108 |
+void GibbsSampler::birth(const AtomicProposal &prop) |
|
121 | 109 |
{ |
122 |
- unsigned row = getRow(prop->atom1->pos); |
|
123 |
- unsigned col = getCol(prop->atom1->pos); |
|
124 |
- |
|
125 | 110 |
// calculate proposed mass |
126 |
- float mass = canUseGibbs(col) |
|
127 |
- ? gibbsMass(alphaParameters(row, col), &(prop->rng)).value() |
|
128 |
- : prop->rng.exponential(mLambda); |
|
111 |
+ float mass = canUseGibbs(prop.c1) |
|
112 |
+ ? gibbsMass(alphaParameters(prop.r1, prop.c1), &(prop.rng)).value() |
|
113 |
+ : prop.rng.exponential(mLambda); |
|
129 | 114 |
|
130 | 115 |
// accept mass as long as it's non-zero |
131 | 116 |
if (mass >= gaps::epsilon) |
132 | 117 |
{ |
133 | 118 |
mQueue.acceptBirth(); |
134 |
- prop->atom1->mass = mass; |
|
135 |
- changeMatrix(row, col, mass); |
|
119 |
+ prop.atom1->mass = mass; |
|
120 |
+ changeMatrix(prop.r1, prop.c1, mass); |
|
136 | 121 |
} |
137 | 122 |
else |
138 | 123 |
{ |
139 | 124 |
mQueue.rejectBirth(); |
140 |
- mDomain.erase(prop->atom1->pos); |
|
125 |
+ mDomain.erase(prop.atom1->pos); |
|
141 | 126 |
} |
142 | 127 |
} |
143 | 128 |
|
144 | 129 |
// automatically accept death, attempt a rebirth at the same position, using |
145 | 130 |
// the original mass or the gibbs mass distribution |
146 |
-void GibbsSampler::death(AtomicProposal *prop) |
|
131 |
+void GibbsSampler::death(const AtomicProposal &prop) |
|
147 | 132 |
{ |
148 |
- unsigned row = getRow(prop->atom1->pos); |
|
149 |
- unsigned col = getCol(prop->atom1->pos); |
|
150 |
- |
|
151 | 133 |
// calculate alpha parameters assuming atom dies |
152 |
- AlphaParameters alpha = alphaParametersWithChange(row, col, -prop->atom1->mass); |
|
134 |
+ AlphaParameters alpha = alphaParametersWithChange(prop.r1, prop.c1, |
|
135 |
+ -prop.atom1->mass); |
|
153 | 136 |
|
154 | 137 |
// try to calculate rebirth mass using gibbs distribution, otherwise exponetial |
155 |
- float rebirthMass = prop->atom1->mass; |
|
156 |
- if (canUseGibbs(col)) |
|
138 |
+ float rebirthMass = prop.atom1->mass; |
|
139 |
+ if (canUseGibbs(prop.c1)) |
|
157 | 140 |
{ |
158 |
- OptionalFloat gMass = gibbsMass(alpha, &(prop->rng)); |
|
141 |
+ OptionalFloat gMass = gibbsMass(alpha, &(prop.rng)); |
|
159 | 142 |
if (gMass.hasValue()) |
160 | 143 |
{ |
161 | 144 |
rebirthMass = gMass.value(); |
... | ... |
@@ -164,84 +147,73 @@ void GibbsSampler::death(AtomicProposal *prop) |
164 | 147 |
|
165 | 148 |
// accept/reject rebirth |
166 | 149 |
float deltaLL = getDeltaLL(alpha, rebirthMass) * mAnnealingTemp; |
167 |
- if (std::log(prop->rng.uniform()) < deltaLL) |
|
150 |
+ if (std::log(prop.rng.uniform()) < deltaLL) |
|
168 | 151 |
{ |
169 | 152 |
mQueue.rejectDeath(); |
170 |
- if (rebirthMass != prop->atom1->mass) |
|
153 |
+ if (rebirthMass != prop.atom1->mass) |
|
171 | 154 |
{ |
172 |
- safelyChangeMatrix(row, col, rebirthMass - prop->atom1->mass); |
|
155 |
+ safelyChangeMatrix(prop.r1, prop.c1, rebirthMass - prop.atom1->mass); |
|
173 | 156 |
} |
174 |
- prop->atom1->mass = rebirthMass; |
|
157 |
+ prop.atom1->mass = rebirthMass; |
|
175 | 158 |
} |
176 | 159 |
else |
177 | 160 |
{ |
178 | 161 |
mQueue.acceptDeath(); |
179 |
- safelyChangeMatrix(row, col, -prop->atom1->mass); |
|
180 |
- mDomain.erase(prop->atom1->pos); |
|
162 |
+ safelyChangeMatrix(prop.r1, prop.c1, -prop.atom1->mass); |
|
163 |
+ mDomain.erase(prop.atom1->pos); |
|
181 | 164 |
} |
182 | 165 |
} |
183 | 166 |
|
184 | 167 |
// move mass from src to dest in the atomic domain |
185 |
-void GibbsSampler::move(AtomicProposal *prop) |
|
168 |
+void GibbsSampler::move(const AtomicProposal &prop) |
|
186 | 169 |
{ |
187 |
- unsigned r1 = getRow(prop->atom1->pos); |
|
188 |
- unsigned c1 = getCol(prop->atom1->pos); |
|
189 |
- unsigned r2 = getRow(prop->pos); |
|
190 |
- unsigned c2 = getCol(prop->pos); |
|
191 |
- GAPS_ASSERT(r1 != r2 || c1 != c2); |
|
192 |
- |
|
193 |
- AlphaParameters alpha = alphaParameters(r1, c1, r2, c2); |
|
194 |
- if (std::log(prop->rng.uniform()) < getDeltaLL(alpha, -prop->atom1->mass) * mAnnealingTemp) |
|
170 |
+ AlphaParameters alpha = alphaParameters(prop.r1, prop.c1, prop.r2, prop.c2); |
|
171 |
+ if (std::log(prop.rng.uniform()) < getDeltaLL(alpha, -prop.atom1->mass) * mAnnealingTemp) |
|
195 | 172 |
{ |
196 |
- prop->atom1->pos = prop->pos; |
|
197 |
- safelyChangeMatrix(r1, c1, -prop->atom1->mass); |
|
198 |
- changeMatrix(r2, c2, prop->atom1->mass); |
|
173 |
+ //prop.atom1->pos = prop.pos; |
|
174 |
+ mDomain.move(prop.atom1->pos, prop.pos); |
|
175 |
+ safelyChangeMatrix(prop.r1, prop.c1, -prop.atom1->mass); |
|
176 |
+ changeMatrix(prop.r2, prop.c2, prop.atom1->mass); |
|
199 | 177 |
} |
200 | 178 |
} |
201 | 179 |
|
202 | 180 |
// exchange some amount of mass between two positions, note it is possible |
203 | 181 |
// for one of the atoms to be deleted if it's mass becomes too small |
204 |
-void GibbsSampler::exchange(AtomicProposal *prop) |
|
182 |
+void GibbsSampler::exchange(const AtomicProposal &prop) |
|
205 | 183 |
{ |
206 |
- unsigned r1 = getRow(a1->pos); |
|
207 |
- unsigned c1 = getCol(a1->pos); |
|
208 |
- unsigned r2 = getRow(a2->pos); |
|
209 |
- unsigned c2 = getCol(a2->pos); |
|
210 |
- GAPS_ASSERT(r1 != r2 || c1 != c2); |
|
211 |
- |
|
212 | 184 |
// attempt gibbs distribution exchange |
213 |
- AlphaParameters alpha = alphaParameters(r1, c1, r2, c2); |
|
214 |
- if (canUseGibbs(c1, c2)) |
|
185 |
+ AlphaParameters alpha = alphaParameters(prop.r1, prop.c1, prop.r2, prop.c2); |
|
186 |
+ if (canUseGibbs(prop.c1, prop.c2)) |
|
215 | 187 |
{ |
216 |
- OptionalFloat gMass = gibbsMass(alpha, prop->atom1->mass, |
|
217 |
- prop->atom2->mass, &(prop->rng)); |
|
188 |
+ OptionalFloat gMass = gibbsMass(alpha, prop.atom1->mass, |
|
189 |
+ prop.atom2->mass, &(prop.rng)); |
|
218 | 190 |
if (gMass.hasValue()) |
219 | 191 |
{ |
220 |
- acceptExchange(prop->atom1, prop->atom2, gMass.value(), r1, c1, r2, c2); |
|
192 |
+ acceptExchange(prop, gMass.value()); |
|
221 | 193 |
return; |
222 | 194 |
} |
223 | 195 |
} |
224 | 196 |
|
225 | 197 |
// resort to metropolis-hastings if gibbs fails |
226 |
- exchangeUsingMetropolisHastings(prop, alpha, r1, c1, r2, c2); |
|
198 |
+ exchangeUsingMetropolisHastings(prop, alpha); |
|
227 | 199 |
} |
228 | 200 |
|
229 |
-void GibbsSampler::exchangeUsingMetropolisHastings(AtomicProposal *prop, |
|
230 |
-AlphaParameters alpha, unsigned r1, unsigned c1, unsigned r2, unsigned c2) |
|
201 |
+void GibbsSampler::exchangeUsingMetropolisHastings(const AtomicProposal &prop, |
|
202 |
+AlphaParameters alpha) |
|
231 | 203 |
{ |
232 | 204 |
// compute amount of mass to be exchanged |
233 |
- float totalMass = prop->atom1->mass + prop->atom2->mass; |
|
234 |
- float newMass = prop->rng.truncGammaUpper(totalMass, 2.f, 1.f / mLambda); |
|
205 |
+ float totalMass = prop.atom1->mass + prop.atom2->mass; |
|
206 |
+ float newMass = prop.rng.truncGammaUpper(totalMass, 2.f, 1.f / mLambda); |
|
235 | 207 |
|
236 | 208 |
// compute amount to change atom1 by - always change larger mass to newMass |
237 |
- float delta = (prop->atom1->mass > prop->atom2->mass) |
|
238 |
- ? newMass - prop->atom1->mass |
|
239 |
- : prop->atom2->mass - newMass; |
|
209 |
+ float delta = (prop.atom1->mass > prop.atom2->mass) |
|
210 |
+ ? newMass - prop.atom1->mass |
|
211 |
+ : prop.atom2->mass - newMass; |
|
240 | 212 |
|
241 | 213 |
// choose mass for priorLL calculation |
242 | 214 |
float oldMass = (2.f * newMass > totalMass) |
243 |
- ? gaps::max(prop->atom1->mass, prop->atom2->mass) |
|
244 |
- : gaps::min(prop->atom1->mass, prop->atom2->mass); |
|
215 |
+ ? gaps::max(prop.atom1->mass, prop.atom2->mass) |
|
216 |
+ : gaps::min(prop.atom1->mass, prop.atom2->mass); |
|
245 | 217 |
|
246 | 218 |
// calculate priorLL |
247 | 219 |
float pNew = gaps::d_gamma(newMass, 2.f, 1.f / mLambda); |
... | ... |
@@ -250,24 +222,23 @@ AlphaParameters alpha, unsigned r1, unsigned c1, unsigned r2, unsigned c2) |
250 | 222 |
|
251 | 223 |
// accept/reject |
252 | 224 |
float deltaLL = getDeltaLL(alpha, delta) * mAnnealingTemp; |
253 |
- if (priorLL == 0.f || std::log(prop->rng.uniform() * priorLL) < deltaLL) |
|
225 |
+ if (priorLL == 0.f || std::log(prop.rng.uniform() * priorLL) < deltaLL) |
|
254 | 226 |
{ |
255 |
- acceptExchange(prop->atom1, prop->atom2, delta, r1, c1, r2, c2); |
|
227 |
+ acceptExchange(prop, delta); |
|
256 | 228 |
return; |
257 | 229 |
} |
258 | 230 |
} |
259 | 231 |
|
260 | 232 |
// helper function for exchange step |
261 |
-void GibbsSampler::acceptExchange(Atom *a1, Atom *a2, float delta, |
|
262 |
-unsigned r1, unsigned c1, unsigned r2, unsigned c2) |
|
233 |
+void GibbsSampler::acceptExchange(const AtomicProposal &prop, float delta) |
|
263 | 234 |
{ |
264 |
- if (a1->mass + delta > gaps::epsilon && a2->mass - delta > gaps::epsilon) |
|
235 |
+ if (prop.atom1->mass + delta > gaps::epsilon && prop.atom2->mass - delta > gaps::epsilon) |
|
265 | 236 |
{ |
266 |
- a1->mass += delta; |
|
267 |
- a2->mass -= delta; |
|
237 |
+ prop.atom1->mass += delta; |
|
238 |
+ prop.atom2->mass -= delta; |
|
268 | 239 |
|
269 |
- changeMatrix(r1, c1, delta); |
|
270 |
- changeMatrix(r2, c2, -delta); |
|
240 |
+ changeMatrix(prop.r1, prop.c1, delta); |
|
241 |
+ changeMatrix(prop.r2, prop.c2, -delta); |
|
271 | 242 |
} |
272 | 243 |
} |
273 | 244 |
|
... | ... |
@@ -319,7 +290,15 @@ float m1, float m2, GapsRng *rng) |
319 | 290 |
return OptionalFloat(); |
320 | 291 |
} |
321 | 292 |
|
322 |
-// needed to prevent negative values in matrix |
|
293 |
+// here mass + delta is guaranteed to be positive |
|
294 |
+void GibbsSampler::changeMatrix(unsigned row, unsigned col, float delta) |
|
295 |
+{ |
|
296 |
+ mMatrix(row, col) += delta; |
|
297 |
+ GAPS_ASSERT(mMatrix(row, col) >= 0.f); |
|
298 |
+ updateAPMatrix(row, col, delta); |
|
299 |
+} |
|
300 |
+ |
|
301 |
+// delta could be negative, this is needed to prevent negative values in matrix |
|
323 | 302 |
void GibbsSampler::safelyChangeMatrix(unsigned row, unsigned col, float delta) |
324 | 303 |
{ |
325 | 304 |
float newVal = gaps::max(mMatrix(row, col) + delta, 0.f); |
... | ... |
@@ -350,16 +329,6 @@ void GibbsSampler::updateAPMatrix(unsigned row, unsigned col, float delta) |
350 | 329 |
} |
351 | 330 |
} |
352 | 331 |
|
353 |
-unsigned GibbsSampler::getRow(uint64_t pos) const |
|
354 |
-{ |
|
355 |
- return pos / (mBinSize * mNumPatterns); // nCol == nPatterns |
|
356 |
-} |
|
357 |
- |
|
358 |
-unsigned GibbsSampler::getCol(uint64_t pos) const |
|
359 |
-{ |
|
360 |
- return (pos / mBinSize) % mNumPatterns; // nCol == nPatterns |
|
361 |
-} |
|
362 |
- |
|
363 | 332 |
bool GibbsSampler::canUseGibbs(unsigned col) const |
364 | 333 |
{ |
365 | 334 |
return !gaps::algo::isVectorZero(mOtherMatrix->colPtr(col), |
... | ... |
@@ -2,6 +2,7 @@ |
2 | 2 |
#define __COGAPS_GIBBS_SAMPLER_H__ |
3 | 3 |
|
4 | 4 |
#include "AtomicDomain.h" |
5 |
+#include "ProposalQueue.h" |
|
5 | 6 |
#include "data_structures/Matrix.h" |
6 | 7 |
#include "math/Algorithms.h" |
7 | 8 |
#include "math/Random.h" |
... | ... |
@@ -27,7 +28,6 @@ public: |
27 | 28 |
void setMaxGibbsMass(float max); |
28 | 29 |
void setAnnealingTemp(float temp); |
29 | 30 |
void setMatrix(const Matrix &mat); |
30 |
- void setSeed(uint64_t seed); |
|
31 | 31 |
|
32 | 32 |
float chi2() const; |
33 | 33 |
uint64_t nAtoms() const; |
... | ... |
@@ -68,21 +68,22 @@ private: |
68 | 68 |
uint64_t mBinSize; |
69 | 69 |
uint64_t mDomainLength; |
70 | 70 |
|
71 |
- void makeAndProcessProposal(GapsRng *rng); |
|
71 |
+ void processProposal(const AtomicProposal &prop); |
|
72 | 72 |
float deathProb(uint64_t nAtoms) const; |
73 | 73 |
|
74 |
- void birth(GapsRng *rng); |
|
75 |
- void death(GapsRng *rng); |
|
76 |
- void move(GapsRng *rng); |
|
77 |
- void exchange(GapsRng *rng); |
|
78 |
- |
|
79 |
- void acceptExchange(Atom *a1, Atom *a2, float d1, unsigned r1, |
|
80 |
- unsigned c1, unsigned r2, unsigned c2); |
|
74 |
+ void birth(const AtomicProposal &prop); |
|
75 |
+ void death(const AtomicProposal &prop); |
|
76 |
+ void move(const AtomicProposal &prop); |
|
77 |
+ void exchange(const AtomicProposal &prop); |
|
78 |
+ void exchangeUsingMetropolisHastings(const AtomicProposal &prop, |
|
79 |
+ AlphaParameters alpha); |
|
80 |
+ void acceptExchange(const AtomicProposal &prop, float delta); |
|
81 | 81 |
bool updateAtomMass(Atom *atom, float delta); |
82 | 82 |
|
83 | 83 |
OptionalFloat gibbsMass(AlphaParameters alpha, GapsRng *rng); |
84 | 84 |
OptionalFloat gibbsMass(AlphaParameters alpha, float m1, float m2, GapsRng *rng); |
85 | 85 |
|
86 |
+ void changeMatrix(unsigned row, unsigned col, float delta); |
|
86 | 87 |
void safelyChangeMatrix(unsigned row, unsigned col, float delta); |
87 | 88 |
void updateAPMatrix(unsigned row, unsigned col, float delta); |
88 | 89 |
|
... | ... |
@@ -105,6 +106,7 @@ mAPMatrix(mDMatrix.nRow(), mDMatrix.nCol()), |
105 | 106 |
mMatrix(mDMatrix.nCol(), nPatterns), |
106 | 107 |
mOtherMatrix(NULL), |
107 | 108 |
mDomain(mMatrix.nRow() * mMatrix.nCol()), |
109 |
+mQueue(mMatrix.nRow(), mMatrix.nCol()), |
|
108 | 110 |
mLambda(0.f), |
109 | 111 |
mMaxGibbsMass(100.f), |
110 | 112 |
mAnnealingTemp(1.f), |
... | ... |
@@ -8,8 +8,10 @@ OBJECTS = AtomicDomain.o \ |
8 | 8 |
GapsRunner.o \ |
9 | 9 |
GapsStatistics.o \ |
10 | 10 |
GibbsSampler.o \ |
11 |
+ ProposalQueue.o \ |
|
11 | 12 |
RcppExports.o \ |
12 | 13 |
test-runner.o \ |
14 |
+ data_structures/HashSets.o \ |
|
13 | 15 |
data_structures/Matrix.o \ |
14 | 16 |
data_structures/Vector.o \ |
15 | 17 |
file_parser/CsvParser.o \ |
... | ... |
@@ -21,7 +23,6 @@ OBJECTS = AtomicDomain.o \ |
21 | 23 |
math/Random.o \ |
22 | 24 |
cpp_tests/testAlgorithms.o \ |
23 | 25 |
cpp_tests/testAtomicDomain.o \ |
24 |
- cpp_tests/testEfficientSets.o \ |
|
25 | 26 |
cpp_tests/testFileParsers.o \ |
26 | 27 |
cpp_tests/testMatrix.o \ |
27 | 28 |
cpp_tests/testRandom.o \ |
... | ... |
@@ -1,6 +1,6 @@ |
1 |
-#include "GapsAssert.h" |
|
2 | 1 |
#include "ProposalQueue.h" |
3 | 2 |
#include "math/Random.h" |
3 |
+#include "utils/GapsAssert.h" |
|
4 | 4 |
|
5 | 5 |
//////////////////////////////// AtomicProposal //////////////////////////////// |
6 | 6 |
|
... | ... |
@@ -11,20 +11,24 @@ AtomicProposal::AtomicProposal(char t) |
11 | 11 |
|
12 | 12 |
//////////////////////////////// ProposalQueue ///////////////////////////////// |
13 | 13 |
|
14 |
-ProposalQueue::ProposalQueue(unsigned primaryDimSize, unsigned secondaryDimSize) |
|
14 |
+ProposalQueue::ProposalQueue(unsigned nrow, unsigned ncol) |
|
15 | 15 |
: |
16 |
-mMinAtoms(0), mMaxAtoms(0), mNumBins(primaryDimSize * secondaryDimSize), |
|
17 |
-mBinLength(std::numeric_limits<uint64_t>::max() / mNumBins), |
|
18 |
-mSecondaryDimLength(mBinLength * secondaryDimSize), |
|
19 |
-mDomainLength(mBinLength * mNumBins), mSecondaryDimSize(secondaryDimSize), |
|
20 |
-mAlpha(0.f), mU1(0.f), mU2(0.f), mUseCachedRng(false) |
|
21 |
-{ |
|
22 |
- mUsedIndices.setDimensionSize(primaryDimSize); |
|
23 |
-} |
|
16 |
+mUsedMatrixIndices(nrow), |
|
17 |
+mMinAtoms(0), |
|
18 |
+mMaxAtoms(0), |
|
19 |
+mBinLength(std::numeric_limits<uint64_t>::max() / static_cast<uint64_t>(nrow * ncol)), |
|
20 |
+mNumCols(ncol), |
|
21 |
+mAlpha(0.0), |
|
22 |
+mDomainLength(static_cast<double>(mBinLength * static_cast<uint64_t>(nrow * ncol))), |
|
23 |
+mNumBins(static_cast<double>(nrow * ncol)), |
|
24 |
+mU1(0.f), |
|
25 |
+mU2(0.f), |
|
26 |
+mUseCachedRng(false) |
|
27 |
+{} |
|
24 | 28 |
|
25 | 29 |
void ProposalQueue::setAlpha(float alpha) |
26 | 30 |
{ |
27 |
- mAlpha = alpha; |
|
31 |
+ mAlpha = static_cast<double>(alpha); |
|
28 | 32 |
} |
29 | 33 |
|
30 | 34 |
void ProposalQueue::populate(AtomicDomain &domain, unsigned limit) |
... | ... |
@@ -36,9 +40,9 @@ void ProposalQueue::populate(AtomicDomain &domain, unsigned limit) |
36 | 40 |
bool success = true; |
37 | 41 |
while (nIter++ < limit && success) |
38 | 42 |
{ |
39 |
- success = makeProposal(domain); |
|
40 |
- if (!success) |
|
43 |
+ if (!makeProposal(domain)) |
|
41 | 44 |
{ |
45 |
+ success = false; |
|
42 | 46 |
mUseCachedRng = true; |
43 | 47 |
} |
44 | 48 |
} |
... | ... |
@@ -49,8 +53,8 @@ void ProposalQueue::clear() |
49 | 53 |
GAPS_ASSERT(mMinAtoms == mMaxAtoms); |
50 | 54 |
|
51 | 55 |
mQueue.clear(); |
52 |
- mUsedPositions.clear(); |
|
53 |
- mUsedIndices.clear(); |
|
56 |
+ mUsedMatrixIndices.clear(); |
|
57 |
+ mUsedAtoms.clear(); |
|
54 | 58 |
} |
55 | 59 |
|
56 | 60 |
unsigned ProposalQueue::size() const |
... | ... |
@@ -69,25 +73,25 @@ AtomicProposal& ProposalQueue::operator[](int n) |
69 | 73 |
void ProposalQueue::acceptDeath() |
70 | 74 |
{ |
71 | 75 |
#pragma omp atomic |
72 |
- mMaxAtoms--; |
|
76 |
+ --mMaxAtoms; |
|
73 | 77 |
} |
74 | 78 |
|
75 | 79 |
void ProposalQueue::rejectDeath() |
76 | 80 |
{ |
77 | 81 |
#pragma omp atomic |
78 |
- mMinAtoms++; |
|
82 |
+ ++mMinAtoms; |
|
79 | 83 |
} |
80 | 84 |
|
81 | 85 |
void ProposalQueue::acceptBirth() |
82 | 86 |
{ |
83 | 87 |
#pragma omp atomic |
84 |
- mMinAtoms++; |
|
88 |
+ ++mMinAtoms; |
|
85 | 89 |
} |
86 | 90 |
|
87 | 91 |
void ProposalQueue::rejectBirth() |
88 | 92 |
{ |
89 | 93 |
#pragma omp atomic |
90 |
- mMaxAtoms--; |
|
94 |
+ --mMaxAtoms; |
|
91 | 95 |
} |
92 | 96 |
|
93 | 97 |
float ProposalQueue::deathProb(double nAtoms) const |
... | ... |
@@ -98,16 +102,14 @@ float ProposalQueue::deathProb(double nAtoms) const |
98 | 102 |
|
99 | 103 |
bool ProposalQueue::makeProposal(AtomicDomain &domain) |
100 | 104 |
{ |
101 |
- // special indeterminate case |
|
102 | 105 |
if (mMinAtoms < 2 && mMaxAtoms >= 2) |
103 | 106 |
{ |
104 |
- return false; |
|
107 |
+ return false; // special indeterminate case |
|
105 | 108 |
} |
106 | 109 |
|
107 |
- // always birth when no atoms exist |
|
108 | 110 |
if (mMaxAtoms < 2) |
109 | 111 |
{ |
110 |
- return birth(domain); |
|
112 |
+ return birth(domain); // always birth when no atoms exist |
|
111 | 113 |
} |
112 | 114 |
|
113 | 115 |
mU1 = mUseCachedRng ? mU1 : mRng.uniform(); |
... | ... |
@@ -131,24 +133,29 @@ bool ProposalQueue::makeProposal(AtomicDomain &domain) |
131 | 133 |
} |
132 | 134 |
return (mU1 < 0.75f) ? move(domain) : exchange(domain); |
133 | 135 |
} |
134 |
- |
|
136 |
+ |
|
135 | 137 |
bool ProposalQueue::birth(AtomicDomain &domain) |
136 | 138 |
{ |
137 | 139 |
AtomicProposal prop('B'); |
138 |
- prop.r1 = (prop.atom1->pos / mBinSize) / mNumCols; |
|
139 |
- prop.c1 = (prop.atom1->pos / mBinSize) % mNumCols; |
|
140 |
- |
|
141 |
- prop.init(); // initialize rng |
|
140 |
+ uint64_t pos = domain.randomFreePosition(&(prop.rng), mUsedAtoms.vec()); |
|
141 |
+ if (pos == 0) |
|
142 |
+ { |
|
143 |
+ DEBUG_PING // want to notify since this event should have near 0 probability |
|
144 |
+ GapsRng::rollBackOnce(); // ensure same proposal next time |
|
145 |
+ return false; // atom conflict, might have open spot if atom moves/dies |
|
146 |
+ } |
|
142 | 147 |
|
148 |
+ prop.r1 = (pos / mBinLength) / mNumCols; |
|
149 |
+ prop.c1 = (pos / mBinLength) % mNumCols; |
|
143 | 150 |
if (mUsedMatrixIndices.contains(prop.r1)) |
144 | 151 |
{ |
145 |
- return false; // matrix conflict - can't compute alpha parameters |
|
152 |
+ GapsRng::rollBackOnce(); // ensure same proposal next time |
|
153 |
+ return false; // matrix conflict - can't compute gibbs mass |
|
146 | 154 |
} |
147 |
- prop.atom1 = mDomain.insert(domain.randomFreePosition(), 0.f); |
|
155 |
+ prop.atom1 = domain.insert(pos, 0.f); |
|
148 | 156 |
|
149 | 157 |
mUsedMatrixIndices.insert(prop.r1); |
150 | 158 |
mUsedAtoms.insert(prop.atom1->pos); |
151 |
- |
|
152 | 159 |
mQueue.push_back(prop); |
153 | 160 |
++mMaxAtoms; |
154 | 161 |
return true; |
... | ... |
@@ -157,18 +164,18 @@ bool ProposalQueue::birth(AtomicDomain &domain) |
157 | 164 |
bool ProposalQueue::death(AtomicDomain &domain) |
158 | 165 |
{ |
159 | 166 |
AtomicProposal prop('D'); |
160 |
- prop->atom1 = domain.randomAtom(); |
|
161 |
- prop.r1 = (prop.atom1->pos / mBinSize) / mNumCols; |
|
162 |
- prop.c1 = (prop.atom1->pos / mBinSize) % mNumCols; |
|
167 |
+ prop.atom1 = domain.randomAtom(&(prop.rng)); |
|
168 |
+ prop.r1 = (prop.atom1->pos / mBinLength) / mNumCols; |
|
169 |
+ prop.c1 = (prop.atom1->pos / mBinLength) % mNumCols; |
|
163 | 170 |
|
164 | 171 |
if (mUsedMatrixIndices.contains(prop.r1)) |
165 | 172 |
{ |
166 |
- return false; // matrix conflict - can't compute alpha parameters |
|
173 |
+ GapsRng::rollBackOnce(); // ensure same proposal next time |
|
174 |
+ return false; // matrix conflict - can't compute gibbs mass or deltaLL |
|
167 | 175 |
} |
176 |
+ |
|
168 | 177 |
mUsedMatrixIndices.insert(prop.r1); |
169 | 178 |
mUsedAtoms.insert(prop.atom1->pos); |
170 |
- |
|
171 |
- prop.init(); // initialize rng |
|
172 | 179 |
mQueue.push_back(prop); |
173 | 180 |
--mMinAtoms; |
174 | 181 |
return true; |
... | ... |
@@ -177,107 +184,89 @@ bool ProposalQueue::death(AtomicDomain &domain) |
177 | 184 |
bool ProposalQueue::move(AtomicDomain &domain) |
178 | 185 |
{ |
179 | 186 |
AtomicProposal prop('M'); |
187 |
+ AtomNeighborhood hood = domain.randomAtomWithNeighbors(&(prop.rng)); |
|
180 | 188 |
|
181 |
- AtomNeighborhood hood = domain.randomAtomWithNeighbors(); |
|
182 |
- |
|
183 | 189 |
uint64_t lbound = hood.hasLeft() ? hood.left->pos : 0; |
184 | 190 |
uint64_t rbound = hood.hasRight() ? hood.right->pos : mDomainLength; |
185 |
- |
|
186 |
- uint64_t newLocation = mRng.uniform64(lbound + 1, rbound - 1); |
|
187 |
- |
|
188 | 191 |
if (mUsedAtoms.contains(lbound) || mUsedAtoms.contains(rbound)) |
189 | 192 |
{ |
193 |
+ GapsRng::rollBackOnce(); |
|
190 | 194 |
return false; // atomic conflict - don't know neighbors |
191 | 195 |
} |
192 | 196 |
|
193 |
- if (primaryIndex(hood.center->pos) == primaryIndex(newLocation) |
|
194 |
- && secondaryIndex(hood.center->pos) == secondaryIndex(newLocation)) |
|
197 |
+ prop.pos = prop.rng.uniform64(lbound + 1, rbound - 1); |
|
198 |
+ prop.atom1 = hood.center; |
|
199 |
+ prop.r1 = (prop.atom1->pos / mBinLength) / mNumCols; |
|
200 |
+ prop.c1 = (prop.atom1->pos / mBinLength) % mNumCols; |
|
201 |
+ prop.r2 = (prop.pos / mBinLength) / mNumCols; |
|
202 |
+ prop.c2 = (prop.pos / mBinLength) % mNumCols; |
|
203 |
+ |
|
204 |
+ if (mUsedMatrixIndices.contains(prop.r1) || mUsedMatrixIndices.contains(prop.r2)) |
|
195 | 205 |
{ |
196 |
- hood.center->pos = newLocation; // automatically accept moves in same bin |
|
197 |
- return true; |
|
206 |
+ GapsRng::rollBackOnce(); // ensure same proposal next time |
|
207 |
+ return false; // matrix conflict - can't compute deltaLL |
|
198 | 208 |
} |
199 | 209 |
|
200 |
- if (mUsedIndices.contains(primaryIndex(hood.center->pos)) |
|
201 |
- || mUsedIndices.contains(primaryIndex(newLocation))) |
|
210 |
+ if (prop.r1 == prop.r2 && prop.c1 == prop.c2) |
|
202 | 211 |
{ |
203 |
- return false; // matrix conflict - can't compute deltaLL |
|
212 |
+ prop.atom1->pos = prop.pos; // automatically accept moves in same bin |
|
213 |
+ return true; |
|
204 | 214 |
} |
205 | 215 |
|
206 |
- mQueue.push_back(AtomicProposal('M', hood.center, newLocation)); |
|
207 |
- mUsedIndices.insert(primaryIndex(hood.center->pos)); |
|
208 |
- mUsedIndices.insert(primaryIndex(newLocation)); |
|
209 |
- mUsedPositions.insert(hood.center->pos); |
|
210 |
- mUsedPositions.insert(newLocation); |
|
216 |
+ mQueue.push_back(prop); |
|
217 |
+ mUsedMatrixIndices.insert(prop.r1); |
|
218 |
+ mUsedMatrixIndices.insert(prop.r2); |
|
219 |
+ mUsedAtoms.insert(prop.atom1->pos); |
|
220 |
+ mUsedAtoms.insert(prop.pos); |
|
211 | 221 |
return true; |
212 | 222 |
} |
213 | 223 |
|
214 | 224 |
bool ProposalQueue::exchange(AtomicDomain &domain) |
215 | 225 |
{ |
216 |
- AtomNeighborhood hood = domain.randomAtomWithRightNeighbor(); |
|
217 |
- Atom* a1 = hood.center; |
|
218 |
- Atom* a2 = hood.hasRight() ? hood.right : domain.front(); |
|
226 |
+ AtomicProposal prop('E'); |
|
227 |
+ AtomNeighborhood hood = domain.randomAtomWithRightNeighbor(&(prop.rng)); |
|
228 |
+ prop.atom1 = hood.center; |
|
229 |
+ prop.atom2 = hood.hasRight() ? hood.right : domain.front(); |
|
219 | 230 |
|
220 |
- if (hood.hasRight()) // has neighbor |
|
221 |
- { |
|
222 |
- if (!mUsedPositions.isEmptyInterval(a1->pos, a2->pos)) |
|
223 |
- { |
|
224 |
- return false; // atomic conflict - don't know right neighbor |
|
225 |
- } |
|
226 |
- } |
|
227 |
- else // exchange with first atom |
|
228 |
- { |
|
229 |
- if (!mUsedPositions.isEmptyInterval(a1->pos, mDomainLength)) |
|
230 |
- { |
|
231 |
- return false; // atomic conflict - don't know right neighbor |
|
232 |
- } |
|
233 |
- |
|
234 |
- if (!mUsedPositions.isEmptyInterval(0, domain.front()->pos)) |
|
235 |
- { |
|
236 |
- return false; // atomic conflict - don't know right neighbor |
|
237 |
- } |
|
238 |
- } |
|
231 |
+ prop.r1 = (prop.atom1->pos / mBinLength) / mNumCols; |
|
232 |
+ prop.c1 = (prop.atom1->pos / mBinLength) % mNumCols; |
|
233 |
+ prop.r2 = (prop.atom2->pos / mBinLength) / mNumCols; |
|
234 |
+ prop.c2 = (prop.atom2->pos / mBinLength) % mNumCols; |
|
239 | 235 |
|
240 |
- if (primaryIndex(a1->pos) == primaryIndex(a2->pos) |
|
241 |
- && secondaryIndex(a1->pos) == secondaryIndex(a2->pos)) |
|
236 |
+ if (mUsedMatrixIndices.contains(prop.r1) || mUsedMatrixIndices.contains(prop.r2)) |
|
242 | 237 |
{ |
243 |
- GapsRng rng; |
|
244 |
- float newMass = rng.truncGammaUpper(a1->mass + a2->mass, 2.f, 1.f / mLambda); |
|
245 |
- float delta = (a1->mass > a2->mass) ? newMass - a1->mass : a2->mass - newMass; |
|
246 |
- if (a1->mass + delta > gaps::epsilon && a2->mass - delta > gaps::epsilon) |
|
247 |
- { |
|
248 |
- a1->mass += delta; |
|
249 |
- a2->mass -= delta; |
|
250 |
- } |
|
251 |
- return true; |
|
238 |
+ GapsRng::rollBackOnce(); // ensure same proposal next time |
|
239 |
+ return false; // matrix conflict - can't compute deltaLL or gibbs mass |
|
252 | 240 |
} |
253 | 241 |
|
254 |
- if (mUsedIndices.contains(primaryIndex(a1->pos)) |
|
255 |
- || mUsedIndices.contains(primaryIndex(a2->pos))) |
|
242 |
+ if (prop.r1 == prop.r2 && prop.c1 == prop.c2) |
|
256 | 243 |
{ |
257 |
- return false; // matrix conflict - can't compute gibbs mass or deltaLL |
|
244 |
+ //float newMass = prop.rng.truncGammaUpper(a1->mass + a2->mass, 2.f, 1.f / mLambda); |
|
245 |
+ //float delta = (a1->mass > a2->mass) ? newMass - a1->mass : a2->mass - newMass; |
|
246 |
+ //if (a1->mass + delta > gaps::epsilon && a2->mass - delta > gaps::epsilon) |
|
247 |
+ //{ |
|
248 |
+ // a1->mass += delta; |
|
249 |
+ // a2->mass -= delta; |
|
250 |
+ //} |
|
251 |
+ return true; // TODO automatically accept exchanges in same bin |
|
258 | 252 |
} |
259 | 253 |
|
260 |
- mQueue.push_back(AtomicProposal('E', a1, a2)); |
|
261 |
- mUsedIndices.insert(primaryIndex(a1->pos)); |
|
262 |
- mUsedIndices.insert(primaryIndex(a2->pos)); |
|
263 |
- mUsedPositions.insert(a1->pos); |
|
264 |
- mUsedPositions.insert(a2->pos); |
|
265 |
- --mMinAtoms; |
|
254 |
+ mQueue.push_back(prop); |
|
255 |
+ mUsedMatrixIndices.insert(prop.r1); |
|
256 |
+ mUsedMatrixIndices.insert(prop.r2); |
|
266 | 257 |
return true; |
267 | 258 |
} |
268 | 259 |
|
269 | 260 |
Archive& operator<<(Archive &ar, ProposalQueue &q) |
270 | 261 |
{ |
271 |
- ar << q.mMinAtoms << q.mMaxAtoms << q.mNumBins << q.mBinLength |
|
272 |
- << q.mSecondaryDimLength << q.mDomainLength << q.mSecondaryDimSize |
|
273 |
- << q.mAlpha << q.mRng; |
|
262 |
+ ar << q.mRng << q.mMinAtoms << q.mMaxAtoms << q.mBinLength << q.mNumCols |
|
263 |
+ << q.mAlpha << q.mDomainLength << q.mNumBins; |
|
274 | 264 |
return ar; |
275 | 265 |
} |
276 | 266 |
|
277 | 267 |
Archive& operator>>(Archive &ar, ProposalQueue &q) |
278 | 268 |
{ |
279 |
- ar >> q.mMinAtoms >> q.mMaxAtoms >> q.mNumBins >> q.mBinLength |
|
280 |
- >> q.mSecondaryDimLength >> q.mDomainLength >> q.mSecondaryDimSize |
|
281 |
- >> q.mAlpha >> q.mRng; |
|
269 |
+ ar >> q.mRng >> q.mMinAtoms >> q.mMaxAtoms >> q.mBinLength >> q.mNumCols |
|
270 |
+ >> q.mAlpha >> q.mDomainLength >> q.mNumBins; |
|
282 | 271 |
return ar; |
283 | 272 |
} |
284 | 273 |
\ No newline at end of file |
... | ... |
@@ -1,10 +1,10 @@ |
1 | 1 |
#ifndef __GAPS_PROPOSAL_QUEUE_H__ |
2 | 2 |
#define __GAPS_PROPOSAL_QUEUE_H__ |
3 | 3 |
|
4 |
-#include "Archive.h" |
|
5 | 4 |
#include "AtomicDomain.h" |
6 |
-#include "data_structures/EfficientSets.h" |
|
5 |
+#include "data_structures/HashSets.h" |
|
7 | 6 |
#include "math/Random.h" |
7 |
+#include "utils/Archive.h" |
|
8 | 8 |
|
9 | 9 |
#include <cstddef> |
10 | 10 |
#include <stdint.h> |
... | ... |
@@ -12,7 +12,7 @@ |
12 | 12 |
|
13 | 13 |
struct AtomicProposal |
14 | 14 |
{ |
15 |
- GapsRng rng; // used for consistency no matter number of threads |
|
15 |
+ mutable GapsRng rng; // used for consistency no matter number of threads |
|
16 | 16 |
|
17 | 17 |
uint64_t pos; // used for move |
18 | 18 |
Atom *atom1; // used for birth/death/move/exchange |
... | ... |
@@ -32,7 +32,8 @@ class ProposalQueue |
32 | 32 |
{ |
33 | 33 |
public: |
34 | 34 |
|
35 |
- ProposalQueue(unsigned primaryDimSize, unsigned secondaryDimSize); |
|
35 |
+ // initialize |
|
36 |
+ ProposalQueue(unsigned nrow, unsigned ncol); |
|
36 | 37 |
void setAlpha(float alpha); |
37 | 38 |
|
38 | 39 |
// modify/access queue |
... | ... |
@@ -47,6 +48,10 @@ public: |
47 | 48 |
void acceptBirth(); |
48 | 49 |
void rejectBirth(); |
49 | 50 |
|
51 |
+ // serialization |
|
52 |
+ friend Archive& operator<<(Archive &ar, ProposalQueue &queue); |
|
53 |
+ friend Archive& operator>>(Archive &ar, ProposalQueue &queue); |
|
54 |
+ |
|
50 | 55 |
private: |
51 | 56 |
|
52 | 57 |
std::vector<AtomicProposal> mQueue; // not really a queue for now |
... | ... |
@@ -58,34 +63,25 @@ private: |
58 | 63 |
|
59 | 64 |
uint64_t mMinAtoms; |
60 | 65 |
uint64_t mMaxAtoms; |
61 |
- uint64_t mBinLength; // atomic length of one bin |
|
62 |
- uint64_t mSecondaryDimLength; // atomic length of one row (col) for A (P) |
|
66 |
+ uint64_t mBinLength; // length of single bin |
|
67 |
+ uint64_t mNumCols; |
|
63 | 68 |
|
69 |
+ double mAlpha; |
|
64 | 70 |
double mDomainLength; // length of entire atomic domain |
65 | 71 |
double mNumBins; // number of matrix elements |
66 | 72 |
|
67 |
- unsigned mSecondaryDimSize; // number of cols (rows) for A (P) |
|
68 |
- |
|
69 |
- float mAlpha; |
|
70 | 73 |
float mU1; |
71 | 74 |
float mU2; |
72 | 75 |
|
73 | 76 |
bool mUseCachedRng; |
74 | 77 |
|
75 |
- unsigned primaryIndex(uint64_t pos) const; |
|
76 |
- unsigned secondaryIndex(uint64_t pos) const; |
|
78 |
+ float deathProb(double nAtoms) const; |
|
77 | 79 |
|
78 |
- float deathProb(uint64_t nAtoms) const; |
|
80 |
+ bool makeProposal(AtomicDomain &domain); |
|
79 | 81 |
bool birth(AtomicDomain &domain); |
80 | 82 |
bool death(AtomicDomain &domain); |
81 | 83 |
bool move(AtomicDomain &domain); |
82 | 84 |
bool exchange(AtomicDomain &domain); |
83 |
- |
|
84 |
- bool makeProposal(AtomicDomain &domain); |
|
85 |
- |
|
86 |
- // serialization |
|
87 |
- friend Archive& operator<<(Archive &ar, ProposalQueue &queue); |
|
88 |
- friend Archive& operator>>(Archive &ar, ProposalQueue &queue); |
|
89 | 85 |
}; |
90 | 86 |
|
91 | 87 |
#endif |
92 | 88 |
\ No newline at end of file |
93 | 89 |
deleted file mode 100644 |
... | ... |
@@ -1,30 +0,0 @@ |
1 |
-#include "catch.h" |
|
2 |
-#include "../data_structures/EfficientSets.h" |
|
3 |
- |
|
4 |
-TEST_CASE("Test IntFixedHashSet") |
|
5 |
-{ |
|
6 |
- IntFixedHashSet set; |
|
7 |
- |
|
8 |
- set.setDimensionSize(1000); |
|
9 |
- REQUIRE(!set.contains(123)); |
|
10 |
- |
|
11 |
- set.insert(123); |
|
12 |
- REQUIRE(set.contains(123)); |
|
13 |
- |
|
14 |
- set.clear(); |
|
15 |
- REQUIRE(!set.contains(123)); |
|
16 |
- |
|
17 |
- set.insert(123); |
|
18 |
- REQUIRE(set.contains(123)); |
|
19 |
-} |
|
20 |
- |
|
21 |
-TEST_CASE("Test IntDenseOrderedSet") |
|
22 |
-{ |
|
23 |
- IntDenseOrderedSet set; |
|
24 |
- |
|
25 |
- REQUIRE(set.isEmptyInterval(10, 500)); |
|
26 |
- set.insert(100); |
|
27 |
- REQUIRE(!set.isEmptyInterval(10, 500)); |
|
28 |
- set.clear(); |
|
29 |
- REQUIRE(set.isEmptyInterval(10, 500)); |
|
30 |
-} |
31 | 0 |
deleted file mode 100644 |
... | ... |
@@ -1,52 +0,0 @@ |
1 |
-#ifndef __COGAPS_EFFICIENT_SETS_H__ |
|
2 |
-#define __COGAPS_EFFICIENT_SETS_H__ |
|
3 |
- |
|
4 |
-#include <vector> |
|
5 |
-#include <stdint.h> |
|
6 |
- |
|
7 |
-class IntFixedHashSet |
|
8 |
-{ |
|
9 |
-public: |
|
10 |
- |
|
11 |
- IntFixedHashSet() : mCurrentKey(1) {} |
|
12 |
- |
|
13 |
- void setDimensionSize(unsigned size) {mSet.resize(size, 0);} |
|
14 |
- void clear() {++mCurrentKey;} |
|
15 |
- bool contains(unsigned n) {return mSet[n] == mCurrentKey;} |
|
16 |
- void insert(unsigned n) {mSet[n] = mCurrentKey;} |
|
17 |
- |
|
18 |
-private: |
|
19 |
- |
|
20 |
- std::vector<uint64_t> mSet; |
|
21 |
- uint64_t mCurrentKey; |
|
22 |
-}; |
|
23 |
- |
|
24 |
-// TODO have sorted vector with at least some % of holes |
|
25 |
-// even distribute entries along it |
|
26 |
-// when shift happens, should be minimal |
|
27 |
-class IntDenseOrderedSet |
|
28 |
-{ |
|
29 |
-public: |
|
30 |
- |
|
31 |
- void insert(uint64_t p) {mVec.push_back(p);} |
|
32 |
- void clear() {mVec.clear();} |
|
33 |
- |
|
34 |
- // inclusive of a and b, TODO improve performance |
|
35 |
- bool isEmptyInterval(uint64_t a, uint64_t b) |
|
36 |
- { |
|
37 |
- for (unsigned i = 0; i < mVec.size(); ++i) |
|
38 |
- { |
|
39 |
- if (mVec[i] >= a && mVec[i] <= b) |
|
40 |
- { |
|
41 |
- return false; |
|
42 |
- } |
|
43 |
- } |
|
44 |
- return true; |
|
45 |
- } |
|
46 |
- |
|
47 |
-private: |
|
48 |
- |
|
49 |
- std::vector<uint64_t> mVec; |
|
50 |
-}; |
|
51 |
- |
|
52 |
-#endif // __COGAPS_EFFICIENT_SETS_H__ |
|
53 | 0 |
\ No newline at end of file |
54 | 1 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,72 @@ |
1 |
+#include "HashSets.h" |
|
2 |
+ |
|
3 |
+///////////////////////////// FixedHashSetU32 ////////////////////////////////// |
|
4 |
+ |
|
5 |
+FixedHashSetU32::FixedHashSetU32(unsigned size) |
|
6 |
+ : mSet(std::vector<uint32_t>(size, 0)), mCurrentKey(1) |
|
7 |
+{} |
|
8 |
+ |
|
9 |
+void FixedHashSetU32::insert(unsigned n) |
|
10 |
+{ |
|
11 |
+ mSet[n] = mCurrentKey; |
|
12 |
+} |
|
13 |
+ |
|
14 |
+void FixedHashSetU32::clear() |
|
15 |
+{ |
|
16 |
+ ++mCurrentKey; |
|
17 |
+} |
|
18 |
+ |
|
19 |
+bool FixedHashSetU32::contains(unsigned n) |
|
20 |
+{ |
|
21 |
+ return mSet[n] == mCurrentKey; |
|
22 |
+} |
|
23 |
+ |
|
24 |
+bool FixedHashSetU32::isEmpty() |
|
25 |
+{ |
|
26 |
+ unsigned sz = mSet.size(); |
|
27 |
+ for (unsigned i = 0; i < sz; ++i) |
|
28 |
+ { |
|
29 |
+ if (mSet[i] == mCurrentKey) |
|
30 |
+ { |
|
31 |
+ return false; |
|
32 |
+ } |
|
33 |
+ } |
|
34 |
+ return true; |
|
35 |
+} |
|
36 |
+ |
|
37 |
+///////////////////////////// SmallHashSetU64 ////////////////////////////////// |
|
38 |
+ |
|
39 |
+SmallHashSetU64::SmallHashSetU64() {} |
|
40 |
+ |
|
41 |
+void SmallHashSetU64::insert(uint64_t pos) |
|
42 |
+{ |
|
43 |
+ mSet.push_back(pos); |
|
44 |
+} |
|
45 |
+ |
|
46 |
+void SmallHashSetU64::clear() |
|
47 |
+{ |
|
48 |
+ mSet.clear(); |
|
49 |
+} |
|
50 |
+ |
|
51 |
+bool SmallHashSetU64::contains(uint64_t pos) |
|
52 |
+{ |
|
53 |
+ unsigned sz = mSet.size(); |
|
54 |
+ for (unsigned i = 0; i < sz; ++i) |
|
55 |
+ { |
|
56 |
+ if (mSet[i] == pos) |
|
57 |
+ { |
|
58 |
+ return true; |
|
59 |
+ } |
|
60 |
+ } |
|
61 |
+ return false; |
|
62 |
+} |
|
63 |
+ |
|
64 |
+bool SmallHashSetU64::isEmpty() |
|
65 |
+{ |
|
66 |
+ return mSet.empty(); |
|
67 |
+} |
|
68 |
+ |
|
69 |
+const std::vector<uint64_t>& SmallHashSetU64::vec() |
|
70 |
+{ |
|
71 |
+ return mSet; |
|
72 |
+} |
0 | 73 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,49 @@ |
1 |
+#ifndef __COGAPS_HASH_SETS_H__ |
|
2 |
+#define __COGAPS_HASH_SETS_H__ |
|
3 |
+ |
|
4 |
+#include "../utils/GlobalConfig.h" |
|
5 |
+ |
|
6 |
+#include <stdint.h> |
|
7 |
+#include <vector> |
|
8 |
+ |
|
9 |
+#include <boost/unordered_set.hpp> |
|
10 |
+ |
|
11 |
+#ifdef __GAPS_OPENMP__ |
|
12 |
+#include <omp.h> |
|
13 |
+#endif |
|
14 |
+ |
|
15 |
+class FixedHashSetU32 |
|
16 |
+{ |
|
17 |
+public: |
|
18 |
+ |
|
19 |
+ FixedHashSetU32(unsigned size); |
|
20 |
+ |
|
21 |
+ void insert(unsigned n); |
|
22 |
+ void clear(); |
|
23 |
+ bool contains(unsigned n); |
|
24 |
+ bool isEmpty(); |
|
25 |
+ |
|
26 |
+private: |
|
27 |
+ |
|
28 |
+ std::vector<uint32_t> mSet; |
|
29 |
+ uint64_t mCurrentKey; |
|
30 |
+}; |
|
31 |
+ |
|
32 |
+class SmallHashSetU64 |
|
33 |
+{ |
|
34 |
+public: |
|
35 |
+ |
|
36 |
+ SmallHashSetU64(); |
|
37 |
+ |
|
38 |
+ void insert(uint64_t pos); |
|
39 |
+ void clear(); |
|
40 |
+ bool contains(uint64_t pos); |
|
41 |
+ bool isEmpty(); |
|
42 |
+ const std::vector<uint64_t>& vec(); |
|
43 |
+ |
|
44 |
+private: |
|
45 |
+ |
|
46 |
+ std::vector<uint64_t> mSet; |
|
47 |
+}; |
|
48 |
+ |
|
49 |
+#endif // __COGAPS_HASH_SETS_H__ |
... | ... |
@@ -22,7 +22,7 @@ namespace gaps |
22 | 22 |
template <class T> |
23 | 23 |
std::string to_string(T a); |
24 | 24 |
|
25 |
- const float epsilon = 1.0e-10f; |
|
25 |
+ const float epsilon = 1.0e-5f; |
|
26 | 26 |
const float pi = 3.1415926535897932384626433832795f; |
27 | 27 |
const float pi_double = 3.1415926535897932384626433832795; |
28 | 28 |
|
... | ... |
@@ -52,6 +52,9 @@ void Xoroshiro128plus::seed(uint64_t seed) |
52 | 52 |
|
53 | 53 |
uint64_t Xoroshiro128plus::next() |
54 | 54 |
{ |
55 |
+ mPreviousState[0] = mState[0]; |
|
56 |
+ mPreviousState[1] = mState[1]; |
|
57 |
+ |
|
55 | 58 |
const uint64_t s0 = mState[0]; |
56 | 59 |
uint64_t s1 = mState[1]; |
57 | 60 |
uint64_t result = s0 + s1; |
... | ... |
@@ -61,6 +64,12 @@ uint64_t Xoroshiro128plus::next() |
61 | 64 |
return result; |
62 | 65 |
} |
63 | 66 |
|
67 |
+void Xoroshiro128plus::rollBackOnce() |
|
68 |
+{ |
|
69 |
+ mState[0] = mPreviousState[0]; |
|
70 |
+ mState[1] = mPreviousState[1]; |
|
71 |
+} |
|
72 |
+ |
|
64 | 73 |
void Xoroshiro128plus::warmup() |
65 | 74 |
{ |
66 | 75 |
for (unsigned i = 0; i < 5000; ++i) |
... | ... |
@@ -100,15 +109,15 @@ Archive& GapsRng::load(Archive &ar) |
100 | 109 |
return ar; |
101 | 110 |
} |
102 | 111 |
|
103 |
-GapsRng::GapsRng() |
|
104 |
- : mState(0) |
|
105 |
-{} |
|
106 |
- |
|
107 |
-void GapsRng::init() |
|
112 |
+void GapsRng::rollBackOnce() |
|
108 | 113 |
{ |
109 |
- mState = seeder.next(); |
|
114 |
+ seeder.rollBackOnce(); |
|
110 | 115 |
} |
111 | 116 |
|
117 |
+GapsRng::GapsRng() |
|
118 |
+ : mState(seeder.next()) |
|
119 |
+{} |
|
120 |
+ |
|
112 | 121 |
uint32_t GapsRng::next() |
113 | 122 |
{ |
114 | 123 |
advance(); |
... | ... |
@@ -117,6 +126,7 @@ uint32_t GapsRng::next() |
117 | 126 |
|
118 | 127 |
void GapsRng::advance() |
119 | 128 |
{ |
129 |
+ mPreviousState = mState; |
|
120 | 130 |
mState = mState * 6364136223846793005ull + (54u|1); |
121 | 131 |
} |
122 | 132 |
|
... | ... |
@@ -274,13 +284,13 @@ float GapsRng::truncGammaUpper(float b, float shape, float scale) |
274 | 284 |
|
275 | 285 |
Archive& operator<<(Archive &ar, GapsRng &gen) |
276 | 286 |
{ |
277 |
- ar << gen.mState << gen.mStream; |
|
287 |
+ ar << gen.mState; |
|
278 | 288 |
return ar; |
279 | 289 |
} |
280 | 290 |
|
281 | 291 |
Archive& operator>>(Archive &ar, GapsRng &gen) |
282 | 292 |
{ |
283 |
- ar >> gen.mState >> gen.mStream; |
|
293 |
+ ar >> gen.mState; |
|
284 | 294 |
return ar; |
285 | 295 |
} |
286 | 296 |
|
... | ... |
@@ -49,10 +49,13 @@ public: |
49 | 49 |
|
50 | 50 |
void seed(uint64_t seed); |
51 | 51 |
uint64_t next(); |
52 |
+ void rollBackOnce(); |
|
52 | 53 |
|
53 | 54 |
private: |
54 | 55 |
|
55 | 56 |
uint64_t mState[2]; |
57 |
+ uint64_t mPreviousState[2]; |
|
58 |
+ |
|
56 | 59 |
void warmup(); |
57 | 60 |
|
58 | 61 |
friend Archive& operator<<(Archive &ar, Xoroshiro128plus &gen); |
... | ... |
@@ -83,12 +86,14 @@ public: |
83 | 86 |
float truncGammaUpper(float b, float shape, float scale); |
84 | 87 |
|
85 | 88 |
static void setSeed(uint32_t sd); |
86 |
- static void save(); |
|
87 |
- static void load(Archive &ar); |
|
88 |
- |
|
89 |
+ static Archive& save(Archive &ar); |
|
90 |
+ static Archive& load(Archive &ar); |
|
91 |
+ static void rollBackOnce(); |
|
92 |
+ |
|
89 | 93 |
private: |
90 | 94 |
|
91 | 95 |
uint64_t mState; |
96 |
+ uint32_t mPreviousState; |
|
92 | 97 |
|
93 | 98 |
uint32_t next(); |
94 | 99 |
void advance(); |