Skip to content

Commit 2aebcd8

Browse files
authored
Add FLAGS_force_sync_ops for executor (#68467)
* Add sync_op_after_launch config for executor * Update code * Update code * Update code
1 parent 1f9ff43 commit 2aebcd8

File tree

7 files changed

+127
-6
lines changed

7 files changed

+127
-6
lines changed

paddle/fluid/framework/new_executor/instruction/instruction_base.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,6 @@ InstructionBase::InstructionBase(size_t id, const phi::Place& place)
197197
no_need_buffer_values_() {
198198
id_ = id;
199199

200-
is_artificial_ = false;
201-
202200
if (phi::is_cpu_place(place)) {
203201
type_ = OpFuncType::kCpuSync;
204202
} else {

paddle/fluid/framework/new_executor/instruction/instruction_base.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class InstructionBase {
4343
bool IsArtificial() const { return is_artificial_; }
4444
void SetArtificial(bool is_artificial) { is_artificial_ = is_artificial; }
4545

46+
bool IsSyncAfterLaunch() const { return sync_after_launch_; }
47+
void SetSyncAfterLaunch(bool sync) { sync_after_launch_ = sync; }
48+
4649
OpFuncType KernelType() const;
4750
void SetKernelType(OpFuncType type) { type_ = type; }
4851

@@ -176,8 +179,12 @@ class InstructionBase {
176179
protected:
177180
size_t id_;
178181

179-
bool is_artificial_; // Instruction is artificial means that it is only used
180-
// to assist scheduling and no need to be executed.
182+
bool is_artificial_{
183+
false}; // Instruction is artificial means that it is only used
184+
// to assist scheduling and no need to be executed.
185+
186+
bool sync_after_launch_{false};
187+
181188
OpFuncType type_;
182189

183190
// dist attrs:lower value, higher priority

paddle/fluid/framework/new_executor/interpreter/execution_config.cc

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,25 @@
1717
#include <set>
1818
#include <thread>
1919

20+
#include "paddle/common/flags.h"
2021
#include "paddle/fluid/platform/device/ipu/ipu_info.h"
2122
#include "paddle/phi/backends/device_manager.h"
2223
#include "paddle/phi/backends/gpu/gpu_info.h"
2324
#include "paddle/phi/backends/xpu/xpu_info.h"
25+
#include "paddle/utils/string/string_helper.h"
26+
27+
// FLAGS_force_sync_ops is used to finer control the op-sync in executor.
28+
// The format is: "micro_batch_id, job_name, op_id, op_name | micro_batch_id,
29+
// job_name, op_id, op_name | ...". Keep spaces to syncs all name/id. Example:
30+
// 1. sync the recv_v2 op in the second backward-job of 1F1B scheduling:
31+
// FLAGS_force_sync_ops="1, backward, , recv_v2"
32+
// 2. sync the full op with op_id=5: FLAGS_force_sync_ops=" , , 5, full"
33+
// 3. sync all ops in the first default-job: FLAGS_force_sync_ops="0,default,,
34+
// 4. sync all ops in the forward-job and backward-job: FLAGS_force_sync_ops=" ,
35+
// forward, , | , backward, , , "
36+
PHI_DEFINE_EXPORTED_string(force_sync_ops,
37+
"",
38+
"Pattern to force sync ops in executor.");
2439

2540
PD_DECLARE_bool(new_executor_serial_run);
2641

@@ -149,4 +164,65 @@ void ExecutionConfig::Log(int log_level) {
149164
VLOG(log_level) << log_str.str();
150165
}
151166

167+
std::set<std::pair<int, std::string>> GetForceSyncOps(
168+
int micro_batch_id, const std::string& job_name) {
169+
std::set<std::pair<int, std::string>> force_sync_ops;
170+
std::stringstream ss(paddle::string::erase_spaces(FLAGS_force_sync_ops));
171+
std::string item;
172+
173+
while (std::getline(ss, item, '|')) {
174+
item += ","; // The comma at the end of the string will be ignored in
175+
// std::getline
176+
std::stringstream item_stream(item);
177+
std::vector<std::string> tokens;
178+
std::string token;
179+
while (std::getline(item_stream, token, ',')) {
180+
VLOG(1) << "token: " << token;
181+
tokens.push_back(token);
182+
}
183+
184+
PADDLE_ENFORCE_EQ(
185+
tokens.size(),
186+
4,
187+
phi::errors::InvalidArgument("Invalid force_sync_ops format: \"%s\", "
188+
"FLAGS_force_sync_ops=\"%s\"",
189+
item,
190+
FLAGS_force_sync_ops));
191+
192+
int micro_batch_id_;
193+
if (tokens[0] == "") {
194+
micro_batch_id_ = -1;
195+
} else {
196+
micro_batch_id_ = std::stoi(tokens[0]);
197+
}
198+
if (micro_batch_id_ != micro_batch_id && micro_batch_id_ != -1) {
199+
continue;
200+
}
201+
202+
if (tokens[1] != job_name && tokens[1] != "") {
203+
continue;
204+
}
205+
206+
int op_id;
207+
if (tokens[2] == "") {
208+
op_id = -1;
209+
} else {
210+
op_id = std::stoi(tokens[2]);
211+
}
212+
std::string op_name = tokens[3];
213+
force_sync_ops.insert({op_id, op_name});
214+
}
215+
216+
if (!force_sync_ops.empty()) {
217+
std::stringstream ss;
218+
ss << "job_name: " << job_name << ", micro_batch_id: " << micro_batch_id
219+
<< ", force_sync_ops: ";
220+
for (auto& pair : force_sync_ops) {
221+
ss << "(" << pair.first << ", " << pair.second << ") ";
222+
}
223+
VLOG(6) << ss.str();
224+
}
225+
return force_sync_ops;
226+
}
227+
152228
} // namespace paddle::framework::interpreter

paddle/fluid/framework/new_executor/interpreter/execution_config.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ namespace interpreter {
2626

2727
struct ExecutionConfig {
2828
bool create_local_scope{true};
29-
3029
bool used_for_cinn{false};
3130
bool used_for_control_flow_op{false};
3231
bool used_for_jit{false};
@@ -35,6 +34,10 @@ struct ExecutionConfig {
3534
size_t device_num_threads{0};
3635
size_t host_num_threads{0};
3736

37+
std::set<std::pair<int, std::string>>
38+
force_sync_ops; // set{pair<op_id, name>}, -1 matches any op_id, ""
39+
// matches any name
40+
3841
std::set<std::string> force_root_scope_vars;
3942
std::set<std::string> jit_input_vars;
4043
std::set<std::string> skip_gc_vars;
@@ -43,6 +46,9 @@ struct ExecutionConfig {
4346
void Log(int log_level);
4447
};
4548

49+
std::set<std::pair<int, std::string>> GetForceSyncOps(
50+
int micro_batch_id, const std::string& job_name);
51+
4652
} // namespace interpreter
4753
} // namespace framework
4854
} // namespace paddle

paddle/fluid/framework/new_executor/pir_interpreter.cc

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,33 @@ void PirInterpreter::AnalyseExecuteOrderForTrace(
765765
}
766766
}
767767

768+
void PirInterpreter::AnalyzeForceSyncOps() {
769+
for (auto& ins : vec_instruction_base_) {
770+
ins->SetSyncAfterLaunch(FLAGS_benchmark);
771+
772+
// Analyze force sync op set by FLAGS_force_sync_op
773+
int op_id = ins->Id();
774+
std::string op_name = ins->Name();
775+
std::string unused_prefix = "pd_op.";
776+
auto pos = op_name.find(unused_prefix);
777+
if (pos != std::string::npos) {
778+
op_name.erase(pos, unused_prefix.size());
779+
}
780+
781+
for (auto& pair : execution_config_.force_sync_ops) {
782+
int sync_op_id = pair.first;
783+
std::string sync_op_name = pair.second;
784+
if ((sync_op_id == op_id || sync_op_id == -1) &&
785+
(sync_op_name == op_name || sync_op_name == "")) {
786+
VLOG(8) << "Force sync op: "
787+
<< "sync_op_id=" << sync_op_id << ", op_id=" << op_id
788+
<< ", sync_op_name=" << sync_op_name << ", op_name=" << op_name;
789+
ins->SetSyncAfterLaunch(true);
790+
}
791+
}
792+
}
793+
}
794+
768795
void PirInterpreter::BuildInstruction() {
769796
VLOG(6) << "Build Instructions for pir ... ";
770797
vec_instruction_base_.clear();
@@ -1900,7 +1927,7 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
19001927
instr_node->Run();
19011928
}
19021929

1903-
if (FLAGS_benchmark) {
1930+
if (instr_node->IsSyncAfterLaunch()) {
19041931
instr_node->DeviceContext().Wait();
19051932
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
19061933
PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
@@ -2003,6 +2030,9 @@ void PirInterpreter::PreAnalysis() {
20032030
ir_instruction_scheduling_priority_less);
20042031
VLOG(4) << "Done AnalyseExecuteOrderForTrace";
20052032

2033+
AnalyzeForceSyncOps();
2034+
VLOG(4) << "Done AnalyzeForceSyncOps";
2035+
20062036
UpdateSyncOpNum();
20072037
VLOG(4) << "Done UpdateSyncOpNum";
20082038

paddle/fluid/framework/new_executor/pir_interpreter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,11 @@ class PirInterpreter : public InterpreterBaseImpl {
129129
void UpdateSyncOpNum();
130130
void UpdateNcclOpNum();
131131
void UpdateOneDNNOpNum();
132+
132133
void AnalyseExecuteOrderForTrace(
133134
std::map<size_t, std::set<size_t>> op_downstream_map,
134135
InstructionSchedulingPriorityLess compare);
136+
void AnalyzeForceSyncOps();
135137
void ConstructEventForJitInput();
136138
void CalculateLastLiveOps();
137139

paddle/fluid/framework/new_executor/standalone_executor.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ StandaloneExecutor::StandaloneExecutor(const phi::Place& place,
8787
interpreter::ExecutionConfig execution_config;
8888
execution_config.create_local_scope = false;
8989
execution_config.skip_gc_vars = job->SkipGcVars();
90+
execution_config.force_sync_ops =
91+
interpreter::GetForceSyncOps(micro_batch_id, job_type);
9092

9193
// TODO(phlrain) we only support cpu for now
9294
if (FLAGS_enable_pir_in_executor) {

0 commit comments

Comments
 (0)