1
+ import functools
1
2
import re
2
3
import sys
3
4
from argparse import Action
29
30
to_yaml ,
30
31
unsafe_merge ,
31
32
)
33
+ from .field_utils import field
32
34
from .flexyclasses import is_flexyclass
33
35
from .logging import configure_logging
34
36
from .nicknames import NicknameRegistry
42
44
43
45
# parameters for the main function
44
46
MP = ParamSpec ("MP" )
47
+ NP = ParamSpec ("NP" )
45
48
46
49
# type for the configuration
47
50
CT = TypeVar ("CT" )
@@ -92,10 +95,14 @@ def add_argparse(self, parser: RichArgumentParser) -> Action:
92
95
def __str__ (self ) -> str :
93
96
return f"{ self .short } /{ self .long } "
94
97
98
+ @classmethod
99
+ def field (cls , * args , ** kwargs ) -> "Flag" :
100
+ return field (default_factory = lambda : cls (* args , ** kwargs ))
101
+
95
102
96
103
@dataclass
97
104
class CliFlags :
98
- config : Flag = Flag (
105
+ config : Flag = Flag . field (
99
106
name = "config" ,
100
107
help = (
101
108
"either a path to a YAML file containing a configuration, or "
@@ -107,22 +114,22 @@ class CliFlags:
107
114
action = "append" ,
108
115
metavar = "/path/to/config.yaml" ,
109
116
)
110
- options : Flag = Flag (
117
+ options : Flag = Flag . field (
111
118
name = "options" ,
112
119
help = "print all default options and CLI flags." ,
113
120
action = "store_true" ,
114
121
)
115
- inputs : Flag = Flag (
122
+ inputs : Flag = Flag . field (
116
123
name = "inputs" ,
117
124
help = "print the input configuration." ,
118
125
action = "store_true" ,
119
126
)
120
- parsed : Flag = Flag (
127
+ parsed : Flag = Flag . field (
121
128
name = "parsed" ,
122
129
help = "print the parsed configuration." ,
123
130
action = "store_true" ,
124
131
)
125
- log_level : Flag = Flag (
132
+ log_level : Flag = Flag . field (
126
133
name = "log-level" ,
127
134
help = (
128
135
"logging level to use for this program; can be one of "
@@ -131,30 +138,30 @@ class CliFlags:
131
138
default = "WARNING" ,
132
139
choices = ["CRITICAL" , "ERROR" , "WARNING" , "INFO" , "DEBUG" ],
133
140
)
134
- debug : Flag = Flag (
141
+ debug : Flag = Flag . field (
135
142
name = "debug" ,
136
143
help = "enable debug mode; equivalent to '--log-level DEBUG'" ,
137
144
action = "store_true" ,
138
145
)
139
- quiet : Flag = Flag (
146
+ quiet : Flag = Flag . field (
140
147
name = "quiet" ,
141
148
help = "if provided, it does not print the configuration when running" ,
142
149
action = "store_true" ,
143
150
)
144
- resolvers : Flag = Flag (
151
+ resolvers : Flag = Flag . field (
145
152
name = "resolvers" ,
146
153
help = (
147
154
"print all registered resolvers in OmegaConf, "
148
155
"Springs, and current codebase"
149
156
),
150
157
action = "store_true" ,
151
158
)
152
- nicknames : Flag = Flag (
159
+ nicknames : Flag = Flag . field (
153
160
name = "nicknames" ,
154
161
help = "print all registered nicknames in Springs" ,
155
162
action = "store_true" ,
156
163
)
157
- save : Flag = Flag (
164
+ save : Flag = Flag . field (
158
165
name = "save" ,
159
166
help = "save the configuration to a YAML file and exit" ,
160
167
default = None ,
@@ -430,10 +437,8 @@ def wrap_main_method(
430
437
def cli (
431
438
config_node_cls : Optional [Type [CT ]] = None ,
432
439
) -> Callable [
433
- [
434
- # this is a main method that takes as first input a parsed config
435
- Callable [Concatenate [CT , MP ], RT ]
436
- ],
440
+ # this is a main method that takes as first input a parsed config
441
+ [Callable [Concatenate [CT , MP ], RT ]],
437
442
# the decorated method doesn't expect the parsed config as first input,
438
443
# since that will be parsed from the command line
439
444
Callable [MP , RT ],
@@ -487,6 +492,7 @@ def main(cfg: Config):
487
492
name = config_node_cls .__name__
488
493
489
494
def wrapper (func : Callable [Concatenate [CT , MP ], RT ]) -> Callable [MP , RT ]:
495
+ @functools .wraps (func )
490
496
def wrapping (* args : MP .args , ** kwargs : MP .kwargs ) -> RT :
491
497
# I could have used a functools.partial here, but defining
492
498
# my own function instead allows me to provide nice typing
@@ -501,4 +507,8 @@ def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT:
501
507
502
508
return wrapping
503
509
504
- return wrapper
510
+ # TODO: figure out why mypy complains with the following error:
511
+ # Incompatible return value type (got "Callable[[Arg(Callable[[CT,
512
+ # **MP], RT], 'func')], Callable[MP, RT]]", expected
513
+ # "Callable[[Callable[[CT, **MP], RT]], Callable[MP, RT]]")
514
+ return wrapper # type: ignore
0 commit comments