Browse code

Fix bug with death-rebirth

When re-factoring the logic involved in a Death update step, so that
some of it was handled by the DataModel instead of the Sampler, a bug was
introduced. This bug resulted in the rebirth attempt of the Death update
step being rejected if the gibbs sampling step failed, rather than
attempting the rebirth with the original atom's mass. This commit pulls
out the logic and implements the correct Death step entirely in the
Sampler.

sherman5 authored on 29/06/2020 08:34:34
Showing 6 changed files

... ...
@@ -147,23 +147,36 @@ void AsynchronousGibbsSampler<DataModel>::birth(const AtomicProposal &prop)
147 147
 template <class DataModel>
148 148
 void AsynchronousGibbsSampler<DataModel>::death(const AtomicProposal &prop)
149 149
 {
150
-    // try to do a rebirth in the place of this atom
150
+    // determine mass to attempt rebirth with
151
+    float rebirthMass = prop.atom1->mass(); // default rebirth mass == no change to atom
152
+    AlphaParameters alpha = DataModel::alphaParametersWithChange(prop.r1, prop.c1,
153
+        -1.f * prop.atom1->mass()) * DataModel::annealingTemp();
151 154
     if (DataModel::canUseGibbs(prop.c1))
152 155
     {
153
-        OptionalFloat mass = DataModel::sampleDeathAndRebirth(prop.r1, prop.c1,
154
-            -1.f * prop.atom1->mass(), &(prop.rng));
155
-        if (mass.hasValue() && mass.value() >= gaps::epsilon)
156
+        OptionalFloat gMass = gibbsMass(alpha, 0.f, DataModel::maxGibbsMass(), &(prop.rng),
157
+            DataModel::lambda());
158
+        if (gMass.hasValue())
156 159
         {
157
-            mQueue.rejectDeath();
158
-            DataModel::safelyChangeMatrix(prop.r1, prop.c1, mass.value() - prop.atom1->mass());
159
-            prop.atom1->updateMass(mass.value());
160
-            return;
160
+            rebirthMass = gMass.value();
161 161
         }
162 162
     }
163
-    // if rebirth fails, then kill off atom
164
-    mQueue.acceptDeath();
165
-    DataModel::safelyChangeMatrix(prop.r1, prop.c1, -1.f * prop.atom1->mass());
166
-    mDomain.cacheErase(prop.atom1);
163
+    // handle accept/reject of the rebirth
164
+    float deltaLL = rebirthMass * (alpha.s_mu - alpha.s * rebirthMass / 2.f);
165
+    if (std::log(prop.rng.uniform()) < deltaLL) // accept
166
+    {
167
+        mQueue.rejectDeath();
168
+        if (rebirthMass != prop.atom1->mass())
169
+        {
170
+            DataModel::safelyChangeMatrix(prop.r1, prop.c1, rebirthMass - prop.atom1->mass());
171
+            prop.atom1->updateMass(rebirthMass);
172
+        }
173
+    }
174
+    else // reject
175
+    {
176
+        mQueue.acceptDeath();
177
+        DataModel::safelyChangeMatrix(prop.r1, prop.c1, -1.f * prop.atom1->mass());
178
+        mDomain.cacheErase(prop.atom1);
179
+    }
167 180
 }
168 181
 
169 182
 // move mass from src to dest in the atomic domain
... ...
@@ -82,11 +82,21 @@ uint64_t DenseNormalModel::nPatterns() const
82 82
     return mMatrix.nCol();
83 83
 }
84 84
 
85
+float DenseNormalModel::annealingTemp() const
86
+{
87
+    return mAnnealingTemp;
88
+}
89
+
85 90
 float DenseNormalModel::lambda() const
86 91
 {
87 92
     return mLambda;
88 93
 }
89 94
 
95
+float DenseNormalModel::maxGibbsMass() const
96
+{
97
+    return mMaxGibbsMass;
98
+}
99
+
90 100
 bool DenseNormalModel::canUseGibbs(unsigned col) const
