... | ... |
@@ -538,24 +538,12 @@ template <class T, class MatA, class MatB> |
538 | 538 |
OptionalFloat GibbsSampler<T, MatA, MatB>::gibbsMass(AlphaParameters alpha, |
539 | 539 |
GapsRng *rng) |
540 | 540 |
{ |
541 |
- alpha.s *= mAnnealingTemp; |
|
542 |
- alpha.su *= mAnnealingTemp; |
|
543 |
- |
|
541 |
+ alpha *= mAnnealingTemp; |
|
544 | 542 |
if (alpha.s > gaps::epsilon) |
545 | 543 |
{ |
546 | 544 |
float mean = (alpha.su - mLambda) / alpha.s; |
547 | 545 |
float sd = 1.f / std::sqrt(alpha.s); |
548 |
- float pLower = gaps::p_norm(0.f, mean, sd); |
|
549 |
- |
|
550 |
- if (pLower < 1.f) |
|
551 |
- { |
|
552 |
- float m = rng->inverseNormSample(pLower, 1.f, mean, sd); |
|
553 |
- float gMass = gaps::min(m, mMaxGibbsMass / mLambda); |
|
554 |
- if (gMass >= gaps::epsilon) |
|
555 |
- { |
|
556 |
- return OptionalFloat(gMass); |
|
557 |
- } |
|
558 |
- } |
|
546 |
+ return rng->truncNormal(gaps::epsilon, mMaxGibbsMass / mLambda, mean, sd); |
|
559 | 547 |
} |
560 | 548 |
return OptionalFloat(); |
561 | 549 |
} |
... | ... |
@@ -564,22 +552,12 @@ template <class T, class MatA, class MatB> |
564 | 552 |
OptionalFloat GibbsSampler<T, MatA, MatB>::gibbsMass(AlphaParameters alpha, |
565 | 553 |
float m1, float m2, GapsRng *rng) |
566 | 554 |
{ |
567 |
- alpha.s *= mAnnealingTemp; |
|
568 |
- alpha.su *= mAnnealingTemp; |
|
569 |
- |
|
555 |
+ alpha *= mAnnealingTemp; |
|
570 | 556 |
if (alpha.s > gaps::epsilon) |
571 | 557 |
{ |
572 | 558 |
float mean = alpha.su / alpha.s; // lambda cancels out |
573 | 559 |
float sd = 1.f / std::sqrt(alpha.s); |
574 |
- float pLower = gaps::p_norm(-m1, mean, sd); |
|
575 |
- float pUpper = gaps::p_norm(m2, mean, sd); |
|
576 |
- |
|
577 |
- if (!(pLower > 0.95f || pUpper < 0.05f)) |
|
578 |
- { |
|
579 |
- float delta = rng->inverseNormSample(pLower, pUpper, mean, sd); |
|
580 |
- float gMass = gaps::min(gaps::max(-m1, delta), m2); // conserve mass |
|
581 |
- return OptionalFloat(gMass); |
|
582 |
- } |
|
560 |
+ return rng->truncNormal(-m1, m2, mean, sd); |
|
583 | 561 |
} |
584 | 562 |
return OptionalFloat(); |
585 | 563 |
} |
... | ... |
@@ -19,9 +19,15 @@ struct AlphaParameters |
19 | 19 |
AlphaParameters operator+(const AlphaParameters &other) const |
20 | 20 |
{ |
21 | 21 |
float rs = s + other.s; |
22 |
- float rsu = su - other.su; // weird |
|
22 |
+ float rsu = su - other.su; // not a typo |
|
23 | 23 |
return AlphaParameters(rs, rsu); |
24 | 24 |
} |
25 |
+ |
|
26 |
+ void operator*=(float v) |
|
27 |
+ { |
|
28 |
+ s *= v; |
|
29 |
+ su *= v; |
|
30 |
+ } |
|
25 | 31 |
}; |
26 | 32 |
|
27 | 33 |
namespace gaps |
... | ... |
@@ -227,14 +227,24 @@ float GapsRng::exponential(float lambda) |
227 | 227 |
return -1.f * std::log(uniform()) / lambda; |
228 | 228 |
} |
229 | 229 |
|
230 |
-float GapsRng::inverseNormSample(float a, float b, float mean, float sd) |
|
230 |
+OptionalFloat GapsRng::truncNormal(float a, float b, float mean, float sd) |
|
231 | 231 |
{ |
232 |
- float u = uniform(a, b); |
|
233 |
- while (u == 0.f || u == 1.f) |
|
232 |
+ float pLower = gaps::p_norm(a, mean, sd); |
|
233 |
+ float pUpper = gaps::p_norm(b, mean, sd); |
|
234 |
+ |
|
235 |
+ if (!(pLower > 0.95f || pUpper < 0.05f)) |
|
234 | 236 |
{ |
235 |
- u = uniform(a, b); |
|
237 |
+ float u = uniform(pLower, pUpper); |
|
238 |
+ while (u == 0.f || u == 1.f) |
|
239 |
+ { |
|
240 |
+ u = uniform(pLower, pUpper); |
|
241 |
+ } |
|
242 |
+ float ret = gaps::q_norm(u, mean, sd); |
|
243 |
+ GAPS_ASSERT(ret >= a); |
|
244 |
+ GAPS_ASSERT(ret <= b); |
|
245 |
+ return OptionalFloat(ret); |
|
236 | 246 |
} |
237 |
- return gaps::q_norm(u, mean, sd); |
|
247 |
+ OptionalFloat(); |
|
238 | 248 |
} |
239 | 249 |
|
240 | 250 |
float GapsRng::truncGammaUpper(float b, float shape, float scale) |
... | ... |
@@ -72,7 +72,7 @@ public: |
72 | 72 |
int poisson(double lambda); |
73 | 73 |
float exponential(float lambda); |
74 | 74 |
|
75 |
- float inverseNormSample(float a, float b, float mean, float sd); |
|
75 |
+ OptionalFloat truncNormal(float a, float b, float mean, float sd); |
|
76 | 76 |
float truncGammaUpper(float b, float shape, float scale); |
77 | 77 |
|
78 | 78 |
static void setSeed(uint64_t seed); |