Skip to content

Commit 097c3b6

Browse files
committed
Move task wait mechanism from worker jobs to executor
Since we now need to access the task object from within the exector to decide what job type to spawn, we need to be sure the task exists at that point already.
1 parent 436f4c9 commit 097c3b6

File tree

4 files changed

+13
-25
lines changed

4 files changed

+13
-25
lines changed

include/executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ namespace detail {
9797
}
9898

9999
void run();
100-
void handle_command(const command_pkg& pkg, const std::vector<command_id>& dependencies);
100+
bool handle_command(const command_pkg& pkg, const std::vector<command_id>& dependencies);
101101

102102
void update_metrics();
103103
};

include/worker_job.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ namespace detail {
108108
detail::host_queue& queue;
109109
detail::task_manager& task_mngr;
110110
std::future<detail::host_queue::execution_info> future;
111-
bool did_log_task_wait = false;
112111
bool submitted = false;
113112

114113
bool execute(const command_pkg& pkg, std::shared_ptr<logger> logger) override;
@@ -130,7 +129,6 @@ namespace detail {
130129
detail::device_queue& queue;
131130
detail::task_manager& task_mngr;
132131
cl::sycl::event event;
133-
bool did_log_task_wait = false;
134132
bool submitted = false;
135133

136134
std::future<void> computecpp_workaround_future;

src/executor.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,18 @@ namespace detail {
132132

133133
if(syncing_on_id == NOT_SYNCING && jobs.size() < MAX_CONCURRENT_JOBS && !command_queue.empty()) {
134134
const auto info = command_queue.front();
135-
command_queue.pop();
136135
if(info.pkg.cmd == command_type::SHUTDOWN) {
137-
assert(command_queue.empty());
136+
assert(command_queue.size() == 1);
138137
done = true;
139138
} else if(info.pkg.cmd == command_type::SYNC) {
140139
syncing_on_id = std::get<sync_data>(info.pkg.data).sync_id;
141140
} else {
142-
handle_command(info.pkg, info.dependencies);
141+
if(!handle_command(info.pkg, info.dependencies)) {
142+
// In case the command couldn't be handled, don't pop it from the queue.
143+
continue;
144+
}
143145
}
146+
command_queue.pop();
144147
}
145148

146149
if(first_command_received) { update_metrics(); }
@@ -149,25 +152,28 @@ namespace detail {
149152
assert(running_device_compute_jobs == 0);
150153
}
151154

152-
void executor::handle_command(const command_pkg& pkg, const std::vector<command_id>& dependencies) {
155+
bool executor::handle_command(const command_pkg& pkg, const std::vector<command_id>& dependencies) {
153156
switch(pkg.cmd) {
154157
case command_type::HORIZON: create_job<horizon_job>(pkg, dependencies); break;
155158
case command_type::PUSH: create_job<push_job>(pkg, dependencies, *btm); break;
156159
case command_type::AWAIT_PUSH: create_job<await_push_job>(pkg, dependencies, *btm); break;
157160
case command_type::TASK: {
158161
const auto& data = std::get<task_data>(pkg.data);
162+
163+
// A bit of a hack: We cannot be sure the main thread has reached the task definition yet, so we have to check it here
164+
if(!task_mngr.has_task(data.tid)) { return false; }
165+
159166
auto tsk = task_mngr.get_task(data.tid);
160167
if(tsk->get_execution_target() == execution_target::HOST) {
161168
create_job<host_execute_job>(pkg, dependencies, h_queue, task_mngr);
162-
break;
163169
} else {
164170
create_job<device_execute_job>(pkg, dependencies, d_queue, task_mngr);
165-
break;
166171
}
167172
break;
168173
}
169174
default: assert(!"Unexpected command");
170175
}
176+
return true;
171177
}
172178

173179
void executor::update_metrics() {

src/worker_job.cc

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,6 @@ namespace detail {
9898

9999
bool host_execute_job::execute(const command_pkg& pkg, std::shared_ptr<logger> logger) {
100100
const auto data = std::get<task_data>(pkg.data);
101-
// A bit of a hack: We cannot be sure the main thread has reached the task definition yet, so we have to check it here
102-
if(!task_mngr.has_task(data.tid)) {
103-
if(!did_log_task_wait) {
104-
logger->trace(logger_map({{"event", "Waiting for task definition"}}));
105-
did_log_task_wait = true;
106-
}
107-
return false;
108-
}
109101

110102
if(!submitted) {
111103
auto tsk = task_mngr.get_task(data.tid);
@@ -157,14 +149,6 @@ namespace detail {
157149

158150
bool device_execute_job::execute(const command_pkg& pkg, std::shared_ptr<logger> logger) {
159151
const auto data = std::get<task_data>(pkg.data);
160-
// A bit of a hack: We cannot be sure the main thread has reached the task definition yet, so we have to check it here
161-
if(!task_mngr.has_task(data.tid)) {
162-
if(!did_log_task_wait) {
163-
logger->trace(logger_map({{"event", "Waiting for task definition"}}));
164-
did_log_task_wait = true;
165-
}
166-
return false;
167-
}
168152

169153
if(!submitted) {
170154
auto tsk = task_mngr.get_task(data.tid);

0 commit comments

Comments
 (0)