Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,6 @@ InstructionBase::InstructionBase(size_t id, const phi::Place& place)
no_need_buffer_values_() {
id_ = id;

is_artificial_ = false;

if (phi::is_cpu_place(place)) {
type_ = OpFuncType::kCpuSync;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class InstructionBase {
bool IsArtificial() const { return is_artificial_; }
void SetArtificial(bool is_artificial) { is_artificial_ = is_artificial; }

bool IsSyncAfterLaunch() const { return sync_after_launch_; }
void SetSyncAfterLaunch(bool sync) { sync_after_launch_ = sync; }

OpFuncType KernelType() const;
void SetKernelType(OpFuncType type) { type_ = type; }

Expand Down Expand Up @@ -176,8 +179,12 @@ class InstructionBase {
protected:
size_t id_;

bool is_artificial_; // Instruction is artificial means that it is only used
// to assist scheduling and no need to be executed.
bool is_artificial_{
false}; // Instruction is artificial means that it is only used
// to assist scheduling and no need to be executed.

bool sync_after_launch_{false};

OpFuncType type_;

// dist attrs:lower value, higher priority
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,25 @@
#include <set>
#include <thread>

#include "paddle/common/flags.h"
#include "paddle/fluid/platform/device/ipu/ipu_info.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/xpu/xpu_info.h"
#include "paddle/utils/string/string_helper.h"

// FLAGS_force_sync_ops is used to finer control the op-sync in executor.
// The format is: "micro_batch_id, job_name, op_id, op_name | micro_batch_id,
// job_name, op_id, op_name | ...". Keep spaces to syncs all name/id. Example:
// 1. sync the recv_v2 op in the second backward-job of 1F1B scheduling:
// FLAGS_force_sync_ops="1, backward, , recv_v2"
// 2. sync the full op with op_id=5: FLAGS_force_sync_ops=" , , 5, full"
// 3. sync all ops in the first default-job: FLAGS_force_sync_ops="0,default,,
// 4. sync all ops in the forward-job and backward-job: FLAGS_force_sync_ops=" ,
// forward, , | , backward, , , "
PHI_DEFINE_EXPORTED_string(force_sync_ops,
"",
"Pattern to force sync ops in executor.");

PD_DECLARE_bool(new_executor_serial_run);

Expand Down Expand Up @@ -149,4 +164,65 @@ void ExecutionConfig::Log(int log_level) {
VLOG(log_level) << log_str.str();
}

std::set<std::pair<int, std::string>> GetForceSyncOps(
int micro_batch_id, const std::string& job_name) {
std::set<std::pair<int, std::string>> force_sync_ops;
std::stringstream ss(paddle::string::erase_spaces(FLAGS_force_sync_ops));
std::string item;

while (std::getline(ss, item, '|')) {
item += ","; // The comma at the end of the string will be ignored in
// std::getline
std::stringstream item_stream(item);
std::vector<std::string> tokens;
std::string token;
while (std::getline(item_stream, token, ',')) {
VLOG(1) << "token: " << token;
tokens.push_back(token);
}

PADDLE_ENFORCE_EQ(
tokens.size(),
4,
phi::errors::InvalidArgument("Invalid force_sync_ops format: \"%s\", "
"FLAGS_force_sync_ops=\"%s\"",
item,
FLAGS_force_sync_ops));

int micro_batch_id_;
if (tokens[0] == "") {
micro_batch_id_ = -1;
} else {
micro_batch_id_ = std::stoi(tokens[0]);
}
if (micro_batch_id_ != micro_batch_id && micro_batch_id_ != -1) {
continue;
}

if (tokens[1] != job_name && tokens[1] != "") {
continue;
}

int op_id;
if (tokens[2] == "") {
op_id = -1;
} else {
op_id = std::stoi(tokens[2]);
}
std::string op_name = tokens[3];
force_sync_ops.insert({op_id, op_name});
}

if (!force_sync_ops.empty()) {
std::stringstream ss;
ss << "job_name: " << job_name << ", micro_batch_id: " << micro_batch_id
<< ", force_sync_ops: ";
for (auto& pair : force_sync_ops) {
ss << "(" << pair.first << ", " << pair.second << ") ";
}
VLOG(6) << ss.str();
}
return force_sync_ops;
}

} // namespace paddle::framework::interpreter
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ namespace interpreter {

struct ExecutionConfig {
bool create_local_scope{true};

bool used_for_cinn{false};
bool used_for_control_flow_op{false};
bool used_for_jit{false};
Expand All @@ -35,6 +34,10 @@ struct ExecutionConfig {
size_t device_num_threads{0};
size_t host_num_threads{0};

std::set<std::pair<int, std::string>>
force_sync_ops; // set{pair<op_id, name>}, -1 matches any op_id, ""
// matches any name

std::set<std::string> force_root_scope_vars;
std::set<std::string> jit_input_vars;
std::set<std::string> skip_gc_vars;
Expand All @@ -43,6 +46,9 @@ struct ExecutionConfig {
void Log(int log_level);
};

std::set<std::pair<int, std::string>> GetForceSyncOps(
int micro_batch_id, const std::string& job_name);

} // namespace interpreter
} // namespace framework
} // namespace paddle
32 changes: 31 additions & 1 deletion paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,33 @@ void PirInterpreter::AnalyseExecuteOrderForTrace(
}
}

void PirInterpreter::AnalyzeForceSyncOps() {
for (auto& ins : vec_instruction_base_) {
ins->SetSyncAfterLaunch(FLAGS_benchmark);

// Analyze force sync op set by FLAGS_force_sync_op
int op_id = ins->Id();
std::string op_name = ins->Name();
std::string unused_prefix = "pd_op.";
auto pos = op_name.find(unused_prefix);
if (pos != std::string::npos) {
op_name.erase(pos, unused_prefix.size());
}

for (auto& pair : execution_config_.force_sync_ops) {
int sync_op_id = pair.first;
std::string sync_op_name = pair.second;
if ((sync_op_id == op_id || sync_op_id == -1) &&
(sync_op_name == op_name || sync_op_name == "")) {
VLOG(8) << "Force sync op: "
<< "sync_op_id=" << sync_op_id << ", op_id=" << op_id
<< ", sync_op_name=" << sync_op_name << ", op_name=" << op_name;
ins->SetSyncAfterLaunch(true);
}
}
}
}

void PirInterpreter::BuildInstruction() {
VLOG(6) << "Build Instructions for pir ... ";
vec_instruction_base_.clear();
Expand Down Expand Up @@ -1900,7 +1927,7 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
instr_node->Run();
}

if (FLAGS_benchmark) {
if (instr_node->IsSyncAfterLaunch()) {
instr_node->DeviceContext().Wait();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
Expand Down Expand Up @@ -1992,6 +2019,9 @@ void PirInterpreter::PreAnalysis() {
ir_instruction_scheduling_priority_less);
VLOG(4) << "Done AnalyseExecuteOrderForTrace";

AnalyzeForceSyncOps();
VLOG(4) << "Done AnalyzeForceSyncOps";

UpdateSyncOpNum();
VLOG(4) << "Done UpdateSyncOpNum";

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/new_executor/pir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ class PirInterpreter : public InterpreterBaseImpl {
void UpdateSyncOpNum();
void UpdateNcclOpNum();
void UpdateOneDNNOpNum();

void AnalyseExecuteOrderForTrace(
std::map<size_t, std::set<size_t>> op_downstream_map,
InstructionSchedulingPriorityLess compare);
void AnalyzeForceSyncOps();
void ConstructEventForJitInput();
void CalculateLastLiveOps();

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/new_executor/standalone_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ StandaloneExecutor::StandaloneExecutor(const phi::Place& place,
interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.skip_gc_vars = job->SkipGcVars();
execution_config.force_sync_ops =
interpreter::GetForceSyncOps(micro_batch_id, job_type);

// TODO(phlrain) we only support cpu for now
if (FLAGS_enable_pir_in_executor) {
Expand Down