Browse code

added templated version of sparse iterator; use this one after benchmarking performance

Tom Sherman authored on 19/10/2018 21:41:06
Showing 4 changed files

... ...
@@ -20,14 +20,14 @@ TEST_CASE("Test SparseIterator.h - One Dimensional")
20 20
         v.insert(7, 8.f);
21 21
         v.insert(9, 10.f);
22 22
 
23
-        SparseIterator it(v);
24
-        REQUIRE(it.getValue() == 1.f);
23
+        TemplatedSparseIterator<1> it(v);
24
+        REQUIRE(get<1>(it) == 1.f);
25 25
         it.next();
26
-        REQUIRE(it.getValue() == 5.f);
26
+        REQUIRE(get<1>(it) == 5.f);
27 27
         it.next();
28
-        REQUIRE(it.getValue() == 8.f);
28
+        REQUIRE(get<1>(it) == 8.f);
29 29
         it.next();
30
-        REQUIRE(it.getValue() == 10.f);
30
+        REQUIRE(get<1>(it) == 10.f);
31 31
         it.next();
32 32
         REQUIRE(it.atEnd());
33 33
     }
... ...
@@ -51,10 +51,10 @@ TEST_CASE("Test SparseIterator.h - One Dimensional")
51 51
         for (unsigned j = 0; j < ref.nCol(); ++j)
52 52
         {
53 53
             float colSum = 0.f;
54
-            SparseIterator it(mat.getCol(j));
54
+            TemplatedSparseIterator<1> it(mat.getCol(j));
55 55
             while (!it.atEnd())
56 56
             {
57
-                colSum += it.getValue();
57
+                colSum += get<1>(it);
58 58
                 it.next();
59 59
             }
60 60
             REQUIRE(colSum == gaps::sum(ref.getCol(j)));
... ...
@@ -79,12 +79,12 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional")
79 79
         hv.add(6, 5.f);
80 80
         hv.add(7, 6.f);
81 81
         
82
-        SparseIteratorTwo it(sv, hv);
83
-        REQUIRE(it.getValue_1() == 5.f);
84
-        REQUIRE(it.getValue_2() == 3.f);
82
+        TemplatedSparseIterator<2> it(sv, hv);
83
+        REQUIRE(get<1>(it) == 5.f);
84
+        REQUIRE(get<2>(it) == 3.f);
85 85
         it.next();
86
-        REQUIRE(it.getValue_1() == 8.f);
87
-        REQUIRE(it.getValue_2() == 6.f);
86
+        REQUIRE(get<1>(it) == 8.f);
87
+        REQUIRE(get<2>(it) == 6.f);
88 88
         it.next();
89 89
         REQUIRE(it.atEnd());
90 90
     }
... ...
@@ -107,9 +107,9 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional")
107 107
         hv.add(8, 9.f);
108 108
         hv.add(75, 76.f);
109 109
 
110
-        SparseIteratorTwo it(sv, hv);
111
-        REQUIRE(it.getValue_1() == 75.f);
112
-        REQUIRE(it.getValue_2() == 76.f);
110
+        TemplatedSparseIterator<2> it(sv, hv);
111
+        REQUIRE(get<1>(it) == 75.f);
112
+        REQUIRE(get<2>(it) == 76.f);
113 113
         it.next();
114 114
         REQUIRE(it.atEnd());
115 115
     }
... ...
@@ -170,7 +170,7 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional")
170 170
 
171 171
         // calculate dot product
172 172
         float sdot = 0.f, ddot = 0.f;
173
-        SparseIteratorTwo it(sv, hv);
173
+        TemplatedSparseIterator<2> it(sv, hv);
174 174
         unsigned i = 0;
175 175
         while (!it.atEnd())
176 176
         {
... ...
@@ -182,12 +182,12 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional")
182 182
             if (i < dv1.size())
183 183
             {
184 184
                 ddot += dv1[i] * dv2[i];
185
-                REQUIRE(dv1[i] == it.getValue_1());
186
-                REQUIRE(dv2[i] == it.getValue_2());
185
+                REQUIRE(dv1[i] == get<1>(it));
186
+                REQUIRE(dv2[i] == get<2>(it));
187 187
                 ++i;
188 188
             }
189 189
 
190
-            sdot += it.getValue_1() * it.getValue_2();
190
+            sdot += get<1>(it) * get<2>(it);
191 191
 
192 192
             it.next();
193 193
         }
... ...
@@ -218,10 +218,10 @@ TEST_CASE("Test SparseIterator.h - Two Dimensional")
218 218
             for (unsigned j2 = j1; j2 < ref.nCol(); ++j2)
