Skip to content

Commit 8dba3d0

Browse files
authored
[Cherry-pick] Two detail fix prs (#41728)
* [Eager] Remove elementwise add in conv (#41515) * remove elementwise add in conv * use reshape * fix warpctc grad kernel dep eror (#41598)
1 parent 0053ba8 commit 8dba3d0

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

paddle/phi/kernels/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_re
6161
kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
6262
kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
6363
kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
64-
kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale)
64+
kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
6565

6666
# 4. auto parse and build kernel targets by cmake
6767
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )

python/paddle/nn/functional/conv.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,12 @@ def _conv_nd(x,
127127
x, weight, stride, padding, padding_algorithm, groups, dilation,
128128
data_format, False, -1, False)
129129
if bias is not None:
130-
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
131-
return out
130+
channel_dim = channel_dim + len(
131+
x.shape) if channel_dim < 0 else channel_dim
132+
tmp_bias = _C_ops.final_state_reshape(
133+
bias, bias.shape +
134+
[1 for i in range(len(x.shape) - channel_dim - 1)])
135+
return _C_ops.final_state_add(pre_bias, tmp_bias)
132136
else:
133137
return pre_bias
134138
if in_dynamic_mode():

0 commit comments

Comments
 (0)