Skip to content

Commit 4449d59

Browse files
committed
Update [release]
1 parent 6e4ce4a commit 4449d59

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

src/trainers/_train_hf_base.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import TYPE_CHECKING, Any, Type, cast
88

99
import torch
10+
1011
from datasets.fingerprint import Hasher
1112

1213
from .. import DataDreamer
@@ -489,9 +490,9 @@ def export_to_disk(self, path: str, adapter_only: bool = False) -> PreTrainedMod
489490
from .train_hf_finetune import TrainHFFineTune
490491
from .train_setfit_classifier import TrainSetFitClassifier
491492

492-
assert (
493-
not adapter_only or self.peft_config
494-
), "`adapter_only` can only be used if a `peft_config` was provided."
493+
assert not adapter_only or self.peft_config, (
494+
"`adapter_only` can only be used if a `peft_config` was provided."
495+
)
495496

496497
# Clear the directory
497498
clear_dir(path)
@@ -624,9 +625,9 @@ def publish_to_hf_hub( # noqa: C901
624625
from .train_hf_finetune import TrainHFFineTune
625626
from .train_setfit_classifier import TrainSetFitClassifier
626627

627-
assert (
628-
not adapter_only or self.peft_config
629-
), "`adapter_only` can only be used if a `peft_config` was provided."
628+
assert not adapter_only or self.peft_config, (
629+
"`adapter_only` can only be used if a `peft_config` was provided."
630+
)
630631

631632
# Login
632633
api = hf_hub_login(token=token)
@@ -759,9 +760,9 @@ def publish_to_hf_hub( # noqa: C901
759760
f""" The training arguments can be found [here](training_args.json)."""
760761
)
761762
if os.path.exists(self.model_name) and os.path.isdir(self.model_name):
762-
readme_contents = readme_contents.replace("{base_model}", self.model_name)
763-
else:
764763
readme_contents = readme_contents.replace("\nbase_model: {base_model}", "")
764+
else:
765+
readme_contents = readme_contents.replace("{base_model}", self.model_name)
765766
tags = tags + model_names
766767
tags = (
767768
tags

0 commit comments

Comments
 (0)