Skip to content

Commit 90db9c6

Browse files
[AutoParallel]:fix fused_dropout_add bug (#68736)
* [AutoParallel]:fix fused_dropout_add bug * [AutoParallel]:fix fused_dropout_add bug
1 parent 2b6ce84 commit 90db9c6

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

paddle/phi/infermeta/spmd_rules/fused_dropout_add.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ SpmdInfo FusedDropoutAddSpmd(const DistMetaTensor& x,
7777
const std::string& mode,
7878
int seed,
7979
bool fix_seed) {
80-
return FusedDropoutAddSpmdBase(x, y);
80+
auto dropout_info = FusedDropoutAddSpmdBase(x, y);
81+
dropout_info.first.push_back(seed_tensor.dist_attr());
82+
return dropout_info;
8183
}
8284

8385
SpmdInfo FusedDropoutAddSpmdReverse(const DistMetaTensor& x,

paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,5 +263,5 @@ PD_REGISTER_KERNEL(fused_dropout_add_grad,
263263
double,
264264
phi::dtype::bfloat16,
265265
phi::dtype::float16) {
266-
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
266+
kernel->InputAt(0).SetBackend(phi::Backend::CPU); // seed_offset
267267
}

0 commit comments

Comments
 (0)