7
7
import accelerate
8
8
import deepspeed
9
9
import numpy as np
10
- import torch as th
10
+ import torch
11
11
import torch .nn .functional as F
12
12
import transformers
13
13
from accelerate .utils import compute_module_sizes
14
14
from torch import nn , tensor
15
15
from transformers import AutoConfig , AutoModelForCausalLM , PretrainedConfig
16
16
17
17
18
- def topk_mask (xs : th .FloatTensor , k : int ):
19
- mintop = th .topk (xs , k )[0 ][:, - 1 ].unsqueeze (- 1 )
20
- return th .where (xs < mintop , - np .inf * th .ones_like (xs , dtype = xs .dtype ), xs )
21
-
22
-
23
- class QVOutput (Tuple ):
24
- logits : th .FloatTensor
25
- qs : th .FloatTensor
26
- target_qs : th .FloatTensor
27
- vs : th .FloatTensor
28
- past_key_values : Tuple [th .FloatTensor ]
18
+ def topk_mask (xs : torch .FloatTensor , k : int ):
19
+ mintop = torch .topk (xs , k )[0 ][:, - 1 ].unsqueeze (- 1 )
20
+ return torch .where (xs < mintop , - np .inf * torch .ones_like (xs , dtype = xs .dtype ), xs )
29
21
30
22
31
23
def make_head (n_embd : int , out : int ):
@@ -34,8 +26,12 @@ def make_head(n_embd: int, out: int):
34
26
)
35
27
36
28
37
- class QVModel (nn .Module ):
38
- def __init__ (self , config : Union [PretrainedConfig , str ], params ):
29
+ class CausalLMWithValueHeads (nn .Module ):
30
+ """This is a wrapper around huggingface AutoModelForCausalLM with two additional scalar heads"""
31
+
32
+ def __init__ (
33
+ self , config : Union [PretrainedConfig , str ], params , num_layers_unfrozen = - 1
34
+ ):
39
35
super ().__init__ ()
40
36
41
37
# enable zero3 init within from_pretrained
@@ -49,15 +45,26 @@ def __init__(self, config: Union[PretrainedConfig, str], params):
49
45
else :
50
46
self .gpt = AutoModelForCausalLM .from_pretrained (config )
51
47
52
- for block in self .gpt .transformer .h :
53
- block .requires_grad_ (False )
54
-
55
- if hasattr (self .gpt .config , "hidden_size" ):
48
+ if hasattr (self .gpt , "gpt_neox" ):
49
+ self .gpt .transformer = self .gpt .gpt_neox
50
+ self .gpt .lm_head = self .gpt_embed_out
56
51
self .n_embd = self .gpt .config .hidden_size
52
+ gpt_blocks = self .gpt .gpt_neox .layers
57
53
else :
58
54
self .n_embd = self .gpt .config .n_embd
59
- self .vocab_size = self .gpt .config .vocab_size
55
+ gpt_blocks = self .gpt .transformer .h
56
+
57
+ if num_layers_unfrozen == 0 :
58
+ gpt_blocks_to_freeze = list (gpt_blocks )
59
+ elif num_layers_unfrozen > 0 :
60
+ gpt_blocks_to_freeze = list (gpt_blocks )[:- num_layers_unfrozen ]
61
+ else :
62
+ gpt_blocks_to_freeze = []
63
+
64
+ for m in gpt_blocks_to_freeze :
65
+ m .requires_grad_ (False )
60
66
67
+ self .vocab_size = self .gpt .config .vocab_size
61
68
self .v_head = make_head (self .n_embd , 1 )
62
69
self .q1_head = make_head (self .n_embd , self .vocab_size )
63
70
self .target_q1_head = deepcopy (self .q1_head )
@@ -77,11 +84,7 @@ def __init__(self, config: Union[PretrainedConfig, str], params):
77
84
self .target_q2_head .requires_grad_ (False )
78
85
79
86
def forward (self , ** x ):
80
- if hasattr (self .gpt , "gpt_neox" ):
81
- out = self .gpt .gpt_neox (** x )
82
- else :
83
- out = self .gpt .transformer (** x )
84
-
87
+ out = self .gpt .transformer (** x )
85
88
hs = out .last_hidden_state
86
89
87
90
if self .two_qs :
@@ -91,12 +94,10 @@ def forward(self, **x):
91
94
qs = self .q1_head (hs )
92
95
target_qs = self .target_q1_head (hs )
93
96
94
- if hasattr (self .gpt , "gpt_neox" ):
95
- logits = self .gpt .embed_out (hs )
96
- else :
97
- logits = self .gpt .lm_head (hs )
97
+ logits = self .gpt .lm_head (hs )
98
+ vs = self .v_head (hs )
98
99
99
- return QVOutput (( logits , qs , target_qs , self . v_head ( hs ) , out .past_key_values ))
100
+ return logits , qs , target_qs , vs , out .past_key_values
100
101
101
102
def loss (self , batch ):
102
103
tokens = batch .input_ids .to (self .device )
@@ -115,7 +116,7 @@ def loss(self, batch):
115
116
116
117
targetQ1 = target_qs [0 ][:, :- 1 ].gather (- 1 , actions ).squeeze (- 1 ).detach ()
117
118
targetQ2 = target_qs [1 ][:, :- 1 ].gather (- 1 , actions ).squeeze (- 1 ).detach ()
118
- targetQ = th .minimum (targetQ1 , targetQ2 )
119
+ targetQ = torch .minimum (targetQ1 , targetQ2 )
119
120
else :
120
121
Q = qs [:, :- 1 ].gather (- 1 , actions ).squeeze (- 1 )
121
122
targetQ = target_qs [:, :- 1 ].gather (- 1 , actions ).squeeze (- 1 ).detach ()
@@ -212,7 +213,7 @@ def sync_target_q_heads(self):
212
213
else :
213
214
self ._sync_target_q_heads (self .alpha )
214
215
215
- @th .inference_mode ()
216
+ @torch .inference_mode ()
216
217
def sample (
217
218
self ,
218
219
query ,
@@ -228,32 +229,32 @@ def sample(
228
229
past_key_values = None
229
230
tensors = defaultdict (list )
230
231
231
- finished = th .zeros (input .shape [0 ], 1 , dtype = th .long , device = query .device )
232
+ finished = torch .zeros (input .shape [0 ], 1 , dtype = torch .long , device = query .device )
232
233
233
234
for _ in range (max_length - 1 ):
234
235
logits , _ , target_qs , vs , past_key_values = self .forward (
235
236
input_ids = input , past_key_values = past_key_values
236
237
)
237
238
238
239
if self .two_qs :
239
- qs = th .minimum (target_qs [0 ][:, - 1 ], target_qs [1 ][:, - 1 ])
240
+ qs = torch .minimum (target_qs [0 ][:, - 1 ], target_qs [1 ][:, - 1 ])
240
241
else :
241
242
qs = target_qs [:, - 1 ]
242
243
243
244
logits = logits [:, - 1 ]
244
245
245
246
if logit_mask is not None :
246
- logits [th .where (logit_mask [input [:, - 1 ]])] = - np .inf
247
+ logits [torch .where (logit_mask [input [:, - 1 ]])] = - np .inf
247
248
248
249
adv = qs - vs [:, - 1 , :]
249
250
pi = F .log_softmax (logits , - 1 )
250
251
modpi = topk_mask (pi + beta * adv , top_k )
251
252
ps = F .softmax (modpi / temperature , - 1 )
252
253
253
- tokens = th .multinomial (ps , 1 )
254
+ tokens = torch .multinomial (ps , 1 )
254
255
tokens = (1 - finished ) * tokens + finished * eos_token_id
255
256
256
- query = th .hstack ((query , tokens ))
257
+ query = torch .hstack ((query , tokens ))
257
258
258
259
input = tokens
259
260
finished = (tokens == eos_token_id ).long ()
@@ -265,21 +266,21 @@ def sample(
265
266
266
267
stats = {}
267
268
for name , xs in tensors .items ():
268
- xs = th .vstack (xs )
269
+ xs = torch .vstack (xs )
269
270
stats .update (
270
271
{
271
272
f"{ name } -min" : xs .min (),
272
273
f"{ name } -max" : xs .max (),
273
274
f"{ name } -std" : xs .std (),
274
- f"{ name } -avg " : xs .mean (),
275
+ f"{ name } -mean " : xs .mean (),
275
276
}
276
277
)
277
278
278
279
return query , stats
279
280
280
281
@property
281
282
def dummy_inputs (self ):
282
- return {"input_ids" : th .ones (1 , 1 , device = self .gpt .device , dtype = th .long )}
283
+ return {"input_ids" : torch .ones (1 , 1 , device = self .gpt .device , dtype = torch .long )}
283
284
284
285
@property
285
286
def device (self ):
0 commit comments