219 219
             {
220 220
                 float dot = 0.f;
221
-                SparseIteratorTwo it(sMat.getCol(j1), hMat.getCol(j2));
221
+                TemplatedSparseIterator<2> it(sMat.getCol(j1), hMat.getCol(j2));
222 222
                 while (!it.atEnd())
223 223
                 {
224
-                    dot += it.getValue_1() * it.getValue_2();
224
+                    dot += get<1>(it) * get<2>(it);
225 225
                     it.next();
226 226
                 }
227 227
                 REQUIRE(dot == gaps::dot(ref.getCol(j1), ref.getCol(j2)));
... ...
@@ -265,14 +265,14 @@ TEST_CASE("Test SparseIterator.h - Three Dimensional")
265 265
         hv2.add(8, 7.f);
266 266
         hv2.add(9, 8.f);
267 267
         
268
-        SparseIteratorThree it(sv, hv1, hv2);
269
-        REQUIRE(it.getValue_1() == 5.f); // 4
270
-        REQUIRE(it.getValue_2() == 3.f);
271
-        REQUIRE(it.getValue_3() == 6.f);
268
+        TemplatedSparseIterator<3> it(sv, hv1, hv2);
269
+        REQUIRE(get<1>(it) == 5.f); // 4
270
+        REQUIRE(get<2>(it) == 3.f);
271
+        REQUIRE(get<3>(it) == 6.f);
272 272
         it.next();
273
-        REQUIRE(it.getValue_1() == 10.f); // 9
274
-        REQUIRE(it.getValue_2() == 7.f);
275
-        REQUIRE(it.getValue_3() == 8.f);
273
+        REQUIRE(get<1>(it) == 10.f); // 9
274
+        REQUIRE(get<2>(it) == 7.f);
275
+        REQUIRE(get<3>(it) == 8.f);
276 276
         it.next();
277 277
         REQUIRE(it.atEnd());
278 278
     }
... ...
@@ -301,11 +301,11 @@ TEST_CASE("Test SparseIterator.h - Three Dimensional")
301 301
                 for (unsigned j3 = j2; j3 < ref.nCol(); ++j3)
