Browse code

cleaned up gibbs mass calculation for B/D

Tom Sherman authored on 22/05/2018 12:34:05
Showing 1 changed files

... ...
@@ -420,30 +420,20 @@ float GibbsSampler<T, MatA, MatB>::gibbsMass(unsigned row, unsigned col, float m
420 420
     AlphaParameters alpha = impl()->alphaParameters(row, col);
421 421
     alpha.s *= mAnnealingTemp;
422 422
     alpha.su *= mAnnealingTemp;
423
-    float mean = (alpha.su - mLambda) / alpha.s;
424
-    float sd = 1.f / std::sqrt(alpha.s);
425
-    float pLower = gaps::random::p_norm(0.f, mean, sd);
426 423
 
427
-    float newMass = 0.f;
428
-    if (pLower == 1.f || alpha.s < 0.00001f)
429
-    {
430
-        newMass = mass < 0.f ? std::abs(mass) : 0.f;
431
-    }
432
-    else if (pLower >= 0.99f) // what's the point of this? TODO
424
+    if (alpha.s > gaps::algo::epsilon)
433 425
     {
434
-        float tmp1 = gaps::random::d_norm(0.f, mean, sd);
435
-        float tmp2 = gaps::random::d_norm(10.f * mLambda, mean, sd);
426
+        float mean = (alpha.su - mLambda) / alpha.s;
427
+        float sd = 1.f / std::sqrt(alpha.s);
428
+        float pLower = gaps::random::p_norm(0.f, mean, sd);
436 429
 
437
-        if (tmp1 > gaps::algo::epsilon && std::abs(tmp1 - tmp2) < gaps::algo::epsilon)
430
+        if (pLower < 1.f)
438 431
         {
439
-            return mass < 0.f ? 0.0 : mass;
432
+            float m = gaps::random::inverseNormSample(pLower, 1.f, mean, sd);
433
+            return std::max(std::min(m, mMaxGibbsMass), 0.f);
440 434
         }
441 435
     }
442
-    else
443
-    {
444
-        newMass = gaps::random::inverseNormSample(pLower, 1.f, mean, sd);
445
-    }
446
-    return std::max(std::min(newMass, mMaxGibbsMass), 0.f);
436
+    return std::min(mass < 0.f ? std::abs(mass) : 0.f, mMaxGibbsMass);
447 437
 }
448 438
 
449 439
 template <class T, class MatA, class MatB>
... ...
@@ -456,7 +446,7 @@ unsigned r2, unsigned c2, float m2)
456 446
 
457 447
     if (alpha.s > gaps::algo::epsilon)
458 448
     {
459
-        float mean = alpha.su / alpha.s;
449
+        float mean = alpha.su / alpha.s; // TODO why not subtract lambda
460 450
         float sd = 1.f / std::sqrt(alpha.s);
461 451
         float pLower = gaps::random::p_norm(-m1, mean, sd);
462 452
         float pUpper = gaps::random::p_norm(m2, mean, sd);