Browse code

separate likelihood calculation for death

Tom Sherman authored on 29/08/2018 18:10:29
Showing8 changed files

... ...
@@ -39,6 +39,7 @@ void GibbsSampler::setMatrix(const Matrix &mat)
39 39
 {   
40 40
     GAPS_ASSERT(mat.nRow() == mMatrix.nRow());
41 41
     GAPS_ASSERT(mat.nCol() == mMatrix.nCol());
42
+
42 43
     mMatrix = mat;
43 44
 }
44 45
 
... ...
@@ -60,34 +61,25 @@ void GibbsSampler::sync(const GibbsSampler &sampler)
60 61
 
61 62
 void GibbsSampler::update(unsigned nSteps, unsigned nCores)
62 63
 {
63
-    unsigned n = 0;
64
-    while (n < nSteps)
64
+    for (unsigned n = 0; n < nSteps; ++n)
65 65
     {
66 66
         makeAndProcessProposal();
67
-        ++n;
68 67
     }
69 68
 }
70 69
 
71 70
 void GibbsSampler::makeAndProcessProposal()
72 71
 {
73
-    // always birth when no atoms exist
74
-    if (mDomain.size() == 0)
72
+    if (mDomain.size() < 2)
75 73
     {
76 74
         return birth();
77 75
     }
78 76
 
79
-    float bdProb = mDomain.size() < 2 ? 0.6667f : 0.5f;
80
-
81 77
     float u1 = mPropRng.uniform();
82
-    float u2 = mPropRng.uniform();
83
-
84
-    float lowerBound = deathProb(mDomain.size());
85
-
86
-    if (u1 <= bdProb)
78
+    if (u1 < 0.5f)
87 79
     {
88
-        return u2 < lowerBound ? death() : birth();
80
+        return mPropRng.uniform() < deathProb(mDomain.size()) ? death() : birth();
89 81
     }
90
-    return (u1 < 0.75f || mDomain.size() < 2) ? move() : exchange();
82
+    return (u1 < 0.75f) ? move() : exchange();
91 83
 }
92 84
 
93 85
 float GibbsSampler::deathProb(uint64_t nAtoms) const
... ...
@@ -109,7 +101,7 @@ void GibbsSampler::birth()
109 101
 
110 102
     // calculate proposed mass
111 103
     float mass = canUseGibbs(col)
112
-        ? gibbsMass(alphaParameters(row, col), &rng).value
104
+        ? gibbsMass(alphaParameters(row, col), &rng).value()
113 105
         : rng.exponential(mLambda);
114 106
 
115 107
     // accept mass as long as it's non-zero
... ...
@@ -122,7 +114,7 @@ void GibbsSampler::birth()
122 114
 }
123 115
 
124 116
 // automatically accept death, attempt a rebirth at the same position, using
125
-// the original mass or the gibbs mass calculation
117
+// the original mass or the gibbs mass distribution
126 118
 void GibbsSampler::death()
127 119
 {
128 120
     // get random atom
... ...
@@ -131,30 +123,32 @@ void GibbsSampler::death()
131 123
     unsigned col = getCol(atom->pos);
132 124
     GapsRng rng;
133 125
 
134
-    // kill off atom
135
-    safelyChangeMatrix(row, col, -1.f * atom->mass);
136
-
137
-    // calculate rebirth mass
138
-    AlphaParameters alpha = alphaParameters(row, col);
126
+    // calculate alpha parameters assuming atom dies
127
+    AlphaParameters alpha = alphaParametersWithChange(row, col, -atom->mass);
139 128
     float rebirthMass = atom->mass;
129
+    bool useSame = true;
140 130
     if (canUseGibbs(col))
141 131
     {
142 132
         OptionalFloat gMass = gibbsMass(alpha, &rng);
143
-        if (gMass.hasValue)
133
+        if (gMass.hasValue())
144 134
         {
145
-            rebirthMass = gMass.value;
135
+            rebirthMass = gMass.value();
136
+            useSame = false;
146 137
         }
147 138
     }
148 139
 
149 140
     // accept/reject rebirth
150
-    if (getDeltaLL(alpha, rebirthMass) * mAnnealingTemp >= std::log(rng.uniform()))
141
+    if (std::log(rng.uniform()) < getDeltaLL(alpha, rebirthMass) * mAnnealingTemp)
151 142
     {
152
-        atom->mass = rebirthMass;
153
-        mMatrix(row, col) += rebirthMass;
154
-        updateAPMatrix(row, col, rebirthMass);
143
+        if (!useSame)
144
+        {
145
+            safelyChangeMatrix(row, col, rebirthMass - atom->mass);
146
+            atom->mass = rebirthMass;
147
+        }
155 148
     }
156 149
     else
157 150
     {
151
+        safelyChangeMatrix(row, col, -atom->mass);
158 152
         mDomain.erase(atom->pos);
159 153
     }
160 154
 }
