Skip to content

Commit f6e273f

Browse files
authored
[fluid_ops]Replace paddle::platform::DeviceContextPool in fluid/imperative (#65837)
* Fix * ci
1 parent 335e445 commit f6e273f

24 files changed

+81
-100
lines changed

paddle/fluid/imperative/all_reduce.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ static void AllReduce(const phi::SelectedRows &src,
9696
auto dtype = framework::TransToProtoVarType(src_tensor.dtype());
9797
auto nccl_dtype = platform::ToNCCLDataType(dtype);
9898
auto *dev_ctx = static_cast<phi::GPUContext *>(
99-
platform::DeviceContextPool::Instance().Get(place));
99+
phi::DeviceContextPool::Instance().Get(place));
100100

101101
bool use_calc_stream = (dev_ctx->stream() == stream);
102102
VLOG(4) << "Is use calculate stream: " << use_calc_stream;
@@ -221,7 +221,7 @@ void AllReduce(const framework::Variable &src,
221221
bool use_calc_stream) {
222222
const auto &place = GetVarPlace(src);
223223
auto *dev_ctx = static_cast<phi::GPUContext *>(
224-
platform::DeviceContextPool::Instance().Get(place));
224+
phi::DeviceContextPool::Instance().Get(place));
225225
platform::NCCLComm *comm =
226226
platform::NCCLCommContext::Instance().Get(ring_id, place);
227227
gpuStream_t stream = (use_calc_stream ? dev_ctx->stream() : comm->stream());

paddle/fluid/imperative/basic_engine.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ void BasicEngine::Init(
107107
VLOG(6) << "init loss grad:" << var->GradVarBase()->Name()
108108
<< " as stop_gradient false";
109109
var->GradVarBase()->InnerSetOverriddenStopGradient(false);
110-
auto* dev_ctx =
111-
platform::DeviceContextPool::Instance().Get(fwd_var.place());
110+
auto* dev_ctx = phi::DeviceContextPool::Instance().Get(fwd_var.place());
112111
if (grad_tensor == nullptr) {
113112
grad_var->Resize(fwd_var.dims());
114113
grad_var->mutable_data(fwd_var.place(), fwd_var.type());
@@ -158,7 +157,7 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) {
158157
}
159158

160159
if (tensor && !tensor->IsInitialized()) {
161-
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(op.place());
160+
auto* dev_ctx = phi::DeviceContextPool::Instance().Get(op.place());
162161
// NOTE(zhiqiu): since grad variable is ungenerated, so the dtype is not
163162
// correct. var->DataType() returns the default dtype, which is float32.
164163
// Here, we use the type of the corresponding forward datatype.

paddle/fluid/imperative/bkcl_context.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ void BKCLParallelContext::AllReduceByStream(const framework::Variable &src,
157157
auto place = place_;
158158

159159
auto *dev_ctx = static_cast<platform::XPUDeviceContext *>(
160-
platform::DeviceContextPool::Instance().Get(place));
160+
phi::DeviceContextPool::Instance().Get(place));
161161
platform::BKCLComm *comm =
162162
platform::BKCLCommContext::Instance().Get(ring_id, place);
163163
XPUStream stream =
@@ -223,7 +223,7 @@ void BKCLParallelContext::WaitCompute(int ring_id) {
223223
ring_id,
224224
strategy_.nrings_));
225225
auto compute_stream = static_cast<platform::XPUDeviceContext *>(
226-
platform::DeviceContextPool::Instance().Get(place_))
226+
phi::DeviceContextPool::Instance().Get(place_))
227227
->stream();
228228
auto comm_stream = platform::BKCLCommContext::Instance()
229229
.Get(ring_id, place_)
@@ -253,7 +253,7 @@ void BKCLParallelContext::WaitComm(int ring_id) {
253253
->dev_context()
254254
->stream();
255255
auto compute_stream = static_cast<platform::XPUDeviceContext *>(
256-
platform::DeviceContextPool::Instance().Get(place_))
256+
phi::DeviceContextPool::Instance().Get(place_))
257257
->stream();
258258
auto event = compute_events_[ring_id].get();
259259

@@ -264,7 +264,7 @@ void BKCLParallelContext::WaitComm(int ring_id) {
264264

265265
void BKCLParallelContext::SynchronizeCompute() {
266266
auto compute_dev_ctx = static_cast<platform::XPUDeviceContext *>(
267-
platform::DeviceContextPool::Instance().Get(place_));
267+
phi::DeviceContextPool::Instance().Get(place_));
268268
compute_dev_ctx->Wait();
269269
}
270270

paddle/fluid/imperative/gradient_accumulator.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void XPUTensorAddFunctor(const platform::Place& place,
8484
phi::DenseTensor* dst) {
8585
using XPUType = typename XPUTypeTrait<T>::Type;
8686
platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
87-
platform::DeviceContextPool::Instance().Get(place));
87+
phi::DeviceContextPool::Instance().Get(place));
8888
const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
8989
XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
9090
int r = -1;
@@ -201,8 +201,8 @@ void TensorAdd(const VarType& src, VarType* dst) {
201201
// check requiring input dtypes to be the same have been removed.
202202
#define PADDLE_TENSOR_ADD(T, CONTEXT) \
203203
if (data_type == framework::DataTypeTrait<T>::DataType()) { \
204-
auto cpu_ctx = static_cast<CONTEXT*>( \
205-
platform::DeviceContextPool::Instance().Get(place)); \
204+
auto cpu_ctx = \
205+
static_cast<CONTEXT*>(phi::DeviceContextPool::Instance().Get(place)); \
206206
phi::AddKernel<T, CONTEXT>(*cpu_ctx, *dst_tensor, src_tensor, dst_tensor); \
207207
return; \
208208
}
@@ -218,13 +218,13 @@ void TensorAdd(const VarType& src, VarType* dst) {
218218
#endif
219219
}
220220

221-
#define TENSOR_ADD_EIGEN(T) \
222-
auto cpu_ctx = static_cast<phi::CPUContext*>( \
223-
platform::DeviceContextPool::Instance().Get(place)); \
224-
auto in = phi::EigenVector<T>::Flatten(src_tensor); \
225-
auto out = phi::EigenVector<T>::Flatten(*dst_tensor); \
226-
auto& p = *(cpu_ctx->eigen_device()); \
227-
out.device(p) = out + in; \
221+
#define TENSOR_ADD_EIGEN(T) \
222+
auto cpu_ctx = static_cast<phi::CPUContext*>( \
223+
phi::DeviceContextPool::Instance().Get(place)); \
224+
auto in = phi::EigenVector<T>::Flatten(src_tensor); \
225+
auto out = phi::EigenVector<T>::Flatten(*dst_tensor); \
226+
auto& p = *(cpu_ctx->eigen_device()); \
227+
out.device(p) = out + in; \
228228
return;
229229

230230
if (phi::is_cpu_place(place)) {
@@ -244,7 +244,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
244244
if (data_type == framework::DataTypeTrait<T>::DataType()) { \
245245
platform::CustomDeviceContext* ctx = \
246246
static_cast<platform::CustomDeviceContext*>( \
247-
platform::DeviceContextPool::Instance().Get(place)); \
247+
phi::DeviceContextPool::Instance().Get(place)); \
248248
phi::stream::Stream stream(place, ctx->stream()); \
249249
auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
250250
device->BlasAXPBY<T>(stream, \
@@ -313,7 +313,7 @@ void SelectedRowsAddToTensor(const VarType& src, VarType* dst) {
313313
auto place = dst_tensor->place();
314314
auto data_type =
315315
framework::TransToProtoVarType(src_selected_rows.value().dtype());
316-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
316+
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
317317

318318
#define PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(dev_ctx_type, cpp_type) \
319319
if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
@@ -363,7 +363,7 @@ void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
363363

364364
const auto& place = src_tensor.place();
365365
auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
366-
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
366+
auto* dev_ctx = phi::DeviceContextPool::Instance().Get(place);
367367

368368
phi::DenseTensor* dst_tensor =
369369
GetInnerMutableTensor<phi::DenseTensor>(dst_tensor_var);
@@ -426,7 +426,7 @@ std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
426426
auto place = src_selected_rows1.value().place();
427427
auto data_type =
428428
framework::TransToProtoVarType(src_selected_rows1.value().dtype());
429-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
429+
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
430430

431431
std::vector<const phi::SelectedRows*> src_selected_rows;
432432
src_selected_rows.emplace_back(&src_selected_rows1);
@@ -667,7 +667,7 @@ void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
667667
if (!dst_var->Var().IsInitialized() ||
668668
!dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
669669
VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero ";
670-
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
670+
auto* dev_ctx = phi::DeviceContextPool::Instance().Get(place);
671671
if (!dst_var->Var().IsInitialized()) {
672672
auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
673673
VLOG(6) << "Dims of " << dst_var->Name()
@@ -807,7 +807,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
807807
if (!dst_var->Var().IsInitialized() ||
808808
!dst_var->Var().Get<phi::DenseTensor>().IsInitialized()) {
809809
VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
810-
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
810+
auto* dev_ctx = phi::DeviceContextPool::Instance().Get(place);
811811
if (!dst_var->Var().IsInitialized()) {
812812
auto* tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
813813
VLOG(6) << "Dims of " << dst_var->Name()

paddle/fluid/imperative/layer.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ void VarBase::ClearGradient(bool set_to_zero) {
240240
if (grad_t->IsInitialized()) {
241241
if (set_to_zero) {
242242
auto* dev_ctx =
243-
platform::DeviceContextPool::Instance().Get(grad_t->place());
243+
phi::DeviceContextPool::Instance().Get(grad_t->place());
244244
phi::funcs::set_constant(*dev_ctx, grad_t, 0.0f);
245245
} else {
246246
grad_t->clear();
@@ -302,10 +302,10 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
302302
new_var->SetType(Type());
303303
framework::TensorCopy(src_tensor, dst_place, dst_tensor);
304304
if (blocking) {
305-
platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
305+
phi::DeviceContextPool::Instance().Get(dst_place)->Wait();
306306
auto src_place = src_tensor.place();
307307
if (!(src_place == dst_place)) {
308-
platform::DeviceContextPool::Instance().Get(src_place)->Wait();
308+
phi::DeviceContextPool::Instance().Get(src_place)->Wait();
309309
}
310310
}
311311
VLOG(4) << "copy tensor " << Name() << " from " << Place() << " to "
@@ -323,10 +323,10 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
323323
dst_place,
324324
dst_selected_rows->mutable_value());
325325
if (blocking) {
326-
platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
326+
phi::DeviceContextPool::Instance().Get(dst_place)->Wait();
327327
auto src_place = src_selected_rows.place();
328328
if (!(src_place == dst_place)) {
329-
platform::DeviceContextPool::Instance().Get(src_place)->Wait();
329+
phi::DeviceContextPool::Instance().Get(src_place)->Wait();
330330
}
331331
}
332332
dst_selected_rows->set_height(src_selected_rows.height());
@@ -413,7 +413,7 @@ void VarBase::CopyFrom(const VarBase& src, const bool blocking) {
413413
framework::TensorCopy(src_tensor, place, dst_tensor);
414414
}
415415
if (blocking) {
416-
platform::DeviceContextPool::Instance().Get(place)->Wait();
416+
phi::DeviceContextPool::Instance().Get(place)->Wait();
417417
}
418418
}
419419

paddle/fluid/imperative/nccl_context.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ void NCCLParallelContext::WaitCompute(int ring_id) {
175175
compute_events_.size()));
176176

177177
auto compute_stream = static_cast<phi::GPUContext *>(
178-
platform::DeviceContextPool::Instance().Get(place_))
178+
phi::DeviceContextPool::Instance().Get(place_))
179179
->stream();
180180
auto comm_stream =
181181
platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream();
@@ -205,7 +205,7 @@ void NCCLParallelContext::WaitComm(int ring_id) {
205205
comm_events_.size()));
206206

207207
auto compute_stream = static_cast<phi::GPUContext *>(
208-
platform::DeviceContextPool::Instance().Get(place_))
208+
phi::DeviceContextPool::Instance().Get(place_))
209209
->stream();
210210
auto comm_stream =
211211
platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream();
@@ -223,7 +223,7 @@ void NCCLParallelContext::WaitComm(int ring_id) {
223223

224224
void NCCLParallelContext::SynchronizeCompute() {
225225
auto *compute_dev_ctx = static_cast<phi::GPUContext *>(
226-
platform::DeviceContextPool::Instance().Get(place_));
226+
phi::DeviceContextPool::Instance().Get(place_));
227227
compute_dev_ctx->Wait();
228228
}
229229

paddle/fluid/imperative/partial_grad_engine.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ static void FillConstantLike(const VariableWrapper &ref_var,
322322
float value) {
323323
auto &ref_tensor = ref_var.Var().Get<phi::DenseTensor>();
324324
auto *dst_tensor = dst_var->MutableVar()->GetMutable<phi::DenseTensor>();
325-
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
325+
auto *dev_ctx = phi::DeviceContextPool::Instance().Get(place);
326326
dst_tensor->Resize(ref_tensor.dims());
327327
// TODO(jiabin): Ugly fix here we have fwd_data_type_ and data_type, since in
328328
// grad mission

paddle/fluid/imperative/prepared_operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ PreparedOp PrepareImpl(
161161
const phi::KernelFactory& phi_kernel_factory,
162162
const phi::OpUtilsMap& phi_op_utils_map,
163163
const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map) {
164-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
164+
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
165165
auto* dev_ctx = pool.Get(place);
166166

167167
#ifdef PADDLE_WITH_DNNL

paddle/fluid/imperative/reducer.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
756756
#ifdef PADDLE_WITH_XPU_BKCL
757757
if (phi::is_xpu_place(group_tensor.place())) {
758758
auto dev_ctx = static_cast<platform::XPUDeviceContext *>(
759-
platform::DeviceContextPool::Instance().Get(place_));
759+
phi::DeviceContextPool::Instance().Get(place_));
760760
if (HasGrad(var_index)) {
761761
auto var_base = vars_[var_index]->GradVarBase();
762762
auto tensor = var_base->MutableVar()->GetMutable<phi::DenseTensor>();
@@ -773,7 +773,7 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
773773
}
774774
}
775775
#else
776-
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_);
776+
auto *dev_ctx = phi::DeviceContextPool::Instance().Get(place_);
777777
if (HasGrad(var_index)) {
778778
auto var_base = vars_[var_index]->GradVarBase();
779779
auto tensor = var_base->MutableVar()->GetMutable<phi::DenseTensor>();
@@ -924,7 +924,7 @@ void Reducer::ProcessUnusedDenseVars() {
924924
// avoid conflicts with communication.
925925
VLOG(3) << "Local used vars : "
926926
<< string::join_strings(local_used_vars_, ',');
927-
const auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_);
927+
const auto *dev_ctx = phi::DeviceContextPool::Instance().Get(place_);
928928
// H2D is to allreduce the local_used_vars_
929929
auto *global_used_tensor = global_used_vars_.GetMutable<phi::DenseTensor>();
930930
framework::TensorFromVector<int>(
@@ -976,7 +976,7 @@ void Reducer::ProcessUnusedDenseVars() {
976976
// 4. set grad tensor
977977
auto *dest_grad_tensor =
978978
grad_var_base_tmp->MutableVar()->GetMutable<phi::DenseTensor>();
979-
const auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_);
979+
const auto *dev_ctx = phi::DeviceContextPool::Instance().Get(place_);
980980
paddle::framework::TensorCopy(
981981
src_tensor, place_, *dev_ctx, dest_grad_tensor);
982982
dest_grad_tensor->Resize(dest_dims);

paddle/fluid/imperative/tracer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ void Tracer::TraceOp(const std::string& type,
470470
default_attrs,
471471
use_default_attr_map);
472472

473-
auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
473+
auto dev_ctx = phi::DeviceContextPool::Instance().Get(place);
474474
for (auto& iter : need_backup_inputs2outputs) {
475475
iter.first->ResetHolder(need_backup_inputs2holder[iter.first]);
476476
iter.first->set_strides(need_backup_inputs2strides[iter.first]);
@@ -613,7 +613,7 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature(
613613
framework::AttributeMap attrs) const {
614614
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
615615
framework::RuntimeContext ctx({}, {});
616-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
616+
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
617617
auto* dev_ctx = pool.Get(phi::CPUPlace());
618618
const auto& op_info = op->Info();
619619
auto* attr_checker = op_info.Checker();

0 commit comments

Comments
 (0)