-
Notifications
You must be signed in to change notification settings - Fork 482
Simplify api #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify api #24
Conversation
Make sure that you update your code to follow @shahbuland's method of documentation. We also need to update the read the docs after this merge. |
V = vs[:, 1:].squeeze() * terminal_mask | ||
Q_ = rewards + self.gamma * V | ||
|
||
if self.two_qs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need comments explaining what this is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree more comments would be useful
from trlx.pipeline.offline_pipeline import OfflinePipeline | ||
|
||
|
||
def train( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to exclusively assume offline...? No?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
soon I'll add online as well
examples/ilql_randomwalks.py
Outdated
) | ||
|
||
model.learn() | ||
trlx.train(walks, lengths, eval_prompts=eval_prompts, metric_fn=metric_fn, config=config, logit_mask=logit_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should trlx.train return the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that's the plan
|
||
|
||
if __name__ == "__main__": | ||
walks, logit_mask, metric_fn = generate_random_walks(seed=1000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow I'd like to move the code for generating graph data outside the run file. Perhaps this belongs in some pipeline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for use examples I think it's better to go without any subclasses
trlx/model/accelerate_ilql_model.py
Outdated
eval_dataloader = self.eval_pipeline.create_loader( | ||
self.config.train.batch_size, shuffle=False | ||
) | ||
train_dataloader = self.train_store.create_loader(self.config.train.batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is train_store defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't decided on a proper way for now, so they are defined dynamically
trlx/model/accelerate_ilql_model.py
Outdated
for prompts in eval_dataloader: | ||
with torch.no_grad(): | ||
samples, _ = self.model.sample( | ||
for beta in self.config.method.betas: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have multiple betas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to compare against finetune (beta=0)
trlx/model/accelerate_ilql_model.py
Outdated
) | ||
|
||
self.model.train() | ||
generate_time = time() - generate_time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we can try to use the Clock object Shabuland made?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it doesn't support granular measurements like these
self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare( | ||
self.model, self.opt, self.scheduler, rollout_loader | ||
) | ||
self.store.clear_history() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out when passing a dataloader into deepspeed it is required to be nonempty. I had a hack that loaded a dummy prompt and then clears it once things are loaded. Is this still in the code? I cannot seem to find it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you talking about clear_history call?
trlx/model/accelerate_ppo_model.py
Outdated
self.scheduler.step() | ||
self.iter_count += 1 | ||
|
||
if self.iter_count % self.config.train.checkpoint_interval == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably this can be put in the accelerate base model? My hope is that all models inheriting the accelerateRLModel can use the default training loop with any changes made via the post_batch and post_epoch callback functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
trlx/model/nn/ilql_models.py
Outdated
* terminal_mask | ||
).sum() / n_nonterminal | ||
|
||
loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it makes more sense to put the loss in the nn.module classes or the accelerator trainer classes? I thought we would want loss defined in the accelerator classes but I'm open to something different if you have a strong opinion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opionion, but forward and generate are also specific here and have to be decoupled and I'm not sure there is a reason for that just yet
trlx/pipeline/offline_pipeline.py
Outdated
return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn) | ||
|
||
|
||
class OfflineRolloutStorage(BaseRolloutStore): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it makes more sense to construct and load our reward labeled datasets in the rollout storage init? I am unsure but I do think the datasets should be separated from the run scripts
) | ||
model.eval_pipeline = OfflinePipeline(model.tokenizer, eval_prompts) | ||
|
||
model.learn() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should return the model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like these changes! Having one simple trlx trainer is a good direction.
The readme will need to be updated when this pr is finished.
Before we commit to master make sure we've tested both the ppo and ilql pipelines on the sentiment task.
@dmarx do you mind checking out the architecture choices of this PR? |
* Had to add py_modules=trlx to setup. * Added a save strategy. * Cleaned up a few things. * Added save_steps to ilql_config.yaml and save steps strategy to accelerate_ilql_model.py for consistency. The save_steps parameter must be set now because of how TrainConfig.from_dict operates. If not save_steps parameter is given in the configs it throws an error. * Adding mininal changes to enable step based save strategy in configs/ppo_config.yml, trlx/data/configs.py, and trlx/model_accelerate_ppo_model.py * Some problems crept in despite merge check. This fixes them. * Realized I am merging into stage-api not main so fixed an issue with ilql_config.yml
2f55b79
to
d0be78a
Compare
7f3a4ca
to
63df70d
Compare
examples/ilql_architext.py
Outdated
count -= 1 | ||
nrooms.append(count) | ||
|
||
return {'nrooms': nrooms} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could be { 'nrooms': [-sample.count(':') for sample in samples] }
Could you make this black formatted? |
4.23.1 complains if .generate() starts with single bos token, when bos=eos=pad token
This PR mainly addresses
|
Can you attach wandb runs verifying ppo and ilql still perform appropriately? (I know you have them just wanna make this a standard process) |
) | ||
return components | ||
|
||
def save(self, directory=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of these need doc strings and comments
|
||
if self.reward_fn: | ||
rewards = torch.as_tensor(self.reward_fn(samples), dtype=torch.float) | ||
mean_reward = rewards.mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs comments
|
||
@abstractmethod | ||
def get_arch(config: TRLConfig): | ||
def get_arch(self, config: TRLConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc strings
if self.iter_count % self.config.method.steps_for_target_q_sync == 0: | ||
self.accelerator.unwrap_model(self.model).sync_target_q_heads() | ||
|
||
def loss(self, batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments
""" | ||
Additional exploration can happen here | ||
""" | ||
def post_epoch_callback(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
End in new line
def loss( | ||
self, query_tensors, response_tensors, all_logprobs, all_values, all_rewards | ||
): | ||
def loss(self, batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments
|
||
@torch.inference_mode() | ||
def sample( | ||
def generate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc string
return outputs.logits | ||
return outputs | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc string
self.model = model | ||
self.split_token = split_token | ||
|
||
def make_experience(self, samples, rewards): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc string and comments
|
||
@register_orchestrator | ||
class PPOOrchestrator(Orchestrator): | ||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc string
configs/ppo_config.yml
Outdated
init_kl_coef: 0.2 # init kl coefficient | ||
target: 6 # target kl coefficient, set None for fixed kl coef | ||
horizon: 10000 # PPO horizon | ||
gamma: 0.99 # PPO discount |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change?
input_ids: TensorType["query_size"] | ||
attention_mask: TensorType["query_size"] | ||
rewards: TensorType["reward_size"] | ||
states_ixs: TensorType["states_size"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are these for?
query_tensors = data.tokens.to( | ||
self.accelerator.device | ||
) # [B, N] #TODO(dahoas): This may need to be changed | ||
def generate(self, input_ids, attention_mask=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed we are no longer loading the model into accelerate at init time. If we just want to do large model (>20B) inference do we still need to load with accelerate?
@@ -1,180 +1,179 @@ | |||
import os | |||
from typing import Dict, Iterable | |||
from typing import Iterable, Union |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately I really cannot review ilql so I trust things are fine here.
Perhaps it would be a good idea to write some unittests for the ilql implementation moving forward.
) | ||
) | ||
|
||
def learn(self, log_fn=None, save_fn=None, eval_fn=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the PPO implementation still have multiple ppo_epochs per batch? I see a new variable defining this (n_updated_per_batch) but since we are relying on the base_model's training loop I am not seeing where it gets used.
Perhaps if we do not want to override base_model's learn method we should write something in the post_backward callback
V = vs[:, 1:].squeeze() * terminal_mask | ||
Q_ = rewards + self.gamma * V | ||
|
||
if self.two_qs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree more comments would be useful
torch.stack([elem.rewards for elem in elems]), | ||
return PPORLBatch( | ||
# Left padding of already left-padded queries | ||
pad_sequence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit funny but I suppose necessary
batch_first=True, | ||
).flip(1), | ||
# Right pad the rest, to have a single horizontal query/response split | ||
pad_sequence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so seeing this now I assume the padded values are handled in loss computation by the attention mask?
The example in the readme is weird... What is it supposed to do? also the simulcra example is kinda odd too.. no explanation of what it is supposed to do. |
Looks good to me... Ready to merge |
WIP on #7 #15