|
19 | 19 | from typing import TYPE_CHECKING, Literal, Optional, Union
|
20 | 20 |
|
21 | 21 | from accelerate.utils import is_deepspeed_available
|
| 22 | +from packaging import version |
22 | 23 | from transformers import PreTrainedModel, PreTrainedTokenizer
|
23 | 24 |
|
24 | 25 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
@@ -141,6 +142,8 @@ def remove_hooks(model: "DeepSpeedEngine") -> None:
|
141 | 142 | optimizer_offload = model.optimizer.parameter_offload
|
142 | 143 | elif model.optimizer is not None:
|
143 | 144 | optimizer_offload = model.optimizer
|
| 145 | + else: |
| 146 | + raise RuntimeError("The model optimizer is None, which is not yet supported.") |
144 | 147 |
|
145 | 148 | for param in iter_params(optimizer_offload.module, recurse=True):
|
146 | 149 | param.ds_active_sub_modules.clear()
|
@@ -170,7 +173,13 @@ def add_hooks(model: "DeepSpeedEngine") -> None:
|
170 | 173 | optimizer_offload = model.optimizer.parameter_offload
|
171 | 174 | elif model.optimizer is not None:
|
172 | 175 | optimizer_offload = model.optimizer
|
173 |
| - optimizer_offload._register_hooks_recursively(optimizer_offload.module) |
| 176 | + else: |
| 177 | + raise RuntimeError("The model optimizer is None, which is not yet supported.") |
| 178 | + if version.parse(deepspeed.__version__) >= version.parse("0.16.4"): |
| 179 | + # Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847 |
| 180 | + optimizer_offload._register_deepspeed_module(optimizer_offload.module) |
| 181 | + else: |
| 182 | + optimizer_offload._register_hooks_recursively(optimizer_offload.module) |
174 | 183 |
|
175 | 184 |
|
176 | 185 | @contextmanager
|
|
0 commit comments