302 302
                 {
303 303
                     float prod = 0.f;
304
-                    SparseIteratorThree it(sMat.getCol(j1), hMat.getCol(j2),
304
+                    TemplatedSparseIterator<3> it(sMat.getCol(j1), hMat.getCol(j2),
305 305
                         hMat.getCol(j3));
306 306
                     while (!it.atEnd())
307 307
                     {
308
-                        prod += it.getValue_1() * it.getValue_2() * it.getValue_3();
308
+                        prod += get<1>(it) * get<2>(it) * get<3>(it);
309 309
                         it.next();
310 310
                     }
311 311
                     REQUIRE(prod == tripleProduct(ref.getCol(j1),
... ...
@@ -21,6 +21,9 @@ public:
21 21
 
22 22
     friend class SparseIteratorTwo;
23 23
     friend class SparseIteratorThree;
24
+    
25
+    template <unsigned N>
26
+    friend class TemplatedSparseIterator;
24 27
 
25 28
     explicit HybridVector(unsigned size);
26 29
     explicit HybridVector(const std::vector<float> &v);
... ...
@@ -18,6 +18,183 @@ static uint64_t clearLowerBits(uint64_t u, unsigned pos)
18 18
     return u & ~((1ull << (pos + 1ull)) - 1ull);
19 19
 }
20 20
 
21
+template <class Iter>
22
+void gotoNextCommon(Iter &it)
23
+{
24
+    // get the common indices in this chunk
25
+    it.calculateCommonFlags();
26
+
27
+    // if nothing common in this chunk, find a chunk that has common indices
28
+    while (!it.mCommonFlags)
29
+    {
30
+        // first count how many sparse indices we are skipping
31
+        it.mSparseIndex += __builtin_popcountll(it.mSparseFlags);
32
+        
33
+        // advance to next chunk
34
+        if (++it.mBigIndex == it.mTotalIndices)
35
+        {   
36
+            it.mAtEnd = true;
37
+            return;
38
+        }
39
+
40
+        // update the flags
41
+        it.getFlags();
42
+        it.calculateCommonFlags();
43
+    }
44
+
45
+    // must have at least one common value, this is our index
46
+    it.mSmallIndex = __builtin_ffsll(it.mCommonFlags) - 1;
47
+
48
+    // find the number of skipped entries in the sparse vector
49
+    it.mSparseIndex += 1 + countLowerBits(it.mSparseFlags, it.mSmallIndex);
50
+
51
+    // clear out all skipped indices and the current index from the bitflags
52
+    it.mSparseFlags = clearLowerBits(it.mSparseFlags, it.mSmallIndex);
53
+}
54
+
55
+template<>
56
+float get<1>(const TemplatedSparseIterator<1> &it)
57
+{
58
+    return it.mSparse.mData[it.mSparseIndex];
59
+}
60
+
61
+template<>
62
+float get<1>(const TemplatedSparseIterator<2> &it)
63
+{
64
+    return it.mSparse.mData[it.mSparseIndex];
65
+}
66
+
67
+template<>
68
+float get<1>(const TemplatedSparseIterator<3> &it)
69
+{
70
+    return it.mSparse.mData[it.mSparseIndex];
71
+}
72
+
73
+template<>
74
+float get<2>(const TemplatedSparseIterator<2> &it)
75
+{
76
+    return it.mHybrid_1[64 * it.mBigIndex + it.mSmallIndex];
77
+}
78
+
79
+template<>
80
+float get<2>(const TemplatedSparseIterator<3> &it)
81
+{
82
+    return it.mHybrid_1[64 * it.mBigIndex + it.mSmallIndex];
83
+}
84
+
85
+template<>
86
+float get<3>(const TemplatedSparseIterator<3> &it)
87
+{
88
+    return it.mHybrid_2[64 * it.mBigIndex + it.mSmallIndex];
89
+}
90
+
91
+
92
+TemplatedSparseIterator<1>::TemplatedSparseIterator(const SparseVector &v)
93
+:
94
+mSparse(v),
95
+mSparseIndex(0)
96
+{}
97
+
98
+bool TemplatedSparseIterator<1>::atEnd() const
99
+{
100
+    return mSparseIndex == mSparse.mData.size();
101
+}
102
+
103
+void TemplatedSparseIterator<1>::next()
104
+{
105
+    ++mSparseIndex;
106
+}
107
+
108
+TemplatedSparseIterator<2>::TemplatedSparseIterator(const SparseVector &v, const HybridVector &h)
109
+    :
110
+mSparse(v),
111
+mHybrid_1(h),
112
+mSparseFlags(v.mIndexBitFlags[0]),
113
+mHybridFlags_1(h.mIndexBitFlags[0]),
114
+mCommonFlags(v.mIndexBitFlags[0] & h.mIndexBitFlags[0]),
115
+mTotalIndices(v.mIndexBitFlags.size()),
116
+mBigIndex(0),
117
+mSmallIndex(0),
118
+mSparseIndex(0),
119
+mAtEnd(false)
120
+{
121
+    GAPS_ASSERT(v.size() == h.size());
122
+
123
+    next();
124
+    mSparseIndex -= 1; // next puts us at position 1, this resets to 0
125
+}
126
+
127
+bool TemplatedSparseIterator<2>::atEnd() const
128
+{
129
+    return mAtEnd;
130
+}
131
+
132
+void TemplatedSparseIterator<2>::next()
133
+{
134
+    gotoNextCommon(*this);
135
+}
136
+
137
+void TemplatedSparseIterator<2>::calculateCommonFlags()
138
+{
139
+    mCommonFlags = mSparseFlags & mHybridFlags_1;
140
+}
141
+
142
+void TemplatedSparseIterator<2>::getFlags()
143
+{
144
+    mSparseFlags = mSparse.mIndexBitFlags[mBigIndex];
145
+    mHybridFlags_1 = mHybrid_1.mIndexBitFlags[mBigIndex];
146
+}
147
+
148
+TemplatedSparseIterator<3>::TemplatedSparseIterator(const SparseVector &v,
149
+const HybridVector &h1, const HybridVector &h2)
150
+    :
151
+mSparse(v),
152
+mHybrid_1(h1),
153
+mHybrid_2(h2),
154
+mSparseFlags(v.mIndexBitFlags[0]),
155
+mHybridFlags_1(h1.mIndexBitFlags[0]),
156
+mHybridFlags_2(h2.mIndexBitFlags[0]),
157
+mCommonFlags(v.mIndexBitFlags[0] & h1.mIndexBitFlags[0] & h2.mIndexBitFlags[0]),
158
+mTotalIndices(v.mIndexBitFlags.size()),
159
+mBigIndex(0),
160
+mSmallIndex(0),
161
+mSparseIndex(0),
162
+mAtEnd(false)
163
+{
164
+    GAPS_ASSERT(v.size() == h1.size());
165
+    GAPS_ASSERT(h1.size() == h2.size());
166
+
167
+    next();
168
+    mSparseIndex -= 1;
169
+}
170
+
171
+bool TemplatedSparseIterator<3>::atEnd() const
172
+{
173
+    return mAtEnd;
174
+}
175
+
176
+void TemplatedSparseIterator<3>::next()
177
+{
178
+    gotoNextCommon(*this);
179
+}
180
+
181
+void TemplatedSparseIterator<3>::calculateCommonFlags()
182
+{
183
+    mCommonFlags = mSparseFlags & mHybridFlags_1 & mHybridFlags_2;
184
+}
185
+
186
+void TemplatedSparseIterator<3>::getFlags()
187
+{
188
+    mSparseFlags = mSparse.mIndexBitFlags[mBigIndex];
189
+    mHybridFlags_1 = mHybrid_1.mIndexBitFlags[mBigIndex];
190
+    mHybridFlags_2 = mHybrid_2.mIndexBitFlags[mBigIndex];
191
+}
192
+
193
+////////////////////////////////////////////////////////////////////////////////
194
+////////////////////////////////////////////////////////////////////////////////
195
+////////////////////////////////////////////////////////////////////////////////
196
+////////////////////////////////////////////////////////////////////////////////
197
+
21 198
 SparseIterator::SparseIterator(const SparseVector &v)
22 199
     :
23 200
 mSparse(v),
... ...
@@ -4,7 +4,111 @@
4 4
 #include "HybridVector.h"
5 5
 #include "SparseVector.h"
6 6
 
7
-// TODO make these nicer with templates - make sure no performance lost
7
+template <unsigned N, class Iter>
8
+float get(const Iter &it);
9
+
10
+// only allow this class to constructed with N=1,2,3
11
+template <unsigned N>
12
+class TemplatedSparseIterator
13
+{
14
+private:
15
+    TemplatedSparseIterator() {}
16
+};
17
+
18
+template<>
19
+class TemplatedSparseIterator<1>
20
+{
21
+public:
22
+
23
+    TemplatedSparseIterator(const SparseVector &v);
24
+
25
+    bool atEnd() const;
26
+    void next();
27
+
28
+private:
29
+
30
+    friend float get<1>(const TemplatedSparseIterator<1> &it);
31
+
32
+    const SparseVector &mSparse;
33
+    unsigned mSparseIndex;
34
+};
35
+
36
+template<>
37
+class TemplatedSparseIterator<2>
38
+{
39
+public:
40
+
41
+    TemplatedSparseIterator(const SparseVector &v, const HybridVector &h);
42
+
43
+    bool atEnd() const;
44
+    void next();
45
+    void calculateCommonFlags();
46
+    void getFlags();
47
+
48
+private:
49
+
50
+    template <class Iter>
51
+    friend void gotoNextCommon(Iter &it);
52
+
53
+    friend float get<1>(const TemplatedSparseIterator<2> &it);
54
+    friend float get<2>(const TemplatedSparseIterator<2> &it);
55
+
56
+    const SparseVector &mSparse;  
57
+    const HybridVector &mHybrid_1;
58
+
59
+    uint64_t mSparseFlags;
60
+    uint64_t mHybridFlags_1;
61
+    uint64_t mCommonFlags;
62
+
63
+    unsigned mTotalIndices;
64
+    unsigned mBigIndex;
65
+    unsigned mSmallIndex;
66
+    unsigned mSparseIndex;
67
+    bool mAtEnd;
68
+};
69
+
70
+template<>
71
+class TemplatedSparseIterator<3>
72
+{
73
+public:
74
+
75
+    TemplatedSparseIterator(const SparseVector &v, const HybridVector &h1,
76
+    const HybridVector &h2);
77
+
78
+    bool atEnd() const;
79
+    void next();
80
+    void calculateCommonFlags();
81
+    void getFlags();
82
+
83
+private:
84
+
85
+    template <class Iter>
86
+    friend void gotoNextCommon(Iter &it);
87
+
88
+    friend float get<1>(const TemplatedSparseIterator<3> &it);
89
+    friend float get<2>(const TemplatedSparseIterator<3> &it);
90
+    friend float get<3>(const TemplatedSparseIterator<3> &it);
91
+
92
+    const SparseVector &mSparse;  
93
+    const HybridVector &mHybrid_1;
94
+    const HybridVector &mHybrid_2;
95
+
96
+    uint64_t mSparseFlags;
97
+    uint64_t mHybridFlags_1;
98
+    uint64_t mHybridFlags_2;
99
+    uint64_t mCommonFlags;
100
+
101
+    unsigned mTotalIndices;
102
+    unsigned mBigIndex;
103
+    unsigned mSmallIndex;
104
+    unsigned mSparseIndex;
105
+    bool mAtEnd;
106
+};
107
+
108
+////////////////////////////////////////////////////////////////////////////////
109
+////////////////////////////////////////////////////////////////////////////////
110
+////////////////////////////////////////////////////////////////////////////////
111
+////////////////////////////////////////////////////////////////////////////////
8 112
 
9 113
 class SparseIterator
10 114
 {