Browse code

atomicdomain set up to be thread safe

Tom Sherman authored on 10/05/2018 18:51:29
Showing 3 changed files

... ...
@@ -46,13 +46,13 @@ uint64_t AtomicDomain::size() const
46 46
 }
47 47
 
48 48
 // O(1)
49
-Atom& AtomicDomain::left(const Atom &atom)
49
+Atom& AtomicDomain::_left(const Atom &atom)
50 50
 {
51 51
     return mAtoms[atom.leftNdx - 1];
52 52
 }
53 53
 
54 54
 // O(1)
55
-Atom& AtomicDomain::right(const Atom &atom)
55
+Atom& AtomicDomain::_right(const Atom &atom)
56 56
 {
57 57
     return mAtoms[atom.rightNdx - 1];
58 58
 }
... ...
@@ -84,8 +84,6 @@ bool AtomicDomain::hasRight(const Atom &atom) const
84 84
 // O(logN)
85 85
 Atom AtomicDomain::insert(uint64_t pos, float mass)
86 86
 {
87
-    //omp_set_lock(&lock);
88
-
89 87
     // insert position into map
90 88
     std::map<uint64_t, uint64_t>::iterator iter, iterLeft, iterRight;
91 89
     iter = mAtomPositions.insert(std::pair<uint64_t, uint64_t>(pos, size())).first;
... ...
@@ -98,21 +96,19 @@ Atom AtomicDomain::insert(uint64_t pos, float mass)
98 96
     {
99 97
         --iterLeft;
100 98
         atom.leftNdx = iterLeft->second + 1;
101
-        left(atom).rightNdx = size() + 1;
99
+        _left(atom).rightNdx = size() + 1;
102 100
     }
103 101
     if (++iter != mAtomPositions.end())
104 102
     {
105 103
         ++iterRight;
106 104
         atom.rightNdx = iterRight->second + 1;
107
-        right(atom).leftNdx = size() + 1;
105
+        _right(atom).leftNdx = size() + 1;
108 106
     } 
109 107
 
110 108
     // add atom to vector
111 109
     mAtoms.push_back(atom);
112 110
     mUsedPositions.insert(pos);
113 111
 
114
-    //omp_unset_lock(&lock);
115
-
116 112
     return atom;
117 113
 }
118 114
 
... ...
@@ -121,19 +117,17 @@ Atom AtomicDomain::insert(uint64_t pos, float mass)
121 117
 // swap with last atom in vector, pop off the back
122 118
 void AtomicDomain::erase(uint64_t pos)
123 119
 {
124
-    //omp_set_lock(&lock);
125
-
126 120
     // get vector index of this atom and erase it
127 121
     uint64_t index = mAtomPositions.at(pos);
128 122
 
129 123
     // connect neighbors of atom to be deleted
130 124
     if (hasLeft(mAtoms[index]))
131 125
     {
132
-        left(mAtoms[index]).rightNdx = mAtoms[index].rightNdx;
126
+        _left(mAtoms[index]).rightNdx = mAtoms[index].rightNdx;
133 127
     }
134 128
     if (hasRight(mAtoms[index]))
135 129
     {
136
-        right(mAtoms[index]).leftNdx = mAtoms[index].leftNdx;
130
+        _right(mAtoms[index]).leftNdx = mAtoms[index].leftNdx;
137 131
     }
138 132
 
139 133
     // replace with atom from back
... ...
@@ -149,11 +143,11 @@ void AtomicDomain::erase(uint64_t pos)
149 143
         // update moved atom's neighbors
150 144
         if (hasLeft(mAtoms[index]))
151 145
         {
152
-            left(mAtoms[index]).rightNdx = index + 1;
146
+            _left(mAtoms[index]).rightNdx = index + 1;
153 147
         }
154 148
         if (hasRight(mAtoms[index]))
155 149
         {
156
-            right(mAtoms[index]).leftNdx = index + 1;
150
+            _right(mAtoms[index]).leftNdx = index + 1;
157 151
         }
158 152
     }
