Skip to content

Commit 290f2ab

Browse files
authored
feat: Add ruff formatting checks to CI
1 parent 7146c89 commit 290f2ab

File tree

103 files changed

+5111
-4486
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+5111
-4486
lines changed

.github/workflows/ruff.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Ruff Format Check
2+
3+
on:
4+
pull_request:
5+
branches: [ main ]
6+
push:
7+
branches: [ main ]
8+
9+
jobs:
10+
ruff:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- name: Checkout code
15+
uses: actions/checkout@v4
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v5
19+
with:
20+
python-version: '3.10'
21+
22+
- name: Install ruff
23+
run: pip install ruff
24+
25+
- name: Check formatting with ruff
26+
run: ruff format --check .
27+
28+
- name: Check linting with ruff
29+
run: ruff check .

CONTRIBUTING.md

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
Clone the repository:
44

55
```bash
6-
git clone https://github.com/OpenPipe/agent-reinforcement-training.git
7-
cd agent-reinforcement-training
6+
git clone https://github.com/OpenPipe/ART.git
7+
cd ART
88
```
99

1010
Install the dependencies:
@@ -13,6 +13,26 @@ Install the dependencies:
1313
uv sync
1414
```
1515

16+
### Code Formatting and Linting
17+
18+
This project uses [ruff](https://github.com/astral-sh/ruff) for both code formatting and linting. Before submitting a pull request, please ensure your code passes both checks:
19+
20+
```bash
21+
# Check code formatting
22+
uv run ruff format --check .
23+
24+
# Run linting checks
25+
uv run ruff check .
26+
27+
# To automatically fix formatting issues
28+
uv run ruff format .
29+
30+
# To automatically fix some linting issues
31+
uv run ruff check --fix .
32+
```
33+
34+
These checks are automatically run in CI for all pull requests. You can also install ruff as a pre-commit hook if desired.
35+
1636
Then follow the SkyPilot or Local Training instructions below.
1737

1838
> **Warning:** There is currently a bug with tool use functionality. The issue appears to be that vLLM does not return all the token log probabilities for tool use. Further investigation is needed to determine the exact cause. For now, teaching use case-specific tool use with non-tool use models is the recommended workaround.

dev/new_models/benchmark_inference.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
and requests approximately 1000 output tokens (max_tokens=1000), repeating
66
for 10 iterations. It measures per-request latencies and summarizes statistics.
77
"""
8-
import os
8+
99
import time
1010
import asyncio
1111
import statistics
@@ -14,6 +14,8 @@
1414
from art.local import LocalBackend
1515

1616
load_dotenv()
17+
18+
1719
async def timed_request(client, model_name, prompt, max_tokens, temperature):
1820
"""Execute a single model request and measure elapsed time and token usage."""
1921
start = time.perf_counter()
@@ -31,11 +33,19 @@ async def timed_request(client, model_name, prompt, max_tokens, temperature):
3133
usage = response.usage
3234
prompt_tokens = getattr(usage, "prompt_tokens", None)
3335
completion_tokens = getattr(usage, "completion_tokens", None)
34-
return {"response": response, "elapsed": elapsed, "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens}
36+
return {
37+
"response": response,
38+
"elapsed": elapsed,
39+
"prompt_tokens": prompt_tokens,
40+
"completion_tokens": completion_tokens,
41+
}
42+
3543

3644
async def main():
3745
# Define prompt (approx 1000 input tokens) and model
38-
prompt = ("Hello world. " * 500).strip() + "Please repeat the entire prompt back to me verbatim"
46+
prompt = (
47+
"Hello world. " * 500
48+
).strip() + "Please repeat the entire prompt back to me verbatim"
3949
# Output tokens to request
4050
max_tokens = 1000
4151
temperature = 1.0
@@ -59,7 +69,9 @@ async def main():
5969
per_request_completion_tokens = []
6070

6171
for i in range(1, iterations + 1):
62-
print(f"Iteration {i}/{iterations}: sending {concurrency} concurrent requests...")
72+
print(
73+
f"Iteration {i}/{iterations}: sending {concurrency} concurrent requests..."
74+
)
6375
iteration_start = time.perf_counter()
6476
# launch concurrent requests and time each individually
6577
tasks = [
@@ -92,11 +104,21 @@ async def main():
92104
pr_min = min(per_request_durations) if per_request_durations else 0.0
93105
pr_max = max(per_request_durations) if per_request_durations else 0.0
94106
pr_avg = statistics.mean(per_request_durations) if per_request_durations else 0.0
95-
pr_std = statistics.stdev(per_request_durations) if len(per_request_durations) > 1 else 0.0
96-
avg_prompt_tokens = (statistics.mean(per_request_prompt_tokens)
97-
if per_request_prompt_tokens else None)
98-
avg_completion_tokens = (statistics.mean(per_request_completion_tokens)
99-
if per_request_completion_tokens else None)
107+
pr_std = (
108+
statistics.stdev(per_request_durations)
109+
if len(per_request_durations) > 1
110+
else 0.0
111+
)
112+
avg_prompt_tokens = (
113+
statistics.mean(per_request_prompt_tokens)
114+
if per_request_prompt_tokens
115+
else None
116+
)
117+
avg_completion_tokens = (
118+
statistics.mean(per_request_completion_tokens)
119+
if per_request_completion_tokens
120+
else None
121+
)
100122

101123
# Report results
102124
print("\nInference benchmark results:")
@@ -118,5 +140,6 @@ async def main():
118140
if avg_completion_tokens is not None:
119141
print(f" Avg completion tokens: {avg_completion_tokens:.2f}")
120142

143+
121144
if __name__ == "__main__":
122-
asyncio.run(main())
145+
asyncio.run(main())

dev/new_models/gemma3.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import art
44
from art.local import LocalBackend
55
from dotenv import load_dotenv
6-
import openai
76

87
load_dotenv()
98

9+
1010
async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
1111
messages: art.Messages = [
1212
{
@@ -16,7 +16,10 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
1616
]
1717
client = model.openai_client()
1818
chat_completion = await client.chat.completions.create(
19-
messages=messages, model=model.name, max_tokens=100, timeout=100,
19+
messages=messages,
20+
model=model.name,
21+
max_tokens=100,
22+
timeout=100,
2023
)
2124
choice = chat_completion.choices[0]
2225
content = choice.message.content
@@ -31,6 +34,7 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
3134
reward = 0.0
3235
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)
3336

37+
3438
async def main():
3539
with open("dev/new_models/prompts.json", "r") as f:
3640
prompts = json.load(f)
@@ -61,5 +65,6 @@ async def main():
6165
config=art.TrainConfig(learning_rate=1e-4),
6266
)
6367

68+
6469
if __name__ == "__main__":
65-
asyncio.run(main())
70+
asyncio.run(main())

dev/new_models/qwen3_try.ipynb

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,10 @@
4040
}
4141
],
4242
"source": [
43-
"import asyncio\n",
4443
"import json\n",
4544
"import art\n",
4645
"from art.local import LocalBackend\n",
4746
"from dotenv import load_dotenv\n",
48-
"import openai\n",
4947
"\n",
5048
"load_dotenv()"
5149
]
@@ -65,7 +63,11 @@
6563
" ]\n",
6664
" client = model.openai_client()\n",
6765
" chat_completion = await client.chat.completions.create(\n",
68-
" messages=messages, model=model.name, max_tokens=100, timeout=100, extra_body={\"chat_template_kwargs\": {\"enable_thinking\": False}},\n",
66+
" messages=messages,\n",
67+
" model=model.name,\n",
68+
" max_tokens=100,\n",
69+
" timeout=100,\n",
70+
" extra_body={\"chat_template_kwargs\": {\"enable_thinking\": False}},\n",
6971
" )\n",
7072
" choice = chat_completion.choices[0]\n",
7173
" content = choice.message.content\n",
@@ -399,26 +401,19 @@
399401
"metadata": {},
400402
"outputs": [],
401403
"source": [
402-
"for _ in range(await model.get_step(), 1_000):\n",
404+
"for _ in range(await qwen3.get_step(), 1_000):\n",
403405
" train_groups = await art.gather_trajectory_groups(\n",
404406
" (\n",
405-
" art.TrajectoryGroup(rollout(model, prompt) for _ in range(32))\n",
407+
" art.TrajectoryGroup(rollout(qwen3, prompt) for _ in range(32))\n",
406408
" for prompt in prompts\n",
407409
" ),\n",
408410
" pbar_desc=\"gather\",\n",
409411
" )\n",
410-
" await model.train(\n",
412+
" await qwen3.train(\n",
411413
" train_groups,\n",
412414
" config=art.TrainConfig(learning_rate=1e-4),\n",
413415
" )"
414416
]
415-
},
416-
{
417-
"cell_type": "code",
418-
"execution_count": null,
419-
"metadata": {},
420-
"outputs": [],
421-
"source": []
422417
}
423418
],
424419
"metadata": {

dev/new_models/qwen3_try.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import art
44
from art.local import LocalBackend
55
from dotenv import load_dotenv
6-
import openai
76

87
load_dotenv()
98

9+
1010
async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
1111
messages: art.Messages = [
1212
{
@@ -16,7 +16,11 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
1616
]
1717
client = model.openai_client()
1818
chat_completion = await client.chat.completions.create(
19-
messages=messages, model=model.name, max_tokens=100, timeout=100, extra_body={"chat_template_kwargs": {"enable_thinking": False}},
19+
messages=messages,
20+
model=model.name,
21+
max_tokens=100,
22+
timeout=100,
23+
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
2024
)
2125
choice = chat_completion.choices[0]
2226
content = choice.message.content
@@ -31,6 +35,7 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
3135
reward = 0.0
3236
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)
3337

38+
3439
async def main():
3540
with open("dev/new_models/prompts.json", "r") as f:
3641
prompts = json.load(f)
@@ -57,5 +62,6 @@ async def main():
5762
config=art.TrainConfig(learning_rate=1e-4),
5863
)
5964

65+
6066
if __name__ == "__main__":
61-
asyncio.run(main())
67+
asyncio.run(main())

0 commit comments

Comments
 (0)