... ...
@@ -176,11 +170,10 @@ void GibbsSampler::move()
176 170
     {
177 171
         GapsRng rng;
178 172
         AlphaParameters alpha = alphaParameters(r1, c1, r2, c2);
179
-        float deltaLL = getDeltaLL(alpha, -1.f * hood.center->mass);
180
-        if (deltaLL * mAnnealingTemp > std::log(rng.uniform()))
173
+        if (std::log(rng.uniform()) < getDeltaLL(alpha, -hood.center->mass) * mAnnealingTemp)
181 174
         {
182 175
             hood.center->pos = newLocation;
183
-            safelyChangeMatrix(r1, c1, -1.f * hood.center->mass);
176
+            safelyChangeMatrix(r1, c1, -hood.center->mass);
184 177
             mMatrix(r2, c2) += hood.center->mass;
185 178
             updateAPMatrix(r2, c2, hood.center->mass);
186 179
         }
... ...
@@ -208,28 +201,27 @@ void GibbsSampler::exchange()
208 201
     if (r1 != r2 || c1 != c2)
209 202
     {
210 203
         GapsRng rng;
211
-
212
-        float m1 = a1->mass;
213
-        float m2 = a2->mass;
214
-
215 204
         AlphaParameters alpha = alphaParameters(r1, c1, r2, c2);
216 205
         if (canUseGibbs(c1, c2))
217 206
         {
218
-            OptionalFloat gMass = gibbsMass(alpha, m1, m2, &rng);
219
-            if (gMass.hasValue)
207
+            OptionalFloat gMass = gibbsMass(alpha, a1->mass, a2->mass, &rng);
208
+            if (gMass.hasValue())
220 209
             {
221
-                acceptExchange(a1, a2, gMass.value, r1, c1, r2, c2);
210
+                acceptExchange(a1, a2, gMass.value(), r1, c1, r2, c2);
222 211
                 return;
223 212
             }
224 213
         }
225 214
 
226
-        float newMass = rng.truncGammaUpper(m1 + m2, 2.f, 1.f / mLambda);
215
+        float newMass = rng.truncGammaUpper(a1->mass + a2->mass, 2.f, 1.f / mLambda);
227 216
 
228
-        float delta = m1 > m2 ? newMass - m1 : m2 - newMass; // change larger mass
229
-        float pOldMass = 2.f * newMass > m1 + m2 ? gaps::max(m1, m2) : gaps::min(m1, m2);
217
+        // change larger mass
218
+        float delta = a1->mass > a2->mass ? newMass - a1->mass : a2->mass - newMass;
219
+        float oldMass = (2.f * newMass > a1->mass + a2->mass)
220
+            ? gaps::max(a1->mass, a2->mass)
221
+            : gaps::min(a1->mass, a2->mass);
230 222
 
231 223
         float pNew = gaps::d_gamma(newMass, 2.f, 1.f / mLambda);
232
-        float pOld = gaps::d_gamma(pOldMass, 2.f, 1.f / mLambda);
224
+        float pOld = gaps::d_gamma(oldMass, 2.f, 1.f / mLambda);
233 225
 
234 226
         float deltaLL = getDeltaLL(alpha, delta);
235 227
         float priorLL = (pNew == 0.f) ? 1.f : pOld / pNew;
... ...
@@ -383,6 +375,14 @@ unsigned r2, unsigned c2)
383 375
     return alphaParameters(r1, c1) + alphaParameters(r2, c2);
384 376
 }
