Skip to content

Commit e820eec

Browse files
kashifclaudesergiopaniego
authored
[ALST/Ulysses] Added ALST/Ulysses documentation (huggingface#4420)
Co-authored-by: Claude <[email protected]> Co-authored-by: Sergio Paniego Blanco <[email protected]>
1 parent d250e4b commit e820eec

File tree

2 files changed

+256
-22
lines changed

2 files changed

+256
-22
lines changed

docs/source/distributing_training.md

Lines changed: 211 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,32 +52,92 @@ Example, these configurations are equivalent, and should yield the same results:
5252
> [!TIP]
5353
> Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration) guide for more details.
5454
55-
## Context Parallelism
55+
## Sequence Parallelism for Long Context Training
5656

57-
Context Parallelism (CP) is a parallelization technique that enables training with longer sequences by splitting the sequence dimension across multiple GPUs. Each GPU processes a portion of the sequence, allowing you to train with sequences longer than what would fit on a single GPU's memory.
57+
Sequence Parallelism (also called Context Parallelism) is a parallelization technique that enables training with longer sequences by splitting the sequence dimension across multiple GPUs. Each GPU processes a portion of the sequence, allowing you to train with sequences longer than what would fit on a single GPU's memory.
5858

59-
For more details on CP, see the [Ultrascale Playbook - Context Parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism).
59+
> [!NOTE]
60+
> **Terminology clarification:** This section describes parallelism techniques for splitting sequences to enable longer context training:
61+
> - **Context Parallelism (CP)**: Splits sequences across GPUs (implemented as Ring Attention with FSDP2)
62+
> - **Sequence Parallelism (SP)**: Another form of sequence splitting (implemented as ALST/Ulysses with DeepSpeed)
63+
>
64+
> Both CP and SP are different from traditional Sequence Parallelism used with Tensor Parallelism (TP+SP) to reduce activation memory. With the techniques here, parallelism dimensions multiply: `TP=2` and `CP=2` would require 4 GPUs (2×2), whereas traditional `TP+SP=2` only needs 2 GPUs as they share the same ranks.
65+
>
66+
> In Accelerate's `ParallelismConfig`:
67+
> - Use `cp_size` with `cp_backend="torch"` for Ring Attention (FSDP2)
68+
> - Use `sp_size` with `sp_backend="deepspeed"` for ALST/Ulysses (DeepSpeed)
6069
61-
CP is particularly useful when:
70+
Sequence parallelism is particularly useful when:
6271

6372
- You want to train with very long sequences (>32k tokens)
6473
- Single GPU memory is insufficient for your desired sequence length
6574
- You need to maintain sequence coherence across the full context
6675

67-
### Requirements and Limitations
76+
### Available Implementations
6877

69-
CP has specific requirements:
78+
TRL supports two sequence parallelism implementations, each with different characteristics:
7079

