Skip to content

Commit 6c8f611

Browse files
authored
fix argument (modelscope#1066)
1 parent 783c243 commit 6c8f611

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

swift/llm/sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
5454
config_path = args.device_map_config_path if os.path.isabs(args.device_map_config_path) else os.path.join(
5555
cwd, args.device_map_config_path)
5656
with open(config_path, 'r') as json_file:
57-
model_kwargs['device_map'] = json.load(json_file)
57+
model_kwargs = {'device_map': json.load(json_file)}
5858
else:
5959
model_kwargs = {'low_cpu_mem_usage': True}
6060
if is_dist() and not is_ddp_plus_mp():

swift/llm/utils/argument.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def select_bnb(self: Union['SftArguments', 'InferArguments']) -> Tuple[Optional[
152152
elif quantization_bit == 8:
153153
require_version('bitsandbytes')
154154
load_in_4bit, load_in_8bit = False, True
155+
else:
156+
logger.warning('bnb only support 4/8 bits quantization, you should assign --quantization_bit 4 or 8,\
157+
Or specify another quantization method; No quantization will be performed here.')
158+
load_in_4bit, load_in_8bit = False, False
155159
else:
156160
load_in_4bit, load_in_8bit = False, False
157161

swift/llm/utils/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import transformers
1515
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
1616
GenerationConfig, GPTQConfig, snapshot_download)
17-
from modelscope.hub.utils.utils import get_cache_dir
1817
from packaging import version
1918
from torch import Tensor
2019
from torch import dtype as Dtype
@@ -25,6 +24,7 @@
2524
from transformers.utils.versions import require_version
2625

2726
from swift import get_logger
27+
from swift.hub.utils.utils import get_cache_dir
2828
from swift.utils import get_dist_setting, safe_ddp_context, subprocess_run, use_torchacc
2929
from .template import TemplateType
3030
from .utils import get_max_model_len, is_unsloth_available

0 commit comments

Comments
 (0)