159 153
 
... ...
@@ -161,14 +155,58 @@ void AtomicDomain::erase(uint64_t pos)
161 155
     mAtomPositions.erase(pos);
162 156
     mAtoms.pop_back();
163 157
     mUsedPositions.erase(pos);
158
+}
159
+
160
+void AtomicDomain::cacheInsert(uint64_t pos, float mass) const
161
+{
162
+    unsigned ndx = 0;
163
+
164
+    #pragma omp critical(atomicInsert)
165
+    {
166
+        ndx = mInsertCacheIndex++;
167
+    }
168
+
169
+    mInsertCache[ndx] = RawAtom(pos, mass);
170
+}
171
+
172
+void AtomicDomain::cacheErase(uint64_t pos) const
173
+{
174
+    unsigned ndx = 0;
175
+
176
+    #pragma omp critical(atomicErase)
177
+    {
178
+        ndx = mEraseCacheIndex++;
179
+    }
180
+
181
+    mEraseCache[ndx] = pos;
182
+}
183
+
184
+void AtomicDomain::resetCache(unsigned n)
185
+{
186
+    mInsertCacheIndex = 0;
187
+    mEraseCacheIndex = 0;
188
+    mInsertCache.resize(n);
189
+    mEraseCache.resize(n);
190
+}
191
+
192
+void AtomicDomain::flushCache()
193
+{
194
+    for (unsigned i = 0; i < mEraseCacheIndex; ++i)
195
+    {
196
+        erase(mEraseCache[i]);
197
+    }
198
+
199
+    for (unsigned i = 0; i < mInsertCacheIndex; ++i)
200
+    {
201
+        insert(mInsertCache[i].pos, mInsertCache[i].mass);
202
+    }
164 203
 
165
-    //omp_unset_lock(&lock);
204
+    mInsertCache.clear();
205
+    mEraseCache.clear();
166 206
 }
167 207
 
168 208
 // O(logN)
169 209
 void AtomicDomain::updateMass(uint64_t pos, float newMass)
170 210
 {
171
-    //omp_set_lock(&lock);
172 211
     mAtoms[mAtomPositions.at(pos)].mass = newMass;
173
-    //omp_unset_lock(&lock);
174 212
 }
175 213
\ No newline at end of file
... ...
@@ -43,6 +43,15 @@ public:
43 43
     }
44 44
 };
45 45
 
46
+struct RawAtom
47
+{
48
+    uint64_t pos;
49
+    float mass;
50
+
51
+    RawAtom() : pos(0), mass(0.f) {}
52
+    RawAtom(uint64_t p, float m) : pos(p), mass(m) {}
53
+};
54
+
46 55
 // data structure that holds atoms
47 56
 class AtomicDomain
48 57
 {
... ...
@@ -58,6 +67,15 @@ private:
58 67
     // TODO google_dense_set - first profile and benchmark
59 68
     boost::unordered_set<uint64_t> mUsedPositions;
60 69
 
70
+    mutable std::vector<RawAtom> mInsertCache;
71
+    mutable std::vector<uint64_t> mEraseCache;
72
+
73
+    mutable unsigned mInsertCacheIndex;
74
+    mutable unsigned mEraseCacheIndex;
75
+
76
+    Atom& _left(const Atom &atom);
77
+    Atom& _right(const Atom &atom);
78
+
61 79
 public:
62 80
 
63 81
     AtomicDomain();
... ...
@@ -70,8 +88,6 @@ public:
70 88
     uint64_t randomFreePosition() const;
71 89
     uint64_t size() const;
72 90
 
73
-    Atom& left(const Atom &atom);
74
-    Atom& right(const Atom &atom);
75 91
     const Atom& left(const Atom &atom) const;
76 92
     const Atom& right(const Atom &atom) const;
77 93
     bool hasLeft(const Atom &atom) const;
... ...
@@ -80,11 +96,15 @@ public:
80 96
     // modify domain
81 97
     Atom insert(uint64_t pos, float mass);
82 98
     void erase(uint64_t pos);
99
+    void cacheInsert(uint64_t pos, float mass) const;
100
+    void cacheErase(uint64_t pos) const;
83 101
     void updateMass(uint64_t pos, float newMass);
102
+    void flushCache();
103
+    void resetCache(unsigned n);
84 104
 
85 105
     // serialization
86 106
     friend Archive& operator<<(Archive &ar, AtomicDomain &domain);
87 107
     friend Archive& operator>>(Archive &ar, AtomicDomain &domain);
88 108
 };
