forked from microsoft/cppwinrt
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmulti_threaded_map.cpp
More file actions
337 lines (301 loc) · 11.2 KB
/
multi_threaded_map.cpp
File metadata and controls
337 lines (301 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
#include "pch.h"
#include <numeric>
#include <shared_mutex>
#include "multi_threaded_common.h"
using namespace winrt;
using namespace Windows::Foundation;
using namespace Windows::Foundation::Collections;
using namespace concurrent_collections;
// Map correctness tests exist elsewhere. These tests are strictly geared toward testing multi threaded functionality
namespace
{
// We use a customized container that mimics std::map and which
// validates that C++ concurrency rules are observed.
// C++ rules for library types are that concurrent use of const methods is allowed,
// but no method call may be concurrent with a non-const method. (Const methods may
// be "shared", but non-const methods are "exclusive".)
//
// NOTE! As the C++/WinRT implementation changes, you may need to add additional members
// to our mock.
//
// The regular single_threaded_map and multi_threaded_map functions require std::map
// or std::unordered_map, so we bypass them and go directly to the underlying classes,
// which take any container that acts map-like.
enum class MapKind
{
IMap,
IObservableMap,
};
// Change the next line to "#if 0" to use a single-threaded map and confirm that every test fails.
// The scenarios use "CHECK" instead of "REQUIRE" so that they continue running even on failure.
// That way, you can just step through the entire test in single-threaded mode and confirm that
// something bad happens at each scenario.
#if 1
template<typename K, typename V, typename Container>
using custom_threaded_map = winrt::impl::multi_threaded_map<K, V, Container>;
template<typename K, typename V, typename Container>
using custom_observable_map = winrt::impl::multi_threaded_observable_map<K, V, Container>;
#else
template<typename K, typename V, typename Container>
using custom_threaded_map = winrt::impl::input_map<K, V, Container>;
template<typename K, typename V, typename Container>
using custom_observable_map = winrt::impl::observable_map<K, V, Container>;
#endif
template<MapKind kind, typename Container>
auto make_threaded_map(Container&& values)
{
using K = typename Container::key_type;
using V = typename Container::mapped_type;
if constexpr (kind == MapKind::IMap)
{
return static_cast<IMap<K, V>>(winrt::make<custom_threaded_map<K, V, Container>>(std::move(values)));
}
else
{
return static_cast<IObservableMap<K, V>>(winrt::make<custom_observable_map<K, V, Container>>(std::move(values)));
}
}
#pragma region map wrapper
// Add more wrapper methods as necessary.
// (Turns out we don't use many features of std::map and std::unordered_map.)
template<typename K, typename V, typename Compare = std::less<K>, typename Allocator = std::allocator<std::pair<const K, V>>>
struct concurrency_checked_map : private std::map<K, V, Compare, Allocator>, concurrency_guard
{
using inner = typename concurrency_checked_map::map;
using key_type = typename inner::key_type;
using mapped_type = typename inner::mapped_type;
using value_type = typename inner::value_type;
using size_type = typename inner::size_type;
using difference_type = typename inner::difference_type;
using allocator_type = typename inner::allocator_type;
using reference = typename inner::reference;
using const_reference = typename inner::const_reference;
using pointer = typename inner::pointer;
using const_pointer = typename inner::const_pointer;
using iterator = concurrency_checked_random_access_iterator<concurrency_checked_map, typename inner::iterator>;
using const_iterator = concurrency_checked_random_access_iterator<concurrency_checked_map, typename inner::const_iterator, typename inner::iterator>;
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
using node_type = typename inner::node_type;
mapped_type& operator[](const key_type& key)
{
auto guard = concurrency_guard::lock_nonconst();
concurrency_guard::call_hook(collection_action::at);
return { this, inner::begin() };
}
iterator begin()
{
auto guard = concurrency_guard::lock_nonconst();
return { this, inner::begin() };
}
const_iterator begin() const
{
auto guard = concurrency_guard::lock_const();
return { this, inner::begin() };
}
iterator end()
{
auto guard = concurrency_guard::lock_nonconst();
return { this, inner::end() };
}
const_iterator end() const
{
auto guard = concurrency_guard::lock_const();
return { this, inner::end() };
}
size_type size() const
{
auto guard = concurrency_guard::lock_const();
return inner::size();
}
void clear()
{
auto guard = concurrency_guard::lock_nonconst();
return inner::clear();
}
void swap(concurrency_checked_map& other)
{
auto guard = concurrency_guard::lock_nonconst();
inner::swap(other);
}
template<typename... Args>
std::pair<iterator, bool> emplace(Args&&... args)
{
auto guard = concurrency_guard::lock_nonconst();
concurrency_guard::call_hook(collection_action::insert);
auto [it, inserted] = inner::emplace(std::forward<Args>(args)...);
return { { this, it }, inserted };
}
node_type extract(const_iterator pos)
{
auto guard = concurrency_guard::lock_nonconst();
concurrency_guard::call_hook(collection_action::erase);
return inner::extract(pos);
}
const_iterator find(const K& key) const
{
auto guard = concurrency_guard::lock_const();
concurrency_guard::call_hook(collection_action::lookup);
return { this, inner::find(key) };
}
};
#pragma endregion
template<typename T, MapKind kind>
void test_map_concurrency()
{
auto raw = concurrency_checked_map<int, T>();
auto hook = raw.hook;
// Convert the raw_map into the desired Windows Runtime map interface.
auto m = make_threaded_map<kind>(std::move(raw));
auto race = [&](collection_action action, auto&& background, auto&& foreground)
{
// Map initial contents are [1] = 1, [2] = 2.
m.Clear();
m.Insert(1, conditional_box<T>(1));
m.Insert(2, conditional_box<T>(2));
hook->race(action, background, foreground);
};
// Verify that Insert does not run concurrently with HasKey().
race(collection_action::insert, [&]
{
m.Insert(42, conditional_box<T>(42));
}, [&]
{
CHECK(m.HasKey(2));
});
// Verify that Insert does not run concurrently with Lookup().
race(collection_action::insert, [&]
{
m.Insert(42, conditional_box<T>(42));
}, [&]
{
CHECK(conditional_unbox<T>(m.Lookup(2)));
});
// Verify that Insert does not run concurrently with another Insert().
race(collection_action::insert, [&]
{
m.Insert(43, conditional_box<T>(43));
}, [&]
{
m.Insert(43, conditional_box<T>(43));
});
// Iterator invalidation tests are a little different because we perform
// the mutation from the foreground thread after the read operation
// has begun on the background thread.
//
// Verify that iterator invalidation doesn't race against
// iterator use.
{
// Current vs Remove
IKeyValuePair<int, T> kvp;
race(collection_action::at, [&]
{
try
{
kvp = m.First().Current();
}
catch (hresult_error const&)
{
}
}, [&]
{
m.Remove(1);
});
CHECK((kvp && conditional_unbox<T>(kvp.Value()) == 1));
}
{
// MoveNext vs Remove
bool moved = false;
race(collection_action::at, [&]
{
try
{
moved = m.First().MoveNext();
}
catch (hresult_error const&)
{
}
}, [&]
{
m.Remove(1);
});
CHECK(moved);
}
{
// Current vs Insert
IKeyValuePair<int, T> kvp;
race(collection_action::at, [&]
{
try
{
kvp = m.First().Current();
}
catch (hresult_error const&)
{
}
}, [&]
{
m.Insert(42, conditional_box<T>(42));
});
CHECK((kvp && conditional_unbox<T>(kvp.Value()) == 1));
}
{
// MoveNext vs Insert
bool moved = false;
race(collection_action::at, [&]
{
try
{
moved = m.First().MoveNext();
}
catch (hresult_error const&)
{
}
}, [&]
{
m.Insert(42, conditional_box<T>(42));
});
CHECK(moved);
}
{
// Verify that concurrent iteration works via GetMany(), which is atomic.
// (Current + MoveNext is non-atomic and can result in two threads
// both grabbing the same Current and then moving two steps forward.)
decltype(m.First()) it;
IKeyValuePair<int, T> kvp1[1];
IKeyValuePair<int, T> kvp2[1];
race(collection_action::at, [&]
{
it = m.First();
CHECK(it.GetMany(kvp1) == 1);
}, [&]
{
CHECK(it.GetMany(kvp2) == 1);
});
CHECK(kvp1[0].Key() != kvp2[0].Key());
}
}
void deadlock_test()
{
auto m = make_threaded_map<MapKind::IMap>(concurrency_checked_map<int, IInspectable>());
m.Insert(0, make<deadlock_object<IMap<int, IInspectable>>>(m));
auto task = [](auto m)-> IAsyncAction
{
co_await resume_background();
m.Remove(0);
}(m);
auto status = task.wait_for(std::chrono::milliseconds(DEADLOCK_TIMEOUT));
REQUIRE(status == AsyncStatus::Completed);
}
}
TEST_CASE("multi_threaded_map")
{
test_map_concurrency<int, MapKind::IMap>();
test_map_concurrency<IInspectable, MapKind::IMap>();
deadlock_test();
}
TEST_CASE("multi_threaded_observable_map")
{
test_map_concurrency<int, MapKind::IObservableMap>();
test_map_concurrency<IInspectable, MapKind::IObservableMap>();
}