Skip to content

Commit de79b37

Browse files
committed
Revert "Merge pull request #89 from CarperAI/loss-refactor"
This reverts commit da935b0, reversing changes made to ea38a94.
1 parent da935b0 commit de79b37

File tree

10 files changed

+239
-241
lines changed

10 files changed

+239
-241
lines changed

configs/ilql_config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
model:
2-
model_path: "gpt2"
2+
model_path : "gpt2"
33
tokenizer_path: "gpt2"
4-
model_type: "ILQLModel"
4+
model_type : "ILQLModel"
55
num_layers_unfrozen: -1
66

77
train:
@@ -19,8 +19,8 @@ train:
1919
checkpoint_interval: 1000
2020
eval_interval: 128
2121

22-
pipeline: "OfflinePipeline"
23-
orchestrator: "OfflineOrchestrator"
22+
pipeline : "OfflinePipeline"
23+
orchestrator : "OfflineOrchestrator"
2424
seed: 1000
2525

2626
method:

examples/ilql_sentiments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Generates positive movie reviews by learning from sentiment-labeled IMDB dataset
2+
13
from datasets import load_dataset
24
from transformers import pipeline
35

trlx/data/ilql_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22

3-
from torchtyping import TensorType # type: ignore
3+
from torchtyping import TensorType
44

55

66
@dataclass

trlx/data/method_configs.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Dict, List
44

55
# specifies a dictionary of method configs
6-
_METHODS: Dict[str, Any] = {} # registry
6+
_METHODS: Dict[str, any] = {} # registry
77

88

99
def register_method(name):
@@ -28,6 +28,17 @@ def register_class(cls, name):
2828
return cls
2929

3030

31+
def get_method(name: str) -> Callable:
32+
"""
33+
Return constructor for specified method config
34+
"""
35+
name = name.lower()
36+
if name in _METHODS:
37+
return _METHODS[name]
38+
else:
39+
raise Exception("Error: Trying to access a method that has not been registered")
40+
41+
3142
@dataclass
3243
@register_method
3344
class MethodConfig:
@@ -45,12 +56,36 @@ def from_dict(cls, config: Dict[str, Any]):
4556
return cls(**config)
4657

4758

48-
def get_method(name: str) -> MethodConfig:
59+
@dataclass
60+
@register_method
61+
class ILQLConfig(MethodConfig):
4962
"""
50-
Return constructor for specified method config
63+
Config for ILQL method
64+
65+
:param tau: Control tradeoff in value loss between punishing value network for underestimating the target Q (i.e. Q value corresponding to the action taken) (high tau) and overestimating the target Q (low tau)
66+
:type tau: float
67+
68+
:param gamma: Discount factor for future rewards
69+
:type gamma: float
70+
71+
:param cql_scale: Weight for CQL loss term
72+
:type cql_scale: float
73+
74+
:param awac_scale: Weight for AWAC loss term
75+
:type awac_scale: float
76+
77+
:param steps_for_target_q_sync: Number of steps to wait before syncing target Q network with Q network
78+
:type steps_for_target_q_sync: int
79+
80+
:param two_qs: Use minimum of two Q-value estimates
81+
:type two_qs: bool
5182
"""
52-
name = name.lower()
53-
if name in _METHODS:
54-
return _METHODS[name]
55-
else:
56-
raise Exception("Error: Trying to access a method that has not been registered")
83+
84+
tau: float
85+
gamma: float
86+
cql_scale: float
87+
awac_scale: float
88+
alpha: float
89+
steps_for_target_q_sync: int
90+
betas: List[float]
91+
two_qs: bool

trlx/model/__init__.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import sys
33
from abc import abstractmethod
4-
from typing import Any, Callable, Dict, Iterable
4+
from typing import Callable, Dict, Iterable
55

66
import torch
77

@@ -11,11 +11,11 @@
1111
from trlx.utils import safe_mkdir
1212

1313
# specifies a dictionary of architectures
14-
_MODELS: Dict[str, Any] = {} # registry
14+
_MODELS: Dict[str, any] = {} # registry
1515

1616

