... | ... |
@@ -538,12 +538,24 @@ template <class T, class MatA, class MatB> |
538 | 538 |
OptionalFloat GibbsSampler<T, MatA, MatB>::gibbsMass(AlphaParameters alpha, |
539 | 539 |
GapsRng *rng) |
540 | 540 |
{ |
541 |
- alpha *= mAnnealingTemp; |
|
541 |
+ alpha.s *= mAnnealingTemp; |
|
542 |
+ alpha.su *= mAnnealingTemp; |
|
543 |
+ |
|
542 | 544 |
if (alpha.s > gaps::epsilon) |
543 | 545 |
{ |
544 | 546 |
float mean = (alpha.su - mLambda) / alpha.s; |
545 | 547 |
float sd = 1.f / std::sqrt(alpha.s); |
546 |
- return rng->truncNormal(gaps::epsilon, mMaxGibbsMass / mLambda, mean, sd); |
|
548 |
+ float pLower = gaps::p_norm(0.f, mean, sd); |
|
549 |
+ |
|
550 |
+ if (pLower < 1.f) |
|
551 |
+ { |
|
552 |
+ float m = rng->inverseNormSample(pLower, 1.f, mean, sd); |
|
553 |
+ float gMass = gaps::min(m, mMaxGibbsMass / mLambda); |
|
554 |
+ if (gMass >= gaps::epsilon) |
|
555 |
+ { |
|
556 |
+ return OptionalFloat(gMass); |
|
557 |
+ } |
|
558 |
+ } |
|
547 | 559 |
} |
548 | 560 |
return OptionalFloat(); |
549 | 561 |
} |
... | ... |
@@ -552,12 +564,22 @@ template <class T, class MatA, class MatB> |
552 | 564 |
OptionalFloat GibbsSampler<T, MatA, MatB>::gibbsMass(AlphaParameters alpha, |
553 | 565 |
float m1, float m2, GapsRng *rng) |
554 | 566 |
{ |
555 |
- alpha *= mAnnealingTemp; |
|
567 |
+ alpha.s *= mAnnealingTemp; |
|
568 |
+ alpha.su *= mAnnealingTemp; |
|
569 |
+ |
|
556 | 570 |
if (alpha.s > gaps::epsilon) |
557 | 571 |
{ |
558 | 572 |
float mean = alpha.su / alpha.s; // lambda cancels out |
559 | 573 |
float sd = 1.f / std::sqrt(alpha.s); |
560 |
- return rng->truncNormal(-m1, m2, mean, sd); |
|
574 |
+ float pLower = gaps::p_norm(-m1, mean, sd); |
|
575 |
+ float pUpper = gaps::p_norm(m2, mean, sd); |
|
576 |
+ |
|
577 |
+ if (!(pLower > 0.95f || pUpper < 0.05f)) |
|
578 |
+ { |
|
579 |
+ float delta = rng->inverseNormSample(pLower, pUpper, mean, sd); |
|
580 |
+ float gMass = gaps::min(gaps::max(-m1, delta), m2); // conserve mass |
|
581 |
+ return OptionalFloat(gMass); |
|
582 |
+ } |
|
561 | 583 |
} |
562 | 584 |
return OptionalFloat(); |
563 | 585 |
} |
... | ... |
@@ -594,4 +616,4 @@ Archive& operator>>(Archive &ar, GibbsSampler<T, MatA, MatB> &s) |
594 | 616 |
return ar; |
595 | 617 |
} |
596 | 618 |
|
597 |
-#endif // __COGAPS_GIBBS_SAMPLER_H__ |
|
619 |
+#endif // __COGAPS_GIBBS_SAMPLER_H__ |
|
598 | 620 |
\ No newline at end of file |
... | ... |
@@ -227,23 +227,14 @@ float GapsRng::exponential(float lambda) |
227 | 227 |
return -1.f * std::log(uniform()) / lambda; |
228 | 228 |
} |
229 | 229 |
|
230 |
-OptionalFloat GapsRng::truncNormal(float a, float b, float mean, float sd) |
|
230 |
+float GapsRng::inverseNormSample(float a, float b, float mean, float sd) |
|
231 | 231 |
{ |
232 |
- float pLower = gaps::p_norm(a, mean, sd); |
|
233 |
- float pUpper = gaps::p_norm(b, mean, sd); |
|
234 |
- |
|
235 |
- if (!(pLower > 0.95f || pUpper < 0.05f)) |
|
232 |
+ float u = uniform(a, b); |
|
233 |
+ while (u == 0.f || u == 1.f) |
|
236 | 234 |
{ |
237 |
- float u = uniform(pLower, pUpper); |
|
238 |
- while (u == 0.f || u == 1.f) |
|
239 |
- { |
|
240 |
- u = uniform(pLower, pUpper); |
|
241 |
- } |
|
242 |
- float ret = gaps::q_norm(u, mean, sd); |
|
243 |
- ret = gaps::min(gaps::max(a, ret), b); // need this for truncation error |
|
244 |
- return OptionalFloat(ret); |
|
235 |
+ u = uniform(a, b); |
|
245 | 236 |
} |
246 |
- OptionalFloat(); |
|
237 |
+ return gaps::q_norm(u, mean, sd); |
|
247 | 238 |
} |
248 | 239 |
|
249 | 240 |
float GapsRng::truncGammaUpper(float b, float shape, float scale) |
... | ... |
@@ -319,4 +310,4 @@ float gaps::d_norm_fast(float d, float mean, float sd) |
319 | 310 |
{ |
320 | 311 |
return std::exp((d - mean) * (d - mean) / (-2.f * sd * sd)) |
321 | 312 |
/ std::sqrt(2.f * gaps::pi * sd * sd); |
322 |
-} |
|
313 |
+} |
|
323 | 314 |
\ No newline at end of file |
... | ... |
@@ -72,7 +72,7 @@ public: |
72 | 72 |
int poisson(double lambda); |
73 | 73 |
float exponential(float lambda); |
74 | 74 |
|
75 |
- OptionalFloat truncNormal(float a, float b, float mean, float sd); |
|
75 |
+ float inverseNormSample(float a, float b, float mean, float sd); |
|
76 | 76 |
float truncGammaUpper(float b, float shape, float scale); |
77 | 77 |
|
78 | 78 |
static void setSeed(uint64_t seed); |
... | ... |
@@ -95,4 +95,4 @@ private: |
95 | 95 |
friend Archive& operator>>(Archive &ar, GapsRng &gen); |
96 | 96 |
}; |
97 | 97 |
|
98 |
-#endif // __COGAPS_RANDOM_H__ |
|
98 |
+#endif // __COGAPS_RANDOM_H__ |
|
99 | 99 |
\ No newline at end of file |