Skip to content

Commit 2692447

Browse files
committed
Choices for Str not just list
In the case of the Str (which translates to a StrEnum for choices) the explicit list isinstance check is incorrect. There are many cases where it would be valid to send an `Iterable` such as with `dict.keys()` returning a `dict_keys`. As we do not want to explictly cast to `list` type this commit results in a check of iterable type instead of list.
1 parent ed7cc89 commit 2692447

File tree

4 files changed

+18
-2
lines changed

4 files changed

+18
-2
lines changed

python/cog/predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import types
99
import uuid
1010
from abc import ABC, abstractmethod
11-
from collections.abc import Iterator
11+
from collections.abc import Iterable, Iterator
1212
from pathlib import Path
1313
from typing import (
1414
Any,
@@ -354,7 +354,7 @@ def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any
354354
# In either case, remove it as an extra field because it will be
355355
# passed automatically as 'enum' in the schema
356356
if choices:
357-
if InputType == str and isinstance(choices, list): # noqa: E721
357+
if InputType == str and isinstance(choices, Iterable): # noqa: E721
358358

359359
class StringEnum(str, enum.Enum):
360360
pass
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from cog import BasePredictor, Input
2+
3+
4+
class Predictor(BasePredictor):
5+
def predict(self, text: str = Input(choices={"foo": "x", "bar": "y"}.keys())) -> str:
6+
assert type(text) == str
7+
return text

python/tests/server/test_http_input.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,14 @@ def test_choices_str(client):
214214
assert resp.status_code == 422
215215

216216

217+
@uses_predictor("input_choices_iterable")
218+
def test_choices_str(client):
219+
resp = client.post("/predictions", json={"input": {"text": "foo"}})
220+
assert resp.status_code == 200
221+
resp = client.post("/predictions", json={"input": {"text": "baz"}})
222+
assert resp.status_code == 422
223+
224+
217225
@uses_predictor("input_choices_integer")
218226
def test_choices_int(client):
219227
resp = client.post("/predictions", json={"input": {"x": 1}})

python/tests/server/test_predictor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
PREDICTOR_FIXTURES = [
1515
("input_choices", "Predictor", "predict"),
1616
("input_choices_integer", "Predictor", "predict"),
17+
("input_choices_iterable", "Predictor", "predict"),
1718
("input_file", "Predictor", "predict"),
1819
("function", "predict", "predict"),
1920
("input_ge_le", "Predictor", "predict"),

0 commit comments

Comments
 (0)