Skip to content

Commit 50a3f6d

Browse files
authored
Fix regression for tuples with mixed primitives (#341)
* Fix regression for tuples with mixed primitives * Docstring * Python 3.8
1 parent 975f782 commit 50a3f6d

File tree

4 files changed

+148
-14
lines changed

4 files changed

+148
-14
lines changed

src/tyro/constructors/_primitive_spec.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def union_rule(
634634

635635
# General unions, eg Union[int, bool]. We'll try to convert these from left to
636636
# right.
637-
option_specs: dict[TypeForm, PrimitiveConstructorSpec] = {}
637+
option_specs: dict[TypeForm[object], PrimitiveConstructorSpec] = {}
638638
choices: tuple[str, ...] | None = ()
639639
nargs: int | tuple[int, ...] | Literal["*"] = 1
640640
first = True
@@ -758,7 +758,9 @@ def str_from_instance(instance: Any) -> list[str]:
758758
nargs=nargs,
759759
metavar=metavar,
760760
instance_from_str=union_instantiator,
761-
is_instance=lambda x: any(spec.is_instance(x) for spec in option_specs),
761+
is_instance=lambda x: any(
762+
spec.is_instance(x) for spec in option_specs.values()
763+
),
762764
str_from_instance=str_from_instance,
763765
choices=None if choices is None else tuple(set(choices)),
764766
)

tests/test_mixed_primitive.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Adapted from @mirceamironenco.
2+
3+
https://github.com/brentyi/tyro/issues/340"""
4+
5+
from __future__ import annotations
6+
7+
from dataclasses import dataclass
8+
9+
from helptext_utils import get_helptext_with_checks
10+
11+
import tyro
12+
13+
14+
@dataclass(frozen=True)
15+
class HFSFTDatasetConfig:
16+
name: str
17+
split: str | None = None
18+
data_files: str | None = None
19+
20+
completions_only: bool = True
21+
packed_seqlen: int | None = None
22+
apply_chat_template: bool = False
23+
max_seq_len: int | None = None
24+
25+
columns: tuple[str, str, str | None, str | None] | None = None
26+
"""instruction, completion, input, system"""
27+
28+
29+
@dataclass
30+
class Config:
31+
train_data: HFSFTDatasetConfig = HFSFTDatasetConfig(
32+
name="yahma/alpaca-cleaned",
33+
completions_only=True,
34+
packed_seqlen=4097,
35+
max_seq_len=2048,
36+
apply_chat_template=True,
37+
columns=("instruction", "output", "input", None),
38+
)
39+
40+
41+
def test_mixed_primitive() -> None:
42+
assert (
43+
tyro.cli(
44+
Config,
45+
args=["--train-data.name", "foo", "--train-data.completions-only"],
46+
).train_data.name
47+
== "foo"
48+
)
49+
assert "None}|{STR STR {None}|STR {None}|STR}" in get_helptext_with_checks(Config)

tests/test_py311_generated/test_add_help_generated.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# mypy: ignore-errors
12
"""Tests for add_help parameter functionality."""
23

34
import dataclasses
@@ -22,15 +23,15 @@ def simple_function(x: int, y: str = "hello") -> str:
2223
return f"{x}:{y}"
2324

2425

25-
def test_cli_add_help_false():
26+
def test_cli_add_help_false() -> None:
2627
"""Test that add_help=False prevents help from being added."""
2728
# Should raise an error when --help is provided with add_help=False
2829
with pytest.raises(SystemExit) as excinfo:
2930
tyro.cli(SimpleConfig, args=["--help"], add_help=False)
3031
assert excinfo.value.code == 2 # Should fail with parsing error, not help display
3132

3233

33-
def test_cli_add_help_true():
34+
def test_cli_add_help_true() -> None:
3435
"""Test that add_help=True (default) allows help."""
3536
# Should show help and exit with code 0
3637
with pytest.raises(SystemExit) as excinfo:
@@ -44,7 +45,7 @@ def test_cli_add_help_true():
4445
assert excinfo.value.code == 0 # Should exit cleanly after showing help
4546

4647

47-
def test_cli_default_add_help():
48+
def test_cli_default_add_help() -> None:
4849
"""Test that add_help defaults to True."""
4950
# Should show help and exit with code 0
5051
with pytest.raises(SystemExit) as excinfo:
@@ -57,14 +58,14 @@ def test_cli_default_add_help():
5758
assert excinfo.value.code == 0
5859

5960

60-
def test_get_parser_add_help_false():
61+
def test_get_parser_add_help_false() -> None:
6162
"""Test that get_parser with add_help=False doesn't add help option."""
6263
parser = tyro.extras.get_parser(SimpleConfig, add_help=False)
6364
assert "-h" not in parser._option_string_actions
6465
assert "--help" not in parser._option_string_actions
6566

6667

67-
def test_get_parser_add_help_true():
68+
def test_get_parser_add_help_true() -> None:
6869
"""Test that get_parser with add_help=True adds help option."""
6970
parser = tyro.extras.get_parser(SimpleConfig, add_help=True)
7071
assert (
@@ -73,7 +74,7 @@ def test_get_parser_add_help_true():
7374
)
7475

7576

76-
def test_get_parser_default_add_help():
77+
def test_get_parser_default_add_help() -> None:
7778
"""Test that get_parser defaults to add_help=True."""
7879
parser = tyro.extras.get_parser(SimpleConfig)
7980
assert (
@@ -82,7 +83,7 @@ def test_get_parser_default_add_help():
8283
)
8384

8485

85-
def test_function_cli_add_help():
86+
def test_function_cli_add_help() -> None:
8687
"""Test add_help works with function targets."""
8788
# Test with add_help=False
8889
result = tyro.cli(simple_function, args=["--x", "42"], add_help=False)
@@ -94,7 +95,7 @@ def test_function_cli_add_help():
9495
assert excinfo.value.code == 2
9596

9697

97-
def test_subcommand_app_add_help():
98+
def test_subcommand_app_add_help() -> None:
9899
"""Test add_help parameter with SubcommandApp."""
99100
from tyro.extras import SubcommandApp
100101

@@ -118,7 +119,7 @@ def cmd2(y: str) -> str:
118119
assert excinfo.value.code == 2
119120

120121

121-
def test_subcommand_cli_from_dict_add_help():
122+
def test_subcommand_cli_from_dict_add_help() -> None:
122123
"""Test add_help parameter with subcommand_cli_from_dict."""
123124

124125
def cmd1(x: int) -> int:
@@ -143,7 +144,7 @@ def cmd2(y: str) -> str:
143144
assert excinfo.value.code == 2
144145

145146

146-
def test_overridable_config_cli_add_help():
147+
def test_overridable_config_cli_add_help() -> None:
147148
"""Test add_help parameter with overridable_config_cli."""
148149

149150
@dataclasses.dataclass
@@ -167,7 +168,7 @@ class Config:
167168
assert excinfo.value.code == 2
168169

169170

170-
def test_subparsers_respect_add_help():
171+
def test_subparsers_respect_add_help() -> None:
171172
"""Test that subparsers inherit the add_help setting from parent parser."""
172173

173174
@dataclasses.dataclass
@@ -199,7 +200,7 @@ class ConfigB:
199200
assert result.value_a == 1
200201

201202

202-
def test_return_unknown_args_with_add_help_false():
203+
def test_return_unknown_args_with_add_help_false() -> None:
203204
"""Test that --help/-h are returned as unknown args when add_help=False and return_unknown_args=True."""
204205

205206
@dataclasses.dataclass
@@ -241,3 +242,36 @@ class Config:
241242
finally:
242243
sys.stdout = old_stdout
243244
assert excinfo.value.code == 0 # Should exit cleanly after showing help
245+
246+
247+
def test_error_messages_respect_add_help() -> None:
248+
"""Test that error messages don't suggest --help when add_help=False."""
249+
import contextlib
250+
251+
@dataclasses.dataclass
252+
class Config:
253+
required_field: int
254+
255+
# Capture stderr to check error messages
256+
captured_output = io.StringIO()
257+
258+
# Test with add_help=False - should not mention --help
259+
with pytest.raises(SystemExit):
260+
with contextlib.redirect_stderr(captured_output):
261+
tyro.cli(Config, args=[], add_help=False, console_outputs=True)
262+
263+
error_message = captured_output.getvalue()
264+
assert "--help" not in error_message, (
265+
f"Error message should not mention --help when add_help=False. Got: {error_message}"
266+
)
267+
268+
# Test with add_help=True - should mention --help
269+
captured_output = io.StringIO()
270+
with pytest.raises(SystemExit):
271+
with contextlib.redirect_stderr(captured_output):
272+
tyro.cli(Config, args=[], add_help=True, console_outputs=True)
273+
274+
error_message = captured_output.getvalue()
275+
assert "--help" in error_message, (
276+
f"Error message should mention --help when add_help=True. Got: {error_message}"
277+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Adapted from @mirceamironenco.
2+
3+
https://github.com/brentyi/tyro/issues/340"""
4+
5+
from __future__ import annotations
6+
7+
from dataclasses import dataclass
8+
9+
from helptext_utils import get_helptext_with_checks
10+
11+
import tyro
12+
13+
14+
@dataclass(frozen=True)
15+
class HFSFTDatasetConfig:
16+
name: str
17+
split: str | None = None
18+
data_files: str | None = None
19+
20+
completions_only: bool = True
21+
packed_seqlen: int | None = None
22+
apply_chat_template: bool = False
23+
max_seq_len: int | None = None
24+
25+
columns: tuple[str, str, str | None, str | None] | None = None
26+
"""instruction, completion, input, system"""
27+
28+
29+
@dataclass
30+
class Config:
31+
train_data: HFSFTDatasetConfig = HFSFTDatasetConfig(
32+
name="yahma/alpaca-cleaned",
33+
completions_only=True,
34+
packed_seqlen=4097,
35+
max_seq_len=2048,
36+
apply_chat_template=True,
37+
columns=("instruction", "output", "input", None),
38+
)
39+
40+
41+
def test_mixed_primitive() -> None:
42+
assert (
43+
tyro.cli(
44+
Config,
45+
args=["--train-data.name", "foo", "--train-data.completions-only"],
46+
).train_data.name
47+
== "foo"
48+
)
49+
assert "None}|{STR STR {None}|STR {None}|STR}" in get_helptext_with_checks(Config)

0 commit comments

Comments
 (0)