When re-factoring the logic involved in a Death update step, so that
some of it was handled by the DataModel instead of the Sampler, a bug was
introduced. This bug resulted in the rebirth attempt of the Death update
step being rejected if the gibbs sampling step failed, rather than
attempting the rebirth with the original atom's mass. This commit pulls
out the logic and implements the correct Death step entirely in the
Sampler.
... | ... |
@@ -147,23 +147,36 @@ void AsynchronousGibbsSampler<DataModel>::birth(const AtomicProposal &prop) |
147 | 147 |
template <class DataModel> |
148 | 148 |
void AsynchronousGibbsSampler<DataModel>::death(const AtomicProposal &prop) |
149 | 149 |
{ |
150 |
- // try to do a rebirth in the place of this atom |
|
150 |
+ // determine mass to attempt rebirth with |
|
151 |
+ float rebirthMass = prop.atom1->mass(); // default rebirth mass == no change to atom |
|
152 |
+ AlphaParameters alpha = DataModel::alphaParametersWithChange(prop.r1, prop.c1, |
|
153 |
+ -1.f * prop.atom1->mass()) * DataModel::annealingTemp(); |
|
151 | 154 |
if (DataModel::canUseGibbs(prop.c1)) |
152 | 155 |
{ |
153 |
- OptionalFloat mass = DataModel::sampleDeathAndRebirth(prop.r1, prop.c1, |
|
154 |
- -1.f * prop.atom1->mass(), &(prop.rng)); |
|
155 |
- if (mass.hasValue() && mass.value() >= gaps::epsilon) |
|
156 |
+ OptionalFloat gMass = gibbsMass(alpha, 0.f, DataModel::maxGibbsMass(), &(prop.rng), |
|
157 |
+ DataModel::lambda()); |
|
158 |
+ if (gMass.hasValue()) |
|
156 | 159 |
{ |
157 |
- mQueue.rejectDeath(); |
|
158 |
- DataModel::safelyChangeMatrix(prop.r1, prop.c1, mass.value() - prop.atom1->mass()); |
|
159 |
- prop.atom1->updateMass(mass.value()); |
|
160 |
- return; |
|
160 |
+ rebirthMass = gMass.value(); |
|
161 | 161 |
} |
162 | 162 |
} |
163 |
- // if rebirth fails, then kill off atom |
|
164 |
- mQueue.acceptDeath(); |
|
165 |
- DataModel::safelyChangeMatrix(prop.r1, prop.c1, -1.f * prop.atom1->mass()); |
|
166 |
- mDomain.cacheErase(prop.atom1); |
|
163 |
+ // handle accept/reject of the rebirth |
|
164 |
+ float deltaLL = rebirthMass * (alpha.s_mu - alpha.s * rebirthMass / 2.f); |
|
165 |
+ if (std::log(prop.rng.uniform()) < deltaLL) // accept |
|
166 |
+ { |
|
167 |
+ mQueue.rejectDeath(); |
|
168 |
+ if (rebirthMass != prop.atom1->mass()) |
|
169 |
+ { |
|
170 |
+ DataModel::safelyChangeMatrix(prop.r1, prop.c1, rebirthMass - prop.atom1->mass()); |
|
171 |
+ prop.atom1->updateMass(rebirthMass); |
|
172 |
+ } |
|
173 |
+ } |
|
174 |
+ else // reject |
|
175 |
+ { |
|
176 |
+ mQueue.acceptDeath(); |
|
177 |
+ DataModel::safelyChangeMatrix(prop.r1, prop.c1, -1.f * prop.atom1->mass()); |
|
178 |
+ mDomain.cacheErase(prop.atom1); |
|
179 |
+ } |
|
167 | 180 |
} |
168 | 181 |
|
169 | 182 |
// move mass from src to dest in the atomic domain |
... | ... |
@@ -82,11 +82,21 @@ uint64_t DenseNormalModel::nPatterns() const |
82 | 82 |
return mMatrix.nCol(); |
83 | 83 |
} |
84 | 84 |
|
85 |
+float DenseNormalModel::annealingTemp() const |
|
86 |
+{ |
|
87 |
+ return mAnnealingTemp; |
|
88 |
+} |
|
89 |
+ |
|
85 | 90 |
float DenseNormalModel::lambda() const |
86 | 91 |
{ |
87 | 92 |
return mLambda; |
88 | 93 |
} |
89 | 94 |
|
95 |
+float DenseNormalModel::maxGibbsMass() const |
|
96 |
+{ |
|
97 |
+ return mMaxGibbsMass; |
|
98 |
+} |
|
99 |
+ |
|
90 | 100 |
bool DenseNormalModel::canUseGibbs(unsigned col) const |
91 | 101 |
{ |
92 | 102 |
return !gaps::isVectorZero(mOtherMatrix->getCol(col)); |
... | ... |
@@ -33,7 +33,9 @@ protected: |
33 | 33 |
friend class GapsStatistics; |
34 | 34 |
uint64_t nElements() const; |
35 | 35 |
uint64_t nPatterns() const; |
36 |
+ float annealingTemp() const; |
|
36 | 37 |
float lambda() const; |
38 |
+ float maxGibbsMass() const; |
|
37 | 39 |
bool canUseGibbs(unsigned col) const; |
38 | 40 |
bool canUseGibbs(unsigned c1, unsigned c2) const; |
39 | 41 |
void changeMatrix(unsigned row, unsigned col, float delta); |
... | ... |
@@ -157,23 +157,34 @@ void SingleThreadedGibbsSampler<DataModel>::death() |
157 | 157 |
AtomType *atom = mDomain.randomAtom(&mRng); |
158 | 158 |
unsigned row = (atom->pos() / mBinLength) / mNumPatterns; |
159 | 159 |
unsigned col = (atom->pos() / mBinLength) % mNumPatterns; |
160 |
- |
|
161 |
- // try to do a rebirth in the place of this atom |
|
160 |
+ // determine mass to attempt rebirth with |
|
161 |
+ float rebirthMass = atom->mass(); // default rebirth mass == no change to atom |
|
162 |
+ AlphaParameters alpha = DataModel::alphaParametersWithChange(row, col, -1.f * atom->mass()) |
|
163 |
+ * DataModel::annealingTemp(); |
|
162 | 164 |
if (DataModel::canUseGibbs(col)) |
163 | 165 |
{ |
164 |
- OptionalFloat mass = DataModel::sampleDeathAndRebirth(row, col, |
|
165 |
- -1.f * atom->mass(), &mRng); |
|
166 |
- if (mass.hasValue() && mass.value() >= gaps::epsilon) |
|
166 |
+ OptionalFloat gMass = gibbsMass(alpha, 0.f, DataModel::maxGibbsMass(), &mRng, |
|
167 |
+ DataModel::lambda()); |
|
168 |
+ if (gMass.hasValue()) |
|
167 | 169 |
{ |
168 |
- DataModel::safelyChangeMatrix(row, col, mass.value() - atom->mass()); |
|
169 |
- atom->updateMass(mass.value()); |
|
170 |
- return; |
|
170 |
+ rebirthMass = gMass.value(); |
|
171 | 171 |
} |
172 | 172 |
} |
173 |
- |
|
174 |
- // if rebirth fails, then kill off atom |
|
175 |
- DataModel::safelyChangeMatrix(row, col, -1.f * atom->mass()); |
|
176 |
- mDomain.erase(atom); |
|
173 |
+ // handle accept/reject of the rebirth |
|
174 |
+ float deltaLL = rebirthMass * (alpha.s_mu - alpha.s * rebirthMass / 2.f); |
|
175 |
+ if (std::log(mRng.uniform()) < deltaLL) // accept |
|
176 |
+ { |
|
177 |
+ if (rebirthMass != atom->mass()) |
|
178 |
+ { |
|
179 |
+ DataModel::safelyChangeMatrix(row, col, rebirthMass - atom->mass()); |
|
180 |
+ atom->updateMass(rebirthMass); |
|
181 |
+ } |
|
182 |
+ } |
|
183 |
+ else // reject |
|
184 |
+ { |
|
185 |
+ DataModel::safelyChangeMatrix(row, col, -1.f * atom->mass()); |
|
186 |
+ mDomain.erase(atom); |
|
187 |
+ } |
|
177 | 188 |
} |
178 | 189 |
|
179 | 190 |
// move mass from src to dest in the atomic domain |
... | ... |
@@ -74,11 +74,21 @@ uint64_t SparseNormalModel::nPatterns() const |
74 | 74 |
return mMatrix.nCol(); |
75 | 75 |
} |
76 | 76 |
|
77 |
+float SparseNormalModel::annealingTemp() const |
|
78 |
+{ |
|
79 |
+ return mAnnealingTemp; |
|
80 |
+} |
|
81 |
+ |
|
77 | 82 |
float SparseNormalModel::lambda() const |
78 | 83 |
{ |
79 | 84 |
return mLambda; |
80 | 85 |
} |
81 | 86 |
|
87 |
+float SparseNormalModel::maxGibbsMass() const |
|
88 |
+{ |
|
89 |
+ return mMaxGibbsMass; |
|
90 |
+} |
|
91 |
+ |
|
82 | 92 |
bool SparseNormalModel::canUseGibbs(unsigned col) const |
83 | 93 |
{ |
84 | 94 |
return !gaps::isVectorZero(mOtherMatrix->getCol(col)); |
... | ... |
@@ -33,7 +33,9 @@ public: |
33 | 33 |
protected: |
34 | 34 |
uint64_t nElements() const; |
35 | 35 |
uint64_t nPatterns() const; |
36 |
+ float annealingTemp() const; |
|
36 | 37 |
float lambda() const; |
38 |
+ float maxGibbsMass() const; |
|
37 | 39 |
bool canUseGibbs(unsigned col) const; |
38 | 40 |
bool canUseGibbs(unsigned c1, unsigned c2) const; |
39 | 41 |
void changeMatrix(unsigned row, unsigned col, float delta); |