Skip to content

Commit 33cf3fd

Browse files
authored
[PIR] Update output place for memcpy op (#70201)
1 parent 5e8544c commit 33cf3fd

File tree

3 files changed

+51
-5
lines changed

3 files changed

+51
-5
lines changed

paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,19 @@ const std::unordered_map<std::string, uint32_t> NoBufferRelatedOps = {
176176
{paddle::dialect::BatchNorm_Op::name(), /*reserve_space*/ 5U},
177177
};
178178

179+
// Please keep the consistency with paddle/phi/kernels/memcpy_kernel.cc
180+
const std::unordered_map<int, phi::Place> MemcpyOpAttr2Place = {
181+
{0, phi::CPUPlace()},
182+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
183+
{1, phi::GPUPlace()},
184+
{2, phi::GPUPinnedPlace()},
185+
#elif defined(PADDLE_WITH_XPU)
186+
{3, phi::XPUPlace()},
187+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
188+
{4, phi::CustomPlace()}
189+
#endif
190+
};
191+
179192
static bool NeedSkipPlaceTransfer(const pir::Operation* op) {
180193
bool need_skip = false;
181194
if (op->isa<paddle::dialect::FetchOp>()) {
@@ -2179,8 +2192,6 @@ void HandleForSpecialOp(
21792192
// only deal with single output
21802193
if (op_item->num_results() > 0) {
21812194
for (size_t i = 0; i < op_item->num_results(); ++i) {
2182-
VLOG(6) << "2816:" << op_item->result(i).type();
2183-
VLOG(6) << "2817:" << op->result(i).type();
21842195
(*map_value_pair)[op_item->result(i)] = op->result(i);
21852196
}
21862197
}
@@ -2508,6 +2519,13 @@ std::vector<pir::Type> BuildOutputs(
25082519
(!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) {
25092520
out_place = phi::TransToPhiPlace(output_defs[i].backend);
25102521
}
2522+
if (op_item->isa<MemcpyOp>()) {
2523+
// If the op is MemcpyOp, the output type is determined by the
2524+
// attribute "dst_place_type".
2525+
out_place = MemcpyOpAttr2Place.at(op_item->attribute("dst_place_type")
2526+
.dyn_cast<pir::Int32Attribute>()
2527+
.data());
2528+
}
25112529
PushBackOutputTypes(ctx,
25122530
op_item,
25132531
op_item->result(i).type(),

paddle/phi/kernels/memcpy_kernel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ void MemcpyKernel(const Context& dev_ctx,
106106
PADDLE_ENFORCE_GE(
107107
dst_place_type,
108108
0,
109-
errors::OutOfRange("dst_place_type only support 0-2, but got: %d",
109+
errors::OutOfRange("dst_place_type only support 0-4, but got: %d",
110110
dst_place_type));
111111
PADDLE_ENFORCE_LE(
112112
dst_place_type,
113-
2,
114-
errors::OutOfRange("dst_place_type only support 0-2, but got: %d",
113+
4,
114+
errors::OutOfRange("dst_place_type only support 0-4, but got: %d",
115115
dst_place_type));
116116
switch (dst_place_type) {
117117
case 0: /* CPUPlace */

test/dygraph_to_static/test_tensor_memcpy_on_gpu.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def tensor_copy_to_cuda_with_warning(x, device_id=None, blocking=True):
4545
return y
4646

4747

48+
def tensor_copy_to_cpu_with_compute(x):
49+
x = paddle.to_tensor(x)
50+
y = x.cpu()
51+
return y + 1
52+
53+
4854
class TestTensorCopyToCpuOnDefaultGPU(Dy2StTestBase):
4955
def _run(self):
5056
x1 = paddle.ones([1, 2, 3])
@@ -124,5 +130,27 @@ def test_with_warning_on_gpu(self):
124130
self.assertIn('math_op_patch.py', cm.filename)
125131

126132

133+
class TestTensorCopyToCPUWithComputeOnDefaultGPU(Dy2StTestBase):
134+
def _run(self):
135+
x1 = paddle.ones([1, 2, 3])
136+
x2 = paddle.jit.to_static(tensor_copy_to_cpu_with_compute)(x1)
137+
return x1.place, x2.place, x2.numpy()
138+
139+
def test_tensor_cpu_with_compute_on_default_gpu(self):
140+
if not paddle.is_compiled_with_cuda():
141+
return
142+
place = paddle.CUDAPlace(int(os.environ.get('FLAGS_selected_gpus', 0)))
143+
paddle.framework._set_expected_place(place)
144+
with enable_to_static_guard(False):
145+
dygraph_x1_place, dygraph_place, dygraph_res = self._run()
146+
147+
static_x1_place, static_place, static_res = self._run()
148+
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
149+
self.assertTrue(dygraph_x1_place.is_gpu_place())
150+
self.assertTrue(static_x1_place.is_gpu_place())
151+
self.assertTrue(dygraph_place.is_cpu_place())
152+
self.assertTrue(static_place.is_cpu_place())
153+
154+
127155
if __name__ == '__main__':
128156
unittest.main()

0 commit comments

Comments
 (0)