385 377
 
378
+AlphaParameters GibbsSampler::alphaParametersWithChange(unsigned row,
379
+unsigned col, float ch)
380
+{
381
+    return gaps::algo::alphaParametersWithChange(mDMatrix.nRow(),
382
+        mDMatrix.colPtr(row), mSMatrix.colPtr(row), mAPMatrix.colPtr(row),
383
+        mOtherMatrix->colPtr(col), ch);
384
+}
385
+
386 386
 Archive& operator<<(Archive &ar, GibbsSampler &s)
387 387
 {
388 388
     // TODO
... ...
@@ -2,7 +2,6 @@
2 2
 #define __COGAPS_GIBBS_SAMPLER_H__
3 3
 
4 4
 #include "AtomicDomain.h"
5
-#include "ProposalQueue.h"
6 5
 #include "data_structures/Matrix.h"
7 6
 #include "math/Algorithms.h"
8 7
 #include "math/Random.h"
... ...
@@ -92,8 +91,8 @@ private:
92 91
     bool canUseGibbs(unsigned c1, unsigned c2) const;
93 92
 
94 93
     AlphaParameters alphaParameters(unsigned row, unsigned col);
95
-    AlphaParameters alphaParameters(unsigned r1, unsigned c1, unsigned r2,
96
-        unsigned c2);
94
+    AlphaParameters alphaParameters(unsigned r1, unsigned c1, unsigned r2, unsigned c2);
95
+    AlphaParameters alphaParametersWithChange(unsigned row, unsigned col, float ch);
97 96
 };
98 97
 
99 98
 template <class DataType>
... ...
@@ -8,7 +8,6 @@ OBJECTS =   AtomicDomain.o \
8 8
             GapsRunner.o \
9 9
             GapsStatistics.o \
10 10
             GibbsSampler.o \
11
-            ProposalQueue.o \
12 11
             RcppExports.o \
13 12
             test-runner.o \
14 13
             data_structures/Matrix.o \
