Skip to content

Commit 38ef847

Browse files
sergiopaniegoBjarniHaukur
authored andcommitted
💎 Gemma 3 VLM SFT example script for single-image and multi-image (huggingface#3131)
Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> log answer key to wandb all Table HTML logging table bump patch hmm formatting html esacape reward isnt string [Liger] Liger KTO support (huggingface#2812) Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> 🏃 Migrate CI to self-hosted runners (huggingface#3174) ❤️‍🩹 [CI] fix transformers dev CI failure (huggingface#3176) Co-authored-by: Quentin Gallouédec <[email protected]> ⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint (huggingface#3148) Co-authored-by: Quentin Gallouédec <[email protected]> 📎 Fix is_clipped to compute the effective clip_ratio (huggingface#3175) Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Fix breaking typo for flash_attention reducing_memory_usage.md (huggingface#3190) Show unique prompts in GRPO WandB tables (huggingface#3191) 🐗 [CI] Fix trufflehog false positives (huggingface#3192) [GRPO] Improve completion length logging (huggingface#3188) preliminary openai compatible endpoint early concept, needs refining dedupe debug print some slop to work on unslop, missing hist almost valid pseudocode middle-ware monkey patch in mp.Pool()... remove unused More accurate .md need gpu renting lambda again much nicer small aider-chat and datasets conflict risky reqs change should work, but hacky some insights, but monkeypatching probably wont suffice refactor: Rewrite test script to use SWE-bench dataset with MultiProcessAider refactor: Remove logging statements from test.py one step closer finally, the correct abstraction doc todo unslop unslop undo accidental black cleaner abstraction new abstraction
1 parent d625c55 commit 38ef847

File tree

8 files changed

+514
-36
lines changed

8 files changed

+514
-36
lines changed

PLAN.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# PLAN.md
2+
3+
## Agent Manager Architecture
4+
5+
### Core Concept
6+
The `AgentManager` coordinates ephemeral agents that exist only for the duration of a single training example, replacing `vllm_client.generate()` in the GRPO trainer with an orchestration layer that captures full conversation histories for more effective reinforcement learning.
7+
8+
### Current Implementation
9+
10+
1. **API Middleware Proxy**
11+
- Lightweight FastAPI server that intercepts API calls between agents and vLLM
12+
- Injects and tracks conversation via custom `X-Agent-ID` headers
13+
- Maintains thread-safe conversation history per agent
14+
- Captures complete request/response pairs for RL training signal
15+
16+
2. **Multiprocessing Approach**
17+
- Uses `multiprocessing.Pool` for parallel agent execution
18+
- Each agent runs in an isolated process to prevent state contamination
19+
- Monkey-patches the `requests` library to inject agent identification headers
20+
- Process isolation ensures clean environment for each agent instance
21+
22+
3. **Agent Deployment Flow**
23+
```
24+
GRPO Training Step
25+
└── AgentManager.deploy(prompts)
26+
├── Generate unique agent_id for each prompt
27+
├── Deploy agents via multiprocessing.Pool
28+
│ ├── Each process runs _process_one(agent_id, prompt)
29+
│ │ ├── Monkey-patch requests to add X-Agent-ID
30+
│ │ └── Call process_one() (e.g., Aider instance)
31+
│ └── Multiple agents run in parallel
32+
├── API Proxy tracks all vLLM interactions
33+
├── Await completion with timeout
34+
├── Collect conversation histories for all agent_ids
35+
└── Return structured completions to GRPO trainer
36+
```
37+
38+
4. **GRPO Integration**
39+
- GRPO trainer uses AgentManager.deploy() for generating completions
40+
- Should properly convert agent completions to token IDs for the training loop
41+
- Maintains compatibility with both direct vLLM and agent-based generation
42+
43+
## Challenges and Solutions
44+
45+
### Conversation Tracking Challenges
46+
47+
1. **Asynchronous API Calls**
48+
- Agents make varying numbers of API calls at unpredictable times
49+
- Solution: Thread-safe conversation tracking with unique agent IDs
50+
- Thread-safe locking ensures proper history capture even with concurrent requests
51+
52+
2. **Process Management**
53+
- Challenge: Ensuring clean process termination and resource cleanup
54+
- Solution: Pool-based multiprocessing with timeout handling
55+
- Proper cleanup in finally blocks ensures resources are released
56+
57+
3. **Proxy Synchronization**
58+
- Challenge: Background tasks in FastAPI may create race conditions
59+
- Solution: Consider making conversation tracking synchronous in the API endpoint
60+
- More robust synchronization mechanisms for production environments
61+
62+
4. **Conversation Continuity**
63+
- Challenge: Ensuring continuous context across multiple API calls
64+
- Solution: Implement validation in the ConversationTracker
65+
- Track and report potential discontinuities that could indicate information loss
66+
67+
### Technical Considerations
68+
69+
1. **Monkey-Patching Approach**
70+
- Current: Patch `requests.request` in each worker process to add custom headers
71+
- Pros: Isolated impact, minimal invasiveness to agent frameworks
72+
- Alternative: Require direct configuration of agent framework
73+
74+
2. **Conversation Collection**
75+
- Current: API Proxy collects all conversations by agent_id
76+
- Challenge: Ensuring all API calls are captured before retrieving history
77+
- Solution: Consider small delay or synchronization primitive before retrieval
78+
79+
3. **Error Handling**
80+
- Challenge: Individual agent failures shouldn't crash the entire batch
81+
- Solution: Improved error handling in AgentManager.deploy()
82+
- Graceful degradation for failed agents while allowing others to continue
83+
84+
## Conclusions and Next Steps
85+
86+
The current implementation successfully achieves:
87+
88+
1. **Process Isolation**: Clean separation of agent environments
89+
2. **Conversation Tracking**: Complete history capture for RL training
90+
3. **Parallel Execution**: Efficient handling of multiple agents
91+
4. **Resource Management**: Proper cleanup of temporary resources
92+
93+
Next development priorities:
94+
95+
1. **Implement ConversationTracker.get_completion_history()**: Properly extract and format the complete history
96+
2. **Address race conditions**: Ensure background tasks complete before history retrieval
97+
3. **Enhance error handling**: Improve robustness to individual agent failures
98+
4. **Performance optimization**: Evaluate and optimize latency introduced by the proxy
99+
5. **Testing**: Develop comprehensive tests for conversation tracking accuracy

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# TRL Fork: Agent-In-The-Loop Reinforcement Trainer (AITLRT)
2+
3+
This fork enhances the TRL (Transformer Reinforcement Learning) library with agentic capabilities, focusing on training and reinforcing multi-turn coding agents:
4+
5+
- **OpenAI-Compatible vLLM Endpoint**: Drop-in replacement for OpenAI API enabling seamless integration with existing tools and agents
6+
- **Direct Agent Integration**: Use existing agent scaffolding and applications directly in the training loop without modification
7+
- **Enterprise-Ready Solutions**: Leverage production-ready agentic frameworks rather than building custom implementations
8+
- **Parallel Agent Execution**: Run multiple instances of the same agent architecture in parallel during training
9+
10+
For original TRL documentation, see below.
11+
12+
---
13+
114
# TRL - Transformer Reinforcement Learning
215

316
<div style="text-align: center">

trl/cli.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from .scripts.kto import make_parser as make_kto_parser
2626
from .scripts.sft import make_parser as make_sft_parser
2727
from .scripts.utils import TrlParser
28-
from .scripts.vllm_serve import main as vllm_serve_main
29-
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser
28+
from .scripts.vllm_serve_sync import main as vllm_serve_main
29+
from .scripts.vllm_serve_sync import make_parser as make_vllm_serve_parser
3030

3131

3232
def main():
@@ -93,7 +93,10 @@ def main():
9393
elif args.command == "vllm-serve":
9494
(script_args,) = parser.parse_args_and_config()
9595
vllm_serve_main(script_args)
96+
97+
# Make the vllm-serve-openai-endpoint subparser
9698

9799

98100
if __name__ == "__main__":
99101
main()
102+

trl/extras/vllm_client.py

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import atexit
1616
import logging
1717
import time
18-
from typing import Optional
18+
from typing import Any, Optional
19+
from abc import ABC, abstractmethod
1920

2021
import torch
2122
from torch import nn
@@ -36,7 +37,7 @@
3637
logger = logging.getLogger(__name__)
3738

3839

39-
class VLLMClient:
40+
class VLLMClient(ABC):
4041
"""
4142
A client class to interact with a vLLM server.
4243
@@ -131,24 +132,21 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
131132

132133
def generate(
133134
self,
134-
prompts: list[str],
135-
n: int = 1,
135+
data: list[dict[str, Any]],
136136
repetition_penalty: float = 1.0,
137137
temperature: float = 1.0,
138138
top_p: float = 1.0,
139139
top_k: int = -1,
140140
min_p: float = 0.0,
141141
max_tokens: int = 16,
142142
guided_decoding_regex: Optional[str] = None,
143-
) -> list[list[int]]:
143+
) -> list[dict[str, Any]]:
144144
"""
145-
Generates model completions for the provided prompts.
145+
Generates model completions for the provided data.
146146
147147
Args:
148-
prompts (`list[str]`):
149-
List of text prompts for which the model will generate completions.
150-
n (`int`, *optional*, defaults to `1`):
151-
Number of completions to generate for each prompt.
148+
data (`list[dict[str, Any]]`):
149+
List of dataset entries.
152150
repetition_penalty (`float`, *optional*, defaults to `1.0`):
153151
Parameter for repetition penalty. 1.0 means no penalty.
154152
temperature (`float`, *optional*, defaults to `1.0`):
@@ -165,28 +163,10 @@ def generate(
165163
Regular expression to guide the decoding process.
166164
167165
Returns:
168-
`list[list[int]]`:
169-
List of lists of token IDs representing the model-generated completions for each prompt.
166+
`list[dict[str, Any]]`:
167+
List of dataset entries with the generated completions added.
170168
"""
171-
url = f"http://{self.host}:{self.server_port}/generate/"
172-
response = self.session.post(
173-
url,
174-
json={
175-
"prompts": prompts,
176-
"n": n,
177-
"repetition_penalty": repetition_penalty,
178-
"temperature": temperature,
179-
"top_p": top_p,
180-
"top_k": top_k,
181-
"min_p": min_p,
182-
"max_tokens": max_tokens,
183-
"guided_decoding_regex": guided_decoding_regex,
184-
},
185-
)
186-
if response.status_code == 200:
187-
return response.json()["completion_ids"]
188-
else:
189-
raise Exception(f"Request failed: {response.status_code}, {response.text}")
169+
pass
190170

191171
def init_communicator(self):
192172
"""
@@ -269,6 +249,77 @@ def close_communicator(self):
269249
else:
270250
if response.status_code != 200:
271251
raise Exception(f"Request failed: {response.status_code}, {response.text}")
252+
253+
class SimpleClient(VLLMClient):
254+
def generate(
255+
self,
256+
data: list[dict[str, Any]],
257+
repetition_penalty: float = 1.0,
258+
temperature: float = 1.0,
259+
top_p: float = 1.0,
260+
top_k: int = -1,
261+
min_p: float = 0.0,
262+
max_tokens: int = 16,
263+
guided_decoding_regex: Optional[str] = None,
264+
) -> list[dict[str, Any]]:
265+
"""
266+
Generates model completions for the provided data.
267+
268+
Args:
269+
data (`list[dict[str, Any]]`):
270+
List of dataset entries.
271+
repetition_penalty (`float`, *optional*, defaults to `1.0`):
272+
Parameter for repetition penalty. 1.0 means no penalty.
273+
temperature (`float`, *optional*, defaults to `1.0`):
274+
Temperature parameter for sampling. Higher values increase diversity.
275+
top_p (`float`, *optional*, defaults to `1.0`):
276+
Top-p sampling parameter.`1.0` means no truncation.
277+
top_k (`int`, *optional*, defaults to `-1`):
278+
Top-k sampling parameter. `-1` means no truncation.
279+
min_p (`float`, *optional*, defaults to `0.0`):
280+
Minimum probability for sampling.
281+
max_tokens (`int`, *optional*, defaults to `16`):
282+
Maximum number of tokens to generate for each prompt.
283+
guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
284+
Regular expression to guide the decoding process.
285+
286+
Returns:
287+
`list[dict[str, Any]]`:
288+
List of dataset entries with the generated completions added.
289+
"""
290+
url = f"http://{self.host}:{self.server_port}/v1/chat/completions"
291+
headers = {"Authorization": "Bearer dummy"}
292+
293+
def get_answer(item):
294+
messages = [
295+
{"role": "system", "content": "You are a helpful AI assistant."},
296+
{"role": "user", "content": item["prompt"]}
297+
]
298+
payload = {
299+
"model": "deployed_model",
300+
"messages": messages,
301+
"temperature": temperature,
302+
"max_tokens": max_tokens,
303+
"repetition_penalty": repetition_penalty,
304+
"top_p": top_p,
305+
"top_k": top_k,
306+
"min_p": min_p,
307+
"stream": False
308+
}
309+
if guided_decoding_regex is not None:
310+
payload["guided_decoding_regex"] = guided_decoding_regex
311+
312+
resp = requests.post(url, json=payload, headers=headers, timeout=timeout)
313+
resp.raise_for_status()
314+
resp_data = resp.json()
315+
return resp_data["choices"][0]["message"]["content"]
316+
317+
with concurrent.futures.ThreadPoolExecutor() as executor:
318+
futures = [executor.submit(get_answer, item) for item in data]
319+
for item, future in zip(data, futures):
320+
item["answer"] = future.result()
321+
322+
return data
272323

273324

274325
# Example usage

trl/import_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
_rich_available = _is_package_available("rich")
3737
_unsloth_available = _is_package_available("unsloth")
3838
_uvicorn_available = _is_package_available("uvicorn")
39+
_uvloop_available = _is_package_available("uvloop")
3940
_vllm_available = _is_package_available("vllm")
4041
_joblib_available = _is_package_available("joblib")
4142

@@ -84,6 +85,10 @@ def is_uvicorn_available() -> bool:
8485
return _uvicorn_available
8586

8687

88+
def is_uvloop_available() -> bool:
89+
return _uvloop_available
90+
91+
8792
def is_vllm_available() -> bool:
8893
return _vllm_available
8994

@@ -99,15 +104,19 @@ class _LazyModule(ModuleType):
99104

100105
# Very heavily inspired by optuna.integration._IntegrationModule
101106
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
102-
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
107+
def __init__(
108+
self, name, module_file, import_structure, module_spec=None, extra_objects=None
109+
):
103110
super().__init__(name)
104111
self._modules = set(import_structure.keys())
105112
self._class_to_module = {}
106113
for key, values in import_structure.items():
107114
for value in values:
108115
self._class_to_module[value] = key
109116
# Needed for autocompletion in an IDE
110-
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
117+
self.__all__ = list(import_structure.keys()) + list(
118+
chain(*import_structure.values())
119+
)
111120
self.__file__ = module_file
112121
self.__spec__ = module_spec
113122
self.__path__ = [os.path.dirname(module_file)]

0 commit comments

Comments
 (0)