Skip to content

Commit e3244d2

Browse files
authored
🚀 Supporting deepspeed>=0.16.4's rename (#2963)
* Added else clause to avoid NameError on optimizer_offload * Accounted for deepspeed's renaming in 0.16.4 * Switched to packaging.version.parse over the (broken) tuple split * Moved from NotImplementedError to RuntimeError in else clause
1 parent 6a02c69 commit e3244d2

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

trl/models/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import TYPE_CHECKING, Literal, Optional, Union
2020

2121
from accelerate.utils import is_deepspeed_available
22+
from packaging import version
2223
from transformers import PreTrainedModel, PreTrainedTokenizer
2324

2425
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
@@ -141,6 +142,8 @@ def remove_hooks(model: "DeepSpeedEngine") -> None:
141142
optimizer_offload = model.optimizer.parameter_offload
142143
elif model.optimizer is not None:
143144
optimizer_offload = model.optimizer
145+
else:
146+
raise RuntimeError("The model optimizer is None, which is not yet supported.")
144147

145148
for param in iter_params(optimizer_offload.module, recurse=True):
146149
param.ds_active_sub_modules.clear()
@@ -170,7 +173,13 @@ def add_hooks(model: "DeepSpeedEngine") -> None:
170173
optimizer_offload = model.optimizer.parameter_offload
171174
elif model.optimizer is not None:
172175
optimizer_offload = model.optimizer
173-
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
176+
else:
177+
raise RuntimeError("The model optimizer is None, which is not yet supported.")
178+
if version.parse(deepspeed.__version__) >= version.parse("0.16.4"):
179+
# Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
180+
optimizer_offload._register_deepspeed_module(optimizer_offload.module)
181+
else:
182+
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
174183

175184

176185
@contextmanager

0 commit comments

Comments
 (0)