... | ... |
@@ -76,31 +76,34 @@ void IndexFlagSet::release(uint64_t n) |
76 | 76 |
|
77 | 77 |
AtomPool::AtomPool() |
78 | 78 |
{ |
79 |
- mPool = new Atom[POOL_SIZE]; |
|
79 |
+ //mPool = new Atom[POOL_SIZE]; |
|
80 | 80 |
} |
81 | 81 |
|
82 | 82 |
AtomPool::~AtomPool() |
83 | 83 |
{ |
84 |
- delete[] mPool; |
|
84 |
+ //delete[] mPool; |
|
85 | 85 |
} |
86 | 86 |
|
87 | 87 |
Atom* AtomPool::alloc() |
88 | 88 |
{ |
89 |
- unsigned n = mIndexFlags.getFirstFree(); |
|
90 |
- mIndexFlags.set(n); |
|
91 |
- Atom *a = &(mPool[n]); |
|
92 |
- a->poolIndex = n; |
|
93 |
- return a; |
|
89 |
+ //unsigned n = mIndexFlags.getFirstFree(); |
|
90 |
+ //mIndexFlags.set(n); |
|
91 |
+ //Atom *a = &(mPool[n]); |
|
92 |
+ //a->poolIndex = n; |
|
93 |
+ //return a; |
|
94 |
+ return new Atom(); |
|
94 | 95 |
} |
95 | 96 |
|
96 | 97 |
void AtomPool::free(Atom* a) |
97 | 98 |
{ |
98 |
- mIndexFlags.release(a->poolIndex); |
|
99 |
+ //mIndexFlags.release(a->poolIndex); |
|
100 |
+ delete a; |
|
99 | 101 |
} |
100 | 102 |
|
101 | 103 |
bool AtomPool::depleted() const |
102 | 104 |
{ |
103 |
- return !mIndexFlags.isAnyFree(); |
|
105 |
+ //return !mIndexFlags.isAnyFree(); |
|
106 |
+ return false; |
|
104 | 107 |
} |
105 | 108 |
|
106 | 109 |
////////////////////////////// AtomAllocator /////////////////////////////////// |
... | ... |
@@ -242,6 +242,7 @@ void GapsRunner::createCheckpoint() |
242 | 242 |
|
243 | 243 |
// create checkpoint file |
244 | 244 |
Archive ar(mCheckpointOutFile, ARCHIVE_WRITE); |
245 |
+ GapsRng::save(ar); |
|
245 | 246 |
ar << mNumPatterns << mSeed << mASampler << mPSampler << mStatistics |
246 | 247 |
<< mFixedMatrix << mMaxIterations << mPhase << mCurrentIteration |
247 | 248 |
<< mNumUpdatesA << mNumUpdatesP << mRng; |
... | ... |
@@ -84,7 +84,7 @@ template <class DataType> |
84 | 84 |
GapsRunner::GapsRunner(const DataType &data, bool transposeData, |
85 | 85 |
unsigned nPatterns, bool partitionRows, const std::vector<unsigned> &indices) |
86 | 86 |
: |
87 |
-mASampler(data, !transposeData, nPatterns,!partitionRows, indices), |
|
87 |
+mASampler(data, !transposeData, nPatterns, !partitionRows, indices), |
|
88 | 88 |
mPSampler(data, transposeData, nPatterns, partitionRows, indices), |
89 | 89 |
mStatistics(mPSampler.dataRows(), mPSampler.dataCols(), nPatterns), |
90 | 90 |
mFixedMatrix('N'), mMaxIterations(1000), mMaxThreads(1), mPrintMessages(true), |
... | ... |
@@ -120,24 +120,18 @@ void GibbsSampler::processProposal(const AtomicProposal &prop) |
120 | 120 |
// exponential distribution or with the gibbs mass distribution |
121 | 121 |
void GibbsSampler::birth(const AtomicProposal &prop) |
122 | 122 |
{ |
123 |
- // calculate proposed mass |
|
124 |
- float mass = 0.f; |
|
125 |
- if (canUseGibbs(prop.c1)) |
|
126 |
- { |
|
127 |
- mass = gibbsMass(alphaParameters(prop.r1, prop.c1) * mAnnealingTemp, |
|
128 |
- 0.f, mMaxGibbsMass, mLambda, &(prop.rng)).value(); |
|
129 |
- } |
|
130 |
- else |
|
131 |
- { |
|
132 |
- mass = prop.rng.exponential(mLambda); |
|
133 |
- } |
|
134 |
- |
|
135 |
- // accept mass as long as it's non-zero |
|
136 |
- if (mass >= gaps::epsilon) |
|
123 |
+ // try to get mass using gibbs, resort to exponential if needed |
|
124 |
+ OptionalFloat mass = canUseGibbs(prop.c1) |
|
125 |
+ ? gibbsMass(alphaParameters(prop.r1, prop.c1) * mAnnealingTemp, 0.f, |
|
126 |
+ mMaxGibbsMass, mLambda, &(prop.rng)) |
|
127 |
+ : prop.rng.exponential(mLambda); |
|
128 |
+ |
|
129 |
+ // accept mass as long as gibbs succeded or it's non-zero |
|
130 |
+ if (mass.hasValue() && mass.value() >= gaps::epsilon) |
|
137 | 131 |
{ |
138 | 132 |
mQueue.acceptBirth(); |
139 |
- prop.atom1->mass = mass; |
|
140 |
- changeMatrix(prop.r1, prop.c1, mass); |
|
133 |
+ prop.atom1->mass = mass.value(); |
|
134 |
+ changeMatrix(prop.r1, prop.c1, mass.value()); |
|
141 | 135 |
} |
142 | 136 |
else |
143 | 137 |
{ |
... | ... |
@@ -154,7 +148,6 @@ void GibbsSampler::death(const AtomicProposal &prop) |
154 | 148 |
AlphaParameters alpha = alphaParametersWithChange(prop.r1, prop.c1, |
155 | 149 |
-prop.atom1->mass); |
156 | 150 |
|
157 |
- // try to calculate rebirth mass using gibbs distribution, otherwise use original |
|
158 | 151 |
float rebirthMass = prop.atom1->mass; |
159 | 152 |
if (canUseGibbs(prop.c1)) |
160 | 153 |
{ |
... | ... |
@@ -223,7 +216,7 @@ AlphaParameters alpha) |
223 | 216 |
{ |
224 | 217 |
// compute amount of mass to be exchanged |
225 | 218 |
float totalMass = prop.atom1->mass + prop.atom2->mass; |
226 |
- float newMass = prop.rng.truncGammaUpper(totalMass, 2.f, 1.f / mLambda); |
|
219 |
+ float newMass = prop.rng.truncGammaUpper(totalMass, 1.f / mLambda); |
|
227 | 220 |
|
228 | 221 |
// compute amount to change atom1 by - always change larger mass to newMass |
229 | 222 |
float delta = (prop.atom1->mass > prop.atom2->mass) |
... | ... |
@@ -266,8 +259,9 @@ void GibbsSampler::acceptExchange(const AtomicProposal &prop, float delta) |
266 | 259 |
void GibbsSampler::changeMatrix(unsigned row, unsigned col, float delta) |
267 | 260 |
{ |
268 | 261 |
mMatrix(row, col) += delta; |
269 |
- GAPS_ASSERT(mMatrix(row, col) >= 0.f); |
|
270 | 262 |
updateAPMatrix(row, col, delta); |
263 |
+ |
|
264 |
+ GAPS_ASSERT(mMatrix(row, col) >= 0.f); |
|
271 | 265 |
} |
272 | 266 |
|
273 | 267 |
// delta could be negative, this is needed to prevent negative values in matrix |
... | ... |
@@ -340,13 +334,17 @@ unsigned col, float ch) |
340 | 334 |
|
341 | 335 |
Archive& operator<<(Archive &ar, GibbsSampler &s) |
342 | 336 |
{ |
343 |
- // TODO |
|
337 |
+ ar << s.mMatrix << s.mDomain << s.mQueue << s.mAlpha |
|
338 |
+ << s.mLambda << s.mMaxGibbsMass << s.mAnnealingTemp << s.mNumPatterns |
|
339 |
+ << s.mNumBins << s.mBinLength << s.mDomainLength; |
|
344 | 340 |
return ar; |
345 | 341 |
} |
346 | 342 |
|
347 | 343 |
Archive& operator>>(Archive &ar, GibbsSampler &s) |
348 | 344 |
{ |
349 |
- // TODO |
|
345 |
+ ar >> s.mMatrix >> s.mDomain >> s.mQueue >> s.mAlpha |
|
346 |
+ >> s.mLambda >> s.mMaxGibbsMass >> s.mAnnealingTemp >> s.mNumPatterns |
|
347 |
+ >> s.mNumBins >> s.mBinLength >> s.mDomainLength; |
|
350 | 348 |
return ar; |
351 | 349 |
} |
352 | 350 |
|
... | ... |
@@ -1,6 +1,9 @@ |
1 | 1 |
#ifndef __COGAPS_GIBBS_SAMPLER_H__ |
2 | 2 |
#define __COGAPS_GIBBS_SAMPLER_H__ |
3 | 3 |
|
4 |
+#define DEFAULT_ALPHA 0.01f |
|
5 |
+#define DEFAULT_MAX_GIBBS_MASS 100.f |
|
6 |
+ |
|
4 | 7 |
#include "AtomicDomain.h" |
5 | 8 |
#include "ProposalQueue.h" |
6 | 9 |
#include "data_structures/Matrix.h" |
... | ... |
@@ -112,7 +115,7 @@ mBinLength(std::numeric_limits<uint64_t>::max() / mNumBins), |
112 | 115 |
mDomainLength(mBinLength * mNumBins) |
113 | 116 |
{ |
114 | 117 |
// default sparsity parameters |
115 |
- setSparsity(0.01, 100.f, false); |
|
118 |
+ setSparsity(DEFAULT_ALPHA, DEFAULT_MAX_GIBBS_MASS, false); |
|
116 | 119 |
} |
117 | 120 |
|
118 | 121 |
template <class DataType> |
... | ... |
@@ -251,14 +251,14 @@ bool ProposalQueue::exchange(AtomicDomain &domain) |
251 | 251 |
|
252 | 252 |
if (prop.r1 == prop.r2 && prop.c1 == prop.c2) |
253 | 253 |
{ |
254 |
- float newMass = prop.rng.truncGammaUpper(prop.atom1->mass + prop.atom2->mass, 2.f, 1.f / mLambda); |
|
254 |
+ float newMass = prop.rng.truncGammaUpper(prop.atom1->mass + prop.atom2->mass, 1.f / mLambda); |
|
255 | 255 |
float delta = (prop.atom1->mass > prop.atom2->mass) ? newMass - prop.atom1->mass : prop.atom2->mass - newMass; |
256 | 256 |
if (prop.atom1->mass + delta > gaps::epsilon && prop.atom2->mass - delta > gaps::epsilon) |
257 | 257 |
{ |
258 | 258 |
prop.atom1->mass += delta; |
259 | 259 |
prop.atom2->mass -= delta; |
260 | 260 |
} |
261 |
- return true; |
|
261 |
+ return true; // automatically accept exchanges in same bin |
|
262 | 262 |
} |
263 | 263 |
|
264 | 264 |
mQueue.push_back(prop); |
... | ... |
@@ -4,6 +4,9 @@ |
4 | 4 |
|
5 | 5 |
TEST_CASE("AtomicDomain") |
6 | 6 |
{ |
7 |
+ GapsRng::setSeed(123); |
|
8 |
+ GapsRng rng; |
|
9 |
+ |
|
7 | 10 |
SECTION("Construction") |
8 | 11 |
{ |
9 | 12 |
AtomicDomain domain(10); |
... | ... |
@@ -47,7 +50,7 @@ TEST_CASE("AtomicDomain") |
47 | 50 |
for (unsigned i = 0; i < 1000; ++i) |
48 | 51 |
{ |
49 | 52 |
domain.insert(i, static_cast<float>(i)); |
50 |
- REQUIRE_NOTHROW(domain.randomFreePosition()); |
|
53 |
+ REQUIRE_NOTHROW(domain.randomFreePosition(&rng)); |
|
51 | 54 |
} |
52 | 55 |
} |
53 | 56 |
|
... | ... |
@@ -60,10 +63,10 @@ TEST_CASE("AtomicDomain") |
60 | 63 |
domain.insert(i, static_cast<float>(i)); |
61 | 64 |
|
62 | 65 |
// single random atom |
63 |
- REQUIRE(domain.randomAtom()->pos < i + 1); |
|
66 |
+ REQUIRE(domain.randomAtom(&rng)->pos < i + 1); |
|
64 | 67 |
|
65 | 68 |
// random atom for exchange |
66 |
- AtomNeighborhood hood = domain.randomAtomWithRightNeighbor(); |
|
69 |
+ AtomNeighborhood hood = domain.randomAtomWithRightNeighbor(&rng); |
|
67 | 70 |
REQUIRE(hood.center->pos < i + 1); |
68 | 71 |
|
69 | 72 |
REQUIRE(!hood.hasLeft()); |
... | ... |
@@ -78,7 +81,7 @@ TEST_CASE("AtomicDomain") |
78 | 81 |
} |
79 | 82 |
|
80 | 83 |
// random atom for move |
81 |
- hood = domain.randomAtomWithNeighbors(); |
|
84 |
+ hood = domain.randomAtomWithNeighbors(&rng); |
|
82 | 85 |
REQUIRE(hood.center->pos < i + 1); |
83 | 86 |
|
84 | 87 |
if (hood.center->pos == 0) |
... | ... |
@@ -4,10 +4,11 @@ |
4 | 4 |
#include "../GibbsSampler.h" |
5 | 5 |
#include "../math/Random.h" |
6 | 6 |
|
7 |
-#if 0 |
|
8 |
- |
|
9 | 7 |
TEST_CASE("Test utils/Archive.h") |
10 | 8 |
{ |
9 |
+ GapsRng::setSeed(123); |
|
10 |
+ GapsRng rng; |
|
11 |
+ |
|
11 | 12 |
SECTION("Reading/Writing to an Archive") |
12 | 13 |
{ |
13 | 14 |
Archive ar1("test_ar.temp", ARCHIVE_WRITE); |
... | ... |
@@ -59,13 +60,13 @@ TEST_CASE("Test utils/Archive.h") |
59 | 60 |
REQUIRE(d_read == d_write); |
60 | 61 |
REQUIRE(b_read == b_write); |
61 | 62 |
} |
62 |
- |
|
63 |
+ |
|
63 | 64 |
SECTION("Vector Serialization") |
64 | 65 |
{ |
65 | 66 |
Vector vec_read(100), vec_write(100); |
66 | 67 |
for (unsigned i = 0; i < 100; ++i) |
67 | 68 |
{ |
68 |
- vec_write[i] = gaps::random::uniform(0.0, 2.0); |
|
69 |
+ vec_write[i] = rng.uniform(0.f, 2.f); |
|
69 | 70 |
} |
70 | 71 |
|
71 | 72 |
Archive arWrite("test_ar.temp", ARCHIVE_WRITE); |
... | ... |
@@ -86,30 +87,24 @@ TEST_CASE("Test utils/Archive.h") |
86 | 87 |
|
87 | 88 |
SECTION("Matrix Serialization") |
88 | 89 |
{ |
89 |
- RowMatrix rMat_read(100,100), rMat_write(100,100); |
|
90 | 90 |
ColMatrix cMat_read(100,100), cMat_write(100,100); |
91 | 91 |
|
92 | 92 |
for (unsigned i = 0; i < 100; ++i) |
93 | 93 |
{ |
94 | 94 |
for (unsigned j = 0; j < 100; ++j) |
95 | 95 |
{ |
96 |
- rMat_write(i,j) = gaps::random::uniform(0.0, 2.0); |
|
97 |
- cMat_write(i,j) = gaps::random::uniform(0.0, 2.0); |
|
96 |
+ cMat_write(i,j) = rng.uniform(0.f, 2.f); |
|
98 | 97 |
} |
99 | 98 |
} |
100 | 99 |
|
101 | 100 |
Archive arWrite("test_ar.temp", ARCHIVE_WRITE); |
102 |
- arWrite << rMat_write; |
|
103 | 101 |
arWrite << cMat_write; |
104 | 102 |
arWrite.close(); |
105 | 103 |
|
106 | 104 |
Archive arRead("test_ar.temp", ARCHIVE_READ); |
107 |
- arRead >> rMat_read; |
|
108 | 105 |
arRead >> cMat_read; |
109 | 106 |
arRead.close(); |
110 | 107 |
|
111 |
- REQUIRE(rMat_read.nRow() == rMat_write.nRow()); |
|
112 |
- REQUIRE(rMat_read.nCol() == rMat_write.nCol()); |
|
113 | 108 |
REQUIRE(cMat_read.nRow() == cMat_write.nRow()); |
114 | 109 |
REQUIRE(cMat_read.nCol() == cMat_write.nCol()); |
115 | 110 |
|
... | ... |
@@ -117,7 +112,6 @@ TEST_CASE("Test utils/Archive.h") |
117 | 112 |
{ |
118 | 113 |
for (unsigned j = 0; j < 100; ++j) |
119 | 114 |
{ |
120 |
- REQUIRE(rMat_read(i,j) == rMat_write(i,j)); |
|
121 | 115 |
REQUIRE(cMat_read(i,j) == cMat_write(i,j)); |
122 | 116 |
} |
123 | 117 |
} |
... | ... |
@@ -125,9 +119,34 @@ TEST_CASE("Test utils/Archive.h") |
125 | 119 |
|
126 | 120 |
SECTION("GibbsSampler Serialization") |
127 | 121 |
{ |
128 |
- //TODO |
|
122 |
+ Rcpp::Environment env = Rcpp::Environment::global_env(); |
|
123 |
+ std::string csvPath = Rcpp::as<std::string>(env["gistCsvPath"]); |
|
124 |
+ |
|
125 |
+ GibbsSampler Asampler(csvPath, false, 7, false, std::vector<unsigned>()); |
|
126 |
+ GibbsSampler Psampler(csvPath, true, 7, false, std::vector<unsigned>()); |
|
127 |
+ Asampler.sync(Psampler); |
|
128 |
+ Psampler.sync(Asampler); |
|
129 |
+ |
|
130 |
+ Asampler.update(10000, 1); |
|
131 |
+ |
|
132 |
+ Archive arWrite("test_ar.temp", ARCHIVE_WRITE); |
|
133 |
+ arWrite << Asampler; |
|
134 |
+ arWrite.close(); |
|
135 |
+ |
|
136 |
+ GibbsSampler savedAsampler(csvPath, false, 7, false, std::vector<unsigned>()); |
|
137 |
+ Archive arRead("test_ar.temp", ARCHIVE_READ); |
|
138 |
+ arRead >> savedAsampler; |
|
139 |
+ arRead.close(); |
|
140 |
+ } |
|
141 |
+ |
|
142 |
+#ifdef GAPS_INTERNAL_TESTS |
|
143 |
+ SECTION("AtomicDomain Serialization") |
|
144 |
+ { |
|
145 |
+ |
|
129 | 146 |
} |
147 |
+#endif |
|
130 | 148 |
|
149 |
+/* |
|
131 | 150 |
SECTION("Random Generator Serialization") |
132 | 151 |
{ |
133 | 152 |
std::vector<float> randSequence; |
... | ... |
@@ -163,9 +182,7 @@ TEST_CASE("Test utils/Archive.h") |
163 | 182 |
REQUIRE(gaps::random::exponential(5.5) == randSequence[i]); |
164 | 183 |
} |
165 | 184 |
} |
166 |
- |
|
185 |
+*/ |
|
167 | 186 |
// cleanup directory |
168 | 187 |
std::remove("test_ar.temp"); |
169 |
-} |
|
170 |
- |
|
171 |
-#endif |
|
172 | 188 |
\ No newline at end of file |
189 |
+} |
|
173 | 190 |
\ No newline at end of file |
... | ... |
@@ -181,8 +181,8 @@ Archive& operator>>(Archive &ar, ColMatrix &mat) |
181 | 181 |
// should already by allocated |
182 | 182 |
unsigned nr = 0, nc = 0; |
183 | 183 |
ar >> nr >> nc; |
184 |
- GAPS_ASSERT(nr == mat.mNumRows); |
|
185 |
- GAPS_ASSERT(nc == mat.mNumCols); |
|
184 |
+ GAPS_ASSERT_MSG(nr == mat.mNumRows, nr << " != " << mat.mNumRows); |
|
185 |
+ GAPS_ASSERT_MSG(nc == mat.mNumCols, nc << " != " << mat.mNumCols); |
|
186 | 186 |
|
187 | 187 |
// read in data |
188 | 188 |
for (unsigned i = 0; i < mat.mNumCols; ++i) |
... | ... |
@@ -26,6 +26,7 @@ namespace gaps |
26 | 26 |
const float epsilon = 1.0e-5f; |
27 | 27 |
const float pi = 3.1415926535897932384626433832795f; |
28 | 28 |
const float pi_double = 3.1415926535897932384626433832795; |
29 |
+ const float sqrt2 = 1.4142135623730950488016887242097f; |
|
29 | 30 |
|
30 | 31 |
float min(float a, float b); |
31 | 32 |
unsigned min(unsigned a, unsigned b); |
... | ... |
@@ -20,6 +20,41 @@ const double maxU32AsDouble = static_cast<double>(std::numeric_limits<uint32_t>: |
20 | 20 |
|
21 | 21 |
static Xoroshiro128plus seeder; |
22 | 22 |
|
23 |
+////////////////////////////// Lookup Tables /////////////////////////////////// |
|
24 |
+ |
|
25 |
+static float erf_lookup_table[3001]; |
|
26 |
+static float erfinv_lookup_table[5001]; |
|
27 |
+static float qgamma_lookup_table[5001]; |
|
28 |
+ |
|
29 |
+static void initLookupTables() |
|
30 |
+{ |
|
31 |
+ // erf |
|
32 |
+ for (unsigned i = 0; i < 3001; ++i) |
|
33 |
+ { |
|
34 |
+ float x = static_cast<float>(i) / 1000.f; |
|
35 |
+ erf_lookup_table[i] = 2.f * gaps::p_norm(x * gaps::sqrt2, 0.f, 1.f) - 1.f; |
|
36 |
+ } |
|
37 |
+ |
|
38 |
+ // erfinv |
|
39 |
+ for (unsigned i = 0; i < 5000; ++i) |
|
40 |
+ { |
|
41 |
+ float x = static_cast<float>(i) / 5000.f; |
|
42 |
+ erfinv_lookup_table[i] = gaps::q_norm((1.f + x) / 2.f, 0.f, 1.f) / gaps::sqrt2; |
|
43 |
+ } |
|
44 |
+ erfinv_lookup_table[5000] = gaps::q_norm(1.9998f / 2.f, 0.f, 1.f) / gaps::sqrt2; |
|
45 |
+ |
|
46 |
+ // qgamma |
|
47 |
+ qgamma_lookup_table[0] = 0.f; |
|
48 |
+ for (unsigned i = 1; i < 5000; ++i) |
|
49 |
+ { |
|
50 |
+ float x = static_cast<float>(i) / 5000.f; |
|
51 |
+ qgamma_lookup_table[i] = gaps::q_gamma(x, 2.f, 1.f); |
|
52 |
+ } |
|
53 |
+ qgamma_lookup_table[5000] = gaps::q_gamma(0.9998f, 2.f, 1.f); |
|
54 |
+ |
|
55 |
+ GAPS_ASSERT(erf_lookup_table[3000] < 1.f); |
|
56 |
+} |
|
57 |
+ |
|
23 | 58 |
/////////////////////////////// OptionalFloat ////////////////////////////////// |
24 | 59 |
|
25 | 60 |
OptionalFloat::OptionalFloat() : mValue(0.f), mHasValue(false) {} |
... | ... |
@@ -94,6 +129,7 @@ Archive& operator>>(Archive &ar, Xoroshiro128plus &gen) |
94 | 129 |
|
95 | 130 |
void GapsRng::setSeed(uint32_t sd) |
96 | 131 |
{ |
132 |
+ initLookupTables(); |
|
97 | 133 |
seeder.seed(sd); |
98 | 134 |
} |
99 | 135 |
|
... | ... |
@@ -105,6 +141,7 @@ Archive& GapsRng::save(Archive &ar) |
105 | 141 |
|
106 | 142 |
Archive& GapsRng::load(Archive &ar) |
107 | 143 |
{ |
144 |
+ initLookupTables(); |
|
108 | 145 |
ar >> seeder; |
109 | 146 |
return ar; |
110 | 147 |
} |
... | ... |
@@ -204,14 +241,7 @@ uint64_t GapsRng::uniform64(uint64_t a, uint64_t b) |
204 | 241 |
|
205 | 242 |
int GapsRng::poisson(double lambda) |
206 | 243 |
{ |
207 |
- if (lambda <= 5.0) |
|
208 |
- { |
|
209 |
- return poissonSmall(lambda); |
|
210 |
- } |
|
211 |
- else |
|
212 |
- { |
|
213 |
- return poissonLarge(lambda); |
|
214 |
- } |
|
244 |
+ return lambda <= 5.0 ? poissonSmall(lambda) : poissonLarge(lambda); |
|
215 | 245 |
} |
216 | 246 |
|
217 | 247 |
// lambda <= 5 |
... | ... |
@@ -263,33 +293,29 @@ float GapsRng::exponential(float lambda) |
263 | 293 |
return -1.f * std::log(uniform()) / lambda; |
264 | 294 |
} |
265 | 295 |
|
296 |
+// fails if too far in tail |
|
266 | 297 |
OptionalFloat GapsRng::truncNormal(float a, float b, float mean, float sd) |
267 | 298 |
{ |
268 |
- float pLower = gaps::p_norm(a, mean, sd); |
|
269 |
- float pUpper = gaps::p_norm(b, mean, sd); |
|
299 |
+ float pLower = gaps::p_norm_fast(a, mean, sd); |
|
300 |
+ float pUpper = gaps::p_norm_fast(b, mean, sd); |
|
270 | 301 |
if (!(pLower > 0.95f || pUpper < 0.05f)) // too far in tail |
271 | 302 |
{ |
272 |
- float u = uniform(pLower, pUpper); |
|
273 |
- while (u == 0.f || u == 1.f) |
|
274 |
- { |
|
275 |
- u = uniform(pLower, pUpper); |
|
276 |
- } |
|
277 |
- float m = gaps::q_norm(u, mean, sd); |
|
278 |
- m = gaps::max(a, gaps::min(m, b)); |
|
279 |
- return m; |
|
303 |
+ GAPS_ASSERT(pLower > 0.f); |
|
304 |
+ GAPS_ASSERT(pUpper < 1.f); |
|
305 |
+ |
|
306 |
+ float z = gaps::q_norm_fast(uniform(pLower, pUpper), mean, sd); |
|
307 |
+ z = gaps::max(a, gaps::min(z, b)); |
|
308 |
+ return z; |
|
280 | 309 |
} |
281 | 310 |
return OptionalFloat(); |
282 | 311 |
} |
283 | 312 |
|
284 |
-float GapsRng::truncGammaUpper(float b, float shape, float scale) |
|
313 |
+// shape is hardcoded to 2 since it never changes |
|
314 |
+float GapsRng::truncGammaUpper(float b, float scale) |
|
285 | 315 |
{ |
286 |
- float upper = gaps::p_gamma(b, shape, scale); |
|
287 |
- float u = uniform(0.f, upper); |
|
288 |
- while (u == 0.f || u == 1.f) |
|
289 |
- { |
|
290 |
- u = uniform(0.f, upper); |
|
291 |
- } |
|
292 |
- return gaps::q_gamma(u, shape, scale); |
|
316 |
+ float upper = 1.f - std::exp(-b / scale) * (1.f + b / scale); |
|
317 |
+ unsigned ndx = static_cast<unsigned>(uniform(0.f, upper * 5000.f)); |
|
318 |
+ return qgamma_lookup_table[ndx] * scale; |
|
293 | 319 |
} |
294 | 320 |
|
295 | 321 |
Archive& operator<<(Archive &ar, GapsRng &gen) |
... | ... |
@@ -346,13 +372,39 @@ float gaps::p_norm(float p, float mean, float sd) |
346 | 372 |
return cdf(norm, p); |
347 | 373 |
} |
348 | 374 |
|
349 |
-double gaps::lgamma(double x) |
|
375 |
+float gaps::p_norm_fast(float p, float mean, float sd) |
|
350 | 376 |
{ |
351 |
- return boost::math::lgamma(x); |
|
377 |
+ float term = (p - mean) / (sd * gaps::sqrt2); |
|
378 |
+ float erf = 0.f; |
|
379 |
+ if (term < 0.f) |
|
380 |
+ { |
|
381 |
+ term = gaps::max(term, -3.f); |
|
382 |
+ erf = -erf_lookup_table[static_cast<unsigned>(-term * 1000.f)]; |
|
383 |
+ } |
|
384 |
+ else |
|
385 |
+ { |
|
386 |
+ term = gaps::min(term, 3.f); |
|
387 |
+ erf = erf_lookup_table[static_cast<unsigned>(term * 1000.f)]; |
|
388 |
+ } |
|
389 |
+ return 0.5f * (1.f + erf); |
|
352 | 390 |
} |
353 | 391 |
|
354 |
-float gaps::d_norm_fast(float d, float mean, float sd) |
|
392 |
+float gaps::q_norm_fast(float q, float mean, float sd) |
|
355 | 393 |
{ |
356 |
- return std::exp((d - mean) * (d - mean) / (-2.f * sd * sd)) |
|
357 |
- / std::sqrt(2.f * gaps::pi * sd * sd); |
|
394 |
+ float term = 2.f * q - 1.f; |
|
395 |
+ float erfinv = 0.f; |
|
396 |
+ if (term < 0.f) |
|
397 |
+ { |
|
398 |
+ erfinv = -erfinv_lookup_table[static_cast<unsigned>(-term * 5000.f)]; |
|
399 |
+ } |
|
400 |
+ else |
|
401 |
+ { |
|
402 |
+ erfinv = erfinv_lookup_table[static_cast<unsigned>(term * 5000.f)]; |
|
403 |
+ } |
|
404 |
+ return mean + sd * gaps::sqrt2 * erfinv; |
|
405 |
+} |
|
406 |
+ |
|
407 |
+double gaps::lgamma(double x) |
|
408 |
+{ |
|
409 |
+ return boost::math::lgamma(x); |
|
358 | 410 |
} |
359 | 411 |
\ No newline at end of file |
... | ... |
@@ -27,19 +27,20 @@ namespace gaps |
27 | 27 |
{ |
28 | 28 |
double lgamma(double x); |
29 | 29 |
|
30 |
+ // fast enough with default implementation |
|
30 | 31 |
float d_gamma(float d, float shape, float scale); |
31 | 32 |
float p_gamma(float p, float shape, float scale); |
33 |
+ |
|
34 |
+ // standard functions |
|
32 | 35 |
float q_gamma(float q, float shape, float scale); |
33 | 36 |
float d_norm(float d, float mean, float sd); |
34 |
- float q_norm(float q, float mean, float sd); |
|
35 | 37 |
float p_norm(float p, float mean, float sd); |
38 |
+ float q_norm(float q, float mean, float sd); |
|
36 | 39 |
|
37 |
- float d_gamma_fast(float d, float shape, float scale); |
|
38 |
- float p_gamma_fast(float p, float shape, float scale); |
|
39 |
- float q_gamma_fast(float q, float shape, float scale); |
|
40 |
- float d_norm_fast(float d, float mean, float sd); |
|
40 |
+ // fast versions, mostly using lookup tables |
|
41 | 41 |
float p_norm_fast(float p, float mean, float sd); |
42 | 42 |
float q_norm_fast(float q, float mean, float sd); |
43 |
+ |
|
43 | 44 |
} |
44 | 45 |
|
45 | 46 |
// used for seeding individual rngs |
... | ... |
@@ -83,7 +84,7 @@ public: |
83 | 84 |
float exponential(float lambda); |
84 | 85 |
|
85 | 86 |
OptionalFloat truncNormal(float a, float b, float mean, float sd); |
86 |
- float truncGammaUpper(float b, float shape, float scale); |
|
87 |
+ float truncGammaUpper(float b, float scale); // shape hardcoded to 2 |
|
87 | 88 |
|
88 | 89 |
static void setSeed(uint32_t sd); |
89 | 90 |
static Archive& save(Archive &ar); |