71-
1. **Accelerate 1.10 or higher** is required
72-
2. **FSDP2 (PyTorch FSDP v2)** is required as the distributed training backend
73-
3. **SDPA attention** - Flash Attention is currently not supported with CP
74-
4. **Sequence length divisibility** - sequences must be divisible by `cp_size * 2`. This is now automatically handled using the `pad_to_multiple_of` parameter in the data collator, which works seamlessly with both standard and padding-free modes.
80+
1. **Ring Attention (FSDP2)** - Uses ring-based communication for memory-efficient processing of extremely long sequences
81+
2. **ALST/Ulysses (DeepSpeed)** - Uses attention head parallelism for faster training with high-bandwidth interconnects
82+
83+
> [!IMPORTANT]
84+
> **Sequence Length Terminology:** When using Context Parallelism, the sequence is split across GPUs, introducing two concepts:
85+
> - **Global sequence length**: The full sequence length before splitting across GPUs
86+
> - **Micro sequence length**: The sequence length per GPU after splitting
87+
>
88+
> In TRL, `max_seq_length` (or `max_length`) refers to the **global sequence length**. The framework automatically handles splitting into micro sequences:
89+
> - **Ring Attention (FSDP2)**: Uses `cp_size` to split sequences. With `max_seq_length=8192` and `cp_size=4`, each GPU processes 2048 tokens.
90+
> - **ALST/Ulysses (DeepSpeed)**: Uses `sp_size` (with `sp_backend="deepspeed"`) to split sequences. With `max_seq_length=8192` and `sp_size=2`, each GPU processes 4096 tokens.
91+
>
92+
> The Trainer automatically accounts for context parallelism when calculating batch sizes and training metrics.
93+
94+
### Choosing Between Ring Attention and Ulysses
95+
96+
The comparison table below highlights the key differences between the two approaches:
97+
98+
| Feature | Ring Attention (FSDP2) | ALST/Ulysses (DeepSpeed) |
99+
|---------|----------|-------------------------|
100+
| **Method** | Ring Self-Attention | Attention Head Parallelism |
101+
| **Backend** | PyTorch FSDP2 | DeepSpeed ZeRO |
102+
| **Attention** | SDPA only | Flash Attention 2 or SDPA |
103+
| **Minimum Accelerate** | 1.11.0+ | 1.12.0+ |
104+
| **Minimum DeepSpeed** | N/A | 0.18.1+ |
105+
| **Sequence Divisibility** | `cp_size * 2` | `sp_size` |
106+
| **Zero Stage** | N/A | ZeRO Stage 1/2/3 |
75107

