|
1 | 1 | # copied from https://github.com/swiss-ai/transformers/blob/1e0881a41d4fda838ece30f730130ddf10ba0913/src/transformers/models/swissai/configuration_swissai.py |
2 | | -from transformers import PretrainedConfig, AutoConfig |
3 | | - |
| 2 | +from transformers import PretrainedConfig |
| 3 | +from transformers.modeling_rope_utils import rope_config_validation |
4 | 4 |
|
5 | 5 | class SwissAIConfig(PretrainedConfig): |
6 | 6 | r""" |
@@ -128,34 +128,12 @@ def __init__( |
128 | 128 | self.use_cache = use_cache |
129 | 129 | self.rope_theta = rope_theta |
130 | 130 | self.rope_scaling = rope_scaling |
131 | | - self._rope_scaling_validation() |
| 131 | + if self.rope_scaling is not None and "type" in self.rope_scaling: |
| 132 | + self.rope_scaling["rope_type"] = self.rope_scaling["type"] |
| 133 | + rope_config_validation(self) |
132 | 134 | self.attention_bias = attention_bias |
133 | 135 | self.attention_dropout = attention_dropout |
134 | 136 |
|
135 | 137 | self.rms_norm_eps = rms_norm_eps |
136 | 138 |
|
137 | | - def _rope_scaling_validation(self): |
138 | | - """ |
139 | | - Validate the `rope_scaling` configuration. |
140 | | - """ |
141 | | - if self.rope_scaling is None: |
142 | | - return |
143 | | - |
144 | | - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: |
145 | | - raise ValueError( |
146 | | - f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}" |
147 | | - ) |
148 | | - rope_scaling_type = self.rope_scaling.get("type", None) |
149 | | - rope_scaling_factor = self.rope_scaling.get("factor", None) |
150 | | - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: |
151 | | - raise ValueError( |
152 | | - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" |
153 | | - ) |
154 | | - if ( |
155 | | - rope_scaling_factor is None |
156 | | - or not isinstance(rope_scaling_factor, float) |
157 | | - or rope_scaling_factor <= 1.0 |
158 | | - ): |
159 | | - raise ValueError( |
160 | | - f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}" |
161 | | - ) |
| 139 | +__all__ = ["SwissAIConfig"] |
0 commit comments