Browse code

bundle truncated normal in it's own function

Tom Sherman authored on 23/08/2018 23:01:18
Showing 4 changed files

... ...
@@ -538,24 +538,12 @@ template <class T, class MatA, class MatB>
538 538
 OptionalFloat GibbsSampler<T, MatA, MatB>::gibbsMass(AlphaParameters alpha,
539 539
 GapsRng *rng)
540 540
 {        
541
-    alpha.s *= mAnnealingTemp;
542
-    alpha.su *= mAnnealingTemp;
543
-
541
+    alpha *= mAnnealingTemp;
544 542
     if (alpha.s > gaps::epsilon)
545 543
     {
546 544
         float mean = (alpha.su - mLambda) / alpha.s;
547 545
         float sd = 1.f / std::sqrt(alpha.s);
548
-        float pLower = gaps::p_norm(0.f, mean, sd);
549
-
550
-        if (pLower < 1.f)
551
-        {
552
-            float m = rng->inverseNormSample(pLower, 1.f, mean, sd);
553
-            float gMass = gaps::min(m, mMaxGibbsMass / mLambda);
554
-            if (gMass >= gaps::epsilon)
555
-            {
556
-                return OptionalFloat(gMass);
557
-            }
558
-        }
546
+        return rng->truncNormal(gaps::epsilon, mMaxGibbsMass / mLambda, mean, sd);
559 547
     }
560 548
     return OptionalFloat();
561 549
 }
... ...
@@ -564,22 +552,12 @@ template <class T, class MatA, class MatB>
564 552
 OptionalFloat GibbsSampler<T, MatA, MatB>::gibbsMass(AlphaParameters alpha,
565 553
 float m1, float m2, GapsRng *rng)
566 554
 {
567
-    alpha.s *= mAnnealingTemp;
568
-    alpha.su *= mAnnealingTemp;
569
-
555
+    alpha *= mAnnealingTemp;
570 556
     if (alpha.s > gaps::epsilon)
571 557
     {
572 558
         float mean = alpha.su / alpha.s; // lambda cancels out
573 559
         float sd = 1.f / std::sqrt(alpha.s);
574
-        float pLower = gaps::p_norm(-m1, mean, sd);
575
-        float pUpper = gaps::p_norm(m2, mean, sd);
576
-
577
-        if (!(pLower >  0.95f || pUpper < 0.05f))
578
-        {
579
-            float delta = rng->inverseNormSample(pLower, pUpper, mean, sd);
580
-            float gMass = gaps::min(gaps::max(-m1, delta), m2); // conserve mass
581
-            return OptionalFloat(gMass);
582
-        }
560
+        return rng->truncNormal(-m1, m2, mean, sd);
583 561
     }
584 562
     return OptionalFloat();
585 563
 }
... ...
@@ -19,9 +19,15 @@ struct AlphaParameters
19 19
     AlphaParameters operator+(const AlphaParameters &other) const
20 20
     {
21 21
         float rs = s + other.s;
22
-        float rsu = su - other.su; // weird
22
+        float rsu = su - other.su; // not a typo
23 23
         return AlphaParameters(rs, rsu);
24 24
     }
25
+
26
+    void operator*=(float v)
27
+    {
28
+        s *= v;
29
+        su *= v;
30
+    }
25 31
 };
26 32
 
27 33
 namespace gaps
... ...
@@ -227,14 +227,24 @@ float GapsRng::exponential(float lambda)
227 227
     return -1.f * std::log(uniform()) / lambda;
228 228
 }
229 229
 
230
-float GapsRng::inverseNormSample(float a, float b, float mean, float sd)
230
+OptionalFloat GapsRng::truncNormal(float a, float b, float mean, float sd)
231 231
 {
232
-    float u = uniform(a, b);
233
-    while (u == 0.f || u == 1.f)
232
+    float pLower = gaps::p_norm(a, mean, sd);
233
+    float pUpper = gaps::p_norm(b, mean, sd);
234
+
235
+    if (!(pLower >  0.95f || pUpper < 0.05f))
234 236
     {
235
-        u = uniform(a, b);
237
+        float u = uniform(pLower, pUpper);
238
+        while (u == 0.f || u == 1.f)
239
+        {
240
+            u = uniform(pLower, pUpper);
241
+        }
242
+        float ret = gaps::q_norm(u, mean, sd);
243
+        GAPS_ASSERT(ret >= a);
244
+        GAPS_ASSERT(ret <= b);
245
+        return OptionalFloat(ret);
236 246
     }
237
-    return gaps::q_norm(u, mean, sd);
247
+    OptionalFloat();
238 248
 }
239 249
 
240 250
 float GapsRng::truncGammaUpper(float b, float shape, float scale)
... ...
@@ -72,7 +72,7 @@ public:
72 72
     int poisson(double lambda);
73 73
     float exponential(float lambda);
74 74
 
75
-    float inverseNormSample(float a, float b, float mean, float sd);
75
+    OptionalFloat truncNormal(float a, float b, float mean, float sd);
76 76
     float truncGammaUpper(float b, float shape, float scale);
77 77
 
78 78
     static void setSeed(uint64_t seed);