Skip to content

Commit 1a793ae

Browse files
refactor: make ilql respect the config (#22)
1 parent 8362840 commit 1a793ae

File tree

5 files changed

+65
-76
lines changed

5 files changed

+65
-76
lines changed

examples/ilql_randomwalks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ def reward_fn(samples):
9595

9696
return rewards
9797

98-
gpt_config_or_path = GPT2Config(
98+
config.model.model_path = GPT2Config(
9999
n_layer=4, n_embd=144, vocab_size=logit_mask.shape[0]
100100
)
101101

102102
model = ILQLModel(
103-
config=config, gpt_config_or_path=gpt_config_or_path, logit_mask=logit_mask
103+
config=config, logit_mask=logit_mask
104104
)
105105

106106
orch = OfflineOrchestrator(

examples/ilql_sentiments.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,38 @@
11
import math
2-
from typing import Callable, Iterable, List
2+
from typing import List
33

44
import numpy as np
55
import torch
66
from datasets import load_dataset
7-
from tqdm import tqdm
87
from transformers import AutoTokenizer, pipeline
98

109
from trlx.data.configs import TRLConfig
1110
from trlx.model.accelerate_ilql_model import ILQLModel
1211
from trlx.orchestrator.offline_orchestrator import OfflineOrchestrator
1312

14-
15-
def batch_map(fn: Callable, xs: Iterable, bsize: int, desc=None):
16-
out = []
17-
for ind in tqdm(range(math.ceil(len(xs) / bsize)), desc=desc, disable=not desc):
18-
batch = xs[ind * bsize : min(len(xs), (ind + 1) * bsize)]
19-
out.extend(fn(batch))
20-
21-
return out
22-
23-
2413
if __name__ == "__main__":
2514
config = TRLConfig.load_yaml("configs/ilql_config.yml")
2615
sentiment_pipe = pipeline(
2716
"sentiment-analysis", "lvwerra/distilbert-imdb", device=torch.device(0)
2817
)
2918

30-
gpt_config_or_path = "gpt2"
31-
tokenizer = AutoTokenizer.from_pretrained(gpt_config_or_path)
19+
tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_path)
3220
tokenizer.pad_token = tokenizer.eos_token
3321

3422
def reward_fn(samples: List[str]) -> List[float]:
3523
if isinstance(samples[0], torch.Tensor):
3624
samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
3725

38-
desc = "sentiment pipeline" if len(samples) > 1024 else None
39-
sentiments = batch_map(
40-
lambda batch: sentiment_pipe(batch), samples, bsize=1024, desc=desc
41-
)
42-
return [
43-
1 - s["score"] if s["label"] == "NEGATIVE" else s["score"]
44-
for s in sentiments
45-
]
46-
47-
model = ILQLModel(
48-
config=config, gpt_config_or_path=gpt_config_or_path, tokenizer=tokenizer
49-
)
26+
sent_kwargs = {
27+
"return_all_scores": True,
28+
"function_to_apply": None,
29+
"batch_size": 1024,
30+
}
31+
pipe_outputs = sentiment_pipe(samples, **sent_kwargs)
32+
scores = torch.tensor([output[1]["score"] for output in pipe_outputs])
33+
return scores
34+
35+
model = ILQLModel(config=config, tokenizer=tokenizer)
5036

5137
n_prompts = 128
5238
eval_prompts = torch.tensor([model.tokenizer.bos_token_id] * n_prompts).view(

trlx/model/accelerate_ilql_model.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import numpy as np
55
import torch
66
import torch.nn.functional as F
7+
import wandb
78
from accelerate import Accelerator
89
from torch.utils.data import DataLoader
910
from transformers import AutoConfig, AutoTokenizer
1011

11-
import wandb
1212
from trlx.model import BaseRLModel, register_model
13-
from trlx.model.nn.ilql_models import QVModel
13+
from trlx.model.nn.ilql_models import CausalLMWithValueHeads
1414
from trlx.pipeline.offline_pipeline import (OfflinePipeline,
1515
OfflineRolloutStorage)
1616
from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask
@@ -24,20 +24,23 @@ class ILQLModel(BaseRLModel):
2424
def __init__(
2525
self,
2626
config,
27-
gpt_config_or_path,
2827
tokenizer=None,
2928
logit_mask=None,
3029
train_mode=True,
3130
):
3231
super().__init__(config, train_mode)
3332

34-
self.model = QVModel(gpt_config_or_path, config.method)
33+
self.model = CausalLMWithValueHeads(
34+
config.model.model_path,
35+
params=config.method,
36+
num_layers_unfrozen=config.model.num_layers_unfrozen,
37+
)
3538
self.max_length = config.train.gen_size
36-
self.tokenizer = tokenizer
39+
3740
self.logit_mask = logit_mask
41+
self.tokenizer = tokenizer
3842

3943
self.accelerator = Accelerator(log_with="wandb")
40-
self.accelerator.print(os.environ)
4144

4245
if WORLD_SIZE > 1:
4346
torch.distributed.barrier(device_ids=[LOCAL_RANK])
@@ -121,7 +124,7 @@ def learn(self):
121124
samples = self.accelerator.gather(torch.vstack(all_samples))
122125

123126
if self.accelerator.is_main_process:
124-
rewards = torch.tensor(self.reward_fn(samples), dtype=float)
127+
rewards = torch.as_tensor(self.reward_fn(samples), dtype=float)
125128
reward = rewards.mean()
126129

127130
if self.stats_fn:
@@ -134,7 +137,7 @@ def learn(self):
134137
)
135138
pairs = list(zip(texts, rewards))
136139
logs["samples"] = wandb.Table(
137-
columns=["samples", "reward"], rows=pairs[:16]
140+
columns=["samples", "reward"], rows=pairs[:128]
138141
)
139142
if os.environ.get("DEBUG"):
140143
print(

trlx/model/nn/ilql_models.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,17 @@
77
import accelerate
88
import deepspeed
99
import numpy as np
10-
import torch as th
10+
import torch
1111
import torch.nn.functional as F
1212
import transformers
1313
from accelerate.utils import compute_module_sizes
1414
from torch import nn, tensor
1515
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
1616

1717

18-
def topk_mask(xs: th.FloatTensor, k: int):
19-
mintop = th.topk(xs, k)[0][:, -1].unsqueeze(-1)
20-
return th.where(xs < mintop, -np.inf * th.ones_like(xs, dtype=xs.dtype), xs)
21-
22-
23-
class QVOutput(Tuple):
24-
logits: th.FloatTensor
25-
qs: th.FloatTensor
26-
target_qs: th.FloatTensor
27-
vs: th.FloatTensor
28-
past_key_values: Tuple[th.FloatTensor]
18+
def topk_mask(xs: torch.FloatTensor, k: int):
19+
mintop = torch.topk(xs, k)[0][:, -1].unsqueeze(-1)
20+
return torch.where(xs < mintop, -np.inf * torch.ones_like(xs, dtype=xs.dtype), xs)
2921

3022

3123
def make_head(n_embd: int, out: int):
@@ -34,8 +26,12 @@ def make_head(n_embd: int, out: int):
3426
)
3527

3628

37-
class QVModel(nn.Module):
38-
def __init__(self, config: Union[PretrainedConfig, str], params):
29+
class CausalLMWithValueHeads(nn.Module):
30+
"""This is a wrapper around huggingface AutoModelForCausalLM with two additional scalar heads"""
31+
32+
def __init__(
33+
self, config: Union[PretrainedConfig, str], params, num_layers_unfrozen=-1
34+
):
3935
super().__init__()
4036

4137
# enable zero3 init within from_pretrained
@@ -49,15 +45,26 @@ def __init__(self, config: Union[PretrainedConfig, str], params):
4945
else:
5046
self.gpt = AutoModelForCausalLM.from_pretrained(config)
5147

52-
for block in self.gpt.transformer.h:
53-
block.requires_grad_(False)
54-
55-
if hasattr(self.gpt.config, "hidden_size"):
48+
if hasattr(self.gpt, "gpt_neox"):
49+
self.gpt.transformer = self.gpt.gpt_neox
50+
self.gpt.lm_head = self.gpt_embed_out
5651
self.n_embd = self.gpt.config.hidden_size
52+
gpt_blocks = self.gpt.gpt_neox.layers
5753
else:
5854
self.n_embd = self.gpt.config.n_embd
59-
self.vocab_size = self.gpt.config.vocab_size
55+
gpt_blocks = self.gpt.transformer.h
56+
57+
if num_layers_unfrozen == 0:
58+
gpt_blocks_to_freeze = list(gpt_blocks)
59+
elif num_layers_unfrozen > 0:
60+
gpt_blocks_to_freeze = list(gpt_blocks)[:-num_layers_unfrozen]
61+
else:
62+
gpt_blocks_to_freeze = []
63+
64+
for m in gpt_blocks_to_freeze:
65+
m.requires_grad_(False)
6066

67+
self.vocab_size = self.gpt.config.vocab_size
6168
self.v_head = make_head(self.n_embd, 1)
6269
self.q1_head = make_head(self.n_embd, self.vocab_size)
6370
self.target_q1_head = deepcopy(self.q1_head)
@@ -77,11 +84,7 @@ def __init__(self, config: Union[PretrainedConfig, str], params):
7784
self.target_q2_head.requires_grad_(False)
7885

7986
def forward(self, **x):
80-
if hasattr(self.gpt, "gpt_neox"):
81-
out = self.gpt.gpt_neox(**x)
82-
else:
83-
out = self.gpt.transformer(**x)
84-
87+
out = self.gpt.transformer(**x)
8588
hs = out.last_hidden_state
8689

8790
if self.two_qs:
@@ -91,12 +94,10 @@ def forward(self, **x):
9194
qs = self.q1_head(hs)
9295
target_qs = self.target_q1_head(hs)
9396

94-
if hasattr(self.gpt, "gpt_neox"):
95-
logits = self.gpt.embed_out(hs)
96-
else:
97-
logits = self.gpt.lm_head(hs)
97+
logits = self.gpt.lm_head(hs)
98+
vs = self.v_head(hs)
9899

99-
return QVOutput((logits, qs, target_qs, self.v_head(hs), out.past_key_values))
100+
return logits, qs, target_qs, vs, out.past_key_values
100101

101102
def loss(self, batch):
102103
tokens = batch.input_ids.to(self.device)
@@ -115,7 +116,7 @@ def loss(self, batch):
115116

116117
targetQ1 = target_qs[0][:, :-1].gather(-1, actions).squeeze(-1).detach()
117118
targetQ2 = target_qs[1][:, :-1].gather(-1, actions).squeeze(-1).detach()
118-
targetQ = th.minimum(targetQ1, targetQ2)
119+
targetQ = torch.minimum(targetQ1, targetQ2)
119120
else:
120121
Q = qs[:, :-1].gather(-1, actions).squeeze(-1)
121122
targetQ = target_qs[:, :-1].gather(-1, actions).squeeze(-1).detach()
@@ -212,7 +213,7 @@ def sync_target_q_heads(self):
212213
else:
213214
self._sync_target_q_heads(self.alpha)
214215

215-
@th.inference_mode()
216+
@torch.inference_mode()
216217
def sample(
217218
self,
218219
query,
@@ -228,32 +229,32 @@ def sample(
228229
past_key_values = None
229230
tensors = defaultdict(list)
230231

231-
finished = th.zeros(input.shape[0], 1, dtype=th.long, device=query.device)
232+
finished = torch.zeros(input.shape[0], 1, dtype=torch.long, device=query.device)
232233

233234
for _ in range(max_length - 1):
234235
logits, _, target_qs, vs, past_key_values = self.forward(
235236
input_ids=input, past_key_values=past_key_values
236237
)
237238

238239
if self.two_qs:
239-
qs = th.minimum(target_qs[0][:, -1], target_qs[1][:, -1])
240+
qs = torch.minimum(target_qs[0][:, -1], target_qs[1][:, -1])
240241
else:
241242
qs = target_qs[:, -1]
242243

243244
logits = logits[:, -1]
244245

245246
if logit_mask is not None:
246-
logits[th.where(logit_mask[input[:, -1]])] = -np.inf
247+
logits[torch.where(logit_mask[input[:, -1]])] = -np.inf
247248

248249
adv = qs - vs[:, -1, :]
249250
pi = F.log_softmax(logits, -1)
250251
modpi = topk_mask(pi + beta * adv, top_k)
251252
ps = F.softmax(modpi / temperature, -1)
252253

253-
tokens = th.multinomial(ps, 1)
254+
tokens = torch.multinomial(ps, 1)
254255
tokens = (1 - finished) * tokens + finished * eos_token_id
255256

256-
query = th.hstack((query, tokens))
257+
query = torch.hstack((query, tokens))
257258

258259
input = tokens
259260
finished = (tokens == eos_token_id).long()
@@ -265,21 +266,21 @@ def sample(
265266

266267
stats = {}
267268
for name, xs in tensors.items():
268-
xs = th.vstack(xs)
269+
xs = torch.vstack(xs)
269270
stats.update(
270271
{
271272
f"{name}-min": xs.min(),
272273
f"{name}-max": xs.max(),
273274
f"{name}-std": xs.std(),
274-
f"{name}-avg": xs.mean(),
275+
f"{name}-mean": xs.mean(),
275276
}
276277
)
277278

278279
return query, stats
279280

280281
@property
281282
def dummy_inputs(self):
282-
return {"input_ids": th.ones(1, 1, device=self.gpt.device, dtype=th.long)}
283+
return {"input_ids": torch.ones(1, 1, device=self.gpt.device, dtype=torch.long)}
283284

284285
@property
285286
def device(self):

trlx/pipeline/offline_pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from functools import partial, reduce
21
from typing import Callable, Iterable, Tuple
32

43
import torch

0 commit comments

Comments
 (0)