Browse code

rolled back some changes

Tom Sherman authored on 24/08/2018 04:36:28
Showing3 changed files

... ...
@@ -538,12 +538,24 @@ template <class T, class MatA, class MatB>
538 538
 OptionalFloat GibbsSampler<T, MatA, MatB>::gibbsMass(AlphaParameters alpha,
539 539
 GapsRng *rng)
540 540
 {        
541
-    alpha *= mAnnealingTemp;
541
+    alpha.s *= mAnnealingTemp;
542
+    alpha.su *= mAnnealingTemp;
543
+
542 544
     if (alpha.s > gaps::epsilon)
543 545
     {
544 546
         float mean = (alpha.su - mLambda) / alpha.s;
545 547
         float sd = 1.f / std::sqrt(alpha.s);
546
-        return rng->truncNormal(gaps::epsilon, mMaxGibbsMass / mLambda, mean, sd);
548
+        float pLower = gaps::p_norm(0.f, mean, sd);
549
+
550
+        if (pLower < 1.f)
551
+        {
552
+            float m = rng->inverseNormSample(pLower, 1.f, mean, sd);
553
+            float gMass = gaps::min(m, mMaxGibbsMass / mLambda);
554
+            if (gMass >= gaps::epsilon)
555
+            {
556
+                return OptionalFloat(gMass);
557
+            }
558
+        }
547 559
     }
548 560
     return OptionalFloat();
549 561
 }
... ...
@@ -552,12 +564,22 @@ template <class T, class MatA, class MatB>
552 564
 OptionalFloat GibbsSampler<T, MatA, MatB>::gibbsMass(AlphaParameters alpha,
553 565
 float m1, float m2, GapsRng *rng)
554 566
 {
555
-    alpha *= mAnnealingTemp;
567
+    alpha.s *= mAnnealingTemp;
568
+    alpha.su *= mAnnealingTemp;
569
+
556 570
     if (alpha.s > gaps::epsilon)
557 571
     {
558 572
         float mean = alpha.su / alpha.s; // lambda cancels out
559 573
         float sd = 1.f / std::sqrt(alpha.s);
560
-        return rng->truncNormal(-m1, m2, mean, sd);
574
+        float pLower = gaps::p_norm(-m1, mean, sd);
575
+        float pUpper = gaps::p_norm(m2, mean, sd);
576
+
577
+        if (!(pLower >  0.95f || pUpper < 0.05f))
578
+        {
579
+            float delta = rng->inverseNormSample(pLower, pUpper, mean, sd);
580
+            float gMass = gaps::min(gaps::max(-m1, delta), m2); // conserve mass
581
+            return OptionalFloat(gMass);
582
+        }
561 583
     }
562 584
     return OptionalFloat();
563 585
 }
... ...
@@ -594,4 +616,4 @@ Archive& operator>>(Archive &ar, GibbsSampler<T, MatA, MatB> &s)
594 616
     return ar;
595 617
 }
596 618
 
597
-#endif // __COGAPS_GIBBS_SAMPLER_H__
619
+#endif // __COGAPS_GIBBS_SAMPLER_H__
598 620
\ No newline at end of file
... ...
@@ -227,23 +227,14 @@ float GapsRng::exponential(float lambda)
227 227
     return -1.f * std::log(uniform()) / lambda;
228 228
 }
229 229
 
230
-OptionalFloat GapsRng::truncNormal(float a, float b, float mean, float sd)
230
+float GapsRng::inverseNormSample(float a, float b, float mean, float sd)
231 231
 {
232
-    float pLower = gaps::p_norm(a, mean, sd);
233
-    float pUpper = gaps::p_norm(b, mean, sd);
234
-
235
-    if (!(pLower >  0.95f || pUpper < 0.05f))
232
+    float u = uniform(a, b);
233
+    while (u == 0.f || u == 1.f)
236 234
     {
237
-        float u = uniform(pLower, pUpper);
238
-        while (u == 0.f || u == 1.f)
239
-        {
240
-            u = uniform(pLower, pUpper);
241
-        }
242
-        float ret = gaps::q_norm(u, mean, sd);
243
-        ret = gaps::min(gaps::max(a, ret), b); // need this for truncation error
244
-        return OptionalFloat(ret);
235
+        u = uniform(a, b);
245 236
     }
246
-    OptionalFloat();
237
+    return gaps::q_norm(u, mean, sd);
247 238
 }
248 239
 
249 240
 float GapsRng::truncGammaUpper(float b, float shape, float scale)
... ...
@@ -319,4 +310,4 @@ float gaps::d_norm_fast(float d, float mean, float sd)
319 310
 {
320 311
     return std::exp((d - mean) * (d - mean) / (-2.f * sd * sd))
321 312
         / std::sqrt(2.f * gaps::pi * sd * sd);
322
-}
313
+}
323 314
\ No newline at end of file
... ...
@@ -72,7 +72,7 @@ public:
72 72
     int poisson(double lambda);
73 73
     float exponential(float lambda);
74 74
 
75
-    OptionalFloat truncNormal(float a, float b, float mean, float sd);
75
+    float inverseNormSample(float a, float b, float mean, float sd);
76 76
     float truncGammaUpper(float b, float shape, float scale);
77 77
 
78 78
     static void setSeed(uint64_t seed);
... ...
@@ -95,4 +95,4 @@ private:
95 95
     friend Archive& operator>>(Archive &ar, GapsRng &gen);
96 96
 };
97 97
 
98
-#endif // __COGAPS_RANDOM_H__
98
+#endif // __COGAPS_RANDOM_H__
99 99
\ No newline at end of file