Skip to content

Commit 258b9ca

Browse files
committed
[Feature]Add a NPU Hook
1 parent a483dba commit 258b9ca

File tree

6 files changed

+56
-8
lines changed

6 files changed

+56
-8
lines changed

mmengine/device/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
import os
32
from typing import Optional
43

54
import torch
@@ -8,10 +7,6 @@
87
import torch_npu # noqa: F401
98
import torch_npu.npu.utils as npu_utils
109

11-
# Enable operator support for dynamic shape and
12-
# binary operator support on the NPU.
13-
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
14-
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
1510
IS_NPU_AVAILABLE = hasattr(torch, 'npu') and torch.npu.is_available()
1611
except Exception:
1712
IS_NPU_AVAILABLE = False

mmengine/hooks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .iter_timer_hook import IterTimerHook
88
from .logger_hook import LoggerHook
99
from .naive_visualization_hook import NaiveVisualizationHook
10+
from .npu_hook import NPUHook
1011
from .param_scheduler_hook import ParamSchedulerHook
1112
from .profiler_hook import NPUProfilerHook, ProfilerHook
1213
from .runtime_info_hook import RuntimeInfoHook
@@ -18,5 +19,5 @@
1819
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
1920
'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
2021
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook',
21-
'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook'
22+
'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook', 'NPUHook'
2223
]

mmengine/hooks/npu_hook.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
3+
4+
import torch
5+
6+
from mmengine.registry import HOOKS
7+
from mmengine.utils import try_import
8+
from .hook import Hook
9+
10+
torch_npu = try_import('torch_npu')
11+
12+
13+
@HOOKS.register_module()
14+
class NPUHook(Hook):
15+
"""A hook specific for NPU device."""
16+
17+
priority = 'HIGHEST'
18+
19+
def __init__(self) -> None:
20+
super().__init__()
21+
if torch_npu is None:
22+
raise ImportError(f'For availability of {self.__class__.__name__},'
23+
'pleasse install ascend pytorch first.')
24+
25+
def before_run(self, runner) -> None:
26+
27+
# Enable operator support for dynamic shape and
28+
# binary operator support on the NPU.
29+
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
30+
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)

mmengine/runner/runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ def __init__(
423423
self._hooks: List[Hook] = []
424424
# register hooks to `self._hooks`
425425
self.register_hooks(default_hooks, custom_hooks)
426+
427+
if get_device() == 'npu':
428+
self.register_hook({'type': 'NPUHook'})
429+
426430
# log hooks information
427431
self.logger.info(f'Hooks will be executed in the following '
428432
f'order:\n{self.get_hooks_info()}')

mmengine/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
is_method_overridden, is_seq_of, is_str, is_tuple_of,
88
iter_cast, list_cast, requires_executable, requires_package,
99
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
10-
to_ntuple, tuple_cast)
10+
to_ntuple, try_import, tuple_cast)
1111
from .package_utils import (call_command, get_installed_path, install_package,
1212
is_installed)
1313
from .path import (check_file_exist, fopen, is_abs, is_filepath,
@@ -29,5 +29,5 @@
2929
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
3030
'TimerError', 'ProgressBar', 'track_iter_progress',
3131
'track_parallel_progress', 'track_progress', 'deprecated_function',
32-
'apply_to', 'get_object_from_string'
32+
'apply_to', 'get_object_from_string', 'try_import'
3333
]

mmengine/utils/misc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from importlib import import_module
1212
from inspect import getfullargspec, ismodule
1313
from itertools import repeat
14+
from types import ModuleType
1415
from typing import Any, Callable, Optional, Type, Union
1516

1617

@@ -540,3 +541,20 @@ def get_object_from_string(obj_name: str):
540541
return obj_cls
541542
except AttributeError:
542543
return None
544+
545+
546+
def try_import(name: str) -> Optional[ModuleType]:
547+
"""Try to import a module.
548+
549+
Args:
550+
name (str): Specifies what module to import in absolute or relative
551+
terms (e.g. either pkg.mod or ..mod).
552+
553+
Returns:
554+
ModuleType or None: If importing successfully, returns the imported
555+
module, otherwise returns None.
556+
"""
557+
try:
558+
return import_module(name)
559+
except ImportError:
560+
return None

0 commit comments

Comments
 (0)