Skip to content

Commit 24e58dd

Browse files
authored
Merge pull request #22 from bnewm0609/main
Upgrade to Python 3.11
2 parents 882cca4 + 5ed3f54 commit 24e58dd

14 files changed

+116
-105
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "springs"
3-
version = "1.12.3"
3+
version = "1.13.0"
44
description = """\
55
A set of utilities to create and manage typed configuration files \
66
effectively, built on top of OmegaConf.\

src/springs/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
debug_logger,
3434
fdict,
3535
flist,
36-
fobj,
37-
fval,
3836
get_nickname,
3937
make_flexy,
4038
make_target,
@@ -49,7 +47,6 @@
4947
__version__ = get_version()
5048

5149
__all__ = [
52-
"add_help",
5350
"all_resolvers",
5451
"cast",
5552
"cli",
@@ -63,8 +60,6 @@
6360
"field",
6461
"flexyclass",
6562
"flist",
66-
"fobj",
67-
"fval",
6863
"from_dataclass",
6964
"from_dict",
7065
"from_file",

src/springs/commandline.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import re
23
import sys
34
from argparse import Action
@@ -29,6 +30,7 @@
2930
to_yaml,
3031
unsafe_merge,
3132
)
33+
from .field_utils import field
3234
from .flexyclasses import is_flexyclass
3335
from .logging import configure_logging
3436
from .nicknames import NicknameRegistry
@@ -42,6 +44,7 @@
4244

4345
# parameters for the main function
4446
MP = ParamSpec("MP")
47+
NP = ParamSpec("NP")
4548

4649
# type for the configuration
4750
CT = TypeVar("CT")
@@ -92,10 +95,14 @@ def add_argparse(self, parser: RichArgumentParser) -> Action:
9295
def __str__(self) -> str:
9396
return f"{self.short}/{self.long}"
9497

98+
@classmethod
99+
def field(cls, *args, **kwargs) -> "Flag":
100+
return field(default_factory=lambda: cls(*args, **kwargs))
101+
95102

96103
@dataclass
97104
class CliFlags:
98-
config: Flag = Flag(
105+
config: Flag = Flag.field(
99106
name="config",
100107
help=(
101108
"either a path to a YAML file containing a configuration, or "
@@ -107,22 +114,22 @@ class CliFlags:
107114
action="append",
108115
metavar="/path/to/config.yaml",
109116
)
110-
options: Flag = Flag(
117+
options: Flag = Flag.field(
111118
name="options",
112119
help="print all default options and CLI flags.",
113120
action="store_true",
114121
)
115-
inputs: Flag = Flag(
122+
inputs: Flag = Flag.field(
116123
name="inputs",
117124
help="print the input configuration.",
118125
action="store_true",
119126
)
120-
parsed: Flag = Flag(
127+
parsed: Flag = Flag.field(
121128
name="parsed",
122129
help="print the parsed configuration.",
123130
action="store_true",
124131
)
125-
log_level: Flag = Flag(
132+
log_level: Flag = Flag.field(
126133
name="log-level",
127134
help=(
128135
"logging level to use for this program; can be one of "
@@ -131,30 +138,30 @@ class CliFlags:
131138
default="WARNING",
132139
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
133140
)
134-
debug: Flag = Flag(
141+
debug: Flag = Flag.field(
135142
name="debug",
136143
help="enable debug mode; equivalent to '--log-level DEBUG'",
137144
action="store_true",
138145
)
139-
quiet: Flag = Flag(
146+
quiet: Flag = Flag.field(
140147
name="quiet",
141148
help="if provided, it does not print the configuration when running",
142149
action="store_true",
143150
)
144-
resolvers: Flag = Flag(
151+
resolvers: Flag = Flag.field(
145152
name="resolvers",
146153
help=(
147154
"print all registered resolvers in OmegaConf, "
148155
"Springs, and current codebase"
149156
),
150157
action="store_true",
151158
)
152-
nicknames: Flag = Flag(
159+
nicknames: Flag = Flag.field(
153160
name="nicknames",
154161
help="print all registered nicknames in Springs",
155162
action="store_true",
156163
)
157-
save: Flag = Flag(
164+
save: Flag = Flag.field(
158165
name="save",
159166
help="save the configuration to a YAML file and exit",
160167
default=None,
@@ -430,10 +437,8 @@ def wrap_main_method(
430437
def cli(
431438
config_node_cls: Optional[Type[CT]] = None,
432439
) -> Callable[
433-
[
434-
# this is a main method that takes as first input a parsed config
435-
Callable[Concatenate[CT, MP], RT]
436-
],
440+
# this is a main method that takes as first input a parsed config
441+
[Callable[Concatenate[CT, MP], RT]],
437442
# the decorated method doesn't expect the parsed config as first input,
438443
# since that will be parsed from the command line
439444
Callable[MP, RT],
@@ -487,6 +492,7 @@ def main(cfg: Config):
487492
name = config_node_cls.__name__
488493

489494
def wrapper(func: Callable[Concatenate[CT, MP], RT]) -> Callable[MP, RT]:
495+
@functools.wraps(func)
490496
def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT:
491497
# I could have used a functools.partial here, but defining
492498
# my own function instead allows me to provide nice typing
@@ -501,4 +507,8 @@ def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT:
501507

502508
return wrapping
503509

504-
return wrapper
510+
# TODO: figure out why mypy complains with the following error:
511+
# Incompatible return value type (got "Callable[[Arg(Callable[[CT,
512+
# **MP], RT], 'func')], Callable[MP, RT]]", expected
513+
# "Callable[[Callable[[CT, **MP], RT]], Callable[MP, RT]]")
514+
return wrapper # type: ignore

src/springs/flexyclasses.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@
99

1010
from .utils import get_annotations
1111

12-
C = TypeVar("C", bound=Any)
12+
_C = TypeVar("_C", bound=Any)
1313

1414

15-
class FlexyClass(dict, Generic[C]):
15+
class FlexyClass(dict, Generic[_C]):
1616
"""A FlexyClass is a dictionary with some default values assigned to it
1717
FlexyClasses are generally not used directly, but rather creating using
1818
the `flexyclass` decorator.
1919
2020
NOTE: When instantiating a new FlexyClass object directly, the constructor
21-
actually returns a `dataclasses.Field` object. This is for API consistency
22-
with how dataclasses are used in a structured configuration. If you want to
23-
access values in the FlexyClass directly, use FlexyClass.defaults property.
21+
actually returns a `dict` object. This is for API consistency with how
22+
dataclasses are used in a structured configuration. If you want to access
23+
values in the FlexyClass directly, use FlexyClass.defaults property.
2424
"""
2525

2626
__origin__: type = dict
@@ -60,7 +60,8 @@ def __new__(cls, **kwargs):
6060
# to use flexyclasses in the same way they would use a dataclass.
6161
factory_dict: Dict[str, Any] = {}
6262
factory_dict = {**cls.defaults(), **kwargs}
63-
return field(default_factory=lambda: factory_dict)
63+
return factory_dict
64+
# return field(default_factory=lambda: factory_dict)
6465

6566
@classmethod
6667
def to_dict_config(cls, **kwargs: Any) -> DictConfig:
@@ -70,7 +71,7 @@ def to_dict_config(cls, **kwargs: Any) -> DictConfig:
7071
return from_dict({**cls.defaults(), **kwargs})
7172

7273
@classmethod
73-
def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]:
74+
def flexyclass(cls, target_cls: Type[_C]) -> Type["FlexyClass[_C]"]:
7475
"""Decorator to create a FlexyClass from a class"""
7576

7677
if is_dataclass(target_cls):
@@ -86,15 +87,16 @@ def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]:
8687
for f_name, f_value in attributes_iterator
8788
}
8889

89-
return type(
90+
rt = type(
9091
target_cls.__name__,
9192
(FlexyClass,),
9293
{"__flexyclass_defaults__": defaults},
9394
)
95+
return rt
9496

9597

96-
@dataclass_transform()
97-
def flexyclass(cls: Type[C]) -> Type[FlexyClass[C]]:
98+
@dataclass_transform(field_specifiers=(Field, field))
99+
def flexyclass(cls: Type[_C]) -> Type[FlexyClass[_C]]:
98100
"""Alias for FlexyClass.flexyclass"""
99101
return FlexyClass.flexyclass(cls)
100102

src/springs/rich_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22
import re
33
from argparse import SUPPRESS, ArgumentParser
44
from dataclasses import dataclass
5-
from typing import IO, Any, Dict, Generator, List, Optional, Sequence, Union
5+
from typing import (
6+
IO,
7+
Any,
8+
Dict,
9+
Generator,
10+
List,
11+
Optional,
12+
Sequence,
13+
Tuple,
14+
Union,
15+
)
616

717
from omegaconf import DictConfig, ListConfig
818
from rich import box
@@ -153,7 +163,7 @@ def format_usage(self):
153163
for ag in self._action_groups:
154164
for act in ag._group_actions:
155165
if isinstance(act.metavar, str):
156-
metavar = (act.metavar,)
166+
metavar: Tuple[str, ...] = (act.metavar,)
157167
elif act.metavar is None:
158168
metavar = (act.dest.upper(),)
159169
else:

src/springs/shortcuts.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -73,37 +73,14 @@ def make_flexy(cls_: Any) -> Any:
7373
return flexyclass(cls_)
7474

7575

76-
def fval(value: T, **kwargs) -> T:
77-
"""Shortcut for creating a Field with a default value.
78-
79-
Args:
80-
value: value returned by default factory"""
81-
82-
return field(default=value, **kwargs)
83-
84-
85-
def fobj(object: T, **kwargs) -> T:
86-
"""Shortcut for creating a Field with a default_factory that returns
87-
a specific object.
88-
89-
Args:
90-
obj: object returned by default factory"""
91-
92-
def _factory_fn() -> T:
93-
# make a copy so that the same object isn't returned
94-
# (it's a factory, not a singleton!)
95-
return copy.deepcopy(object)
96-
97-
return field(default_factory=_factory_fn, **kwargs)
98-
99-
10076
def fdict(**kwargs: Any) -> Dict[str, Any]:
10177
"""Shortcut for creating a Field with a default_factory that returns
10278
a dictionary.
10379
10480
Args:
10581
**kwargs: values for the dictionary returned by default factory"""
106-
return fobj(kwargs)
82+
kwargs = copy.deepcopy(kwargs)
83+
return field(default_factory=lambda: kwargs)
10784

10885

10986
def flist(*args: Any) -> List[Any]:
@@ -112,7 +89,8 @@ def flist(*args: Any) -> List[Any]:
11289
11390
Args:
11491
*args: values for the list returned by default factory"""
115-
return fobj(list(args))
92+
l_args = list(copy.deepcopy(args))
93+
return field(default_factory=lambda: l_args)
11694

11795

11896
def debug_logger(*args: Any, **kwargs: Any) -> Logger:

0 commit comments

Comments
 (0)