1
- from typing import Iterable , Sequence , Union , cast
1
+ from typing import Iterable , Union
2
2
3
3
import torch
4
4
import torch .nn .functional as F
5
5
6
-
7
6
from trlx .model import register_model
8
- from trlx .model .nn .ilql_models import ILQLConfig , CausalLMWithValueHeads
9
- from trlx .data .ilql_types import ILQLBatch
10
- from trlx .data .configs import TRLConfig
11
- from trlx .utils import to_device
7
+ from trlx .model .nn .ilql_models import CausalLMWithValueHeads
12
8
13
9
from .accelerate_base_model import AccelerateRLModel
14
10
17
13
class AccelerateILQLModel (AccelerateRLModel ):
18
14
def __init__ (
19
15
self ,
20
- config : TRLConfig ,
16
+ config ,
21
17
logit_mask = None ,
22
18
metric_fn = None ,
23
19
train_mode = True ,
@@ -26,20 +22,16 @@ def __init__(
26
22
self .logit_mask = logit_mask
27
23
self .metric_fn = metric_fn
28
24
self .reward_fn = None
29
-
30
- if not isinstance (config .method , ILQLConfig ):
31
- raise ValueError ("config.method must be ILQLConfig" )
32
-
33
- self .ilql : ILQLConfig = cast (ILQLConfig , config .method )
25
+ self .params = config .method
34
26
35
27
def get_arch (self , config ):
36
28
return CausalLMWithValueHeads (
37
29
config .model .model_path ,
38
- ilql_config = config .method ,
30
+ params = config .method ,
39
31
num_layers_unfrozen = config .model .num_layers_unfrozen ,
40
32
)
41
33
42
- def tokenize (self , texts : Union [Sequence [str ], Sequence [torch .LongTensor ]]):
34
+ def tokenize (self , texts : Union [Iterable [str ], Iterable [torch .LongTensor ]]):
43
35
if isinstance (texts [0 ], torch .LongTensor ):
44
36
return texts
45
37
@@ -55,17 +47,113 @@ def post_backward_callback(self):
55
47
if self .iter_count % self .config .method .steps_for_target_q_sync == 0 :
56
48
self .accelerator .unwrap_model (self .model ).sync_target_q_heads ()
57
49
58
- def loss (self , batch : ILQLBatch ):
59
- batch = to_device (batch , self .accelerator .device )
50
+ def loss (self , batch ):
51
+ input_ids = batch .input_ids .to (self .accelerator .device )
52
+ attn = batch .attention_mask .to (self .accelerator .device )
53
+ rewards = batch .rewards .to (self .accelerator .device )
54
+ states_ixs = batch .states_ixs .to (self .accelerator .device )
55
+ actions_ixs = batch .actions_ixs .to (self .accelerator .device )
56
+ dones = batch .dones .to (self .accelerator .device )
60
57
61
58
logits , qs , target_qs , vs , _ = self .model (
62
- input_ids = batch . input_ids ,
63
- attention_mask = batch . attention_mask ,
64
- actions_ixs = batch . actions_ixs ,
65
- states_ixs = batch . states_ixs ,
59
+ input_ids = input_ids ,
60
+ attention_mask = attn ,
61
+ actions_ixs = actions_ixs ,
62
+ states_ixs = states_ixs ,
66
63
)
67
64
68
- return self .ilql .loss ((logits , (qs , target_qs , vs )), batch )
65
+ actions = input_ids [:, 1 :].gather (dim = 1 , index = actions_ixs ).unsqueeze (- 1 )
66
+ bsize , ntokens , dsize = logits .shape
67
+
68
+ # compute two separate q-value estimates, to then select minimum values from both
69
+ if self .params .two_qs :
70
+ Q1 = qs [0 ].gather (- 1 , actions ).squeeze (- 1 )
71
+ Q2 = qs [1 ].gather (- 1 , actions ).squeeze (- 1 )
72
+
73
+ targetQ1 = target_qs [0 ].gather (- 1 , actions ).squeeze (- 1 ).detach ()
74
+ targetQ2 = target_qs [1 ].gather (- 1 , actions ).squeeze (- 1 ).detach ()
75
+ targetQ = torch .minimum (targetQ1 , targetQ2 )
76
+ else :
77
+ Q = qs .gather (- 1 , actions ).squeeze (- 1 )
78
+ targetQ = target_qs .gather (- 1 , actions ).squeeze (- 1 ).detach ()
79
+
80
+ terminal_mask = dones [:, :- 1 ]
81
+ n_nonterminal = max (1 , terminal_mask .sum ())
82
+
83
+ # values of current states
84
+ V = vs [:, :- 1 ].squeeze ()
85
+ # values of next states
86
+ Vnext = vs [:, 1 :].squeeze () * dones [:, 1 :]
87
+ # target to fit Q
88
+ Q_ = rewards + self .params .gamma * Vnext .detach ()
89
+
90
+ if self .params .two_qs :
91
+ loss_q1 = ((Q1 - Q_ ) * terminal_mask ).pow (2 ).sum () / n_nonterminal
92
+ loss_q2 = ((Q2 - Q_ ) * terminal_mask ).pow (2 ).sum () / n_nonterminal
93
+ loss_q = loss_q1 + loss_q2
94
+ else :
95
+ loss_q = ((Q - Q_ ) * terminal_mask ).pow (2 ).sum () / n_nonterminal
96
+
97
+ targetQ = targetQ .detach ()
98
+
99
+ loss_v = (
100
+ (
101
+ (targetQ >= V ).int () * self .params .tau * (targetQ - V ).pow (2 )
102
+ + (targetQ < V ).int () * (1 - self .params .tau ) * (targetQ - V ).pow (2 )
103
+ )
104
+ * terminal_mask
105
+ ).sum () / n_nonterminal
106
+
107
+ if self .params .two_qs :
108
+ nactions = qs [0 ].shape [1 ]
109
+ loss_cql_q1 = (
110
+ F .cross_entropy (
111
+ qs [0 ].reshape (- 1 , dsize ),
112
+ actions .reshape (- 1 ),
113
+ reduction = "none" ,
114
+ ).reshape (bsize , nactions )
115
+ * terminal_mask
116
+ ).sum () / n_nonterminal
117
+ loss_cql_q2 = (
118
+ F .cross_entropy (
119
+ qs [1 ].reshape (- 1 , dsize ),
120
+ actions .reshape (- 1 ),
121
+ reduction = "none" ,
122
+ ).reshape (bsize , nactions )
123
+ * terminal_mask
124
+ ).sum () / n_nonterminal
125
+ loss_cql = loss_cql_q1 + loss_cql_q2
126
+ else :
127
+ nactions = qs .shape [1 ]
128
+ loss_cql = (
129
+ F .cross_entropy (
130
+ qs .reshape (- 1 , dsize ), actions .reshape (- 1 ), reduction = "none"
131
+ ).reshape (bsize , nactions )
132
+ * terminal_mask
133
+ ).sum () / n_nonterminal
134
+
135
+ loss_awac = (
136
+ F .cross_entropy (
137
+ logits [:, :- 1 , :].reshape (- 1 , dsize ),
138
+ input_ids [:, 1 :].reshape (- 1 ),
139
+ reduction = "none" ,
140
+ ).reshape (bsize , ntokens - 1 )
141
+ * attn [:, 1 :]
142
+ ).sum () / attn [:, 1 :].sum ()
143
+
144
+ loss = (
145
+ loss_q
146
+ + loss_v
147
+ + self .params .cql_scale * loss_cql
148
+ + self .params .awac_scale * loss_awac
149
+ )
150
+ stats = {
151
+ f"losses/{ k } " : v
152
+ for k , v in locals ().items ()
153
+ if k in ["loss" , "loss_v" , "loss_q" , "loss_cql" , "loss_awac" ]
154
+ }
155
+
156
+ return loss , stats
69
157
70
158
def prepare_learning (self ):
71
159
train_dataloader = self .store .create_loader (self .config .train .batch_size )
0 commit comments