Skip to content

Commit 6e8992d

Browse files
committed
Reduce performance influence by record event in python (PaddlePaddle#42040)
* optimize performance * fix * improve coverage * fix * fix
1 parent 963e364 commit 6e8992d

File tree

5 files changed

+80
-20
lines changed

5 files changed

+80
-20
lines changed

python/paddle/fluid/dataloader/dataloader_iter.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import paddle
3333
import paddle.profiler as profiler
34+
from paddle.profiler.utils import in_profiler_mode
3435
from .. import core, layers
3536
from ..framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
3637
from ..multiprocess_utils import _set_SIGCHLD_handler, MP_STATUS_CHECK_INTERVAL, CleanupFuncRegistrar
@@ -252,10 +253,11 @@ def _thread_loop(self, legacy_expected_place):
252253
self._exit_thread_expectedly()
253254

254255
def __next__(self):
255-
trace_event = profiler.RecordEvent(
256-
name="_DataLoaderIterSingleProcess",
257-
event_type=profiler.TracerEventType.Dataloader)
258-
trace_event.begin()
256+
if in_profiler_mode():
257+
trace_event = profiler.RecordEvent(
258+
name="_DataLoaderIterSingleProcess",
259+
event_type=profiler.TracerEventType.Dataloader)
260+
trace_event.begin()
259261
try:
260262
benchmark().check_if_need_record(self)
261263
benchmark().before_reader()
@@ -294,7 +296,8 @@ def __next__(self):
294296
self._try_shutdown_all()
295297
six.reraise(*sys.exc_info())
296298
finally:
297-
trace_event.end()
299+
if in_profiler_mode():
300+
trace_event.end()
298301

299302
def _shutdown_thread(self):
300303
if self._thread:
@@ -708,10 +711,11 @@ def _shutdown_on_exit(self):
708711
self._try_shutdown_all(1)
709712

710713
def __next__(self):
711-
trace_event = profiler.RecordEvent(
712-
name="_DataLoaderIterMultiProcess",
713-
event_type=profiler.TracerEventType.Dataloader)
714-
trace_event.begin()
714+
if in_profiler_mode():
715+
trace_event = profiler.RecordEvent(
716+
name="_DataLoaderIterMultiProcess",
717+
event_type=profiler.TracerEventType.Dataloader)
718+
trace_event.begin()
715719
try:
716720
benchmark().check_if_need_record(self)
717721
benchmark().before_reader()
@@ -765,7 +769,8 @@ def __next__(self):
765769
self._try_shutdown_all()
766770
six.reraise(*sys.exc_info())
767771
finally:
768-
trace_event.end()
772+
if in_profiler_mode():
773+
trace_event.end()
769774

770775
# python2 compatibility
771776
def next(self):

python/paddle/fluid/dygraph/layers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import paddle
2828
import paddle.profiler as profiler
29+
from paddle.profiler.utils import in_profiler_mode
2930

3031
from . import parallel_helper
3132
from .. import unique_name
@@ -906,8 +907,11 @@ def _dygraph_call_func(self, *inputs, **kwargs):
906907

907908
self._built = True
908909

909-
with profiler.RecordEvent(self.full_name(),
910-
profiler.TracerEventType.Forward):
910+
if in_profiler_mode():
911+
with profiler.RecordEvent(self.full_name(),
912+
profiler.TracerEventType.Forward):
913+
outputs = self.forward(*inputs, **kwargs)
914+
else:
911915
outputs = self.forward(*inputs, **kwargs)
912916

913917
for forward_post_hook in self._forward_post_hooks.values():
@@ -919,7 +923,7 @@ def _dygraph_call_func(self, *inputs, **kwargs):
919923

920924
def __call__(self, *inputs, **kwargs):
921925
if (not in_declarative_mode()) and (not self._forward_pre_hooks) \
922-
and (not self._forward_post_hooks) and (not self._built) and in_dygraph_mode():
926+
and (not self._forward_post_hooks) and (not self._built) and in_dygraph_mode() and (not in_profiler_mode()):
923927
self._build_once(*inputs, **kwargs)
924928
return self.forward(*inputs, **kwargs)
925929
else:

python/paddle/fluid/dygraph/varbase_patch_methods.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
3131
import paddle.utils.deprecated as deprecated
3232
import paddle.profiler as profiler
33+
from paddle.profiler.utils import in_profiler_mode
3334
from paddle import _C_ops
3435

3536
_grad_scalar = None
@@ -247,9 +248,10 @@ def backward(self, grad_tensor=None, retain_graph=False):
247248
248249
"""
249250
if framework._non_static_mode():
250-
record_event = profiler.RecordEvent(
251-
"Gradient Backward", profiler.TracerEventType.Backward)
252-
record_event.begin()
251+
if in_profiler_mode():
252+
record_event = profiler.RecordEvent(
253+
"Gradient Backward", profiler.TracerEventType.Backward)
254+
record_event.begin()
253255
if grad_tensor is not None:
254256
if framework._in_eager_mode_:
255257
assert isinstance(
@@ -288,7 +290,8 @@ def backward(self, grad_tensor=None, retain_graph=False):
288290
core.dygraph_run_backward([self], [grad_tensor],
289291
retain_graph,
290292
framework._dygraph_tracer())
291-
record_event.end()
293+
if in_profiler_mode():
294+
record_event.end()
292295
else:
293296
raise ValueError(
294297
"Variable.backward() is only available in DyGraph mode")

python/paddle/fluid/tests/unittests/test_newprofiler.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,42 @@ def my_sheduler1(num_step):
134134
prof.export(path='./test_profiler_pb.pb', format='pb')
135135
prof.summary()
136136
result = profiler.utils.load_profiler_result('./test_profiler_pb.pb')
137+
prof = None
138+
dataset = RandomDataset(10 * 4)
139+
simple_net = SimpleNet()
140+
opt = paddle.optimizer.SGD(learning_rate=1e-3,
141+
parameters=simple_net.parameters())
142+
loader = DataLoader(
143+
dataset, batch_size=4, shuffle=True, drop_last=True, num_workers=2)
144+
prof = profiler.Profiler(on_trace_ready=lambda prof: None)
145+
prof.start()
146+
for i, (image, label) in enumerate(loader()):
147+
out = simple_net(image)
148+
loss = F.cross_entropy(out, label)
149+
avg_loss = paddle.mean(loss)
150+
avg_loss.backward()
151+
opt.minimize(avg_loss)
152+
simple_net.clear_gradients()
153+
prof.step()
154+
prof.stop()
155+
prof.summary()
156+
prof = None
157+
dataset = RandomDataset(10 * 4)
158+
simple_net = SimpleNet()
159+
loader = DataLoader(dataset, batch_size=4, shuffle=True, drop_last=True)
160+
opt = paddle.optimizer.Adam(
161+
learning_rate=1e-3, parameters=simple_net.parameters())
162+
prof = profiler.Profiler(on_trace_ready=lambda prof: None)
163+
prof.start()
164+
for i, (image, label) in enumerate(loader()):
165+
out = simple_net(image)
166+
loss = F.cross_entropy(out, label)
167+
avg_loss = paddle.mean(loss)
168+
avg_loss.backward()
169+
opt.step()
170+
simple_net.clear_gradients()
171+
prof.step()
172+
prof.stop()
137173

138174

139175
class TestNvprof(unittest.TestCase):

python/paddle/profiler/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from paddle.fluid.core import (_RecordEvent, TracerEventType)
2222

2323
_is_profiler_used = False
24+
_has_optimizer_wrapped = False
2425

2526
_AllowedEventTypeList = [
2627
TracerEventType.Dataloader, TracerEventType.ProfileStep,
@@ -154,20 +155,31 @@ def load_profiler_result(filename: str):
154155
return core.load_profiler_result(filename)
155156

156157

158+
def in_profiler_mode():
159+
return _is_profiler_used == True
160+
161+
157162
def wrap_optimizers():
158163
def optimizer_warpper(func):
159164
@functools.wraps(func)
160165
def warpper(*args, **kwargs):
161-
with RecordEvent(
162-
'Optimization Step',
163-
event_type=TracerEventType.Optimization):
166+
if in_profiler_mode():
167+
with RecordEvent(
168+
'Optimization Step',
169+
event_type=TracerEventType.Optimization):
170+
return func(*args, **kwargs)
171+
else:
164172
return func(*args, **kwargs)
165173

166174
return warpper
167175

176+
global _has_optimizer_wrapped
177+
if _has_optimizer_wrapped == True:
178+
return
168179
import paddle.optimizer as optimizer
169180
for classname in optimizer.__all__:
170181
if classname != 'Optimizer':
171182
classobject = getattr(optimizer, classname)
172183
if getattr(classobject, 'step', None) != None:
173184
classobject.step = optimizer_warpper(classobject.step)
185+
_has_optimizer_wrapped = True

0 commit comments

Comments
 (0)