15 14
deleted file mode 100644
... ...
@@ -1,277 +0,0 @@
1
-#include "utils/GapsAssert.h"
2
-#include "ProposalQueue.h"
3
-#include "math/Random.h"
4
-
5
-//////////////////////////////// AtomicProposal ////////////////////////////////
6
-
7
-// birth
8
-AtomicProposal::AtomicProposal(char t, uint64_t pos)
9
-    : type(t), birthPos(pos), moveDest(0), atom1(NULL), atom2(NULL)
10
-{}
11
-    
12
-// death
13
-AtomicProposal::AtomicProposal(char t, Atom *atom)
14
-    : type(t), birthPos(0), moveDest(0), atom1(atom), atom2(NULL)
15
-{}
16
-
17
-// move
18
-AtomicProposal::AtomicProposal(char t, Atom *atom, uint64_t dest)
19
-    : type(t), birthPos(0), moveDest(dest), atom1(atom), atom2(NULL)
20
-{}
21
-
22
-// exchange
23
-AtomicProposal::AtomicProposal(char t, Atom *a1, Atom *a2)
24
-    : type(t), birthPos(0), moveDest(0), atom1(a1), atom2(a2)
25
-{}
26
-
27
-
28
-//////////////////////////////// ProposalQueue /////////////////////////////////
29
-
30
-ProposalQueue::ProposalQueue(unsigned primaryDimSize, unsigned secondaryDimSize)
31
-    :
32
-mMinAtoms(0), mMaxAtoms(0), mNumBins(primaryDimSize * secondaryDimSize),
33
-mBinLength(std::numeric_limits<uint64_t>::max() / mNumBins),
34
-mSecondaryDimLength(mBinLength * secondaryDimSize),
35
-mDomainLength(mBinLength * mNumBins), mSecondaryDimSize(secondaryDimSize),
36
-mAlpha(0.f), mU1(0.f), mU2(0.f), mUseCachedRng(false)
37
-{
38
-    mUsedIndices.setDimensionSize(primaryDimSize);
39
-}
40
-
41
-void ProposalQueue::setAlpha(float alpha)
42
-{
43
-    mAlpha = alpha;
44
-}
45
-
46
-void ProposalQueue::populate(AtomicDomain &domain, unsigned limit)
47
-{
48
-    GAPS_ASSERT(mMinAtoms == mMaxAtoms);
49
-    GAPS_ASSERT(mMaxAtoms == domain.size());
50
-
51
-    unsigned nIter = 0;
52
-    bool success = true;
53
-    while (nIter++ < limit && success)
54
-    {
55
-        success = makeProposal(domain);
56
-        if (!success)
57
-        {
58
-            mUseCachedRng = true;
59
-        }
60
-    }
61
-}
62
-
63
-void ProposalQueue::clear()
64
-{
65
-    GAPS_ASSERT(mMinAtoms == mMaxAtoms);
66
-
67
-    mQueue.clear();
68
-    mUsedPositions.clear();
69
-    mUsedIndices.clear();
70
-}
71
-
72
-unsigned ProposalQueue::size() const
73
-{
74
-    return mQueue.size();
75
-}
76
-
77
-AtomicProposal& ProposalQueue::operator[](int n)
78
-{
79
-    GAPS_ASSERT(mQueue.size() > 0);
80
-    GAPS_ASSERT(n < mQueue.size());
81
-
82
-    return mQueue[n];
83
-}
84
-
85
-void ProposalQueue::acceptDeath()
86
-{
87
-    #pragma omp atomic
88
-    mMaxAtoms--;
89
-}
90
-
91
-void ProposalQueue::rejectDeath()
92
-{
93
-    #pragma omp atomic
94
-    mMinAtoms++;
95
-}
96
-
97
-void ProposalQueue::acceptBirth()
98
-{
99
-    #pragma omp atomic
100
-    mMinAtoms++;
101
-}
102
-
103
-void ProposalQueue::rejectBirth()
104
-{
105
-    #pragma omp atomic
106
-    mMaxAtoms--;
107
-}
108
-
109
-float ProposalQueue::deathProb(uint64_t nAtoms) const
110
-{
111
-    double size = static_cast<double>(mDomainLength);
112
-    double term1 = (size - static_cast<double>(nAtoms)) / size;
113
-    double term2 = mAlpha * static_cast<double>(mNumBins) * term1;
114
-    return static_cast<double>(nAtoms) / (static_cast<double>(nAtoms) + term2);
115
-}
116
-
117
-bool ProposalQueue::makeProposal(AtomicDomain &domain)
118
-{
119
-    // always birth when no atoms exist
120
-    if (domain.size() == 0)
121
-    {
122
-        return birth(domain);
123
-    }
124
-
125
-    float bdProb = domain.size() < 2 ? 0.6667f : 0.5f;
126
-
127
-    mU1 = mRng.uniform();
128
-    mU2 = mRng.uniform();
129
-
130
-    float lowerBound = deathProb(domain.size());
131
-
132
-    if (mU1 <= bdProb)
133
-    {
134
-        return mU2 < lowerBound ? death(domain) : birth(domain);
135
-    }
136
-    return (mU1 < 0.75f || mMaxAtoms < 2) ? move(domain) : exchange(domain);
137
-}
138
-    
139
-unsigned ProposalQueue::primaryIndex(uint64_t pos) const
140
-{
141
-    return pos / mSecondaryDimLength;
142
-}
143
-
144
-unsigned ProposalQueue::secondaryIndex(uint64_t pos) const
145
-{
146
-    return (pos / mBinLength) % mSecondaryDimSize;
147
-}
148
-
149
-// TODO add atoms with empty mass? fill in mass in gibbssampler?
150
-// inserting invalidates previous pointers, but not inserting
151
-// prevents them from being selected for death
152
-bool ProposalQueue::birth(AtomicDomain &domain)
153
-{
154
-    uint64_t pos = domain.randomFreePosition();
155
-    if (mUsedIndices.contains(primaryIndex(pos)))
156
-    {
157
-        return false; // matrix conflict - can't compute gibbs mass
158
-    }
159
-
160
-    mQueue.push_back(AtomicProposal('B', pos));
161
-    mUsedIndices.insert(primaryIndex(pos));
162
-    mUsedPositions.insert(pos);
163
-    ++mMaxAtoms;
164
-    return true;
165
-}
166
-
167
-bool ProposalQueue::death(AtomicDomain &domain)
168
-{
169
-    Atom* a = domain.randomAtom();
170
-    if (mUsedIndices.contains(primaryIndex(a->pos)))
171
-    {
172
-        return false; // matrix conflict - can't compute gibbs mass or deltaLL
173
-    }
174
-
175
-    mQueue.push_back(AtomicProposal('D', a));
176
-    mUsedIndices.insert(primaryIndex(a->pos));
177
-    mUsedPositions.insert(a->pos);
178
-    --mMinAtoms;
179
-    return true;
180
-}
181
-
182
-bool ProposalQueue::move(AtomicDomain &domain)
183
-{
184
-    AtomNeighborhood hood = domain.randomAtomWithNeighbors();
185
-    uint64_t lbound = hood.hasLeft() ? hood.left->pos : 0;
186
-    uint64_t rbound = hood.hasRight() ? hood.right->pos : mDomainLength;
187
-
188
-    if (!mUsedPositions.isEmptyInterval(lbound, rbound))
189
-    {
190
-        return false; // atomic conflict - don't know neighbors
191
-    }
192
-
193
-    uint64_t newLocation = mRng.uniform64(lbound + 1, rbound - 1);
194
-
195
-    if (primaryIndex(hood.center->pos) == primaryIndex(newLocation)
196
-    && secondaryIndex(hood.center->pos) == secondaryIndex(newLocation))
197
-    {
198
-        hood.center->pos = newLocation; // automatically accept moves in same bin
199
-        return true;
200
-    }
201
-
202
-    if (mUsedIndices.contains(primaryIndex(hood.center->pos))
203
-    || mUsedIndices.contains(primaryIndex(newLocation)))
204
-    {
205
-        return false; // matrix conflict - can't compute deltaLL
206
-    }
207
-
208
-    mQueue.push_back(AtomicProposal('M', hood.center, newLocation));
209
-    mUsedIndices.insert(primaryIndex(hood.center->pos));
210
-    mUsedIndices.insert(primaryIndex(newLocation));
211
-    mUsedPositions.insert(hood.center->pos);
212
-    mUsedPositions.insert(newLocation);
213
-    return true;
214
-}
215
-
216
-bool ProposalQueue::exchange(AtomicDomain &domain)
217
-{
218
-    AtomNeighborhood hood = domain.randomAtomWithRightNeighbor();
219
-    Atom* a1 = hood.center;
220
-    Atom* a2 = hood.hasRight() ? hood.right : domain.front();
221
-
222
-    if (hood.hasRight()) // has neighbor
223
-    {
224
-        if (!mUsedPositions.isEmptyInterval(a1->pos, a2->pos))
225
-        {
226
-            return false; // atomic conflict - don't know right neighbor
227
-        }
228
-    }
229
-    else // exchange with first atom
230
-    {
231
-        if (!mUsedPositions.isEmptyInterval(a1->pos, mDomainLength))
232
-        {
233
-            return false; // atomic conflict - don't know right neighbor
234
-        }
235
-        
236
-        if (!mUsedPositions.isEmptyInterval(0, domain.front()->pos))
237
-        {
238
-            return false; // atomic conflict - don't know right neighbor
239
-        }
240
-    }
241
-
242
-    if (primaryIndex(a1->pos) == primaryIndex(a2->pos)
243
-    && secondaryIndex(a1->pos) == secondaryIndex(a2->pos))
244
-    {
245
-        return true; // TODO automatically accept exchanges in same bin
246
-    }
247
-
248
-    if (mUsedIndices.contains(primaryIndex(a1->pos))
249
-    || mUsedIndices.contains(primaryIndex(a2->pos)))
250
-    {
251
-        return false; // matrix conflict - can't compute gibbs mass or deltaLL
252
-    }
253
-
254
-    mQueue.push_back(AtomicProposal('E', a1, a2));
255
-    mUsedIndices.insert(primaryIndex(a1->pos));
256
-    mUsedIndices.insert(primaryIndex(a2->pos));
257
-    mUsedPositions.insert(a1->pos);
258
-    mUsedPositions.insert(a2->pos);
259
-    --mMinAtoms;
260
-    return true;
261
-}
262
-
263
-Archive& operator<<(Archive &ar, ProposalQueue &q)
264
-{
265
-    ar << q.mMinAtoms << q.mMaxAtoms << q.mNumBins << q.mBinLength
266
-        << q.mSecondaryDimLength << q.mDomainLength << q.mSecondaryDimSize
267
-        << q.mAlpha << q.mRng;
268
-    return ar;
269
-}
270
-
271
-Archive& operator>>(Archive &ar, ProposalQueue &q)
272
-{
273
-    ar >> q.mMinAtoms >> q.mMaxAtoms >> q.mNumBins >> q.mBinLength
274
-        >> q.mSecondaryDimLength >> q.mDomainLength >> q.mSecondaryDimSize
275
-        >> q.mAlpha >> q.mRng;
276
-    return ar;
277
-}
278 0
deleted file mode 100644
... ...
@@ -1,89 +0,0 @@
1
-#ifndef __GAPS_PROPOSAL_QUEUE_H__
2
-#define __GAPS_PROPOSAL_QUEUE_H__
3
-
4
-#include "utils/Archive.h"
5
-#include "AtomicDomain.h"
6
-#include "data_structures/EfficientSets.h"
7
-#include "math/Random.h"
8
-
9
-#include <cstddef>
10
-#include <stdint.h>
11
-#include <vector>
12
-
13
-struct AtomicProposal
14
-{
15
-    char type;
16
-    uint64_t birthPos; // used in birth
17
-    uint64_t moveDest; // used in move
18
-
19
-    Atom *atom1; // used in death, move, exchange
20
-    Atom *atom2; // used in exchange
21
-
22
-    mutable GapsRng rng;
23
-
24
-    AtomicProposal(char t, uint64_t pos); // birth
25
-    AtomicProposal(char t, Atom *atom); // death
26
-    AtomicProposal(char t, Atom *atom, uint64_t dest); // move
27
-    AtomicProposal(char t, Atom *a1, Atom *a2); // exchange
28
-};
29
-
30
-class ProposalQueue
31
-{
32
-public:
33
-
34
-    ProposalQueue(unsigned primaryDimSize, unsigned secondaryDimSize);
35
-    void setAlpha(float alpha);
36
-
37
-    // modify/access queue
38
-    void populate(AtomicDomain &domain, unsigned limit);
39
-    void clear();
40
-    unsigned size() const;
41
-    AtomicProposal& operator[](int n);
42
-
43
-    // update min/max atoms
44
-    void acceptDeath();
45
-    void rejectDeath();
46
-    void acceptBirth();
47
-    void rejectBirth();
48
-
49
-private:
50
-
51
-    std::vector<AtomicProposal> mQueue; // not really a queue for now
52
-    
53
-    IntFixedHashSet mUsedIndices;
54
-    IntDenseOrderedSet mUsedPositions;
55
-
56
-    uint64_t mMinAtoms;
57
-    uint64_t mMaxAtoms;
58
-
59
-    unsigned mNumBins; // number of matrix elements
60
-    uint64_t mBinLength; // atomic length of one bin
61
-    uint64_t mSecondaryDimLength; // atomic length of one row (col) for A (P)
62
-    uint64_t mDomainLength; // length of entire atomic domain
63
-    unsigned mSecondaryDimSize; // number of cols (rows) for A (P)
64
-
65
-    float mAlpha;
66
-
67
-    mutable GapsRng mRng;
68
-
69
-    float mU1;
70
-    float mU2;
71
-    bool mUseCachedRng;
72
-
73
-    unsigned primaryIndex(uint64_t pos) const;
74
-    unsigned secondaryIndex(uint64_t pos) const;
75
-
76
-    float deathProb(uint64_t nAtoms) const;
77
-    bool birth(AtomicDomain &domain);
78
-    bool death(AtomicDomain &domain);
79
-    bool move(AtomicDomain &domain);
80
-    bool exchange(AtomicDomain &domain);
81
-
82
-    bool makeProposal(AtomicDomain &domain);
83
-
84
-    // serialization
85
-    friend Archive& operator<<(Archive &ar, ProposalQueue &queue);
86
-    friend Archive& operator>>(Archive &ar, ProposalQueue &queue);
87
-};
88
-
89
-#endif
90 0
\ No newline at end of file
... ...
@@ -148,7 +148,7 @@ bool gaps::algo::isVectorZero(const float *vec, unsigned size)
148 148
 AlphaParameters gaps::algo::alphaParameters(unsigned size, const float *D,
149 149
 const float *S, const float *AP, const float *mat)
