Skip to content

Commit 3900999

Browse files
authored
rerun #89 (#92)
In this, we refactor out the ILQL loss function and model additions so they can be reused with other accelerator libraries. I also refactored the loss to be slightly clearer and fixed some type errors. First part of #75 W&B run: https://wandb.ai/carperai/trlx/runs/3tam2www
1 parent ea38a94 commit 3900999

File tree

10 files changed

+241
-239
lines changed

10 files changed

+241
-239
lines changed

configs/ilql_config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
model:
2-
model_path : "gpt2"
2+
model_path: "gpt2"
33
tokenizer_path: "gpt2"
4-
model_type : "ILQLModel"
4+
model_type: "ILQLModel"
55
num_layers_unfrozen: -1
66

77
train:
@@ -19,8 +19,8 @@ train:
1919
checkpoint_interval: 1000
2020
eval_interval: 128
2121

22-
pipeline : "OfflinePipeline"
23-
orchestrator : "OfflineOrchestrator"
22+
pipeline: "OfflinePipeline"
23+
orchestrator: "OfflineOrchestrator"
2424
seed: 1000
2525

2626
method:

examples/ilql_sentiments.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# Generates positive movie reviews by learning from sentiment-labeled IMDB dataset
2-
31
from datasets import load_dataset
42
from transformers import pipeline
53

trlx/data/ilql_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22

3-
from torchtyping import TensorType
3+
from torchtyping import TensorType # type: ignore
44

55

66
@dataclass

trlx/data/method_configs.py

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Dict, List
44

55
# specifies a dictionary of method configs
6-
_METHODS: Dict[str, any] = {} # registry
6+
_METHODS: Dict[str, Any] = {} # registry
77

88

99
def register_method(name):
@@ -28,17 +28,6 @@ def register_class(cls, name):
2828
return cls
2929

3030

31-
def get_method(name: str) -> Callable:
32-
"""
33-
Return constructor for specified method config
34-
"""
35-
name = name.lower()
36-
if name in _METHODS:
37-
return _METHODS[name]
38-
else:
39-
raise Exception("Error: Trying to access a method that has not been registered")
40-
41-
4231
@dataclass
4332
@register_method
4433
class MethodConfig:
@@ -56,36 +45,12 @@ def from_dict(cls, config: Dict[str, Any]):
5645
return cls(**config)
5746

5847

59-
@dataclass
60-
@register_method
61-
class ILQLConfig(MethodConfig):
48+
def get_method(name: str) -> MethodConfig:
6249
"""
63-
Config for ILQL method
64-
65-
:param tau: Control tradeoff in value loss between punishing value network for underestimating the target Q (i.e. Q value corresponding to the action taken) (high tau) and overestimating the target Q (low tau)
66-
:type tau: float
67-
68-
:param gamma: Discount factor for future rewards
69-
:type gamma: float
70-
71-
:param cql_scale: Weight for CQL loss term
72-
:type cql_scale: float
73-
74-
:param awac_scale: Weight for AWAC loss term
75-
:type awac_scale: float
76-
77-
:param steps_for_target_q_sync: Number of steps to wait before syncing target Q network with Q network
78-
:type steps_for_target_q_sync: int
79-
80-
:param two_qs: Use minimum of two Q-value estimates
81-
:type two_qs: bool
50+
Return constructor for specified method config
8251
"""
83-
84-
tau: float
85-
gamma: float
86-
cql_scale: float
87-
awac_scale: float
88-
alpha: float
89-
steps_for_target_q_sync: int
90-
betas: List[float]
91-
two_qs: bool
52+
name = name.lower()
53+
if name in _METHODS:
54+
return _METHODS[name]
55+
else:
56+
raise Exception("Error: Trying to access a method that has not been registered")

trlx/model/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import sys
33
from abc import abstractmethod
4-
from typing import Callable, Dict, Iterable
4+
from typing import Any, Callable, Dict, Iterable
55

66
import torch
77

@@ -11,11 +11,11 @@
1111
from trlx.utils import safe_mkdir
1212

1313
# specifies a dictionary of architectures
14-
_MODELS: Dict[str, any] = {} # registry
14+
_MODELS: Dict[str, Any] = {} # registry
1515

1616

