Skip to content

Commit fea0c43

Browse files
committed
Merge commit 'refs/pull/56/head' of github.com:eth-easl/Scratchpad into dev
2 parents ad11191 + 4381a45 commit fea0c43

File tree

1 file changed

+6
-28
lines changed

1 file changed

+6
-28
lines changed

scratchpad/nn/models/swissai/config.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# copied from https://github.com/swiss-ai/transformers/blob/1e0881a41d4fda838ece30f730130ddf10ba0913/src/transformers/models/swissai/configuration_swissai.py
2-
from transformers import PretrainedConfig, AutoConfig
3-
2+
from transformers import PretrainedConfig
3+
from transformers.modeling_rope_utils import rope_config_validation
44

55
class SwissAIConfig(PretrainedConfig):
66
r"""
@@ -128,34 +128,12 @@ def __init__(
128128
self.use_cache = use_cache
129129
self.rope_theta = rope_theta
130130
self.rope_scaling = rope_scaling
131-
self._rope_scaling_validation()
131+
if self.rope_scaling is not None and "type" in self.rope_scaling:
132+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
133+
rope_config_validation(self)
132134
self.attention_bias = attention_bias
133135
self.attention_dropout = attention_dropout
134136

135137
self.rms_norm_eps = rms_norm_eps
136138

137-
def _rope_scaling_validation(self):
138-
"""
139-
Validate the `rope_scaling` configuration.
140-
"""
141-
if self.rope_scaling is None:
142-
return
143-
144-
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
145-
raise ValueError(
146-
f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
147-
)
148-
rope_scaling_type = self.rope_scaling.get("type", None)
149-
rope_scaling_factor = self.rope_scaling.get("factor", None)
150-
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
151-
raise ValueError(
152-
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
153-
)
154-
if (
155-
rope_scaling_factor is None
156-
or not isinstance(rope_scaling_factor, float)
157-
or rope_scaling_factor <= 1.0
158-
):
159-
raise ValueError(
160-
f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
161-
)
139+
__all__ = ["SwissAIConfig"]

0 commit comments

Comments
 (0)