|
1 | 1 | import torch |
2 | 2 | import glob |
3 | 3 | import contextlib |
4 | | -from typing import List, Generator, Tuple, Type |
| 4 | +from typing import List, Generator, Tuple, Type, Protocol |
5 | 5 | from tqdm import tqdm |
6 | 6 | import json |
7 | 7 | import os |
8 | | -from scratchpad.utils import snapshot_download, get_lock, DisabledTqdm |
| 8 | +from scratchpad.utils import ( |
| 9 | + snapshot_download, |
| 10 | + get_lock, |
| 11 | + DisabledTqdm, |
| 12 | + is_pin_memory_available, |
| 13 | +) |
9 | 14 | from safetensors.torch import safe_open |
10 | 15 | from scratchpad.nn.models import ModelRegistry |
11 | 16 | from scratchpad.config import ModelConfig, LoadConfig |
12 | 17 | from scratchpad.nn.quantization import get_quantization_config, QuantizationConfig |
13 | 18 | import huggingface_hub |
14 | 19 | from torch import nn |
| 20 | +from torch.func import functional_call |
15 | 21 |
|
16 | 22 | _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 |
| 23 | +_CPU_OFFLOAD_BYTES = 0 |
| 24 | +_CPU_OFFLOAD_MAX_BYTES = 0 |
17 | 25 |
|
18 | 26 |
|
19 | 27 | @contextlib.contextmanager |
@@ -208,3 +216,96 @@ def get_quant_config( |
208 | 216 | ) |
209 | 217 |
|
210 | 218 | return quant_cls.from_config(config) |
| 219 | + |
| 220 | + |
| 221 | +def set_cpu_offload_max_bytes(max_bytes: int) -> None: |
| 222 | + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES |
| 223 | + _CPU_OFFLOAD_BYTES = 0 |
| 224 | + _CPU_OFFLOAD_MAX_BYTES = max_bytes |
| 225 | + |
| 226 | + |
| 227 | +def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: |
| 228 | + device = next(module.parameters()).device |
| 229 | + |
| 230 | + if device == torch.device("cpu"): |
| 231 | + return module |
| 232 | + |
| 233 | + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES |
| 234 | + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: |
| 235 | + return module |
| 236 | + |
| 237 | + pin_memory = is_pin_memory_available() |
| 238 | + # offload parameters to CPU |
| 239 | + # use pin_memory if possible, which helps cudagraph capture speed |
| 240 | + offloaded_parameters = False |
| 241 | + for p in module.parameters(): |
| 242 | + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: |
| 243 | + # we use per-parameter offloading |
| 244 | + # one module might have some parameters offloaded and some not |
| 245 | + break |
| 246 | + |
| 247 | + # `torch.empty_like` does not support `pin_memory` argument |
| 248 | + cpu_data = torch.empty_strided( |
| 249 | + size=p.data.size(), |
| 250 | + stride=p.data.stride(), |
| 251 | + dtype=p.data.dtype, |
| 252 | + layout=p.data.layout, |
| 253 | + device="cpu", |
| 254 | + pin_memory=pin_memory, |
| 255 | + ) |
| 256 | + cpu_data.copy_(p.data) |
| 257 | + p.data = cpu_data |
| 258 | + _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() |
| 259 | + offloaded_parameters = True |
| 260 | + |
| 261 | + if offloaded_parameters: |
| 262 | + original_forward = module.forward |
| 263 | + |
| 264 | + def forward(*args, **kwargs): |
| 265 | + module.forward = original_forward |
| 266 | + device_state = { |
| 267 | + # here we blindly call `to(device)` |
| 268 | + # if the parameter is already on the device, it will be a no-op |
| 269 | + k: v.to(device, non_blocking=True) |
| 270 | + for k, v in module.state_dict().items() |
| 271 | + } |
| 272 | + output = functional_call(module, device_state, args=args, kwargs=kwargs) |
| 273 | + module.forward = forward |
| 274 | + return output |
| 275 | + |
| 276 | + module.forward = forward |
| 277 | + |
| 278 | + return module |
| 279 | + |
| 280 | + |
| 281 | +class LayerFn(Protocol): |
| 282 | + def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: |
| 283 | + ... |
| 284 | + |
| 285 | + |
| 286 | +def add_prefix(name: str, prefix: str) -> str: |
| 287 | + """Add a weight path prefix to a module name. |
| 288 | +
|
| 289 | + Args: |
| 290 | + name: base module name. |
| 291 | + prefix: weight prefix str to added to the front of `name` concatenated with `.`. |
| 292 | +
|
| 293 | + Returns: |
| 294 | + The string `prefix.name` if prefix is non-empty, otherwise just `name`. |
| 295 | + """ |
| 296 | + return name if not prefix else f"{prefix}.{name}" |
| 297 | + |
| 298 | + |
| 299 | +def make_layers( |
| 300 | + num_hidden_layers: int, |
| 301 | + layer_fn: LayerFn, |
| 302 | + prefix: str = "", |
| 303 | +) -> Tuple[int, int, torch.nn.ModuleList]: |
| 304 | + """Make a list of layers with the given layer function""" |
| 305 | + modules = torch.nn.ModuleList( |
| 306 | + [ |
| 307 | + maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix))) |
| 308 | + for idx in range(num_hidden_layers) |
| 309 | + ] |
| 310 | + ) |
| 311 | + return modules |
0 commit comments