89 109
 
90
-#endif
110
+#endif
91 111
\ No newline at end of file
... ...
@@ -185,6 +185,7 @@ void GibbsSampler<T, MatA, MatB>::update(unsigned nSteps, unsigned nCores)
185 185
         mNumQueues += 1.f;
186 186
         mAvgQueue = mQueue.size() / mNumQueues + mAvgQueue * (mNumQueues - 1.f) / mNumQueues;
187 187
         n += mQueue.size();
188
+        mDomain.resetCache(mQueue.size());
188 189
         //Rprintf("round: %d\n", count++);
189 190
 
190 191
         #pragma omp parallel for num_threads(nCores)
... ...
@@ -192,6 +193,7 @@ void GibbsSampler<T, MatA, MatB>::update(unsigned nSteps, unsigned nCores)
192 193
         {
193 194
             processProposal(mQueue[i]);
194 195
         }
196
+        mDomain.flushCache();
195 197
         mQueue.clear(1);
196 198
         GAPS_ASSERT(n <= nSteps);
197 199
     }
... ...
@@ -229,23 +231,17 @@ void GibbsSampler<T, MatA, MatB>::processProposal(const AtomicProposal &prop)
229 231
 template <class T, class MatA, class MatB>
230 232
 void GibbsSampler<T, MatA, MatB>::addMass(uint64_t pos, float mass, unsigned row, unsigned col)
231 233
 {
232
-    #pragma omp critical(gibbs)
233
-    {
234
-        mDomain.insert(pos, mass);
235
-        mMatrix(row, col) += mass;
236
-        impl()->updateAPMatrix(row, col, mass);
237
-    }
234
+    mDomain.cacheInsert(pos, mass);
235
+    mMatrix(row, col) += mass;
236
+    impl()->updateAPMatrix(row, col, mass);
238 237
 }
239 238
 
240 239
 template <class T, class MatA, class MatB>
241 240
 void GibbsSampler<T, MatA, MatB>::removeMass(uint64_t pos, float mass, unsigned row, unsigned col)
242 241
 {
243
-    #pragma omp critical(gibbs)
244
-    {
245
-        mDomain.erase(pos);
246
-        mMatrix(row, col) += -mass;
247
-        impl()->updateAPMatrix(row, col, -mass);
248
-    }
242
+    mDomain.cacheErase(pos);
243
+    mMatrix(row, col) += -mass;
244
+    impl()->updateAPMatrix(row, col, -mass);
249 245
 }
250 246
 
251 247
 // add an atom at pos, calculate mass either with an exponential distribution
