Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from transformers.file_utils import add_end_docstrings, is_tf_available, is_torch_available
from transformers.modelcard import ModelCard
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from transformers.utils import logging


Expand Down Expand Up @@ -577,7 +577,9 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
)

def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(
self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
):
"""
Parse arguments and tokenize
"""
Expand All @@ -587,6 +589,7 @@ def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **k
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
truncation=truncation,
)

return inputs
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional, Union

from transformers.file_utils import add_end_docstrings, is_tf_available, is_torch_available
from transformers.tokenization_utils import TruncationStrategy
from transformers.utils import logging

from .base import PIPELINE_INIT_ARGS, Pipeline
Expand Down Expand Up @@ -274,12 +275,14 @@ def __call__(
else:
return output

def _parse_and_tokenize(self, inputs, **kwargs):
def _parse_and_tokenize(
self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
):
"""
Parse arguments and tokenize, adding an EOS token at the end of the user input
"""
# Parse arguments
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", [])
for input in inputs:
input.append(self.tokenizer.eos_token_id)
return inputs
Expand Down
Loading