1717
def register_model(name):
18-
"""Decorator used register an architecture
18+
"""Decorator used register a CARP architecture
1919
Args:
2020
name: Name of the architecture
2121
"""
@@ -46,10 +46,6 @@ def __init__(self, config: TRLConfig, train_mode=False):
4646
def push_to_store(self, data):
4747
self.store.push(data)
4848

49-
def add_eval_pipeline(self, eval_pipeline):
50-
"""Adds pipeline from with validation prompts"""
51-
self.eval_pipeline = eval_pipeline
52-
5349
@abstractmethod
5450
def act(self, data: RLElement) -> RLElement:
5551
"""
@@ -96,7 +92,7 @@ def learn(
9692
pass
9793

9894
@abstractmethod
99-
def get_components(self) -> Dict[str, Any]:
95+
def get_components(self) -> Dict[str, any]:
10096
"""
10197
Get pytorch components (mainly for saving/loading)
10298
"""

trlx/model/accelerate_base_model.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import os
33
from abc import abstractmethod
44
from time import time
5-
from typing import Any, Dict, Iterable, Sequence, Tuple, Union
5+
from typing import Dict, Iterable, Tuple
66

77
import torch
88
import torch.nn.functional as F
9-
from accelerate import Accelerator # type: ignore
9+
from accelerate import Accelerator
1010
from transformers import AutoTokenizer
1111

1212
import wandb
@@ -92,13 +92,10 @@ def __init__(self, config, train_mode=True):
9292
eta_min=self.config.train.lr_target,
9393
)
9494

95-
def tokenize(self, text: Union[Sequence[str], Sequence[torch.LongTensor]]):
95+
def tokenize(self, text: Iterable[str]):
9696
"""
9797
Tokenize a batch of text after adding bos token to each of the samples
9898
"""
99-
if isinstance(text[0], torch.LongTensor):
100-
return text
101-
10299
text = [self.tokenizer.bos_token + txt for txt in text]
103100
return self.tokenizer(
104101
text,
@@ -120,7 +117,7 @@ def generate(self, input_ids, attention_mask=None, **kwargs):
120117
input_ids=input_ids, attention_mask=attention_mask, **kwargs
121118
)
122119