1717
def register_model(name):
18-
"""Decorator used register a CARP architecture
18+
"""Decorator used register an architecture
1919
Args:
2020
name: Name of the architecture
2121
"""
@@ -46,6 +46,10 @@ def __init__(self, config: TRLConfig, train_mode=False):
4646
def push_to_store(self, data):
4747
self.store.push(data)
4848

49+
def add_eval_pipeline(self, eval_pipeline):
50+
"""Adds pipeline from with validation prompts"""
51+
self.eval_pipeline = eval_pipeline
52+
4953
@abstractmethod
5054
def act(self, data: RLElement) -> RLElement:
5155
"""
@@ -92,7 +96,7 @@ def learn(
9296
pass
9397

9498
@abstractmethod
95-
def get_components(self) -> Dict[str, any]:
99+
def get_components(self) -> Dict[str, Any]:
96100
"""
97101
Get pytorch components (mainly for saving/loading)
98102
"""

trlx/model/accelerate_base_model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import os
33
from abc import abstractmethod
44
from time import time
5-
from typing import Dict, Iterable, Tuple
5+
from typing import Any, Dict, Iterable, Sequence, Tuple, Union
66

77
import torch
88
import torch.nn.functional as F
9-
from accelerate import Accelerator
9+
from accelerate import Accelerator # type: ignore
1010
from transformers import AutoTokenizer
1111

1212
import wandb
@@ -92,10 +92,13 @@ def __init__(self, config, train_mode=True):
9292
eta_min=self.config.train.lr_target,
9393
)
9494

95-
def tokenize(self, text: Iterable[str]):
95+
def tokenize(self, text: Union[Sequence[str], Sequence[torch.LongTensor]]):
9696
"""
9797
Tokenize a batch of text after adding bos token to each of the samples
9898
"""
99+
if isinstance(text[0], torch.LongTensor):
100+
return text
101+
99102
text = [self.tokenizer.bos_token + txt for txt in text]
100103
return self.tokenizer(
101104
text,
@@ -117,7 +120,7 @@ def generate(self, input_ids, attention_mask=None, **kwargs):
117120
input_ids=input_ids, attention_mask=attention_mask, **kwargs
118121
)
119122