150 150
 {
151
-    gaps::simd::packedFloat ratio, pMat, pD, pAP, pS;
151
+    gaps::simd::packedFloat pMat, pD, pAP, pS;
152 152
     gaps::simd::packedFloat partialS(0.f), partialSU(0.f);
153 153
     gaps::simd::Index i(0);
154 154
     for (; i <= size - i.increment(); ++i)
... ...
@@ -157,25 +157,24 @@ const float *S, const float *AP, const float *mat)
157 157
         pD.load(D + i);
158 158
         pAP.load(AP + i);
159 159
         pS.load(S + i);
160
-        ratio = pMat / pS;
161
-        partialS += ratio * ratio;
162
-        partialSU += (ratio * (pD - pAP)) / pS;
160
+        gaps::simd::packedFloat ratio(pMat / (pS * pS));
161
+        partialS += pMat * ratio;
162
+        partialSU += ratio * (pD - pAP);
163 163
     }
164
-    float fratio, s = partialS.scalar(), su = partialSU.scalar();
164
+    float s = partialS.scalar(), su = partialSU.scalar();
165 165
     for (unsigned j = i.value(); j < size; ++j)
166 166
     {
167
-        fratio = mat[j] / S[j]; // can save one division here by dividing by S^2
168
-        s += fratio * fratio;
169
-        su += (fratio * (D[j] - AP[j])) / S[j];
167
+        float ratio = mat[j] / (S[j] * S[j]);
168
+        s += mat[j] * ratio;
169
+        su += ratio * (D[j] - AP[j]);
170 170
     }