123-
def get_components(self) -> Dict[str, Any]:
120+
def get_components(self) -> Dict[str, any]:
124121
components = (
125122
{"model": self.model, "opt": self.opt, "scheduler": self.scheduler}
126123
if self.train_mode

trlx/model/accelerate_ilql_model.py

Lines changed: 109 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
from typing import Iterable, Sequence, Union, cast
1+
from typing import Iterable, Union
22

33
import torch
44
import torch.nn.functional as F
55

6-
76
from trlx.model import register_model
8-
from trlx.model.nn.ilql_models import ILQLConfig, CausalLMWithValueHeads
9-
from trlx.data.ilql_types import ILQLBatch
10-
from trlx.data.configs import TRLConfig
11-
from trlx.utils import to_device
7+
from trlx.model.nn.ilql_models import CausalLMWithValueHeads
128

139
from .accelerate_base_model import AccelerateRLModel
1410

@@ -17,7 +13,7 @@
1713
class AccelerateILQLModel(AccelerateRLModel):
1814
def __init__(
1915
self,
20-
config: TRLConfig,
16+
config,
2117
logit_mask=None,
2218
metric_fn=None,
2319
train_mode=True,
@@ -26,20 +22,16 @@ def __init__(
2622
self.logit_mask = logit_mask
2723
self.metric_fn = metric_fn
2824
self.reward_fn = None
29-
30-
if not isinstance(config.method, ILQLConfig):
31-
raise ValueError("config.method must be ILQLConfig")
32-
33-
self.ilql: ILQLConfig = cast(ILQLConfig, config.method)
25+
self.params = config.method
3426

3527
def get_arch(self, config):
3628
return CausalLMWithValueHeads(
3729
config.model.model_path,
38-
ilql_config=config.method,
30+
params=config.method,
3931
num_layers_unfrozen=config.model.num_layers_unfrozen,
4032
)
4133

42-
def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]):
34+
def tokenize(self, texts: Union[Iterable[str], Iterable[torch.LongTensor]]):
4335
if isinstance(texts[0], torch.LongTensor):
4436
return texts
4537

@@ -55,17 +47,113 @@ def post_backward_callback(self):
5547
if self.iter_count % self.config.method.steps_for_target_q_sync == 0:
5648
self.accelerator.unwrap_model(self.model).sync_target_q_heads()
5749

58-
def loss(self, batch: ILQLBatch):
59-
batch = to_device(batch, self.accelerator.device)
50+
def loss(self, batch):
51+
input_ids = batch.input_ids.to(self.accelerator.device)
52+
attn = batch.attention_mask.to(self.accelerator.device)
53+
rewards = batch.rewards.to(self.accelerator.device)
54+
states_ixs = batch.states_ixs.to(self.accelerator.device)
55+
actions_ixs = batch.actions_ixs.to(self.accelerator.device)
56+
dones = batch.dones.to(self.accelerator.device)
6057

6158
logits, qs, target_qs, vs, _ = self.model(
62-
input_ids=batch.input_ids,
63-
attention_mask=batch.attention_mask,
64-
actions_ixs=batch.actions_ixs,
65-
states_ixs=batch.states_ixs,
59+
input_ids=input_ids,
60+
attention_mask=attn,
61+
actions_ixs=actions_ixs,
62+
states_ixs=states_ixs,
6663
)
6764

68-
return self.ilql.loss((logits, (qs, target_qs, vs)), batch)
65+
actions = input_ids[:, 1:].gather(dim=1, index=actions_ixs).unsqueeze(-1)
66+
bsize, ntokens, dsize = logits.shape
67+
68+
# compute two separate q-value estimates, to then select minimum values from both
69+
if self.params.two_qs:
70+
Q1 = qs[0].gather(-1, actions).squeeze(-1)
71+
Q2 = qs[1].gather(-1, actions).squeeze(-1)
72+
73+
targetQ1 = target_qs[0].gather(-1, actions).squeeze(-1).detach()
74+
targetQ2 = target_qs[1].gather(-1, actions).squeeze(-1).detach()
75+
targetQ = torch.minimum(targetQ1, targetQ2)
76+
else:
77+
Q = qs.gather(-1, actions).squeeze(-1)
78+
targetQ = target_qs.gather(-1, actions).squeeze(-1).detach()
79+
80+
terminal_mask = dones[:, :-1]
81+
n_nonterminal = max(1, terminal_mask.sum())
82+
83+
# values of current states
84+
V = vs[:, :-1].squeeze()
85+
# values of next states
86+
Vnext = vs[:, 1:].squeeze() * dones[:, 1:]
87+
# target to fit Q
88+
Q_ = rewards + self.params.gamma * Vnext.detach()
89+
90+
if self.params.two_qs:
91+
loss_q1 = ((Q1 - Q_) * terminal_mask).pow(2).sum() / n_nonterminal
92+
loss_q2 = ((Q2 - Q_) * terminal_mask).pow(2).sum() / n_nonterminal
93+
loss_q = loss_q1 + loss_q2
94+
else:
95+
loss_q = ((Q - Q_) * terminal_mask).pow(2).sum() / n_nonterminal
96+
97+
targetQ = targetQ.detach()
98+
99+
loss_v = (
100+
(
101+
(targetQ >= V).int() * self.params.tau * (targetQ - V).pow(2)
102+
+ (targetQ < V).int() * (1 - self.params.tau) * (targetQ - V).pow(2)
103+
)
104+
* terminal_mask
105+
).sum() / n_nonterminal
106+
107+
if self.params.two_qs:
108+
nactions = qs[0].shape[1]
109+
loss_cql_q1 = (
110+
F.cross_entropy(
111+
qs[0].reshape(-1, dsize),
112+
actions.reshape(-1),
113+
reduction="none",
114+
).reshape(bsize, nactions)
115+
* terminal_mask
116+
).sum() / n_nonterminal
117+
loss_cql_q2 = (
118+
F.cross_entropy(
119+
qs[1].reshape(-1, dsize),
120+
actions.reshape(-1),
121+
reduction="none",
122+
).reshape(bsize, nactions)
123+
* terminal_mask
124+
).sum() / n_nonterminal
125+
loss_cql = loss_cql_q1 + loss_cql_q2
126+
else:
127+
nactions = qs.shape[1]
128+
loss_cql = (
129+
F.cross_entropy(
130+
qs.reshape(-1, dsize), actions.reshape(-1), reduction="none"
131+
).reshape(bsize, nactions)
132+
* terminal_mask
133+
).sum() / n_nonterminal
134+
135+
loss_awac = (
136+
F.cross_entropy(
137+
logits[:, :-1, :].reshape(-1, dsize),
138+
input_ids[:, 1:].reshape(-1),
139+
reduction="none",
140+
).reshape(bsize, ntokens - 1)
141+
* attn[:, 1:]
142+
).sum() / attn[:, 1:].sum()
143+
144+
loss = (
145+
loss_q
146+
+ loss_v
147+
+ self.params.cql_scale * loss_cql
148+
+ self.params.awac_scale * loss_awac
149+
)
150+
stats = {
151+
f"losses/{k}": v
152+
for k, v in locals().items()
153+
if k in ["loss", "loss_v", "loss_q", "loss_cql", "loss_awac"]
154+
}
155+
156+
return loss, stats
69157

70158
def prepare_learning(self):
71159
train_dataloader = self.store.create_loader(self.config.train.batch_size)

0 commit comments

Comments
 (0)