Skip to content

Commit 486a1d2

Browse files
zhouzaidaxuuyangg
authored andcommitted
[Fix]fix multiprocessing fails to launch on Ascend NPU
1 parent 488fddc commit 486a1d2

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

.github/workflows/merge_stage_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ on:
1212
- "CONTRIBUTING_zh-CN.md"
1313
- ".pre-commit-config.yaml"
1414
- ".pre-commit-config-zh-cn.yaml"
15+
- "examples/**"
1516
branches:
1617
- main
1718

.github/workflows/pr_stage_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ on:
1212
- "CONTRIBUTING_zh-CN.md"
1313
- ".pre-commit-config.yaml"
1414
- ".pre-commit-config-zh-cn.yaml"
15+
- "examples/**"
1516

1617
concurrency:
1718
group: ${{ github.workflow }}-${{ github.ref }}

mmengine/device/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
import os
32
from typing import Optional
43

54
import torch
@@ -8,10 +7,6 @@
87
import torch_npu # noqa: F401
98
import torch_npu.npu.utils as npu_utils
109

11-
# Enable operator support for dynamic shape and
12-
# binary operator support on the NPU.
13-
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
14-
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
1510
IS_NPU_AVAILABLE = hasattr(torch, 'npu') and torch.npu.is_available()
1611
except Exception:
1712
IS_NPU_AVAILABLE = False

mmengine/model/base_model/base_model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
23
from abc import abstractmethod
34
from collections import OrderedDict
45
from typing import Dict, Optional, Tuple, Union
@@ -267,6 +268,12 @@ def _set_device(self, device: torch.device) -> None:
267268
buffers in this module.
268269
"""
269270

271+
if device.type == 'npu':
272+
# Enable operator support for dynamic shape and
273+
# binary operator support on the NPU.
274+
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
275+
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
276+
270277
def apply_fn(module):
271278
if not isinstance(module, BaseDataPreprocessor):
272279
return

0 commit comments

Comments
 (0)