76-
### Configuration
108+
**Ring Attention is better when:**
109+
- You need to handle extremely long sequences (1M+ tokens)
110+
- The model has limited attention heads (Ring Attention is not constrained by head count)
111+
- You want flexibility in scaling to any sequence length
112+
- Network topology is limited (Ring Attention works with simple P2P ring communication)
113+
114+
**Ulysses is better when:**
115+
- You have high-bandwidth, low-latency interconnects (NVLink, InfiniBand)
116+
- The model has many attention heads that can be split across GPUs
117+
- You want lower communication volume
118+
- You want faster training speed for moderate sequence lengths (up to ~500k tokens)
119+
120+
**Key Trade-offs:**
121+
- **Communication Volume:** Ulysses has lower communication volume, making it more efficient with good interconnects. Ring Attention has higher communication volume but is more flexible with different network topologies.
122+
- **Attention Head Constraints:** Ulysses is limited by the number of attention heads (requires `num_heads >= sp_size`). Ring Attention scales with sequence length regardless of model architecture.
123+
- **Network Sensitivity:** Ulysses all-to-all communication is sensitive to network latency. Ring Attention uses P2P ring communication which is more tolerant of varying network conditions.
124+
125+
For a detailed comparison, see the [Ulysses and Ring Attention blog post](https://huggingface.co/blog/exploding-gradients/ulysses-ring-attention).
126+
127+
### Ring Attention Implementation (FSDP2)
128+
129+
Ring Attention uses a ring-like communication pattern where each GPU processes a portion of the sequence and passes information to the next GPU in the ring.
130+
131+
#### Requirements and Limitations
132+
133+
1. **Accelerate 1.11.0 or higher** is required for Ring Attention / Context Parallelism support
134+
2. **FSDP2 (PyTorch FSDP v2)** is required as the distributed training backend
135+
3. **SDPA attention** - Flash Attention is currently not supported
136+
4. **Sequence length divisibility** - sequences must be divisible by `cp_size * 2`. This is automatically handled using the `pad_to_multiple_of` parameter in the data collator.
77137

78-
To enable CP, you need to configure both Accelerate and your training arguments:
138+
#### Configuration
79139

80-
#### Accelerate Configuration
140+
##### Accelerate Configuration
81141

82142
Use one of the provided accelerate config files (e.g. [`context_parallel_2gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/context_parallel_2gpu.yaml) for 2 GPUs):
83143

@@ -113,7 +173,7 @@ parallelism_config:
113173
parallelism_config_cp_size: 2 # Context parallel size
114174
```
115175
116-
#### Training Configuration
176+
##### Training Configuration
117177
118178
```python
119179
from trl import SFTConfig
@@ -137,7 +197,7 @@ Then, launch your training script with the appropriate accelerate config file:
137197
accelerate launch --config_file context_parallel_2gpu.yaml train.py
138198
```
139199

140-
### Best Practices
200+
#### Best Practices
141201

142202
1. **Use the `pad_to_multiple_of` parameter** - This is now the recommended way to ensure sequence length divisibility:
143203
- For `cp_size=2`: use `pad_to_multiple_of=4` (since `cp_size * 2 = 4`)
@@ -154,9 +214,9 @@ accelerate launch --config_file context_parallel_2gpu.yaml train.py
154214

155215
5. **Monitor memory usage** across all GPUs to ensure balanced workload
156216

157-
### Benchmarking Context Parallelism
217+
#### Benchmarking Ring Attention
158218

159-
We benchmarked CP to highlight its potential improvements in training efficiency.
219+
We benchmarked Ring Attention to highlight its potential improvements in training efficiency.
160220
Our experiments were conducted using **1, 2, 4, and 8 H100 GPUs**, though the results can be extended to larger clusters with more nodes and GPUs.
161221

162222
For the setup, we fine-tuned an **8B model** ([Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B)) using the provided accelerate configuration
@@ -178,12 +238,141 @@ These results show that **Context Parallelism (CP) scales effectively with more
178238
>
179239
> You can learn more and explore configuration examples in the [Accelerate ND-parallelism guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism).
180240
181-
### Further Reading on Context Parallelism
241+
### ALST/Ulysses Implementation (DeepSpeed)
242+
243+
ALST (Arctic Long Sequence Training) / Ulysses uses attention head parallelism to split long sequences across GPUs, working with DeepSpeed's ZeRO optimizer.
244+
245+
> [!NOTE]
246+
> **Technical Note on Parallelism Configuration:**
247+
> - **DeepSpeed ALST/Ulysses** uses `sp_size` with `sp_backend="deepspeed"` in both YAML and Python API
248+
> - **Ring Attention (FSDP2)** uses `cp_size` with `cp_backend="torch"`
249+
>
250+
> The Trainer automatically accounts for both CP and SP when calculating effective batch sizes and training metrics.
251+
252+
#### Requirements and Limitations
253+
254+
1. **DeepSpeed 0.18.1 or higher** is required
255+
2. **Accelerate 1.12.0 or higher** is required for ALST/Ulysses sequence parallelism support
256+
3. **Attention implementation** - Flash Attention 2 recommended (clean output), SDPA works as fallback
257+
4. **Sequence length divisibility** - sequences must be divisible by `sp_size`. Use `pad_to_multiple_of` in your training config.
258+
5. **Parallelism configuration** - You must ensure `dp_replicate_size × dp_shard_size × sp_size = num_processes`
259+
260+
#### Configuration
261+
262+
##### Accelerate Configuration
263+
264+
Use the provided accelerate config file ([`alst_ulysses_4gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/alst_ulysses_4gpu.yaml)):
265+
266+
```yaml
267+
compute_environment: LOCAL_MACHINE
268+
debug: false
269+
deepspeed_config:
270+
zero_stage: 3
271+
seq_parallel_communication_data_type: bf16
272+
distributed_type: DEEPSPEED
273+
mixed_precision: bf16
274+
num_machines: 1
275+
num_processes: 4 # Number of GPUs
276+
parallelism_config:
277+
parallelism_config_dp_replicate_size: 1
278+
parallelism_config_dp_shard_size: 2 # Enables 2D parallelism with SP
279+
parallelism_config_tp_size: 1
280+
parallelism_config_sp_size: 2 # Sequence parallel size
281+
parallelism_config_sp_backend: deepspeed
282+
parallelism_config_sp_seq_length_is_variable: true
283+
parallelism_config_sp_attn_implementation: flash_attention_2
284+
```
285+
286+
##### Training Configuration
287+
288+
```python
289+
from trl import SFTConfig
290+
291+
training_args = SFTConfig(
292+
# required
293+
pad_to_multiple_of=2, # Must equal sp_size
294+
# to get the most out of SP
295+
max_seq_length=4096,
296+
packing=True,
297+
gradient_checkpointing=True,
298+
attn_implementation="flash_attention_2",
299+
per_device_train_batch_size=1,
300+
...
301+
)
302+
```
303+
304+
Then, launch your training script with the appropriate accelerate config file:
305+
306+
```bash
307+
accelerate launch --config_file examples/accelerate_configs/alst_ulysses_4gpu.yaml train.py
308+
```
309+
310+
#### 2D Parallelism
311+
312+
The 4 GPU configuration above automatically enables 2D parallelism by combining Data Parallelism (DP) with Sequence Parallelism (SP). With `sp_size=2` and `dp_shard_size=2`, the 4 GPUs are organized as:
313+
- 2 sequence parallel groups (processing the same data split across sequences)
314+
- 2 data parallel groups (processing different data)
315+
316+
To adjust the parallelism for different GPU counts, modify the YAML config:
317+
318+
| GPUs | sp_size | dp_shard_size | Use Case | YAML Changes |
319+
|------|---------|---------------|----------|--------------|
320+
| 4 | 2 | 2 | Balanced - longer sequences + more data | `num_processes: 4`, `sp_size: 2`, `dp_shard_size: 2` |
321+
| 4 | 4 | 1 | Pure SP for maximum sequence length | `num_processes: 4`, `sp_size: 4`, `dp_shard_size: 1` |
322+
| 8 | 2 | 4 | Large-scale training | `num_processes: 8`, `sp_size: 2`, `dp_shard_size: 4` |
323+
324+
#### Best Practices
325+
326+
1. **Use `pad_to_multiple_of`** to ensure sequences are divisible by `sp_size`
327+
2. **Use Flash Attention 2** for clean output (SDPA works but shows packing warnings)
328+
3. **Start with `sp_size=2`** before scaling to larger values
329+
4. **Use DeepSpeed ZeRO Stage 3** for large models
330+
5. **Combine with memory optimizations** like Liger kernels and gradient checkpointing
331+
6. **Validate parallelism config**: Ensure `dp_replicate_size × dp_shard_size × sp_size = num_processes`
332+
333+
#### Complete Example
334+
335+
Here's how to run ALST/Ulysses training using the built-in [`sft.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) script with 4 GPUs:
336+
337+
```bash
338+
accelerate launch --config_file examples/accelerate_configs/alst_ulysses_4gpu.yaml \
339+
trl/scripts/sft.py \
340+
--model_name_or_path Qwen/Qwen2-0.5B \
341+
--dataset_name trl-lib/Capybara \
342+
--learning_rate 2e-4 \
343+
--max_steps 100 \
344+
--max_seq_length 4096 \
345+
--packing \
346+
--packing_strategy wrapped \
347+
--torch_dtype bfloat16 \
348+
--gradient_checkpointing \
349+
--attn_implementation flash_attention_2 \
350+
--output_dir output-alst-4gpu \
351+
--logging_steps 10 \
352+
--report_to trackio
353+
```
354+
355+
This command automatically:
356+
- Configures 2D parallelism (SP=2, DP=2) across 4 GPUs
357+
- Uses Flash Attention 2 for clean training
358+
- Enables packing with automatic padding to ensure sequence divisibility
359+
- Leverages DeepSpeed ZeRO Stage 3 for memory efficiency
360+
361+
### Further Reading
362+
363+
#### General Resources
364+
- [Hugging Face Blog: Understanding Ulysses and Ring Attention](https://huggingface.co/blog/exploding-gradients/ulysses-ring-attention) - Detailed comparison of Ring Attention vs Ulysses approaches
365+
- [Accelerate: Context Parallelism Guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallelism)
366+
- [Hugging Face Blog: Enabling Long-Context Training with Sequence Parallelism in Axolotl](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl)
367+
368+
#### Ring Attention (FSDP2)
369+
- [Ultrascale Playbook - Context Parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism)
370+
- [Accelerate Example: 128k Sequence Length](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#context-parallelism-128k-sequence-length)
371+
- [Accelerate ND-parallelism Guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism)
182372

183-
- [Accelerate: Context Parallelism Guide](https://github.com/huggingface/accelerate/blob/main/docs/source/concept_guides/context_parallelism.md)
184-
- [Accelerate Example: 128k Sequence Length](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#context-parallelism-128k-sequence-length)
185-
- [Hugging Face Blog: Enabling Long-Context Training with Sequence Parallelism in Axolotl](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl)
186-
- [Snowflake Engineering Blog: Arctic Long Sequence Training (ALST) — Scalable and Efficient Training for Multi-Million Token Sequences (Note that they use a different strategy)](https://www.snowflake.com/en/engineering-blog/arctic-long-sequence-training-multi-million-token-ai/)
373+
#### ALST/Ulysses (DeepSpeed)
374+
- [DeepSpeed Sequence Parallelism Documentation](https://www.deepspeed.ai/tutorials/ds-sequence/)
375+
- [Snowflake Engineering Blog: Arctic Long Sequence Training (ALST)](https://www.snowflake.com/en/engineering-blog/arctic-long-sequence-training-multi-million-token-ai/)
187376

188377
## Multi-Node Training
189378

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# ALST/Ulysses Sequence Parallelism with 2D Parallelism (DP + SP) for 4 GPUs
2+
#
3+
# This configuration enables 2D parallelism:
4+
# - Sequence Parallelism (sp_size=2): Sequences split across 2 GPUs using ALST/Ulysses
5+
# - Data Parallelism (dp_shard_size=2): Model/optimizer sharded across 2 GPUs
6+
# - Total: 4 GPUs (2 × 2)
7+
#
8+
# Set parallelism_config in your training script:
9+
# parallelism_config = ParallelismConfig(
10+
# sp_backend="deepspeed",
11+
# sp_size=2,
12+
# dp_shard_size=2, # Calculated as: num_gpus // sp_size
13+
# sp_handler=DeepSpeedSequenceParallelConfig(...)
14+
# )
15+
16+
compute_environment: LOCAL_MACHINE
17+
debug: false
18+
deepspeed_config:
19+
zero_stage: 3
20+
seq_parallel_communication_data_type: bf16
21+
offload_optimizer_device: none
22+
offload_param_device: none
23+
zero3_init_flag: false
24+
zero3_save_16bit_model: false
25+
distributed_type: DEEPSPEED
26+
downcast_bf16: 'no'
27+
machine_rank: 0
28+
main_training_function: main
29+
mixed_precision: bf16
30+
num_machines: 1
31+
num_processes: 4 # Total number of GPUs
32+
rdzv_backend: static
33+
same_network: true
34+
tpu_env: []
35+
tpu_use_cluster: false
36+
tpu_use_sudo: false
37+
use_cpu: false
38+
parallelism_config:
39+
parallelism_config_dp_replicate_size: 1
40+
parallelism_config_dp_shard_size: 2 # Enables 2D parallelism with SP
41+
parallelism_config_tp_size: 1
42+
parallelism_config_sp_size: 2 # Sequence parallel size
43+
parallelism_config_sp_backend: deepspeed
44+
parallelism_config_sp_seq_length_is_variable: true
45+
parallelism_config_sp_attn_implementation: flash_attention_2

0 commit comments

Comments
 (0)