Browse code

no longer crashing with checkpoints; still not consistent

Tom Sherman authored on 01/10/2018 17:41:04
Showing 14 changed files

... ...
@@ -1,5 +1,5 @@
1 1
 Package: CoGAPS
2
-Version: 3.3.32
2
+Version: 3.3.33
3 3
 Date: 2018-04-24
4 4
 Title: Coordinated Gene Activity in Pattern Sets
5 5
 Author: Thomas Sherman, Wai-shing Lee, Conor Kelton, Ondrej Maxian, Jacob Carey,
... ...
@@ -76,31 +76,34 @@ void IndexFlagSet::release(uint64_t n)
76 76
 
77 77
 AtomPool::AtomPool()
78 78
 {
79
-    mPool = new Atom[POOL_SIZE];
79
+    //mPool = new Atom[POOL_SIZE];
80 80
 }
81 81
 
82 82
 AtomPool::~AtomPool()
83 83
 {
84
-    delete[] mPool;
84
+    //delete[] mPool;
85 85
 }
86 86
 
87 87
 Atom* AtomPool::alloc()
88 88
 {
89
-    unsigned n = mIndexFlags.getFirstFree();
90
-    mIndexFlags.set(n);
91
-    Atom *a = &(mPool[n]);
92
-    a->poolIndex = n;
93
-    return a;
89
+    //unsigned n = mIndexFlags.getFirstFree();
90
+    //mIndexFlags.set(n);
91
+    //Atom *a = &(mPool[n]);
92
+    //a->poolIndex = n;
93
+    //return a;
94
+    return new Atom();
94 95
 }
95 96
 
96 97
 void AtomPool::free(Atom* a)
97 98
 {
98
-    mIndexFlags.release(a->poolIndex);
99
+    //mIndexFlags.release(a->poolIndex);
100
+    delete a;
99 101
 }
100 102
 
101 103
 bool AtomPool::depleted() const
102 104
 {
103
-    return !mIndexFlags.isAnyFree();
105
+    //return !mIndexFlags.isAnyFree();
106
+    return false;
104 107
 }
105 108
 
106 109
 ////////////////////////////// AtomAllocator ///////////////////////////////////
... ...
@@ -30,6 +30,7 @@ public:
30 30
 
31 31
     AtomicDomain(uint64_t nBins);
32 32
 
33
+    // TODO can we have internal rng since these are always called sequentially
33 34
     // access atoms
34 35
     Atom* front();
35 36
     Atom* randomAtom(GapsRng *rng);
... ...
@@ -242,6 +242,7 @@ void GapsRunner::createCheckpoint()
242 242
     
243 243
         // create checkpoint file
244 244
         Archive ar(mCheckpointOutFile, ARCHIVE_WRITE);
245
+        GapsRng::save(ar);
245 246
         ar << mNumPatterns << mSeed << mASampler << mPSampler << mStatistics
246 247
             << mFixedMatrix << mMaxIterations << mPhase << mCurrentIteration
247 248
             << mNumUpdatesA << mNumUpdatesP << mRng;
... ...
@@ -84,7 +84,7 @@ template <class DataType>
84 84
 GapsRunner::GapsRunner(const DataType &data, bool transposeData,
85 85
 unsigned nPatterns, bool partitionRows, const std::vector<unsigned> &indices)
86 86
     :
87
-mASampler(data, !transposeData, nPatterns,!partitionRows, indices),
87
+mASampler(data, !transposeData, nPatterns, !partitionRows, indices),
88 88
 mPSampler(data, transposeData, nPatterns, partitionRows, indices),
89 89
 mStatistics(mPSampler.dataRows(), mPSampler.dataCols(), nPatterns),
90 90
 mFixedMatrix('N'), mMaxIterations(1000), mMaxThreads(1), mPrintMessages(true),
... ...
@@ -120,24 +120,18 @@ void GibbsSampler::processProposal(const AtomicProposal &prop)
120 120
 // exponential distribution or with the gibbs mass distribution