91 101
 {
92 102
     return !gaps::isVectorZero(mOtherMatrix->getCol(col));
... ...
@@ -33,7 +33,9 @@ protected:
33 33
     friend class GapsStatistics;
34 34
     uint64_t nElements() const;
35 35
     uint64_t nPatterns() const;
36
+    float annealingTemp() const;
36 37
     float lambda() const;
38
+    float maxGibbsMass() const;
37 39
     bool canUseGibbs(unsigned col) const;
38 40
     bool canUseGibbs(unsigned c1, unsigned c2) const;
39 41
     void changeMatrix(unsigned row, unsigned col, float delta);
... ...
@@ -157,23 +157,34 @@ void SingleThreadedGibbsSampler<DataModel>::death()
157 157
     AtomType *atom = mDomain.randomAtom(&mRng);
158 158
     unsigned row = (atom->pos() / mBinLength) / mNumPatterns;
159 159
     unsigned col = (atom->pos() / mBinLength) % mNumPatterns;
160
-
161
-    // try to do a rebirth in the place of this atom
160
+    // determine mass to attempt rebirth with
161
+    float rebirthMass = atom->mass(); // default rebirth mass == no change to atom
162
+    AlphaParameters alpha = DataModel::alphaParametersWithChange(row, col, -1.f * atom->mass())
163
+        * DataModel::annealingTemp();
162 164
     if (DataModel::canUseGibbs(col))
163 165
     {
164
-        OptionalFloat mass = DataModel::sampleDeathAndRebirth(row, col,
165
-            -1.f * atom->mass(), &mRng);
166
-        if (mass.hasValue() && mass.value() >= gaps::epsilon)
166
+        OptionalFloat gMass = gibbsMass(alpha, 0.f, DataModel::maxGibbsMass(), &mRng,
167
+            DataModel::lambda());
168
+        if (gMass.hasValue())
167 169
         {
168
-            DataModel::safelyChangeMatrix(row, col, mass.value() - atom->mass());
169
-            atom->updateMass(mass.value());
170
-            return;
170
+            rebirthMass = gMass.value();
171 171
         }
172 172
     }
173
-
174
-    // if rebirth fails, then kill off atom
175
-    DataModel::safelyChangeMatrix(row, col, -1.f * atom->mass());
176
-    mDomain.erase(atom);
173
+    // handle accept/reject of the rebirth
174
+    float deltaLL = rebirthMass * (alpha.s_mu - alpha.s * rebirthMass / 2.f);
175
+    if (std::log(mRng.uniform()) < deltaLL) // accept
176
+    {
177
+        if (rebirthMass != atom->mass())
178
+        {
179
+            DataModel::safelyChangeMatrix(row, col, rebirthMass - atom->mass());
180
+            atom->updateMass(rebirthMass);
181
+        }
182
+    }
183
+    else // reject
184
+    {
185
+        DataModel::safelyChangeMatrix(row, col, -1.f * atom->mass());
186
+        mDomain.erase(atom);
187
+    }
177 188
 }
178 189
 
179 190
 // move mass from src to dest in the atomic domain
... ...
@@ -74,11 +74,21 @@ uint64_t SparseNormalModel::nPatterns() const
74 74
     return mMatrix.nCol();
75 75
 }
76 76
 
77
+float SparseNormalModel::annealingTemp() const
78
+{
79
+    return mAnnealingTemp;
80
+}
81
+
77 82
 float SparseNormalModel::lambda() const
78 83
 {
79 84
     return mLambda;
80 85
 }
81 86
 
87
+float SparseNormalModel::maxGibbsMass() const
88
+{
89
+    return mMaxGibbsMass;
90
+}
91
+
82 92
 bool SparseNormalModel::canUseGibbs(unsigned col) const
83 93
 {
84 94
     return !gaps::isVectorZero(mOtherMatrix->getCol(col));
... ...
@@ -33,7 +33,9 @@ public:
33 33
 protected:
34 34
     uint64_t nElements() const;
35 35
     uint64_t nPatterns() const;
36
+    float annealingTemp() const;
36 37
     float lambda() const;
38
+    float maxGibbsMass() const;
37 39
     bool canUseGibbs(unsigned col) const;
38 40
     bool canUseGibbs(unsigned c1, unsigned c2) const;
39 41
     void changeMatrix(unsigned row, unsigned col, float delta);