Skip to content

Commit f581f5b

Browse files
authored
[new-exec] fix bug that no thread is waked up when adding task to threadpool (#41567)
* fix bug that no thread is waked up when adding task to threadpool * fix typo
1 parent b3e7973 commit f581f5b

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

paddle/fluid/framework/new_executor/interpretercore_util.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ constexpr size_t kPrepareWorkQueueIdx = 2;
3939

4040
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
4141
std::function<void()> fn) {
42+
VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " ";
4243
// NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
4344
if (FLAGS_new_executor_sequential_run) {
4445
VLOG(4) << "FLAGS_new_executor_sequential_run:"

paddle/fluid/framework/new_executor/workqueue/event_count.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
#include <cstdlib>
5555
#include <mutex>
5656
#include <vector>
57+
#include "glog/logging.h"
5758

5859
namespace paddle {
5960
namespace framework {
@@ -255,6 +256,7 @@ class EventCount {
255256
std::unique_lock<std::mutex> lock(w->mu);
256257
while (w->state != Waiter::kSignaled) {
257258
w->state = Waiter::kWaiting;
259+
VLOG(10) << "Go to wait " << &(w->cv);
258260
w->cv.wait(lock);
259261
}
260262
}
@@ -270,7 +272,10 @@ class EventCount {
270272
w->state = Waiter::kSignaled;
271273
}
272274
// Avoid notifying if it wasn't waiting.
273-
if (state == Waiter::kWaiting) w->cv.notify_one();
275+
if (state == Waiter::kWaiting) {
276+
VLOG(10) << "Go to notify " << &(w->cv);
277+
w->cv.notify_one();
278+
}
274279
}
275280
}
276281
};

paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ class ThreadPoolTempl {
5353
all_coprimes_.reserve(num_threads_);
5454
for (int i = 1; i <= num_threads_; ++i) {
5555
all_coprimes_.emplace_back();
56-
all_coprimes_.back().push_back(i);
5756
ComputeCoprimes(i, &(all_coprimes_.back()));
5857
}
5958
for (int i = 0; i < num_threads_; i++) {
@@ -130,8 +129,11 @@ class ThreadPoolTempl {
130129
// this. We expect that such scenario is prevented by program, that is,
131130
// this is kept alive while any threads can potentially be in Schedule.
132131
if (!t.f) {
133-
if (num_tasks > num_threads_ - blocked_.load(std::memory_order_relaxed)) {
132+
if (num_tasks > num_threads_ - blocked_) {
133+
VLOG(6) << "Add task, Notify";
134134
ec_.Notify(false);
135+
} else {
136+
VLOG(6) << "Add task, No Notify";
135137
}
136138
} else {
137139
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
@@ -376,17 +378,21 @@ class ThreadPoolTempl {
376378
ec_.CancelWait();
377379
return false;
378380
}
381+
382+
// Number of blocked threads is used as termination condition.
383+
// If we are shutting down and all worker threads blocked without work,
384+
// that's we are done.
385+
blocked_++;
386+
379387
// Now do a reliable emptiness check.
380388
int victim = NonEmptyQueueIndex();
381389
if (victim != -1) {
382390
ec_.CancelWait();
383391
*t = thread_data_[victim].queue.PopBack();
392+
blocked_--;
384393
return true;
385394
}
386-
// Number of blocked threads is used as termination condition.
387-
// If we are shutting down and all worker threads blocked without work,
388-
// that's we are done.
389-
blocked_++;
395+
390396
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
391397
ec_.CancelWait();
392398
// Almost done, but need to re-check queues.

0 commit comments

Comments
 (0)