Skip to content

Commit 3fc4ca3

Browse files
Support kwargs in forward pre hook (#71283)
* support kwargs in forward pre hook * fix args passing * move test case to test_imperative_hook_for_layer.py * add test after remove hook * let auto parallel pp llama test case use kwargs after global layer * rename var and delete useless code * format code * fix bug * fix bug * remove kwargs flag when removing hook * fix hook id
1 parent c3f8dc3 commit 3fc4ca3

File tree

4 files changed

+103
-23
lines changed

4 files changed

+103
-23
lines changed

python/paddle/distributed/auto_parallel/intermediate/pipeline_parallel.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -217,29 +217,35 @@ def forward_post_hook(layer, input, output):
217217
"layer output can only be tensor or list/tuple of tensor"
218218
)
219219

220-
def forward_pre_hook(layer, input):
220+
def forward_pre_hook(layer, args, kwargs):
221221
pp_idx = getattr(layer, "pipeline_stage_index", 0)
222-
new_input = []
223-
for t in input:
222+
new_args = []
223+
new_kwargs = {}
224+
225+
def reshard_tensor_args(t):
224226
if is_tensor(t) and t.is_dist() and t.process_mesh == g_mesh:
225-
new_input.append(
226-
dist.reshard(
227-
t,
228-
self.get_mesh(pp_idx),
229-
[dist.Replicate(), dist.Replicate()],
230-
)
227+
return dist.reshard(
228+
t,
229+
self.get_mesh(pp_idx),
230+
[dist.Replicate(), dist.Replicate()],
231231
)
232-
else:
233-
new_input.append(t)
234-
return tuple(new_input)
232+
return t
233+
234+
for arg in args:
235+
new_args.append(reshard_tensor_args(arg))
236+
237+
for key, arg in kwargs.items():
238+
new_kwargs[key] = reshard_tensor_args(arg)
239+
240+
return (new_args, new_kwargs)
235241

236242
for layer_name in self.global_spec:
237243
layer = self.get_layer_by_name(layer_name)
238244
layer.register_forward_post_hook(forward_post_hook)
239245

240246
for layer_name in self.pipeline_layers:
241247
layer = self.get_layer_by_name(layer_name)
242-
layer.register_forward_pre_hook(forward_pre_hook)
248+
layer.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
243249

244250

245251
def pipeline_parallel(model, optimizer=None, config=None):

python/paddle/nn/layer/layers.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -338,17 +338,29 @@ class HookRemoveHelper:
338338
next_hook_id: int = 0
339339

340340
def __init__(
341-
self, hooks: typing.OrderedDict[int, Callable[..., Any]]
341+
self,
342+
hooks: typing.OrderedDict[int, Callable[..., Any]],
343+
*,
344+
extra_hook_dict: Any = None,
342345
) -> None:
343346
self._hooks_ref = weakref.ref(hooks)
344347
self._hook_id = HookRemoveHelper.next_hook_id
345348
HookRemoveHelper.next_hook_id += 1
346349

350+
self._extra_hooks_ref = None
351+
if extra_hook_dict is not None:
352+
self._extra_hooks_ref = weakref.ref(extra_hook_dict)
353+
347354
def remove(self) -> None:
348355
hooks = self._hooks_ref()
349356
if hooks is not None and self._hook_id in hooks:
350357
del hooks[self._hook_id]
351358

359+
if self._extra_hooks_ref is not None:
360+
extra_hooks = self._extra_hooks_ref()
361+
if extra_hooks is not None and self._hook_id in extra_hooks:
362+
del extra_hooks[self._hook_id]
363+
352364

353365
class Layer:
354366
"""
@@ -437,6 +449,9 @@ def __init__(
437449
self._forward_post_hooks: typing.OrderedDict[int, _ForwardPostHook] = (
438450
OrderedDict()
439451
)
452+
self._forward_pre_hooks_with_kwargs_flag: typing.OrderedDict[
453+
int, bool
454+
] = OrderedDict()
440455

441456
# only used in AMP Training
442457
self._cast_to_low_precision = True
@@ -696,7 +711,7 @@ def register_forward_post_hook(
696711
return hook_remove_helper
697712

698713
def register_forward_pre_hook(
699-
self, hook: _ForwardPreHook
714+
self, hook: _ForwardPreHook, *, with_kwargs: bool = False
700715
) -> HookRemoveHelper:
701716
"""
702717
@@ -748,8 +763,15 @@ def register_forward_pre_hook(
748763
>>> # hook change the linear's input to input * 2, so out0 is equal to out1.
749764
>>> assert (out0.numpy() == out1.numpy()).any()
750765
"""
751-
hook_remove_helper = HookRemoveHelper(self._forward_pre_hooks)
766+
hook_remove_helper = HookRemoveHelper(
767+
self._forward_pre_hooks,
768+
extra_hook_dict=self._forward_pre_hooks_with_kwargs_flag,
769+
)
752770
self._forward_pre_hooks[hook_remove_helper._hook_id] = hook
771+
if with_kwargs:
772+
self._forward_pre_hooks_with_kwargs_flag[
773+
hook_remove_helper._hook_id
774+
] = True
753775
return hook_remove_helper
754776

755777
def create_parameter(
@@ -1490,12 +1512,27 @@ def _build_once(self, *args: Any, **kwargs: Any) -> None:
14901512
pass
14911513

14921514
def _dygraph_call_func(self, *inputs: Any, **kwargs: Any) -> Any:
1493-
for forward_pre_hook in self._forward_pre_hooks.values():
1494-
hook_result = forward_pre_hook(self, inputs)
1495-
if hook_result is not None:
1496-
if not isinstance(hook_result, tuple):
1497-
hook_result = (hook_result,)
1498-
inputs = hook_result
1515+
1516+
for hook_id, forward_pre_hook in self._forward_pre_hooks.items():
1517+
if hook_id in self._forward_pre_hooks_with_kwargs_flag:
1518+
args_kwargs_result = forward_pre_hook(self, inputs, kwargs)
1519+
if args_kwargs_result is not None:
1520+
if (
1521+
isinstance(args_kwargs_result, tuple)
1522+
and len(args_kwargs_result) == 2
1523+
):
1524+
inputs, kwargs = args_kwargs_result
1525+
else:
1526+
raise RuntimeError(
1527+
"forward pre-hook must return None or a tuple "
1528+
f"of (new_args, new_kwargs), but got {args_kwargs_result}."
1529+
)
1530+
else:
1531+
hook_result = forward_pre_hook(self, inputs)
1532+
if hook_result is not None:
1533+
if not isinstance(hook_result, tuple):
1534+
hook_result = (hook_result,)
1535+
inputs = hook_result
14991536

15001537
if not self._built:
15011538
self._build_once(*inputs, **kwargs)

test/auto_parallel/hybrid_strategy/single_llama_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def forward(self, input_ids):
205205
global_tensor = self.global_layer(None)
206206

207207
for idx, (decoder_layer) in enumerate(self.layers):
208-
hidden_states = decoder_layer(hidden_states, global_tensor)
208+
hidden_states = decoder_layer(
209+
hidden_states=hidden_states, global_tensor=global_tensor
210+
)
209211

210212
hidden_states = self.norm(hidden_states)
211213

test/legacy_test/test_imperative_hook_for_layer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,5 +224,40 @@ def test_forward_hook(self):
224224
self.assertFalse(call_forward_pre_hook)
225225

226226

227+
def forward_pre_hook_with_kwargs(layer, args, kwargs):
228+
kwargs['x'] = kwargs['x'] * 2
229+
return (args, kwargs)
230+
231+
232+
class SimpleNetWithKWArgs(paddle.nn.Layer):
233+
def __init__(
234+
self,
235+
):
236+
super().__init__()
237+
238+
def forward(self, x, y):
239+
z = x + y
240+
241+
return z
242+
243+
244+
class TestHookWithKWArgs(unittest.TestCase):
245+
def test_kwargs_hook(self):
246+
net = SimpleNetWithKWArgs()
247+
remove_handler = net.register_forward_pre_hook(
248+
forward_pre_hook_with_kwargs, with_kwargs=True
249+
)
250+
251+
x = paddle.randn((2, 3))
252+
y = paddle.randn((2, 3))
253+
254+
out = net(x=x, y=y)
255+
np.testing.assert_allclose(out.numpy(), (x * 2 + y).numpy())
256+
257+
remove_handler.remove()
258+
out = net(x=x, y=y)
259+
np.testing.assert_allclose(out.numpy(), (x + y).numpy())
260+
261+
227262
if __name__ == '__main__':
228263
unittest.main()

0 commit comments

Comments
 (0)