171 171
     return AlphaParameters(s,su);
172 172
 }
173 173
 
174
-//
175 174
 AlphaParameters gaps::algo::alphaParameters(unsigned size, const float *D,
176 175
 const float *S, const float *AP, const float *mat1, const float *mat2)
177 176
 {
178
-    gaps::simd::packedFloat ratio, pMat1, pMat2, pD, pAP, pS;
177
+    gaps::simd::packedFloat pMat1, pMat2, pD, pAP, pS;
179 178
     gaps::simd::packedFloat partialS(0.f), partialSU(0.f);
180 179
     gaps::simd::Index i(0);
181 180
     for (; i <= size - i.increment(); ++i)
... ...
@@ -185,17 +184,44 @@ const float *S, const float *AP, const float *mat1, const float *mat2)
185 184
         pD.load(D + i);
186 185
         pAP.load(AP + i);
187 186
         pS.load(S + i);
188
-        ratio = (pMat1 - pMat2) / pS;
189
-        partialS += ratio * ratio;
190
-        partialSU += ratio * (pD - pAP) / pS;
187
+        gaps::simd::packedFloat ratio((pMat1 - pMat2) / (pS * pS));
188
+        partialS += (pMat1 - pMat2) * ratio;
189
+        partialSU += ratio * (pD - pAP);
191 190
     }
