Skip to content

Commit 65aafdd

Browse files
fix(modeling_ilql): single q-head indexing (#471)
* fix(modeling_ilql): single q-head indexing in `.generate` * fix(modeling_ilql): same error for seq2seq --------- Co-authored-by: reciprocated <[email protected]>
1 parent 51832a6 commit 65aafdd

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

trlx/models/modeling_ilql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def generate(
289289
if self.two_qs:
290290
qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :])
291291
else:
292-
qs = target_qs[:, -1, :]
292+
qs = target_qs[0][:, -1, :]
293293

294294
logits = logits[:, -1, :]
295295
vs = vs[:, -1, :]
@@ -469,7 +469,7 @@ def generate(
469469
if self.two_qs:
470470
qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :])
471471
else:
472-
qs = target_qs[:, -1, :]
472+
qs = target_qs[0][:, -1, :]
473473

474474
logits = logits[:, -1, :]
475475
vs = vs[:, -1, :]

0 commit comments

Comments
 (0)