... ...
@@ -256,19 +252,16 @@ unsigned col)
256 252
 {
257 253
     float mass = impl()->canUseGibbs(row, col) ? gibbsMass(row, col, mass)
258 254
         : gaps::random::exponential(mLambda);
259
-    #pragma omp critical(gibbs)
255
+    if (mass >= gaps::algo::epsilon)
260 256
     {
261
-        if (mass >= gaps::algo::epsilon)
262
-        {
263
-            mDomain.updateMass(pos, mass);
264
-            mMatrix(row, col) += mass;
265
-            impl()->updateAPMatrix(row, col, mass);
266
-        }
267
-        else
268
-        {
269
-            mDomain.erase(pos);
270
-            mQueue.rejectBirth();
271
-        }
257
+        mDomain.updateMass(pos, mass);
258
+        mMatrix(row, col) += mass;
259
+        impl()->updateAPMatrix(row, col, mass);
260
+    }
261
+    else
262
+    {
263
+        mDomain.cacheErase(pos);
264
+        mQueue.rejectBirth();
272 265
     }
273 266
 }
274 267
 
... ...
@@ -279,17 +272,24 @@ void GibbsSampler<T, MatA, MatB>::death(uint64_t pos, float mass, unsigned row,
279 272
 unsigned col)
280 273
 {
281 274
     GAPS_ASSERT(mass > 0.f);
282
-    removeMass(pos, mass, row, col);
275
+
276
+    //removeMass(pos, mass, row, col);
277
+    mMatrix(row, col) += -mass;
278
+    impl()->updateAPMatrix(row, col, -mass);
279
+
283 280
     float newMass = impl()->canUseGibbs(row, col) ? gibbsMass(row, col, -mass) : 0.f;
284 281
     mass = newMass < gaps::algo::epsilon ? mass : newMass;
285 282
     float deltaLL = impl()->computeDeltaLL(row, col, mass);
286 283
     if (deltaLL * mAnnealingTemp >= std::log(gaps::random::uniform()))
287 284
     {
288
-        addMass(pos, mass, row, col);
285
+        mDomain.updateMass(pos, mass);
286
+        mMatrix(row, col) += mass;
287
+        impl()->updateAPMatrix(row, col, mass);
289 288
         mQueue.rejectDeath();
290 289
     }
291 290
     else
292 291
     {
292
+        mDomain.cacheErase(pos);
293 293
         mQueue.acceptDeath();
294 294
     }
295 295
 }
... ...
@@ -374,22 +374,17 @@ template <class T, class MatA, class MatB>
374 374
 float GibbsSampler<T, MatA, MatB>::updateAtomMass(uint64_t pos, float mass,
375 375
 float delta)
376 376
 {
377
-    bool ret_val = false;
378
-    #pragma omp critical(gibbs)
377
+    if (mass + delta < gaps::algo::epsilon)
379 378
     {
380
-        if (mass + delta < gaps::algo::epsilon)
381
-        {
382
-            mDomain.erase(pos);
383
-            mQueue.acceptDeath();
384
-            ret_val = false;
385
-        }
386
-        else
387
-        {
388
-            mDomain.updateMass(pos, mass + delta);
389
-            ret_val = true;
390
-        }
379
+        mDomain.cacheErase(pos);
380
+        mQueue.acceptDeath();
381
+        return false;
382
+    }
383
+    else
384
+    {
385
+        mDomain.updateMass(pos, mass + delta);
386
+        return true;
391 387
     }
392
-    return ret_val;
393 388
 }
394 389
 
395 390
 // helper function for exchange step, updates the atomic domain, matrix, and
... ...
@@ -412,13 +407,10 @@ unsigned r2, unsigned c2)
412 407
         mQueue.rejectDeath();
413 408
     }
414 409
 
415
-    #pragma omp critical(gibbs)
416
-    {
417
-        mMatrix(r1, c1) += d1;
418
-        mMatrix(r2, c2) += d2;
419
-        impl()->updateAPMatrix(r1, c1, d1);
420
-        impl()->updateAPMatrix(r2, c2, d2);
421
-    }
410
+    mMatrix(r1, c1) += d1;
411
+    mMatrix(r2, c2) += d2;
412
+    impl()->updateAPMatrix(r1, c1, d1);
413
+    impl()->updateAPMatrix(r2, c2, d2);
422 414
 }
423 415
 
424 416
 template <class T, class MatA, class MatB>