192 191
 
193
-    float fratio, s = partialS.scalar(), su = partialSU.scalar();
192
+    float s = partialS.scalar(), su = partialSU.scalar();
194 193
     for (unsigned j = i.value(); j < size; ++j)
195 194
     {
196
-        fratio = (mat1[j] - mat2[j]) / S[j];
197
-        s += fratio * fratio;
198
-        su += fratio * (D[j] - AP[j]) / S[j];
195
+        float ratio = (mat1[j] - mat2[j]) / (S[j] * S[j]);
196
+        s += (mat1[j] - mat2[j]) * ratio;
197
+        su += ratio * (D[j] - AP[j]);
198
+    }
199
+    return AlphaParameters(s,su);
200
+}
201
+
202
+AlphaParameters gaps::algo::alphaParametersWithChange(unsigned size,
203
+const float *D, const float *S, const float *AP, const float *mat, float ch)
204
+{
205
+    gaps::simd::packedFloat pCh(ch);
206
+    gaps::simd::packedFloat pMat, pD, pAP, pS;
207
+    gaps::simd::packedFloat partialS(0.f), partialSU(0.f);
208
+    gaps::simd::Index i(0);
209
+    for (; i <= size - i.increment(); ++i)
210
+    {   
211
+        pMat.load(mat + i);
212
+        pD.load(D + i);
213
+        pAP.load(AP + i);
214
+        pS.load(S + i);
215
+        gaps::simd::packedFloat ratio(pMat / (pS * pS));
216
+        partialS += pMat * ratio;
217
+        partialSU += ratio * (pD - (pAP + pCh * pMat));
218
+    }
219
+    float s = partialS.scalar(), su = partialSU.scalar();
220
+    for (unsigned j = i.value(); j < size; ++j)
221
+    {
222
+        float ratio = mat[j] / (S[j] * S[j]);
223
+        s += mat[j] * ratio;
224
+        su += ratio * (D[j] - (AP[j] + ch * mat[j]));
199 225
     }
200 226
     return AlphaParameters(s,su);
201 227
 }