120-
def get_components(self) -> Dict[str, any]:
123+
def get_components(self) -> Dict[str, Any]:
121124
components = (
122125
{"model": self.model, "opt": self.opt, "scheduler": self.scheduler}
123126
if self.train_mode

trlx/model/accelerate_ilql_model.py

Lines changed: 21 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
from typing import Iterable, Union
1+
from typing import Iterable, Sequence, Union, cast
22

33
import torch
44
import torch.nn.functional as F
55

6+
67
from trlx.model import register_model
7-
from trlx.model.nn.ilql_models import CausalLMWithValueHeads
8+
from trlx.model.nn.ilql_models import ILQLConfig, CausalLMWithValueHeads
9+
from trlx.data.ilql_types import ILQLBatch
10+
from trlx.data.configs import TRLConfig
11+
from trlx.utils import to_device
812

913
from .accelerate_base_model import AccelerateRLModel
1014

@@ -13,7 +17,7 @@
1317
class AccelerateILQLModel(AccelerateRLModel):
1418
def __init__(
1519
self,
16-
config,
20+
config: TRLConfig,
1721
logit_mask=None,
1822
metric_fn=None,
1923
train_mode=True,
@@ -22,16 +26,20 @@ def __init__(
2226
self.logit_mask = logit_mask
2327
self.metric_fn = metric_fn
2428
self.reward_fn = None
25-
self.params = config.method
29+
30+
if not isinstance(config.method, ILQLConfig):
31+
raise ValueError("config.method must be ILQLConfig")
32+
33+
self.ilql: ILQLConfig = cast(ILQLConfig, config.method)
2634

2735
def get_arch(self, config):
2836
return CausalLMWithValueHeads(
2937
config.model.model_path,
30-
params=config.method,
38+
ilql_config=config.method,
3139
num_layers_unfrozen=config.model.num_layers_unfrozen,
3240
)
3341

34-
def tokenize(self, texts: Union[Iterable[str], Iterable[torch.LongTensor]]):
42+
def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]):
3543
if isinstance(texts[0], torch.LongTensor):
3644
return texts
3745

@@ -47,113 +55,17 @@ def post_backward_callback(self):
4755
if self.iter_count % self.config.method.steps_for_target_q_sync == 0:
4856
self.accelerator.unwrap_model(self.model).sync_target_q_heads()
4957

50-
def loss(self, batch):
51-
input_ids = batch.input_ids.to(self.accelerator.device)
52-
attn = batch.attention_mask.to(self.accelerator.device)
53-
rewards = batch.rewards.to(self.accelerator.device)
54-
states_ixs = batch.states_ixs.to(self.accelerator.device)
55-
actions_ixs = batch.actions_ixs.to(self.accelerator.device)
56-
dones = batch.dones.to(self.accelerator.device)
58+
def loss(self, batch: ILQLBatch):
59+
batch = to_device(batch, self.accelerator.device)
5760

5861
logits, qs, target_qs, vs, _ = self.model(
59-
input_ids=input_ids,
60-
attention_mask=attn,
61-
actions_ixs=actions_ixs,
62-
states_ixs=states_ixs,
62+
input_ids=batch.input_ids,
63+
attention_mask=batch.attention_mask,
64+
actions_ixs=batch.actions_ixs,
65+
states_ixs=batch.states_ixs,
6366
)
6467

65-
actions = input_ids[:, 1:].gather(dim=1, index=actions_ixs).unsqueeze(-1)
66-
bsize, ntokens, dsize = logits.shape
67-
68-
# compute two separate q-value estimates, to then select minimum values from both
69-
if self.params.two_qs:
70-
Q1 = qs[0].gather(-1, actions).squeeze(-1)
71-
Q2 = qs[1].gather(-1, actions).squeeze(-1)
72-
73-
targetQ1 = target_qs[0].gather(-1, actions).squeeze(-1).detach()
74-
targetQ2 = target_qs[1].gather(-1, actions).squeeze(-1).detach()
75-
targetQ = torch.minimum(targetQ1, targetQ2)
76-
else:
77-
Q = qs.gather(-1, actions).squeeze(-1)
78-
targetQ = target_qs.gather(-1, actions).squeeze(-1).detach()
79-
80-
terminal_mask = dones[:, :-1]
81-
n_nonterminal = max(1, terminal_mask.sum())
82-
83-
# values of current states
84-
V = vs[:, :-1].squeeze()
85-
# values of next states
86-
Vnext = vs[:, 1:].squeeze() * dones[:, 1:]
87-
# target to fit Q
88-
Q_ = rewards + self.params.gamma * Vnext.detach()
89-
90-
if self.params.two_qs:
91-
loss_q1 = ((Q1 - Q_) * terminal_mask).pow(2).sum() / n_nonterminal
92-
loss_q2 = ((Q2 - Q_) * terminal_mask).pow(2).sum() / n_nonterminal
93-
loss_q = loss_q1 + loss_q2
94-
else:
95-
loss_q = ((Q - Q_) * terminal_mask).pow(2).sum() / n_nonterminal
96-
97-
targetQ = targetQ.detach()
98-
99-
loss_v = (
100-
(
101-
(targetQ >= V).int() * self.params.tau * (targetQ - V).pow(2)
102-
+ (targetQ < V).int() * (1 - self.params.tau) * (targetQ - V).pow(2)
103-
)
104-
* terminal_mask
105-
).sum() / n_nonterminal
106-
107-
if self.params.two_qs:
108-
nactions = qs[0].shape[1]
109-
loss_cql_q1 = (
110-
F.cross_entropy(
111-
qs[0].reshape(-1, dsize),
112-
actions.reshape(-1),
113-
reduction="none",
114-
).reshape(bsize, nactions)
115-
* terminal_mask
116-
).sum() / n_nonterminal
117-
loss_cql_q2 = (
118-
F.cross_entropy(
119-
qs[1].reshape(-1, dsize),
120-
actions.reshape(-1),
121-
reduction="none",
122-
).reshape(bsize, nactions)
123-
* terminal_mask
124-
).sum() / n_nonterminal
125-
loss_cql = loss_cql_q1 + loss_cql_q2
126-
else:
127-
nactions = qs.shape[1]
128-
loss_cql = (
129-
F.cross_entropy(
130-
qs.reshape(-1, dsize), actions.reshape(-1), reduction="none"
131-
).reshape(bsize, nactions)
132-
* terminal_mask
133-
).sum() / n_nonterminal
134-
135-
loss_awac = (
136-
F.cross_entropy(
137-
logits[:, :-1, :].reshape(-1, dsize),
138-
input_ids[:, 1:].reshape(-1),
139-
reduction="none",
140-
).reshape(bsize, ntokens - 1)
141-
* attn[:, 1:]
142-
).sum() / attn[:, 1:].sum()
143-
144-
loss = (
145-
loss_q
146-
+ loss_v
147-
+ self.params.cql_scale * loss_cql
148-
+ self.params.awac_scale * loss_awac
149-
)
150-
stats = {
151-
f"losses/{k}": v
152-
for k, v in locals().items()
153-
if k in ["loss", "loss_v", "loss_q", "loss_cql", "loss_awac"]
154-
}
155-
156-
return loss, stats
68+
return self.ilql.loss((logits, (qs, target_qs, vs)), batch)
15769

15870
def prepare_learning(self):
15971
train_dataloader = self.store.create_loader(self.config.train.batch_size)

0 commit comments

Comments
 (0)