Skip to content

Commit 2886ae2

Browse files
committed
Implement buffer locking mechanism
The buffer manager introduced two issues (documented in the code) that could cause certain buffer access patterns to become unsafe. This adds a coarse grained buffer locking mechanism to circumvent these issues.
1 parent 097c3b6 commit 2886ae2

File tree

8 files changed

+239
-13
lines changed

8 files changed

+239
-13
lines changed

include/buffer_manager.h

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "ranges.h"
1717
#include "region_map.h"
1818
#include "types.h"
19+
#include <unordered_set>
1920

2021
namespace celerity {
2122
namespace detail {
@@ -62,7 +63,7 @@ namespace detail {
6263
* Essentially, this means that any requests made to the buffer_manager are assumed to be operations
6364
* that are currently allowed by the command graph.
6465
*
65-
* FIXME: There are two important caveats that we need to deal with:
66+
* There are two important caveats that we need to deal with:
6667
*
6768
* - Reading from a buffer is no longer a const operation, as the buffer may need to be resized.
6869
* This means that two tasks that could be considered independent on a TDAG basis actually have an
@@ -75,6 +76,11 @@ namespace detail {
7576
* buffer first with "discard_write" and followed by a "read" should result in a combined "write" mode.
7677
* However the effect of the discard_write is recorded immediately, and the buffer_manager will thus
7778
* wrongly assume that no coherence update for the "read" is required.
79+
*
80+
* Currently, these issues are handled through the buffer locking mechanism.
81+
* See buffer_manager::try_lock, buffer_manager::unlock and buffer_manager::is_locked.
82+
*
83+
* FIXME: The current buffer locking mechanism limits task parallelism. Come up with a better solution.
7884
*/
7985
class buffer_manager {
8086
public:
@@ -104,6 +110,8 @@ namespace detail {
104110
cl::sycl::id<Dims> offset;
105111
};
106112

113+
using buffer_lock_id = size_t;
114+
107115
public:
108116
buffer_manager(device_queue& queue, buffer_lifecycle_callback lifecycle_cb);
109117

@@ -212,6 +220,8 @@ namespace detail {
212220
}
213221
}
214222

223+
audit_buffer_access(bid, new_buffer.is_allocated(), mode);
224+
215225
backing_buffer& target_buffer = new_buffer.is_allocated() ? new_buffer : old_buffer;
216226
const backing_buffer empty{};
217227
const backing_buffer& previous_buffer = new_buffer.is_allocated() ? old_buffer : empty;
@@ -242,6 +252,8 @@ namespace detail {
242252
}
243253
}
244254

255+
audit_buffer_access(bid, new_buffer.is_allocated(), mode);
256+
245257
backing_buffer& target_buffer = new_buffer.is_allocated() ? new_buffer : old_buffer;
246258
const backing_buffer empty{};
247259
const backing_buffer& previous_buffer = new_buffer.is_allocated() ? old_buffer : empty;
@@ -253,6 +265,32 @@ namespace detail {
253265
id_cast<Dims>(buffers[bid].host_buf.offset)};
254266
}
255267

268+
/**
269+
* @brief Tries to lock the given list of @p buffers using the given lock @p id.
270+
*
271+
* If any of the buffers is currently locked, the locking attempt fails.
272+
*
273+
* Locking is currently an optional (opt-in) mechanism, i.e., buffers can also be
274+
* accessed without being locked. This is because locking is a bit of a band-aid fix
275+
* that doesn't properly cover all use-cases (for example, host-pointer initialized buffers).
276+
*
277+
* However, when accessing a locked buffer, the buffer_manager enforces additional
278+
* rules to ensure they are used in a safe manner for the duration of the lock:
279+
* - A locked buffer may only be resized at most once, and only for the first access.
280+
* - A locked buffer may not be accessed using consumer access modes, if it was previously
281+
* accessed using a pure producer mode.
282+
*
283+
* @returns Returns true if the list of buffers was successfully locked.
284+
*/
285+
bool try_lock(buffer_lock_id, const std::unordered_set<buffer_id>& buffers);
286+
287+
/**
288+
* Unlocks all buffers that were previously locked with a call to try_lock with the given @p id.
289+
*/
290+
void unlock(buffer_lock_id id);
291+
292+
bool is_locked(buffer_id bid) const;
293+
256294
private:
257295
struct backing_buffer {
258296
std::unique_ptr<buffer_storage> storage = nullptr;
@@ -302,6 +340,15 @@ namespace detail {
302340
struct buffer_type_guard : buffer_type_guard_base {};
303341
#endif
304342

343+
struct buffer_lock_info {
344+
bool is_locked = false;
345+
346+
// For lack of a better name, this stores *an* access mode that has already been used during this lock.
347+
// While it initially stores whatever is first used to access the buffer, it will always be overwritten
348+
// by subsequent pure producer accesses, as those are the only ones we really care about.
349+
std::optional<cl::sycl::access::mode> earlier_access_mode = std::nullopt;
350+
};
351+
305352
private:
306353
device_queue& queue;
307354
buffer_lifecycle_callback lifecycle_cb;
@@ -312,6 +359,9 @@ namespace detail {
312359
std::unordered_map<buffer_id, std::vector<transfer>> scheduled_transfers;
313360
std::unordered_map<buffer_id, region_map<data_location>> newest_data_location;
314361

362+
std::unordered_map<buffer_id, buffer_lock_info> buffer_lock_infos;
363+
std::unordered_map<buffer_lock_id, std::vector<buffer_id>> buffer_locks_by_id;
364+
315365
#if !defined(NDEBUG)
316366
// Since we store buffers without type information (i.e., its data type and dimensionality),
317367
// it is the user's responsibility to only request access to a buffer using the correct type.
@@ -356,6 +406,15 @@ namespace detail {
356406
*/
357407
void make_buffer_subrange_coherent(buffer_id bid, cl::sycl::access::mode mode, backing_buffer& target_buffer, const subrange<3>& coherent_sr,
358408
const backing_buffer& previous_buffer = backing_buffer{});
409+
410+
/**
411+
* Checks whether access to a currently locked buffer is safe.
412+
*
413+
* There's two distinct issues that can cause an access to be unsafe:
414+
* - If a buffer that has been accessed earlier needs to be resized (reallocated) now
415+
* - If a buffer was previously accessed using a discard_* mode and is now accessed using a consumer mode
416+
*/
417+
void audit_buffer_access(buffer_id bid, bool requires_allocation, cl::sycl::access::mode mode);
359418
};
360419

361420
} // namespace detail

include/executor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <chrono>
44
#include <thread>
55

6+
#include "buffer_manager.h"
67
#include "buffer_transfer_manager.h"
78
#include "logger.h"
89
#include "worker_job.h"
@@ -41,7 +42,7 @@ namespace detail {
4142
class executor {
4243
public:
4344
// TODO: Try to decouple this more.
44-
executor(host_queue& h_queue, device_queue& d_queue, task_manager& tm, std::shared_ptr<logger> execution_logger);
45+
executor(host_queue& h_queue, device_queue& d_queue, task_manager& tm, buffer_manager& buffer_mngr, std::shared_ptr<logger> execution_logger);
4546

4647
void startup();
4748

@@ -59,6 +60,8 @@ namespace detail {
5960
host_queue& h_queue;
6061
device_queue& d_queue;
6162
task_manager& task_mngr;
63+
// FIXME: We currently need this for buffer locking in some jobs, which is a bit of a band-aid fix. Get rid of this at some point.
64+
buffer_manager& buffer_mngr;
6265
std::unique_ptr<buffer_transfer_manager> btm;
6366
std::shared_ptr<logger> execution_logger;
6467
std::thread exec_thrd;

include/worker_job.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <limits>
77
#include <utility>
88

9+
#include "buffer_manager.h"
910
#include "buffer_transfer_manager.h"
1011
#include "command.h"
1112
#include "host_queue.h"
@@ -84,12 +85,14 @@ namespace detail {
8485

8586
class push_job : public worker_job {
8687
public:
87-
push_job(command_pkg pkg, std::shared_ptr<logger> job_logger, buffer_transfer_manager& btm) : worker_job(pkg, job_logger), btm(btm) {
88+
push_job(command_pkg pkg, std::shared_ptr<logger> job_logger, buffer_transfer_manager& btm, buffer_manager& bm)
89+
: worker_job(pkg, job_logger), btm(btm), buffer_mngr(bm) {
8890
assert(pkg.cmd == command_type::PUSH);
8991
}
9092

9193
private:
9294
buffer_transfer_manager& btm;
95+
buffer_manager& buffer_mngr;
9396
std::shared_ptr<const buffer_transfer_manager::transfer_handle> data_handle = nullptr;
9497

9598
bool execute(const command_pkg& pkg, std::shared_ptr<logger> logger) override;
@@ -99,14 +102,15 @@ namespace detail {
99102
// host-compute jobs, master-node tasks and collective host tasks
100103
class host_execute_job : public worker_job {
101104
public:
102-
host_execute_job(command_pkg pkg, std::shared_ptr<logger> job_logger, detail::host_queue& queue, detail::task_manager& tm)
103-
: worker_job(pkg, job_logger), queue(queue), task_mngr(tm) {
105+
host_execute_job(command_pkg pkg, std::shared_ptr<logger> job_logger, detail::host_queue& queue, detail::task_manager& tm, buffer_manager& bm)
106+
: worker_job(pkg, job_logger), queue(queue), task_mngr(tm), buffer_mngr(bm) {
104107
assert(pkg.cmd == command_type::TASK);
105108
}
106109

107110
private:
108111
detail::host_queue& queue;
109112
detail::task_manager& task_mngr;
113+
detail::buffer_manager& buffer_mngr;
110114
std::future<detail::host_queue::execution_info> future;
111115
bool submitted = false;
112116

@@ -120,14 +124,15 @@ namespace detail {
120124
*/
121125
class device_execute_job : public worker_job {
122126
public:
123-
device_execute_job(command_pkg pkg, std::shared_ptr<logger> job_logger, detail::device_queue& queue, detail::task_manager& tm)
124-
: worker_job(pkg, job_logger), queue(queue), task_mngr(tm) {
127+
device_execute_job(command_pkg pkg, std::shared_ptr<logger> job_logger, detail::device_queue& queue, detail::task_manager& tm, buffer_manager& bm)
128+
: worker_job(pkg, job_logger), queue(queue), task_mngr(tm), buffer_mngr(bm) {
125129
assert(pkg.cmd == command_type::TASK);
126130
}
127131

128132
private:
129133
detail::device_queue& queue;
130134
detail::task_manager& task_mngr;
135+
detail::buffer_manager& buffer_mngr;
131136
cl::sycl::event event;
132137
bool submitted = false;
133138

src/buffer_manager.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,32 @@ namespace detail {
7373
scheduled_transfers[bid].push_back({std::move(data), offset});
7474
}
7575

76+
bool buffer_manager::try_lock(buffer_lock_id id, const std::unordered_set<buffer_id>& buffers) {
77+
assert(buffer_locks_by_id.count(id) == 0);
78+
for(auto bid : buffers) {
79+
if(buffer_lock_infos[bid].is_locked) return false;
80+
}
81+
buffer_locks_by_id[id].reserve(buffers.size());
82+
for(auto bid : buffers) {
83+
buffer_lock_infos[bid] = {true, std::nullopt};
84+
buffer_locks_by_id[id].push_back(bid);
85+
}
86+
return true;
87+
}
88+
89+
void buffer_manager::unlock(buffer_lock_id id) {
90+
assert(buffer_locks_by_id.count(id) != 0);
91+
for(auto bid : buffer_locks_by_id[id]) {
92+
buffer_lock_infos[bid] = {};
93+
}
94+
buffer_locks_by_id.erase(id);
95+
}
96+
97+
bool buffer_manager::is_locked(buffer_id bid) const {
98+
if(buffer_lock_infos.count(bid) == 0) return false;
99+
return buffer_lock_infos.at(bid).is_locked;
100+
}
101+
76102
// TODO: Something we could look into is to dispatch all memory copies concurrently and wait for them in the end.
77103
void buffer_manager::make_buffer_subrange_coherent(
78104
buffer_id bid, cl::sycl::access::mode mode, backing_buffer& target_buffer, const subrange<3>& coherent_sr, const backing_buffer& previous_buffer) {
@@ -218,5 +244,34 @@ namespace detail {
218244
if(detail::access::mode_traits::is_producer(mode)) { newest_data_location.at(bid).update_region(coherent_box, target_buffer_location); }
219245
}
220246

247+
void buffer_manager::audit_buffer_access(buffer_id bid, bool requires_allocation, cl::sycl::access::mode mode) {
248+
auto& lock_info = buffer_lock_infos[bid];
249+
250+
// Buffer locking is currently opt-in, so if this buffer isn't locked, we won't check anything else.
251+
if(!lock_info.is_locked) return;
252+
253+
if(lock_info.earlier_access_mode == std::nullopt) {
254+
// First access, all good.
255+
lock_info.earlier_access_mode = mode;
256+
return;
257+
}
258+
259+
if(requires_allocation) {
260+
// Re-allocation of a buffer that is currently being accessed never works.
261+
throw std::runtime_error("You are requesting multiple accessors for the same buffer, with later ones requiring a larger part of the buffer, "
262+
"causing a backing buffer reallocation. "
263+
"This is currently unsupported. Try changing the order of your calls to buffer::get_access.");
264+
}
265+
266+
if(!access::mode_traits::is_consumer(*lock_info.earlier_access_mode) && access::mode_traits::is_consumer(mode)) {
267+
// Accessing a buffer using a pure producer mode followed by a consumer mode breaks our coherence bookkeeping.
268+
throw std::runtime_error("You are requesting multiple accessors for the same buffer, using a discarding access mode first, followed by a "
269+
"non-discarding mode. This is currently unsupported. Try changing the order of your calls to buffer::get_access.");
270+
}
271+
272+
// We only need to remember pure producer accesses.
273+
if(!access::mode_traits::is_consumer(mode)) { lock_info.earlier_access_mode = mode; }
274+
}
275+
221276
} // namespace detail
222277
} // namespace celerity

src/executor.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ namespace detail {
2222
running = false;
2323
}
2424

25-
executor::executor(host_queue& h_queue, device_queue& d_queue, task_manager& tm, std::shared_ptr<logger> execution_logger)
26-
: h_queue(h_queue), d_queue(d_queue), task_mngr(tm), execution_logger(execution_logger) {
25+
executor::executor(host_queue& h_queue, device_queue& d_queue, task_manager& tm, buffer_manager& buffer_mngr, std::shared_ptr<logger> execution_logger)
26+
: h_queue(h_queue), d_queue(d_queue), task_mngr(tm), buffer_mngr(buffer_mngr), execution_logger(execution_logger) {
2727
btm = std::make_unique<buffer_transfer_manager>(execution_logger);
2828
metrics.initial_idle.resume();
2929
}
@@ -155,7 +155,7 @@ namespace detail {
155155
bool executor::handle_command(const command_pkg& pkg, const std::vector<command_id>& dependencies) {
156156
switch(pkg.cmd) {
157157
case command_type::HORIZON: create_job<horizon_job>(pkg, dependencies); break;
158-
case command_type::PUSH: create_job<push_job>(pkg, dependencies, *btm); break;
158+
case command_type::PUSH: create_job<push_job>(pkg, dependencies, *btm, buffer_mngr); break;
159159
case command_type::AWAIT_PUSH: create_job<await_push_job>(pkg, dependencies, *btm); break;
160160
case command_type::TASK: {
161161
const auto& data = std::get<task_data>(pkg.data);
@@ -165,9 +165,9 @@ namespace detail {
165165

166166
auto tsk = task_mngr.get_task(data.tid);
167167
if(tsk->get_execution_target() == execution_target::HOST) {
168-
create_job<host_execute_job>(pkg, dependencies, h_queue, task_mngr);
168+
create_job<host_execute_job>(pkg, dependencies, h_queue, task_mngr, buffer_mngr);
169169
} else {
170-
create_job<device_execute_job>(pkg, dependencies, d_queue, task_mngr);
170+
create_job<device_execute_job>(pkg, dependencies, d_queue, task_mngr, buffer_mngr);
171171
}
172172
break;
173173
}

src/runtime.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ namespace detail {
109109
}
110110
});
111111
task_mngr = std::make_unique<task_manager>(num_nodes, h_queue.get(), is_master);
112-
exec = std::make_unique<executor>(*h_queue, *d_queue, *task_mngr, default_logger);
112+
exec = std::make_unique<executor>(*h_queue, *d_queue, *task_mngr, *buffer_mngr, default_logger);
113113
if(is_master) {
114114
cdag = std::make_unique<command_graph>();
115115
ggen = std::make_shared<graph_generator>(num_nodes, *task_mngr, *cdag);

src/worker_job.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,18 @@ namespace detail {
8181

8282
bool push_job::execute(const command_pkg& pkg, std::shared_ptr<logger> logger) {
8383
if(data_handle == nullptr) {
84+
const auto data = std::get<push_data>(pkg.data);
85+
// Getting buffer data from the buffer manager may incur a host-side buffer reallocation.
86+
// If any other tasks are currently using this buffer for reading, we run into problems.
87+
// To avoid this, we use a very crude buffer locking mechanism for now.
88+
// FIXME: Get rid of this, replace with finer grained approach.
89+
if(buffer_mngr.is_locked(data.bid)) { return false; }
90+
8491
logger->trace(logger_map({{"event", "Submit buffer to BTM"}}));
8592
data_handle = btm.push(pkg);
8693
logger->trace(logger_map({{"event", "Buffer submitted to BTM"}}));
8794
}
95+
8896
return data_handle->complete;
8997
}
9098

@@ -102,6 +110,9 @@ namespace detail {
102110
if(!submitted) {
103111
auto tsk = task_mngr.get_task(data.tid);
104112
assert(tsk->get_execution_target() == execution_target::HOST);
113+
114+
if(!buffer_mngr.try_lock(pkg.cid, tsk->get_buffer_access_map().get_accessed_buffers())) { return false; }
115+
105116
logger->trace(logger_map({{"event", "Execute live-pass, scheduling host task in thread pool"}}));
106117

107118
// Note that for host tasks, there is no indirection through a queue->submit step like there is for SYCL tasks. The CGF is executed directly,
@@ -118,6 +129,8 @@ namespace detail {
118129

119130
assert(future.valid());
120131
if(future.wait_for(std::chrono::seconds(0)) == std::future_status::ready) {
132+
buffer_mngr.unlock(pkg.cid);
133+
121134
auto info = future.get();
122135
logger->trace(logger_map({{"event", fmt::format("Delta time submit -> start: {}us",
123136
std::chrono::duration_cast<std::chrono::microseconds>(info.start_time - info.submit_time).count())}}));
@@ -153,6 +166,9 @@ namespace detail {
153166
if(!submitted) {
154167
auto tsk = task_mngr.get_task(data.tid);
155168
assert(tsk->get_execution_target() == execution_target::DEVICE);
169+
170+
if(!buffer_mngr.try_lock(pkg.cid, tsk->get_buffer_access_map().get_accessed_buffers())) { return false; }
171+
156172
logger->trace(logger_map({{"event", "Execute live-pass, submit kernel to SYCL"}}));
157173

158174
event = queue.submit([tsk, sr = data.sr](cl::sycl::handler& handler, size_t forced_work_group_size) {
@@ -167,6 +183,8 @@ namespace detail {
167183

168184
const auto status = event.get_info<cl::sycl::info::event::command_execution_status>();
169185
if(status == cl::sycl::info::event_command_status::complete) {
186+
buffer_mngr.unlock(pkg.cid);
187+
170188
#if !WORKAROUND(HIPSYCL, 0)
171189
if(queue.is_profiling_enabled()) {
172190
const auto queued = get_profiling_info(event.get(), CL_PROFILING_COMMAND_QUEUED);

0 commit comments

Comments
 (0)