... | ... |
@@ -20,14 +20,14 @@ TEST_CASE("Test SparseIterator.h - One Dimensional") |
20 | 20 |
v.insert(7, 8.f); |
21 | 21 |
v.insert(9, 10.f); |
22 | 22 |
|
23 |
- SparseIterator it(v); |
|
24 |
- REQUIRE(it.getValue() == 1.f); |
|
23 |
+ TemplatedSparseIterator<1> it(v); |
|
24 |
+ REQUIRE(get<1>(it) == 1.f); |
|
25 | 25 |
it.next(); |
26 |
- REQUIRE(it.getValue() == 5.f); |
|
26 |
+ REQUIRE(get<1>(it) == 5.f); |
|
27 | 27 |
it.next(); |
28 |
- REQUIRE(it.getValue() == 8.f); |
|
28 |
+ REQUIRE(get<1>(it) == 8.f); |
|
29 | 29 |
it.next(); |
30 |
- REQUIRE(it.getValue() == 10.f); |
|
30 |
+ REQUIRE(get<1>(it) == 10.f); |
|
31 | 31 |
it.next(); |
32 | 32 |
REQUIRE(it.atEnd()); |
33 | 33 |
} |
... | ... |
@@ -51,10 +51,10 @@ TEST_CASE("Test SparseIterator.h - One Dimensional") |
51 | 51 |
for (unsigned j = 0; j < ref.nCol(); ++j) |
52 | 52 |
{ |
53 | 53 |
float colSum = 0.f; |
54 |
- SparseIterator it(mat.getCol(j)); |
|
54 |
+ TemplatedSparseIterator<1> it(mat.getCol(j)); |
|
55 | 55 |
while (!it.atEnd()) |
56 | 56 |
{ |
57 |
- colSum += it.getValue(); |
|
57 |
+ colSum += get<1>(it); |
|
58 | 58 |
it.next(); |
59 | 59 |
} |
60 | 60 |
REQUIRE(colSum == gaps::sum(ref.getCol(j))); |
... | ... |
@@ -79,12 +79,12 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional") |
79 | 79 |
hv.add(6, 5.f); |
80 | 80 |
hv.add(7, 6.f); |
81 | 81 |
|
82 |
- SparseIteratorTwo it(sv, hv); |
|
83 |
- REQUIRE(it.getValue_1() == 5.f); |
|
84 |
- REQUIRE(it.getValue_2() == 3.f); |
|
82 |
+ TemplatedSparseIterator<2> it(sv, hv); |
|
83 |
+ REQUIRE(get<1>(it) == 5.f); |
|
84 |
+ REQUIRE(get<2>(it) == 3.f); |
|
85 | 85 |
it.next(); |
86 |
- REQUIRE(it.getValue_1() == 8.f); |
|
87 |
- REQUIRE(it.getValue_2() == 6.f); |
|
86 |
+ REQUIRE(get<1>(it) == 8.f); |
|
87 |
+ REQUIRE(get<2>(it) == 6.f); |
|
88 | 88 |
it.next(); |
89 | 89 |
REQUIRE(it.atEnd()); |
90 | 90 |
} |
... | ... |
@@ -107,9 +107,9 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional") |
107 | 107 |
hv.add(8, 9.f); |
108 | 108 |
hv.add(75, 76.f); |
109 | 109 |
|
110 |
- SparseIteratorTwo it(sv, hv); |
|
111 |
- REQUIRE(it.getValue_1() == 75.f); |
|
112 |
- REQUIRE(it.getValue_2() == 76.f); |
|
110 |
+ TemplatedSparseIterator<2> it(sv, hv); |
|
111 |
+ REQUIRE(get<1>(it) == 75.f); |
|
112 |
+ REQUIRE(get<2>(it) == 76.f); |
|
113 | 113 |
it.next(); |
114 | 114 |
REQUIRE(it.atEnd()); |
115 | 115 |
} |
... | ... |
@@ -170,7 +170,7 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional") |
170 | 170 |
|
171 | 171 |
// calculate dot product |
172 | 172 |
float sdot = 0.f, ddot = 0.f; |
173 |
- SparseIteratorTwo it(sv, hv); |
|
173 |
+ TemplatedSparseIterator<2> it(sv, hv); |
|
174 | 174 |
unsigned i = 0; |
175 | 175 |
while (!it.atEnd()) |
176 | 176 |
{ |
... | ... |
@@ -182,12 +182,12 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional") |
182 | 182 |
if (i < dv1.size()) |
183 | 183 |
{ |
184 | 184 |
ddot += dv1[i] * dv2[i]; |
185 |
- REQUIRE(dv1[i] == it.getValue_1()); |
|
186 |
- REQUIRE(dv2[i] == it.getValue_2()); |
|
185 |
+ REQUIRE(dv1[i] == get<1>(it)); |
|
186 |
+ REQUIRE(dv2[i] == get<2>(it)); |
|
187 | 187 |
++i; |
188 | 188 |
} |
189 | 189 |
|
190 |
- sdot += it.getValue_1() * it.getValue_2(); |
|
190 |
+ sdot += get<1>(it) * get<2>(it); |
|
191 | 191 |
|
192 | 192 |
it.next(); |
193 | 193 |
} |
... | ... |
@@ -218,10 +218,10 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional") |
218 | 218 |
for (unsigned j2 = j1; j2 < ref.nCol(); ++j2) |
219 | 219 |
{ |
220 | 220 |
float dot = 0.f; |
221 |
- SparseIteratorTwo it(sMat.getCol(j1), hMat.getCol(j2)); |
|
221 |
+ TemplatedSparseIterator<2> it(sMat.getCol(j1), hMat.getCol(j2)); |
|
222 | 222 |
while (!it.atEnd()) |
223 | 223 |
{ |
224 |
- dot += it.getValue_1() * it.getValue_2(); |
|
224 |
+ dot += get<1>(it) * get<2>(it); |
|
225 | 225 |
it.next(); |
226 | 226 |
} |
227 | 227 |
REQUIRE(dot == gaps::dot(ref.getCol(j1), ref.getCol(j2))); |
... | ... |
@@ -265,14 +265,14 @@ TEST_CASE("Test SparseIterator.h - Three Dimensional") |
265 | 265 |
hv2.add(8, 7.f); |
266 | 266 |
hv2.add(9, 8.f); |
267 | 267 |
|
268 |
- SparseIteratorThree it(sv, hv1, hv2); |
|
269 |
- REQUIRE(it.getValue_1() == 5.f); // 4 |
|
270 |
- REQUIRE(it.getValue_2() == 3.f); |
|
271 |
- REQUIRE(it.getValue_3() == 6.f); |
|
268 |
+ TemplatedSparseIterator<3> it(sv, hv1, hv2); |
|
269 |
+ REQUIRE(get<1>(it) == 5.f); // 4 |
|
270 |
+ REQUIRE(get<2>(it) == 3.f); |
|
271 |
+ REQUIRE(get<3>(it) == 6.f); |
|
272 | 272 |
it.next(); |
273 |
- REQUIRE(it.getValue_1() == 10.f); // 9 |
|
274 |
- REQUIRE(it.getValue_2() == 7.f); |
|
275 |
- REQUIRE(it.getValue_3() == 8.f); |
|
273 |
+ REQUIRE(get<1>(it) == 10.f); // 9 |
|
274 |
+ REQUIRE(get<2>(it) == 7.f); |
|
275 |
+ REQUIRE(get<3>(it) == 8.f); |
|
276 | 276 |
it.next(); |
277 | 277 |
REQUIRE(it.atEnd()); |
278 | 278 |
} |
... | ... |
@@ -301,11 +301,11 @@ TEST_CASE("Test SparseIterator.h - Three Dimensional") |
301 | 301 |
for (unsigned j3 = j2; j3 < ref.nCol(); ++j3) |
302 | 302 |
{ |
303 | 303 |
float prod = 0.f; |
304 |
- SparseIteratorThree it(sMat.getCol(j1), hMat.getCol(j2), |
|
304 |
+ TemplatedSparseIterator<3> it(sMat.getCol(j1), hMat.getCol(j2), |
|
305 | 305 |
hMat.getCol(j3)); |
306 | 306 |
while (!it.atEnd()) |
307 | 307 |
{ |
308 |
- prod += it.getValue_1() * it.getValue_2() * it.getValue_3(); |
|
308 |
+ prod += get<1>(it) * get<2>(it) * get<3>(it); |
|
309 | 309 |
it.next(); |
310 | 310 |
} |
311 | 311 |
REQUIRE(prod == tripleProduct(ref.getCol(j1), |
... | ... |
@@ -21,6 +21,9 @@ public: |
21 | 21 |
|
22 | 22 |
friend class SparseIteratorTwo; |
23 | 23 |
friend class SparseIteratorThree; |
24 |
+ |
|
25 |
+ template <unsigned N> |
|
26 |
+ friend class TemplatedSparseIterator; |
|
24 | 27 |
|
25 | 28 |
explicit HybridVector(unsigned size); |
26 | 29 |
explicit HybridVector(const std::vector<float> &v); |
... | ... |
@@ -18,6 +18,183 @@ static uint64_t clearLowerBits(uint64_t u, unsigned pos) |
18 | 18 |
return u & ~((1ull << (pos + 1ull)) - 1ull); |
19 | 19 |
} |
20 | 20 |
|
21 |
+template <class Iter> |
|
22 |
+void gotoNextCommon(Iter &it) |
|
23 |
+{ |
|
24 |
+ // get the common indices in this chunk |
|
25 |
+ it.calculateCommonFlags(); |
|
26 |
+ |
|
27 |
+ // if nothing common in this chunk, find a chunk that has common indices |
|
28 |
+ while (!it.mCommonFlags) |
|
29 |
+ { |
|
30 |
+ // first count how many sparse indices we are skipping |
|
31 |
+ it.mSparseIndex += __builtin_popcountll(it.mSparseFlags); |
|
32 |
+ |
|
33 |
+ // advance to next chunk |
|
34 |
+ if (++it.mBigIndex == it.mTotalIndices) |
|
35 |
+ { |
|
36 |
+ it.mAtEnd = true; |
|
37 |
+ return; |
|
38 |
+ } |
|
39 |
+ |
|
40 |
+ // update the flags |
|
41 |
+ it.getFlags(); |
|
42 |
+ it.calculateCommonFlags(); |
|
43 |
+ } |
|
44 |
+ |
|
45 |
+ // must have at least one common value, this is our index |
|
46 |
+ it.mSmallIndex = __builtin_ffsll(it.mCommonFlags) - 1; |
|
47 |
+ |
|
48 |
+ // find the number of skipped entries in the sparse vector |
|
49 |
+ it.mSparseIndex += 1 + countLowerBits(it.mSparseFlags, it.mSmallIndex); |
|
50 |
+ |
|
51 |
+ // clear out all skipped indices and the current index from the bitflags |
|
52 |
+ it.mSparseFlags = clearLowerBits(it.mSparseFlags, it.mSmallIndex); |
|
53 |
+} |
|
54 |
+ |
|
55 |
+template<> |
|
56 |
+float get<1>(const TemplatedSparseIterator<1> &it) |
|
57 |
+{ |
|
58 |
+ return it.mSparse.mData[it.mSparseIndex]; |
|
59 |
+} |
|
60 |
+ |
|
61 |
+template<> |
|
62 |
+float get<1>(const TemplatedSparseIterator<2> &it) |
|
63 |
+{ |
|
64 |
+ return it.mSparse.mData[it.mSparseIndex]; |
|
65 |
+} |
|
66 |
+ |
|
67 |
+template<> |
|
68 |
+float get<1>(const TemplatedSparseIterator<3> &it) |
|
69 |
+{ |
|
70 |
+ return it.mSparse.mData[it.mSparseIndex]; |
|
71 |
+} |
|
72 |
+ |
|
73 |
+template<> |
|
74 |
+float get<2>(const TemplatedSparseIterator<2> &it) |
|
75 |
+{ |
|
76 |
+ return it.mHybrid_1[64 * it.mBigIndex + it.mSmallIndex]; |
|
77 |
+} |
|
78 |
+ |
|
79 |
+template<> |
|
80 |
+float get<2>(const TemplatedSparseIterator<3> &it) |
|
81 |
+{ |
|
82 |
+ return it.mHybrid_1[64 * it.mBigIndex + it.mSmallIndex]; |
|
83 |
+} |
|
84 |
+ |
|
85 |
+template<> |
|
86 |
+float get<3>(const TemplatedSparseIterator<3> &it) |
|
87 |
+{ |
|
88 |
+ return it.mHybrid_2[64 * it.mBigIndex + it.mSmallIndex]; |
|
89 |
+} |
|
90 |
+ |
|
91 |
+ |
|
92 |
+TemplatedSparseIterator<1>::TemplatedSparseIterator(const SparseVector &v) |
|
93 |
+: |
|
94 |
+mSparse(v), |
|
95 |
+mSparseIndex(0) |
|
96 |
+{} |
|
97 |
+ |
|
98 |
+bool TemplatedSparseIterator<1>::atEnd() const |
|
99 |
+{ |
|
100 |
+ return mSparseIndex == mSparse.mData.size(); |
|
101 |
+} |
|
102 |
+ |
|
103 |
+void TemplatedSparseIterator<1>::next() |
|
104 |
+{ |
|
105 |
+ ++mSparseIndex; |
|
106 |
+} |
|
107 |
+ |
|
108 |
+TemplatedSparseIterator<2>::TemplatedSparseIterator(const SparseVector &v, const HybridVector &h) |
|
109 |
+ : |
|
110 |
+mSparse(v), |
|
111 |
+mHybrid_1(h), |
|
112 |
+mSparseFlags(v.mIndexBitFlags[0]), |
|
113 |
+mHybridFlags_1(h.mIndexBitFlags[0]), |
|
114 |
+mCommonFlags(v.mIndexBitFlags[0] & h.mIndexBitFlags[0]), |
|
115 |
+mTotalIndices(v.mIndexBitFlags.size()), |
|
116 |
+mBigIndex(0), |
|
117 |
+mSmallIndex(0), |
|
118 |
+mSparseIndex(0), |
|
119 |
+mAtEnd(false) |
|
120 |
+{ |
|
121 |
+ GAPS_ASSERT(v.size() == h.size()); |
|
122 |
+ |
|
123 |
+ next(); |
|
124 |
+ mSparseIndex -= 1; // next puts us at position 1, this resets to 0 |
|
125 |
+} |
|
126 |
+ |
|
127 |
+bool TemplatedSparseIterator<2>::atEnd() const |
|
128 |
+{ |
|
129 |
+ return mAtEnd; |
|
130 |
+} |
|
131 |
+ |
|
132 |
+void TemplatedSparseIterator<2>::next() |
|
133 |
+{ |
|
134 |
+ gotoNextCommon(*this); |
|
135 |
+} |
|
136 |
+ |
|
137 |
+void TemplatedSparseIterator<2>::calculateCommonFlags() |
|
138 |
+{ |
|
139 |
+ mCommonFlags = mSparseFlags & mHybridFlags_1; |
|
140 |
+} |
|
141 |
+ |
|
142 |
+void TemplatedSparseIterator<2>::getFlags() |
|
143 |
+{ |
|
144 |
+ mSparseFlags = mSparse.mIndexBitFlags[mBigIndex]; |
|
145 |
+ mHybridFlags_1 = mHybrid_1.mIndexBitFlags[mBigIndex]; |
|
146 |
+} |
|
147 |
+ |
|
148 |
+TemplatedSparseIterator<3>::TemplatedSparseIterator(const SparseVector &v, |
|
149 |
+const HybridVector &h1, const HybridVector &h2) |
|
150 |
+ : |
|
151 |
+mSparse(v), |
|
152 |
+mHybrid_1(h1), |
|
153 |
+mHybrid_2(h2), |
|
154 |
+mSparseFlags(v.mIndexBitFlags[0]), |
|
155 |
+mHybridFlags_1(h1.mIndexBitFlags[0]), |
|
156 |
+mHybridFlags_2(h2.mIndexBitFlags[0]), |
|
157 |
+mCommonFlags(v.mIndexBitFlags[0] & h1.mIndexBitFlags[0] & h2.mIndexBitFlags[0]), |
|
158 |
+mTotalIndices(v.mIndexBitFlags.size()), |
|
159 |
+mBigIndex(0), |
|
160 |
+mSmallIndex(0), |
|
161 |
+mSparseIndex(0), |
|
162 |
+mAtEnd(false) |
|
163 |
+{ |
|
164 |
+ GAPS_ASSERT(v.size() == h1.size()); |
|
165 |
+ GAPS_ASSERT(h1.size() == h2.size()); |
|
166 |
+ |
|
167 |
+ next(); |
|
168 |
+ mSparseIndex -= 1; |
|
169 |
+} |
|
170 |
+ |
|
171 |
+bool TemplatedSparseIterator<3>::atEnd() const |
|
172 |
+{ |
|
173 |
+ return mAtEnd; |
|
174 |
+} |
|
175 |
+ |
|
176 |
+void TemplatedSparseIterator<3>::next() |
|
177 |
+{ |
|
178 |
+ gotoNextCommon(*this); |
|
179 |
+} |
|
180 |
+ |
|
181 |
+void TemplatedSparseIterator<3>::calculateCommonFlags() |
|
182 |
+{ |
|
183 |
+ mCommonFlags = mSparseFlags & mHybridFlags_1 & mHybridFlags_2; |
|
184 |
+} |
|
185 |
+ |
|
186 |
+void TemplatedSparseIterator<3>::getFlags() |
|
187 |
+{ |
|
188 |
+ mSparseFlags = mSparse.mIndexBitFlags[mBigIndex]; |
|
189 |
+ mHybridFlags_1 = mHybrid_1.mIndexBitFlags[mBigIndex]; |
|
190 |
+ mHybridFlags_2 = mHybrid_2.mIndexBitFlags[mBigIndex]; |
|
191 |
+} |
|
192 |
+ |
|
193 |
+//////////////////////////////////////////////////////////////////////////////// |
|
194 |
+//////////////////////////////////////////////////////////////////////////////// |
|
195 |
+//////////////////////////////////////////////////////////////////////////////// |
|
196 |
+//////////////////////////////////////////////////////////////////////////////// |
|
197 |
+ |
|
21 | 198 |
SparseIterator::SparseIterator(const SparseVector &v) |
22 | 199 |
: |
23 | 200 |
mSparse(v), |
... | ... |
@@ -4,7 +4,111 @@ |
4 | 4 |
#include "HybridVector.h" |
5 | 5 |
#include "SparseVector.h" |
6 | 6 |
|
7 |
-// TODO make these nicer with templates - make sure no performance lost |
|
7 |
+template <unsigned N, class Iter> |
|
8 |
+float get(const Iter &it); |
|
9 |
+ |
|
10 |
+// only allow this class to constructed with N=1,2,3 |
|
11 |
+template <unsigned N> |
|
12 |
+class TemplatedSparseIterator |
|
13 |
+{ |
|
14 |
+private: |
|
15 |
+ TemplatedSparseIterator() {} |
|
16 |
+}; |
|
17 |
+ |
|
18 |
+template<> |
|
19 |
+class TemplatedSparseIterator<1> |
|
20 |
+{ |
|
21 |
+public: |
|
22 |
+ |
|
23 |
+ TemplatedSparseIterator(const SparseVector &v); |
|
24 |
+ |
|
25 |
+ bool atEnd() const; |
|
26 |
+ void next(); |
|
27 |
+ |
|
28 |
+private: |
|
29 |
+ |
|
30 |
+ friend float get<1>(const TemplatedSparseIterator<1> &it); |
|
31 |
+ |
|
32 |
+ const SparseVector &mSparse; |
|
33 |
+ unsigned mSparseIndex; |
|
34 |
+}; |
|
35 |
+ |
|
36 |
+template<> |
|
37 |
+class TemplatedSparseIterator<2> |
|
38 |
+{ |
|
39 |
+public: |
|
40 |
+ |
|
41 |
+ TemplatedSparseIterator(const SparseVector &v, const HybridVector &h); |
|
42 |
+ |
|
43 |
+ bool atEnd() const; |
|
44 |
+ void next(); |
|
45 |
+ void calculateCommonFlags(); |
|
46 |
+ void getFlags(); |
|
47 |
+ |
|
48 |
+private: |
|
49 |
+ |
|
50 |
+ template <class Iter> |
|
51 |
+ friend void gotoNextCommon(Iter &it); |
|
52 |
+ |
|
53 |
+ friend float get<1>(const TemplatedSparseIterator<2> &it); |
|
54 |
+ friend float get<2>(const TemplatedSparseIterator<2> &it); |
|
55 |
+ |
|
56 |
+ const SparseVector &mSparse; |
|
57 |
+ const HybridVector &mHybrid_1; |
|
58 |
+ |
|
59 |
+ uint64_t mSparseFlags; |
|
60 |
+ uint64_t mHybridFlags_1; |
|
61 |
+ uint64_t mCommonFlags; |
|
62 |
+ |
|
63 |
+ unsigned mTotalIndices; |
|
64 |
+ unsigned mBigIndex; |
|
65 |
+ unsigned mSmallIndex; |
|
66 |
+ unsigned mSparseIndex; |
|
67 |
+ bool mAtEnd; |
|
68 |
+}; |
|
69 |
+ |
|
70 |
+template<> |
|
71 |
+class TemplatedSparseIterator<3> |
|
72 |
+{ |
|
73 |
+public: |
|
74 |
+ |
|
75 |
+ TemplatedSparseIterator(const SparseVector &v, const HybridVector &h1, |
|
76 |
+ const HybridVector &h2); |
|
77 |
+ |
|
78 |
+ bool atEnd() const; |
|
79 |
+ void next(); |
|
80 |
+ void calculateCommonFlags(); |
|
81 |
+ void getFlags(); |
|
82 |
+ |
|
83 |
+private: |
|
84 |
+ |
|
85 |
+ template <class Iter> |
|
86 |
+ friend void gotoNextCommon(Iter &it); |
|
87 |
+ |
|
88 |
+ friend float get<1>(const TemplatedSparseIterator<3> &it); |
|
89 |
+ friend float get<2>(const TemplatedSparseIterator<3> &it); |
|
90 |
+ friend float get<3>(const TemplatedSparseIterator<3> &it); |
|
91 |
+ |
|
92 |
+ const SparseVector &mSparse; |
|
93 |
+ const HybridVector &mHybrid_1; |
|
94 |
+ const HybridVector &mHybrid_2; |
|
95 |
+ |
|
96 |
+ uint64_t mSparseFlags; |
|
97 |
+ uint64_t mHybridFlags_1; |
|
98 |
+ uint64_t mHybridFlags_2; |
|
99 |
+ uint64_t mCommonFlags; |
|
100 |
+ |
|
101 |
+ unsigned mTotalIndices; |
|
102 |
+ unsigned mBigIndex; |
|
103 |
+ unsigned mSmallIndex; |
|
104 |
+ unsigned mSparseIndex; |
|
105 |
+ bool mAtEnd; |
|
106 |
+}; |
|
107 |
+ |
|
108 |
+//////////////////////////////////////////////////////////////////////////////// |
|
109 |
+//////////////////////////////////////////////////////////////////////////////// |
|
110 |
+//////////////////////////////////////////////////////////////////////////////// |
|
111 |
+//////////////////////////////////////////////////////////////////////////////// |
|
8 | 112 |
|
9 | 113 |
class SparseIterator |
10 | 114 |
{ |