Skip to content

Commit 1394e08

Browse files
authored
Support ONNX export for causal LM sequence classifiers (#27450)
support onnx for causal lm sequence classification
1 parent 06343b0 commit 1394e08

File tree

14 files changed

+14
-14
lines changed

14 files changed

+14
-14
lines changed

src/transformers/models/ctrl/modeling_ctrl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def forward(
796796
sequence_lengths = -1
797797
else:
798798
if input_ids is not None:
799-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
799+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
800800
logits.device
801801
)
802802
else:

src/transformers/models/deprecated/open_llama/modeling_open_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def forward(
924924
sequence_lengths = -1
925925
else:
926926
if input_ids is not None:
927-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
927+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
928928
logits.device
929929
)
930930
else:

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1451,7 +1451,7 @@ def forward(
14511451
sequence_lengths = -1
14521452
else:
14531453
if input_ids is not None:
1454-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1454+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
14551455
logits.device
14561456
)
14571457
else:

src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,7 @@ def forward(
11841184
sequence_lengths = -1
11851185
else:
11861186
if input_ids is not None:
1187-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1187+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
11881188
logits.device
11891189
)
11901190
else:

src/transformers/models/gpt_neo/modeling_gpt_neo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ def forward(
10901090
sequence_lengths = -1
10911091
else:
10921092
if input_ids is not None:
1093-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1093+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
10941094
logits.device
10951095
)
10961096
else:

src/transformers/models/gpt_neox/modeling_gpt_neox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def forward(
948948
sequence_lengths = -1
949949
else:
950950
if input_ids is not None:
951-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
951+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
952952
logits.device
953953
)
954954
else:

src/transformers/models/gptj/modeling_gptj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ def forward(
10011001
sequence_lengths = -1
10021002
else:
10031003
if input_ids is not None:
1004-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1004+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
10051005
logits.device
10061006
)
10071007
else:

src/transformers/models/llama/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ def forward(
12041204
sequence_lengths = -1
12051205
else:
12061206
if input_ids is not None:
1207-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1207+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
12081208
logits.device
12091209
)
12101210
else:

src/transformers/models/mistral/modeling_mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ def forward(
11741174
sequence_lengths = -1
11751175
else:
11761176
if input_ids is not None:
1177-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1177+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
11781178
logits.device
11791179
)
11801180
else:

src/transformers/models/openai/modeling_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def forward(
814814
sequence_lengths = -1
815815
else:
816816
if input_ids is not None:
817-
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
817+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
818818
logits.device
819819
)
820820
else:

0 commit comments

Comments
 (0)