... ...
@@ -72,6 +72,9 @@ namespace algo
72 72
     AlphaParameters alphaParameters(unsigned size, const float *D,
73 73
         const float *S, const float *AP, const float *mat);
74 74
 
75
+    AlphaParameters alphaParametersWithChange(unsigned size, const float *D,
76
+        const float *S, const float *AP, const float *mat, float d);
77
+
75 78
     AlphaParameters alphaParameters(unsigned size, const float *D,
76 79
         const float *S, const float *AP, const float *mat1, const float *mat2);
77 80
 
... ...
@@ -9,11 +9,18 @@
9 9
 
10 10
 struct OptionalFloat
11 11
 {
12
-    float value;
13
-    bool hasValue;
12
+public :
14 13
 
15
-    OptionalFloat() : hasValue(false), value(0.f) {}
16
-    OptionalFloat(float f) : hasValue(true), value(f) {}
14
+    OptionalFloat() : mHasValue(false), mValue(0.f) {}
15
+    OptionalFloat(float f) : mHasValue(true), mValue(f) {}
16
+
17
+    float value() { return mValue; }
18
+    bool hasValue() const { return mHasValue; }
19
+
20
+private :
21
+
22
+    float mValue;
23
+    bool mHasValue;
17 24
 };
18 25
 
19 26
 namespace gaps