121 121
 void GibbsSampler::birth(const AtomicProposal &prop)
122 122
 {
123
-    // calculate proposed mass
124
-    float mass = 0.f;
125
-    if (canUseGibbs(prop.c1))
126
-    {
127
-        mass = gibbsMass(alphaParameters(prop.r1, prop.c1) * mAnnealingTemp,
128
-            0.f, mMaxGibbsMass, mLambda, &(prop.rng)).value();
129
-    }
130
-    else
131
-    {
132
-        mass = prop.rng.exponential(mLambda);
133
-    }
134
-
135
-    // accept mass as long as it's non-zero
136
-    if (mass >= gaps::epsilon)
123
+    // try to get mass using gibbs, resort to exponential if needed
124
+    OptionalFloat mass = canUseGibbs(prop.c1)
125
+        ? gibbsMass(alphaParameters(prop.r1, prop.c1) * mAnnealingTemp, 0.f,
126
+            mMaxGibbsMass, mLambda, &(prop.rng))
127
+        : prop.rng.exponential(mLambda);
128
+
129
+    // accept mass as long as gibbs succeded or it's non-zero
130
+    if (mass.hasValue() && mass.value() >= gaps::epsilon)
137 131
     {
138 132
         mQueue.acceptBirth();
139
-        prop.atom1->mass = mass;
140
-        changeMatrix(prop.r1, prop.c1, mass);
133
+        prop.atom1->mass = mass.value();
134
+        changeMatrix(prop.r1, prop.c1, mass.value());
141 135
     }
142 136
     else
143 137
     {
... ...
@@ -154,7 +148,6 @@ void GibbsSampler::death(const AtomicProposal &prop)
154 148
     AlphaParameters alpha = alphaParametersWithChange(prop.r1, prop.c1,
155 149
         -prop.atom1->mass);
156 150
 
157
-    // try to calculate rebirth mass using gibbs distribution, otherwise use original
158 151
     float rebirthMass = prop.atom1->mass;
159 152
     if (canUseGibbs(prop.c1))
160 153
     {
... ...
@@ -223,7 +216,7 @@ AlphaParameters alpha)
223 216
 {
224 217
     // compute amount of mass to be exchanged
225 218
     float totalMass = prop.atom1->mass + prop.atom2->mass;
226
-    float newMass = prop.rng.truncGammaUpper(totalMass, 2.f, 1.f / mLambda);
219
+    float newMass = prop.rng.truncGammaUpper(totalMass, 1.f / mLambda);
227 220
 
228 221
     // compute amount to change atom1 by - always change larger mass to newMass
229 222
     float delta = (prop.atom1->mass > prop.atom2->mass)
... ...
@@ -266,8 +259,9 @@ void GibbsSampler::acceptExchange(const AtomicProposal &prop, float delta)
266 259
 void GibbsSampler::changeMatrix(unsigned row, unsigned col, float delta)
267 260
 {
268 261
     mMatrix(row, col) += delta;
269
-    GAPS_ASSERT(mMatrix(row, col) >= 0.f);
270 262
     updateAPMatrix(row, col, delta);
263
+
264
+    GAPS_ASSERT(mMatrix(row, col) >= 0.f);
271 265
 }
272 266
 
273 267
 // delta could be negative, this is needed to prevent negative values in matrix
... ...
@@ -340,13 +334,17 @@ unsigned col, float ch)
340 334
 
341 335
 Archive& operator<<(Archive &ar, GibbsSampler &s)
342 336
 {
343
-    // TODO
337
+    ar << s.mMatrix << s.mDomain << s.mQueue << s.mAlpha
338
+        << s.mLambda << s.mMaxGibbsMass << s.mAnnealingTemp << s.mNumPatterns
339
+        << s.mNumBins << s.mBinLength << s.mDomainLength;
344 340
     return ar;
345 341
 }
346 342
 
347 343
 Archive& operator>>(Archive &ar, GibbsSampler &s)
348 344
 {
349
-    // TODO
345
+    ar >> s.mMatrix >> s.mDomain >> s.mQueue >> s.mAlpha
346
+        >> s.mLambda >> s.mMaxGibbsMass >> s.mAnnealingTemp >> s.mNumPatterns
347
+        >> s.mNumBins >> s.mBinLength >> s.mDomainLength;
350 348
     return ar;
351 349
 }
352 350
 
... ...
@@ -1,6 +1,9 @@
1 1
 #ifndef __COGAPS_GIBBS_SAMPLER_H__
2 2
 #define __COGAPS_GIBBS_SAMPLER_H__
3 3
 
4
+#define DEFAULT_ALPHA           0.01f
5
+#define DEFAULT_MAX_GIBBS_MASS  100.f
6
+
4 7
 #include "AtomicDomain.h"
5 8
 #include "ProposalQueue.h"
6 9
 #include "data_structures/Matrix.h"
... ...
@@ -112,7 +115,7 @@ mBinLength(std::numeric_limits<uint64_t>::max() / mNumBins),
112 115
 mDomainLength(mBinLength * mNumBins)
113 116
 {
114 117
     // default sparsity parameters
115
-    setSparsity(0.01, 100.f, false);
118
+    setSparsity(DEFAULT_ALPHA, DEFAULT_MAX_GIBBS_MASS, false);
116 119
 }
117 120
 
118 121
 template <class DataType>
... ...
@@ -251,14 +251,14 @@ bool ProposalQueue::exchange(AtomicDomain &domain)
251 251
 
252 252
     if (prop.r1 == prop.r2 && prop.c1 == prop.c2)
253 253
     {
254
-        float newMass = prop.rng.truncGammaUpper(prop.atom1->mass + prop.atom2->mass, 2.f, 1.f / mLambda);
254
+        float newMass = prop.rng.truncGammaUpper(prop.atom1->mass + prop.atom2->mass, 1.f / mLambda);
255 255
         float delta = (prop.atom1->mass > prop.atom2->mass) ? newMass - prop.atom1->mass : prop.atom2->mass - newMass;
256 256
         if (prop.atom1->mass + delta > gaps::epsilon && prop.atom2->mass - delta > gaps::epsilon)
257 257
         {
258 258
             prop.atom1->mass += delta;
259 259
             prop.atom2->mass -= delta;
260 260
         }        
261
-        return true;
261
+        return true; // automatically accept exchanges in same bin
262 262
     }
263 263
 
264 264
     mQueue.push_back(prop);
... ...
@@ -4,6 +4,9 @@
4 4
 
5 5
 TEST_CASE("AtomicDomain")
6 6
 {
7
+    GapsRng::setSeed(123);
8
+    GapsRng rng;
9
+
7 10
     SECTION("Construction")
8 11
     {
9 12
         AtomicDomain domain(10);
... ...
@@ -47,7 +50,7 @@ TEST_CASE("AtomicDomain")
47 50
         for (unsigned i = 0; i < 1000; ++i)
48 51
         {
49 52
             domain.insert(i, static_cast<float>(i));
50
-            REQUIRE_NOTHROW(domain.randomFreePosition());
53
+            REQUIRE_NOTHROW(domain.randomFreePosition(&rng));
51 54
         }
52 55
     }
53 56
 
... ...
@@ -60,10 +63,10 @@ TEST_CASE("AtomicDomain")
60 63
             domain.insert(i, static_cast<float>(i));
61 64
             
62 65
             // single random atom
63
-            REQUIRE(domain.randomAtom()->pos < i + 1);
66
+            REQUIRE(domain.randomAtom(&rng)->pos < i + 1);
64 67
 
65 68
             // random atom for exchange
66
-            AtomNeighborhood hood = domain.randomAtomWithRightNeighbor();
69
+            AtomNeighborhood hood = domain.randomAtomWithRightNeighbor(&rng);
67 70
             REQUIRE(hood.center->pos < i + 1);
68 71
 
69 72
             REQUIRE(!hood.hasLeft());
... ...
@@ -78,7 +81,7 @@ TEST_CASE("AtomicDomain")
78 81
             }
79 82
 
80 83
             // random atom for move
81
-            hood = domain.randomAtomWithNeighbors();
84
+            hood = domain.randomAtomWithNeighbors(&rng);
82 85
             REQUIRE(hood.center->pos < i + 1);
83 86
 
84 87
             if (hood.center->pos == 0)
... ...
@@ -4,10 +4,11 @@
4 4
 #include "../GibbsSampler.h"
5 5
 #include "../math/Random.h"
6 6
 
7
-#if 0
8
-
9 7
 TEST_CASE("Test utils/Archive.h")
10 8
 {
9
+    GapsRng::setSeed(123);
10
+    GapsRng rng;
11
+
11 12
     SECTION("Reading/Writing to an Archive")
12 13
     {
13 14
         Archive ar1("test_ar.temp", ARCHIVE_WRITE);
... ...
@@ -59,13 +60,13 @@ TEST_CASE("Test utils/Archive.h")
59 60
         REQUIRE(d_read == d_write);
60 61
         REQUIRE(b_read == b_write);
61 62
     }
62
-    
63
+
63 64
     SECTION("Vector Serialization")
64 65
     {
65 66
         Vector vec_read(100), vec_write(100);
66 67
         for (unsigned i = 0; i < 100; ++i)
67 68
         {
68
-            vec_write[i] = gaps::random::uniform(0.0, 2.0);
69
+            vec_write[i] = rng.uniform(0.f, 2.f);
69 70
         }
70 71
 
71 72
         Archive arWrite("test_ar.temp", ARCHIVE_WRITE);
... ...
@@ -86,30 +87,24 @@ TEST_CASE("Test utils/Archive.h")
86 87
 
87 88
     SECTION("Matrix Serialization")
88 89
     {
89
-        RowMatrix rMat_read(100,100), rMat_write(100,100);
90 90
         ColMatrix cMat_read(100,100), cMat_write(100,100);
91 91
 
92 92
         for (unsigned i = 0; i < 100; ++i)
93 93
         {
94 94
             for (unsigned j = 0; j < 100; ++j)
95 95
             {
96
-                rMat_write(i,j) = gaps::random::uniform(0.0, 2.0);
97
-                cMat_write(i,j) = gaps::random::uniform(0.0, 2.0);
96
+                cMat_write(i,j) = rng.uniform(0.f, 2.f);
98 97
             }
99 98
         }
100 99
 
101 100
         Archive arWrite("test_ar.temp", ARCHIVE_WRITE);
102
-        arWrite << rMat_write;
103 101
         arWrite << cMat_write;
104 102
         arWrite.close();
105 103
 
106 104
         Archive arRead("test_ar.temp", ARCHIVE_READ);
107
-        arRead >> rMat_read;
108 105
         arRead >> cMat_read;
109 106
         arRead.close();
110 107
 
111
-        REQUIRE(rMat_read.nRow() == rMat_write.nRow());
112
-        REQUIRE(rMat_read.nCol() == rMat_write.nCol());
113 108
         REQUIRE(cMat_read.nRow() == cMat_write.nRow());
114 109
         REQUIRE(cMat_read.nCol() == cMat_write.nCol());
115 110
     
... ...
@@ -117,7 +112,6 @@ TEST_CASE("Test utils/Archive.h")
117 112
         {
118 113
             for (unsigned j = 0; j < 100; ++j)
119 114
             {
120
-                REQUIRE(rMat_read(i,j) == rMat_write(i,j));
121 115
                 REQUIRE(cMat_read(i,j) == cMat_write(i,j));
122 116
             }
123 117
         }
... ...
@@ -125,9 +119,34 @@ TEST_CASE("Test utils/Archive.h")
125 119
 
126 120
     SECTION("GibbsSampler Serialization")
127 121
     {
128
-        //TODO
122
+        Rcpp::Environment env = Rcpp::Environment::global_env();
123
+        std::string csvPath = Rcpp::as<std::string>(env["gistCsvPath"]);
124
+
125
+        GibbsSampler Asampler(csvPath, false, 7, false, std::vector<unsigned>());
126
+        GibbsSampler Psampler(csvPath, true, 7, false, std::vector<unsigned>());
127
+        Asampler.sync(Psampler);
128
+        Psampler.sync(Asampler);
129
+        
130
+        Asampler.update(10000, 1);
131
+
132
+        Archive arWrite("test_ar.temp", ARCHIVE_WRITE);
133
+        arWrite << Asampler;
134
+        arWrite.close();
135
+
136
+        GibbsSampler savedAsampler(csvPath, false, 7, false, std::vector<unsigned>());
137
+        Archive arRead("test_ar.temp", ARCHIVE_READ);
138
+        arRead >> savedAsampler;
139
+        arRead.close();
140
+    }
141
+
142
+#ifdef GAPS_INTERNAL_TESTS    
143
+    SECTION("AtomicDomain Serialization")
144
+    {
145
+
129 146
     }
147
+#endif
130 148
 
149
+/*
131 150
     SECTION("Random Generator Serialization")
132 151
     {
133 152
         std::vector<float> randSequence;
... ...
@@ -163,9 +182,7 @@ TEST_CASE("Test utils/Archive.h")
163 182
             REQUIRE(gaps::random::exponential(5.5) == randSequence[i]);
164 183
         }
165 184
     }
166
-
185
+*/
167 186
     // cleanup directory
168 187
     std::remove("test_ar.temp");
169
-}
170
-
171
-#endif
172 188
\ No newline at end of file
189
+}
173 190
\ No newline at end of file
... ...
@@ -181,8 +181,8 @@ Archive& operator>>(Archive &ar, ColMatrix &mat)
181 181
     // should already by allocated
182 182
     unsigned nr = 0, nc = 0;
183 183
     ar >> nr >> nc;
184
-    GAPS_ASSERT(nr == mat.mNumRows);
185
-    GAPS_ASSERT(nc == mat.mNumCols);
184
+    GAPS_ASSERT_MSG(nr == mat.mNumRows, nr << " != " << mat.mNumRows);
185
+    GAPS_ASSERT_MSG(nc == mat.mNumCols, nc << " != " << mat.mNumCols);
186 186
 
187 187
     // read in data
188 188
     for (unsigned i = 0; i < mat.mNumCols; ++i)
... ...
@@ -26,6 +26,7 @@ namespace gaps
26 26
     const float epsilon = 1.0e-5f;
27 27
     const float pi = 3.1415926535897932384626433832795f;
28 28
     const float pi_double = 3.1415926535897932384626433832795;
29
+    const float sqrt2 = 1.4142135623730950488016887242097f;
29 30
 
30 31
     float min(float a, float b);
31 32
     unsigned min(unsigned a, unsigned b);
... ...
@@ -20,6 +20,41 @@ const double maxU32AsDouble = static_cast<double>(std::numeric_limits<uint32_t>:
20 20
 
21 21
 static Xoroshiro128plus seeder;
22 22
 
23
+////////////////////////////// Lookup Tables ///////////////////////////////////
24
+
25
+static float erf_lookup_table[3001];
26
+static float erfinv_lookup_table[5001];
27
+static float qgamma_lookup_table[5001];
28
+
29
+static void initLookupTables()
30
+{
31
+    // erf
32
+    for (unsigned i = 0; i < 3001; ++i)
33
+    {
34
+        float x = static_cast<float>(i) / 1000.f;
35
+        erf_lookup_table[i] = 2.f * gaps::p_norm(x * gaps::sqrt2, 0.f, 1.f) - 1.f;
36
+    }
37
+
38
+    // erfinv
39
+    for (unsigned i = 0; i < 5000; ++i)
40
+    {
41
+        float x = static_cast<float>(i) / 5000.f;
42
+        erfinv_lookup_table[i] = gaps::q_norm((1.f + x) / 2.f, 0.f, 1.f) / gaps::sqrt2;
43
+    }
44
+    erfinv_lookup_table[5000] = gaps::q_norm(1.9998f / 2.f, 0.f, 1.f) / gaps::sqrt2;
45
+
46
+    // qgamma
47
+    qgamma_lookup_table[0] = 0.f;
48
+    for (unsigned i = 1; i < 5000; ++i)
49
+    {
50
+        float x = static_cast<float>(i) / 5000.f;
51
+        qgamma_lookup_table[i] = gaps::q_gamma(x, 2.f, 1.f);
52
+    }
53
+    qgamma_lookup_table[5000] = gaps::q_gamma(0.9998f, 2.f, 1.f);
54
+
55
+    GAPS_ASSERT(erf_lookup_table[3000] < 1.f);
56
+}
57
+
23 58
 /////////////////////////////// OptionalFloat //////////////////////////////////
24 59
 
25 60
 OptionalFloat::OptionalFloat() : mValue(0.f), mHasValue(false) {}
... ...
@@ -94,6 +129,7 @@ Archive& operator>>(Archive &ar, Xoroshiro128plus &gen)
94 129
 
95 130
 void GapsRng::setSeed(uint32_t sd)
96 131
 {
132
+    initLookupTables();
97 133
     seeder.seed(sd);
98 134
 }
99 135
 
... ...
@@ -105,6 +141,7 @@ Archive& GapsRng::save(Archive &ar)
105 141
 
106 142
 Archive& GapsRng::load(Archive &ar)
107 143
 {
144
+    initLookupTables();
108 145
     ar >> seeder;
109 146
     return ar;
110 147
 }
... ...
@@ -204,14 +241,7 @@ uint64_t GapsRng::uniform64(uint64_t a, uint64_t b)
204 241
 
205 242
 int GapsRng::poisson(double lambda)
206 243
 {
207
-    if (lambda <= 5.0)
208
-    {
209
-        return poissonSmall(lambda);
210
-    }
211
-    else
212
-    {
213
-        return poissonLarge(lambda);
214
-    }
244
+    return lambda <= 5.0 ? poissonSmall(lambda) : poissonLarge(lambda);
215 245
 }
216 246
 
217 247
 // lambda <= 5
... ...
@@ -263,33 +293,29 @@ float GapsRng::exponential(float lambda)
263 293
     return -1.f * std::log(uniform()) / lambda;
264 294
 }
265 295
 
296
+// fails if too far in tail
266 297
 OptionalFloat GapsRng::truncNormal(float a, float b, float mean, float sd)
267 298
 {
268
-    float pLower = gaps::p_norm(a, mean, sd);
269
-    float pUpper = gaps::p_norm(b, mean, sd);
299
+    float pLower = gaps::p_norm_fast(a, mean, sd);
300
+    float pUpper = gaps::p_norm_fast(b, mean, sd);
270 301
     if (!(pLower > 0.95f || pUpper < 0.05f)) // too far in tail
271 302
     {
272
-        float u = uniform(pLower, pUpper);
273
-        while (u == 0.f || u == 1.f)
274
-        {
275
-            u = uniform(pLower, pUpper);
276
-        }
277
-        float m = gaps::q_norm(u, mean, sd); 
278
-        m = gaps::max(a, gaps::min(m, b));
279
-        return m;
303
+        GAPS_ASSERT(pLower > 0.f);
304
+        GAPS_ASSERT(pUpper < 1.f);
305
+
306
+        float z = gaps::q_norm_fast(uniform(pLower, pUpper), mean, sd); 
307
+        z = gaps::max(a, gaps::min(z, b));
308
+        return z;
280 309
     }
281 310
     return OptionalFloat();
282 311
 }
283 312
 
284
-float GapsRng::truncGammaUpper(float b, float shape, float scale)
313
+// shape is hardcoded to 2 since it never changes
314
+float GapsRng::truncGammaUpper(float b, float scale)
285 315
 {
286
-    float upper = gaps::p_gamma(b, shape, scale);
287
-    float u = uniform(0.f, upper);
288
-    while (u == 0.f || u == 1.f)
289
-    {
290
-        u = uniform(0.f, upper);
291
-    }
292
-    return gaps::q_gamma(u, shape, scale);
316
+    float upper = 1.f - std::exp(-b / scale) * (1.f + b / scale);
317
+    unsigned ndx = static_cast<unsigned>(uniform(0.f, upper * 5000.f));
318
+    return qgamma_lookup_table[ndx] * scale;
293 319
 }
294 320
 
295 321
 Archive& operator<<(Archive &ar, GapsRng &gen)
... ...
@@ -346,13 +372,39 @@ float gaps::p_norm(float p, float mean, float sd)
346 372
     return cdf(norm, p);
347 373
 }
348 374
 
349
-double gaps::lgamma(double x)
375
+float gaps::p_norm_fast(float p, float mean, float sd)
350 376
 {
351
-    return boost::math::lgamma(x);
377
+    float term = (p - mean) / (sd * gaps::sqrt2);
378
+    float erf = 0.f;
379
+    if (term < 0.f)
380
+    {
381
+        term = gaps::max(term, -3.f);
382
+        erf = -erf_lookup_table[static_cast<unsigned>(-term * 1000.f)];
383
+    }
384
+    else
385
+    {
386
+        term = gaps::min(term, 3.f);
387
+        erf = erf_lookup_table[static_cast<unsigned>(term * 1000.f)];
388
+    }
389
+    return 0.5f * (1.f + erf);
352 390
 }
353 391
 
354
-float gaps::d_norm_fast(float d, float mean, float sd)
392
+float gaps::q_norm_fast(float q, float mean, float sd)
355 393
 {
356
-    return std::exp((d - mean) * (d - mean) / (-2.f * sd * sd))
357
-        / std::sqrt(2.f * gaps::pi * sd * sd);
394
+    float term = 2.f * q - 1.f;
395
+    float erfinv = 0.f;
396
+    if (term < 0.f)
397
+    {
398
+        erfinv = -erfinv_lookup_table[static_cast<unsigned>(-term * 5000.f)];
399
+    }
400
+    else
401
+    {
402
+        erfinv = erfinv_lookup_table[static_cast<unsigned>(term * 5000.f)];
403
+    }
404
+    return mean + sd * gaps::sqrt2 * erfinv;
405
+}
406
+
407
+double gaps::lgamma(double x)
408
+{
409
+    return boost::math::lgamma(x);
358 410
 }
359 411
\ No newline at end of file
... ...
@@ -27,19 +27,20 @@ namespace gaps
27 27
 {
28 28
     double lgamma(double x);
29 29
 
30
+    // fast enough with default implementation
30 31
     float d_gamma(float d, float shape, float scale);
31 32
     float p_gamma(float p, float shape, float scale);
33
+
34
+    // standard functions
32 35
     float q_gamma(float q, float shape, float scale);
33 36
     float d_norm(float d, float mean, float sd);
34
-    float q_norm(float q, float mean, float sd);
35 37
     float p_norm(float p, float mean, float sd);
38
+    float q_norm(float q, float mean, float sd);
36 39
 
37
-    float d_gamma_fast(float d, float shape, float scale);
38
-    float p_gamma_fast(float p, float shape, float scale);
39
-    float q_gamma_fast(float q, float shape, float scale);
40
-    float d_norm_fast(float d, float mean, float sd);
40
+    // fast versions, mostly using lookup tables
41 41
     float p_norm_fast(float p, float mean, float sd);
42 42
     float q_norm_fast(float q, float mean, float sd);
43
+
43 44
 }
44 45
 
45 46
 // used for seeding individual rngs
... ...
@@ -83,7 +84,7 @@ public:
83 84
     float exponential(float lambda);
84 85
 
85 86
     OptionalFloat truncNormal(float a, float b, float mean, float sd);
86
-    float truncGammaUpper(float b, float shape, float scale);
87
+    float truncGammaUpper(float b, float scale); // shape hardcoded to 2
87 88
 
88 89
     static void setSeed(uint32_t sd);
89 90
     static Archive& save(Archive &ar);