Skip to content

Commit 81e935a

Browse files
authored
refactor: remove orchestrator abstraction from API (#289)
* refactor: remove orchestrator abstraction from API * Remove orchestrator in GPT-J config * Add `reward_fn` arg to NeMo constructor to match base trainer API * Initial support for `make_experience` in NeMo ILQL * Run pre-commit * Remove unused sampling util
1 parent eb62d08 commit 81e935a

27 files changed

+466
-607
lines changed

configs/ilql_config.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ train:
88
eval_interval: 100
99

1010
pipeline: "PromptPipeline"
11-
orchestrator: "OfflineOrchestrator"
1211
trainer: "AccelerateILQLTrainer"
1312
seed: 1000
1413

configs/nemo_ilql_config.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ train:
77
eval_interval: 20
88

99
pipeline: "PromptPipeline"
10-
orchestrator: "OfflineOrchestrator"
1110
trainer: "NeMoILQLTrainer"
1211
trainer_kwargs:
1312
pretrained_model: "/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/"

configs/ppo_config.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ train:
88
eval_interval: 100
99

1010
pipeline: "PromptPipeline"
11-
orchestrator: "PPOOrchestrator"
1211
trainer: "AcceleratePPOTrainer"
1312

1413
model:

configs/ppo_gptj.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ train:
88
eval_interval: 16
99

1010
pipeline: "PromptPipeline"
11-
orchestrator: "PPOOrchestrator"
1211
trainer: "AcceleratePPOTrainer"
1312

1413
model:

configs/sft_config.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ train:
88
eval_interval: 100
99

1010
pipeline: "PromptPipeline"
11-
orchestrator: "PPOOrchestrator"
1211
trainer: "AccelerateSFTTrainer"
1312

1413
model:

configs/test_config.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ train:
88
eval_interval: 128 # eval interval
99

1010
pipeline: "PromptPipeline" # prompt pipeline to load
11-
orchestrator: "PPOOrchestrator" # orchestrator to load
1211
trainer: "AcceleratePPOTrainer" # Name of model trainer to load
1312

1413
model:
@@ -36,7 +35,7 @@ scheduler:
3635
method:
3736
name: "ppoconfig" # Name of RL method config
3837
num_rollouts: 128 # Number of rollouts to collect per epoch
39-
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
38+
chunk_size: 128 # Number of rollouts to collect in one loop
4039
ppo_epochs: 4 # Number of ppo epochs
4140
init_kl_coef: 0.2 # init kl coefficient
4241
target: 6 # target kl coefficient, set None for fixed kl coef

docs/source/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ currently supports training using PPO or ILQL for models up to 20B using Acceler
1414

1515
data
1616
models
17-
orchestrator
1817
configs
1918
pipeline
2019
examples

docs/source/orchestrator.rst

Lines changed: 0 additions & 23 deletions
This file was deleted.

docs/source/pipeline.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Pipelines
44
************************
55

66
Pipelines are how you read from a dataset with trlX. Rollout stores are how models store experiences created
7-
for them by the orchestrator. It is these experiences in their rollout store that they are trained on.
7+
for them. It is these experiences in their rollout store that they are trained on.
88

99
**General**
1010

examples/experiments/grounded_program_synthesis/configs/trlx_ppo_config.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ train:
88
eval_interval: 16
99

1010
pipeline: "PromptPipeline"
11-
orchestrator: "PPOOrchestrator"
1211
trainer: "AcceleratePPOTrainer"
1312

1413
model:

0 commit comments

Comments
 (0)