Skip to content

Commit 66349cc

Browse files
authored
fix_xpu_pp (#71500)
1 parent f743791 commit 66349cc

File tree

4 files changed

+17
-1
lines changed

4 files changed

+17
-1
lines changed

paddle/phi/common/place.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "glog/logging.h"
2121
#include "paddle/common/exception.h"
2222
#include "paddle/phi/backends/gpu/gpu_info.h"
23+
#include "paddle/phi/backends/xpu/xpu_info.h"
2324

2425
namespace phi {
2526

@@ -274,4 +275,13 @@ GPUPlace DefaultGPUPlace() {
274275
#endif
275276
}
276277

278+
phi::XPUPlace DefaultXPUPlace() {
279+
return phi::XPUPlace(
280+
#ifdef PADDLE_WITH_XPU
281+
phi::backends::xpu::GetXPUCurrentDeviceId());
282+
#else
283+
0);
284+
#endif
285+
}
286+
277287
} // namespace paddle

paddle/phi/common/place.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,4 +261,6 @@ PADDLE_API bool operator==(PlaceType place_type, const Place& place);
261261

262262
PADDLE_API GPUPlace DefaultGPUPlace();
263263

264+
PADDLE_API phi::XPUPlace DefaultXPUPlace();
265+
264266
} // namespace paddle

paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ Place GetDefaultPlace() {
202202
if (phi::backends::gpu::GetGPUDeviceCount() >= 0) {
203203
return paddle::DefaultGPUPlace();
204204
}
205+
#elif defined(PADDLE_WITH_XPU)
206+
if (phi::backends::xpu::GetXPUDeviceCount() >= 0) {
207+
return paddle::DefaultXPUPlace();
208+
}
205209
#endif
206210
return paddle::CPUPlace();
207211
}

python/paddle/nn/clip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def async_add_n(var_list):
773773
global_norm_var = async_add_n(global_norm_var)
774774
global_norm_var = paddle.sqrt(global_norm_var)
775775
max_global_norm = paddle.full(
776-
shape=[], dtype=sum_dtype, fill_value=self.clip_norm
776+
shape=[1], dtype=sum_dtype, fill_value=self.clip_norm
777777
)
778778

779779
need_clip = False

0 commit comments

Comments
 (0)