Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions src/magicgui/type_map/_type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class MissingWidget(RuntimeError):
"""Raised when a backend widget cannot be found."""


_RETURN_CALLBACKS: DefaultDict[type, list[ReturnCallback]] = defaultdict(list)
_RETURN_CALLBACKS: DefaultDict[frozenset[type], list[ReturnCallback]] = defaultdict(
list
)
_TYPE_DEFS: dict[type, WidgetTuple] = {}


Expand Down Expand Up @@ -426,11 +428,17 @@ def _deco(type_: _T) -> _T:
# if the type is a Union, add the callback to all of the types in the union
# (except NoneType)
if get_origin(resolved_type) is Union:
for t in get_args(resolved_type):
if not _is_none_type(t):
_RETURN_CALLBACKS[t].append(return_callback)
type_set = frozenset(get_args(resolved_type))
else:
type_set = frozenset([resolved_type])

for keyset in sorted(_RETURN_CALLBACKS, key=len, reverse=True):
if type_set.issubset(keyset):
if return_callback not in _RETURN_CALLBACKS[keyset]:
_RETURN_CALLBACKS[keyset].append(return_callback)
break
else:
_RETURN_CALLBACKS[resolved_type].append(return_callback)
_RETURN_CALLBACKS[type_set].append(return_callback)

_options = cast(dict, options)

Expand Down Expand Up @@ -539,14 +547,21 @@ def type2callback(type_: type) -> list[ReturnCallback]:

# look for direct hits ...
# if it's an Optional, we need to look for the type inside the Optional
_, type_ = _is_optional(resolve_single_type(type_))
if type_ in _RETURN_CALLBACKS:
return _RETURN_CALLBACKS[type_]
type_ = resolve_single_type(type_)
if _is_none_type(type_):
return []

sorted_callbacks = sorted(_RETURN_CALLBACKS, key=len, reverse=True)
types = frozenset(get_args(type_) if get_origin(type_) is Union else [type_])
for keyset in sorted_callbacks:
if types.issubset(keyset):
return _RETURN_CALLBACKS[keyset]

# look for subclasses
for registered_type in _RETURN_CALLBACKS: # sourcery skip: use-next
if safe_issubclass(type_, registered_type):
return _RETURN_CALLBACKS[registered_type]
for keyset in sorted_callbacks:
for registered_type in keyset:
if safe_issubclass(type_, registered_type):
return _RETURN_CALLBACKS[keyset]
return []


Expand Down
41 changes: 28 additions & 13 deletions tests/test_magicgui.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from magicgui import magicgui, register_type, type_map, widgets
from magicgui.signature import MagicSignature, magic_signature
from magicgui.type_map import _type_map


def func(a: str = "works", b: int = 3, c=7.1) -> str:
Expand Down Expand Up @@ -483,10 +484,7 @@ def func2(a=1) -> Sub:

func2()
finally:
from magicgui.type_map._type_map import _RETURN_CALLBACKS

_RETURN_CALLBACKS.pop(int)
_RETURN_CALLBACKS.pop(Base)
_type_map._RETURN_CALLBACKS.clear()


# @pytest.mark.skip(reason="need to rethink how to test this")
Expand Down Expand Up @@ -882,22 +880,39 @@ class A:
)


@pytest.mark.parametrize("optional", [True, False])
def test_call_union_return_type(optional: bool):
"""registering Optional[type] should imply registering"""
@pytest.mark.parametrize("register", ["optional", "both", "required"])
@pytest.mark.parametrize("return_optional", [True, False])
def test_call_union_return_type(return_optional: bool, register: str):
"""registering Optional[X] should imply registering X"""
mock = Mock()
_type_map._RETURN_CALLBACKS.clear()

NewInt = NewType("NewInt", int)
register_type(Optional[NewInt], return_callback=mock)

ReturnType = Optional[NewInt] if optional else NewInt
# registering both forms should not result in 2 calls
if register in {"optional", "both"}:
register_type(Optional[NewInt], return_callback=mock)
if register in {"required", "both"}:
register_type(NewInt, return_callback=mock)

ReturnType = Optional[NewInt] if return_optional else NewInt

@magicgui
def func_optional(a: bool) -> ReturnType:
return NewInt(1) if a else None

func_optional(a=True)
mock.assert_called_once_with(func_optional, 1, ReturnType)
mock.reset_mock()
func_optional(a=False)
mock.assert_called_once_with(func_optional, None, ReturnType)
# if the function returns Optional[X] and we only registered X, we should
# not get a callback. (i.e. if the return type Union is not a subset of the
# registered types, we should not get a callback)
if return_optional and register not in {"optional", "both"}:
mock.assert_not_called()
else:
# otherwise, regardless of whether we registered Optional[X] or X,
# the callback should be called just once per function call
mock.assert_called_once_with(func_optional, 1, ReturnType)
mock.reset_mock()
func_optional(a=False)
mock.assert_called_once_with(func_optional, None, ReturnType)

_type_map._RETURN_CALLBACKS.clear()