Skip to content

Conversation

maxreciprocate
Copy link
Collaborator

WIP on #7 #15

@LouisCastricato
Copy link
Contributor

Make sure that you update your code to follow @shahbuland's method of documentation. We also need to update the read the docs after this merge.

V = vs[:, 1:].squeeze() * terminal_mask
Q_ = rewards + self.gamma * V

if self.two_qs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need comments explaining what this is

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree more comments would be useful

from trlx.pipeline.offline_pipeline import OfflinePipeline


def train(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to exclusively assume offline...? No?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soon I'll add online as well

)

model.learn()
trlx.train(walks, lengths, eval_prompts=eval_prompts, metric_fn=metric_fn, config=config, logit_mask=logit_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should trlx.train return the model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's the plan



if __name__ == "__main__":
walks, logit_mask, metric_fn = generate_random_walks(seed=1000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow I'd like to move the code for generating graph data outside the run file. Perhaps this belongs in some pipeline?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for use examples I think it's better to go without any subclasses

eval_dataloader = self.eval_pipeline.create_loader(
self.config.train.batch_size, shuffle=False
)
train_dataloader = self.train_store.create_loader(self.config.train.batch_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is train_store defined?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't decided on a proper way for now, so they are defined dynamically

for prompts in eval_dataloader:
with torch.no_grad():
samples, _ = self.model.sample(
for beta in self.config.method.betas:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have multiple betas?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to compare against finetune (beta=0)

)

self.model.train()
generate_time = time() - generate_time
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can try to use the Clock object Shabuland made?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't support granular measurements like these

self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare(
self.model, self.opt, self.scheduler, rollout_loader
)
self.store.clear_history()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out when passing a dataloader into deepspeed it is required to be nonempty. I had a hack that loaded a dummy prompt and then clears it once things are loaded. Is this still in the code? I cannot seem to find it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you talking about clear_history call?

self.scheduler.step()
self.iter_count += 1

if self.iter_count % self.config.train.checkpoint_interval == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this can be put in the accelerate base model? My hope is that all models inheriting the accelerateRLModel can use the default training loop with any changes made via the post_batch and post_epoch callback functions

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

* terminal_mask
).sum() / n_nonterminal

loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it makes more sense to put the loss in the nn.module classes or the accelerator trainer classes? I thought we would want loss defined in the accelerator classes but I'm open to something different if you have a strong opinion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opionion, but forward and generate are also specific here and have to be decoupled and I'm not sure there is a reason for that just yet

return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn)


class OfflineRolloutStorage(BaseRolloutStore):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it makes more sense to construct and load our reward labeled datasets in the rollout storage init? I am unsure but I do think the datasets should be separated from the run scripts

)
model.eval_pipeline = OfflinePipeline(model.tokenizer, eval_prompts)

model.learn()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should return the model

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like these changes! Having one simple trlx trainer is a good direction.

The readme will need to be updated when this pr is finished.

Before we commit to master make sure we've tested both the ppo and ilql pipelines on the sentiment task.

@LouisCastricato
Copy link
Contributor

@dmarx do you mind checking out the architecture choices of this PR?

* Had to add py_modules=trlx to setup.

* Added a save strategy.

* Cleaned up a few things.

* Added save_steps to ilql_config.yaml and save steps strategy to accelerate_ilql_model.py for consistency. The save_steps parameter must be set now because of how TrainConfig.from_dict operates. If not save_steps parameter is given in the configs it throws an error.

* Adding mininal changes to enable step based save strategy in configs/ppo_config.yml, trlx/data/configs.py, and trlx/model_accelerate_ppo_model.py

* Some problems crept in despite merge check. This fixes them.

* Realized I am merging into stage-api not main so fixed an issue with ilql_config.yml
count -= 1
nrooms.append(count)

return {'nrooms': nrooms}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be { 'nrooms': [-sample.count(':') for sample in samples] }

@cat-state
Copy link
Collaborator

cat-state commented Oct 18, 2022

Could you make this black formatted?

@cat-state cat-state mentioned this pull request Oct 18, 2022
@maxreciprocate maxreciprocate requested a review from Dahoas October 19, 2022 21:39
@maxreciprocate maxreciprocate marked this pull request as ready for review October 19, 2022 23:48
@maxreciprocate
Copy link
Collaborator Author

This PR mainly addresses

@Dahoas
Copy link
Collaborator

Dahoas commented Oct 20, 2022

Can you attach wandb runs verifying ppo and ilql still perform appropriately? (I know you have them just wanna make this a standard process)

)
return components

def save(self, directory=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of these need doc strings and comments


if self.reward_fn:
rewards = torch.as_tensor(self.reward_fn(samples), dtype=torch.float)
mean_reward = rewards.mean()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs comments


@abstractmethod
def get_arch(config: TRLConfig):
def get_arch(self, config: TRLConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc strings

if self.iter_count % self.config.method.steps_for_target_q_sync == 0:
self.accelerator.unwrap_model(self.model).sync_target_q_heads()

def loss(self, batch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments

"""
Additional exploration can happen here
"""
def post_epoch_callback(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

End in new line

def loss(
self, query_tensors, response_tensors, all_logprobs, all_values, all_rewards
):
def loss(self, batch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments


@torch.inference_mode()
def sample(
def generate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string

return outputs.logits
return outputs

def forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string

self.model = model
self.split_token = split_token

def make_experience(self, samples, rewards):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string and comments


@register_orchestrator
class PPOOrchestrator(Orchestrator):
def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string

init_kl_coef: 0.2 # init kl coefficient
target: 6 # target kl coefficient, set None for fixed kl coef
horizon: 10000 # PPO horizon
gamma: 0.99 # PPO discount
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change?

input_ids: TensorType["query_size"]
attention_mask: TensorType["query_size"]
rewards: TensorType["reward_size"]
states_ixs: TensorType["states_size"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these for?

query_tensors = data.tokens.to(
self.accelerator.device
) # [B, N] #TODO(dahoas): This may need to be changed
def generate(self, input_ids, attention_mask=None, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed we are no longer loading the model into accelerate at init time. If we just want to do large model (>20B) inference do we still need to load with accelerate?

@@ -1,180 +1,179 @@
import os
from typing import Dict, Iterable
from typing import Iterable, Union
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I really cannot review ilql so I trust things are fine here.

Perhaps it would be a good idea to write some unittests for the ilql implementation moving forward.

)
)

def learn(self, log_fn=None, save_fn=None, eval_fn=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the PPO implementation still have multiple ppo_epochs per batch? I see a new variable defining this (n_updated_per_batch) but since we are relying on the base_model's training loop I am not seeing where it gets used.

Perhaps if we do not want to override base_model's learn method we should write something in the post_backward callback

V = vs[:, 1:].squeeze() * terminal_mask
Q_ = rewards + self.gamma * V

if self.two_qs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree more comments would be useful

torch.stack([elem.rewards for elem in elems]),
return PPORLBatch(
# Left padding of already left-padded queries
pad_sequence(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit funny but I suppose necessary

batch_first=True,
).flip(1),
# Right pad the rest, to have a single horizontal query/response split
pad_sequence(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so seeing this now I assume the padded values are handled in loss computation by the attention mask?

@LouisCastricato
Copy link
Contributor

The example in the readme is weird... What is it supposed to do? also the simulcra example is kinda odd too.. no explanation of what it is supposed to do.

@LouisCastricato
Copy link
Contributor

Looks good to me... Ready to merge

@LouisCastricato LouisCastricato merged commit 06cd30f into master Oct 21, 2022
@maxreciprocate maxreciprocate deleted the stage-api branch October 21, 2022 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants