Skip to content

Commit 06c1a45

Browse files
Fix device map in limited mem (#1190)
1 parent 1be2d72 commit 06c1a45

File tree

4 files changed

+7
-3
lines changed

4 files changed

+7
-3
lines changed

.dev_scripts/dockerci.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ MODELSCOPE_CACHE_DIR_IN_CONTAINER=/modelscope_cache
33
CODE_DIR=$PWD
44
CODE_DIR_IN_CONTAINER=/swift
55
echo "$USER"
6-
gpus='0,1 2,3 4,5 6,7'
7-
cpu_sets='0-15 16-31 32-47 48-63'
6+
gpus='0,1 2,3'
7+
cpu_sets='0-15 16-31'
88
cpu_sets_arr=($cpu_sets)
99
is_get_file_lock=false
1010
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh python tests/run.py --parallel 2 --run_config tests/run_config.yaml}

swift/llm/infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def prepare_model_template(args: InferArguments,
140140
else:
141141
logger.info(f'device_count: {torch.cuda.device_count()}')
142142
if device_map is None:
143-
device_map = 'auto'
143+
device_map = 'auto' if torch.cuda.device_count() > 1 else 'cuda:0'
144144
if device_map == 'auto':
145145
model_kwargs['low_cpu_mem_usage'] = True
146146
model_kwargs['device_map'] = device_map

swift/llm/rlhf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
4949
model_kwargs = {'low_cpu_mem_usage': True}
5050
if is_dist() and not is_ddp_plus_mp():
5151
model_kwargs['device_map'] = {'': local_rank}
52+
elif torch.cuda.device_count() == 1:
53+
model_kwargs['device_map'] = 'cuda:0'
5254
else:
5355
model_kwargs['device_map'] = 'auto'
5456

swift/llm/sft.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
5656
model_kwargs = {'low_cpu_mem_usage': True}
5757
if is_dist() and not is_ddp_plus_mp():
5858
model_kwargs['device_map'] = {'': local_rank}
59+
elif torch.cuda.device_count() == 1:
60+
model_kwargs['device_map'] = 'cuda:0'
5961
elif not use_torchacc():
6062
model_kwargs['device_map'] = 'auto'
6163

0 commit comments

Comments
 (0)