-
-
Notifications
You must be signed in to change notification settings - Fork 581
Description
Feature Request Type
- Core functionality
- Alteration (enhancement/optimization) of existing feature(s)
- New behavior
Description
Providing a way for users to declaratively implement non-schema based validation of input data would be fairly useful to reduce boiler-plate logic in mutations. Especially for cases where the same types are re-used in multiple mutations.
In my particular case, I have written a POC annotation based input extension that, when combined with an extension, achieves this reasonably well. Any validation errors are thrown with a custom exception type, which is then caught and returns them in a Shopify-style userErrors field that all my mutations provide in their return type.
It's relatively untested and probably littered with bugs, as I do not want to overcommit to a path that may end up breaking due to how it handles finding metadata, but perhaps it could either start a discussion around the topic, or at least I can get an idea of if the API I'm touching on is stable enough to rely on for further development.
Please let me know your thoughts :)
from typing import Annotated, get_origin, get_args, Union, List, Optional
import strawberry
from flask import Flask
from strawberry import UNSET
from strawberry.extensions import FieldExtension
from strawberry.flask.views import GraphQLView
from strawberry.schema.compat import is_input_type
from strawberry.utils.str_converters import to_camel_case
def convert_string_to_camel_case(value):
if not isinstance(value, str):
return value
return to_camel_case(value)
def convert_field_to_camel_case(field):
return [convert_string_to_camel_case(str(part)) for part in field]
@strawberry.type
class UserError:
def __init__(
self,
message: str,
code: Optional[str] = None,
field: Optional[List[str]] = None,
camelcase: bool = True
):
self.message = message
self.code = code
if camelcase:
self.field = convert_field_to_camel_case(field)
else:
self.field = field
message: str
code: Optional[str] = None
field: Optional[List[str]] = None
class UserErrorException(Exception):
def __init__(self, message: str, code: Optional[str] = None, field: Optional[list[str]] = None) -> None:
self.message = message
self.code = code
self.field = field
def to_user_error(self) -> UserError:
return UserError(
message=self.message,
code=self.code,
field=self.field
)
def to_user_errors(self) -> list[UserError]:
return [self.to_user_error()]
class UserErrorsException(Exception):
def __init__(self, user_errors: list[UserError]) -> None:
self.user_errors = user_errors
def to_user_errors(self) -> list[UserError]:
return self.user_errors
class InputExtensions:
def __init__(self, *extensions):
self.extensions = extensions
def run_extensions(value, path, info, user_errors, extensions):
if not extensions:
return value, False
try:
new_value = value
for extension in extensions:
new_value = extension(new_value, info)
return new_value, False
except UserErrorException as e:
user_errors.append(
UserError(
message=e.message,
code=e.code,
field=path
)
)
except ValueError as e:
user_errors.append(
UserError(
message=str(e),
code='validation_error',
field=path
)
)
return value, True
def get_field_extensions_from_metadata(field):
if metadata := getattr(field, '__metadata__', None):
extensions = []
for m in metadata:
if isinstance(m, InputExtensions):
extensions.extend(m.extensions)
extensions.reverse()
return extensions
return None
def create_scalar_resolver(field_type):
if extensions := get_field_extensions_from_metadata(field_type):
def resolver(value, path, info, user_errors):
return run_extensions(value, path, info, user_errors, extensions)
return resolver
return None
def create_optional_resolver(field_type, child_resolver):
if extensions := get_field_extensions_from_metadata(field_type):
def optional_resolver(value, path, info, user_errors):
if value is not None and child_resolver:
value, error = child_resolver(value, path, info, user_errors)
if error:
return value, True
return run_extensions(value, path, info, user_errors, extensions)
return optional_resolver
return child_resolver
def create_list_resolver(field_type, child_resolver):
extensions = get_field_extensions_from_metadata(field_type)
if not extensions and not child_resolver:
return None
def list_resolver(value, path, info, user_errors):
if child_resolver:
has_errors = False
new_list = []
for index, item in enumerate(value):
result, error = child_resolver(item, path + [index], info, user_errors)
new_list.append(result)
if error:
has_errors = True
if has_errors:
return value, True
value = new_list
return run_extensions(value, path, info, user_errors, extensions)
return list_resolver
def get_input_base_type(field_type):
origin = get_origin(field_type)
if origin is Annotated:
return get_args(field_type)[0]
return field_type
def create_input_resolver(field_type):
extensions = get_field_extensions_from_metadata(field_type)
input_base_type = get_input_base_type(field_type)
annotations = getattr(input_base_type, '__annotations__', {})
child_resolvers = {}
has_child_resolvers = False
for field_name, field in annotations.items():
resolver = get_field_resolver(field)
if resolver:
child_resolvers[field_name] = resolver
has_child_resolvers = True
if not extensions and not has_child_resolvers:
return None
def input_resolver(value, path, info, user_errors):
if has_child_resolvers:
has_errors = False
new_values = {}
for field_name, resolver in child_resolvers.items():
if (field_value := getattr(value, field_name, UNSET)) is not UNSET:
new_values[field_name], error = resolver(field_value, path + [field_name], info, user_errors)
if error:
has_errors = True
if has_errors:
return value, True
for field_name, field_value in new_values.items():
setattr(value, field_name, field_value)
return run_extensions(value, path, info, user_errors, extensions)
return input_resolver
def get_field_type(field_type):
origin = get_origin(field_type)
if origin is Annotated:
return get_origin(get_args(field_type)[0])
return origin
def is_optional(field):
origin = get_origin(field)
if origin is Union:
return type(None) in get_args(field)
if origin is Annotated:
return is_optional(get_args(field)[0])
return False
def get_optional_child_type(field):
origin = get_origin(field)
if origin is Union:
return [f for f in get_args(field) if f][0]
if origin is Annotated:
return get_optional_child_type(get_args(field)[0])
return None
def is_list(field):
origin = get_origin(field)
if origin is Annotated:
return is_list(get_args(field)[0])
return origin in (list, List)
def get_list_child_type(field):
origin = get_origin(field)
if origin is Annotated:
return get_list_child_type(get_args(field)[0])
return get_args(field)[0]
def get_field_resolver(field):
if is_optional(field):
child_type = get_optional_child_type(field)
child_resolver = get_field_resolver(child_type)
return create_optional_resolver(field, child_resolver)
elif is_list(field):
child_type = get_list_child_type(field)
child_resolver = get_field_resolver(child_type)
return create_list_resolver(field, child_resolver)
elif is_input_type(field):
return create_input_resolver(field)
else:
return create_scalar_resolver(field)
def get_argument_resolver(resolver, key):
if not hasattr(resolver, '__argument_resolvers__'):
setattr(resolver, '__argument_resolvers__', {})
if key not in resolver.__argument_resolvers__:
resolver.__argument_resolvers__[key] = get_field_resolver(resolver.annotations[key])
return resolver.__argument_resolvers__[key]
def process_mutation_kwargs(resolver, info, kwargs):
user_errors = []
for key, value in kwargs.items():
if value is not UNSET:
if resolver := get_argument_resolver(resolver, key):
kwargs[key], _ = resolver(value, [key], info, user_errors)
if user_errors:
raise UserErrorsException(user_errors=user_errors)
return kwargs
class UserErrorsExtension(FieldExtension):
def __init__(self, input_extensions=True):
self.input_extensions = input_extensions
def resolve(self, next_, root, info, **kwargs):
try:
if self.input_extensions:
kwargs = process_mutation_kwargs(info._field.base_resolver, info, kwargs)
return next_(root, info, **kwargs)
except (UserErrorException, UserErrorsException) as e:
return info.return_type(user_errors=e.to_user_errors())
class ToUpperCase:
def __call__(self, value, info):
return value.upper()
class NotNull:
def __call__(self, value, info):
if value is None:
raise ValueError('Value cannot be null')
return value
@strawberry.interface
class UserErrorsInterface:
user_errors: Optional[list[UserError]] = None
@strawberry.input
class Mutation2Input:
# name can either be a string, or not provided. null is invalid
# it will also be uppercased by the extension
name: Annotated[Optional[str], InputExtensions(ToUpperCase(), NotNull())] = UNSET
@strawberry.type
class Mutation2Payload(UserErrorsInterface):
name: Optional[str] = None
@strawberry.input
class Mutation3Input:
# nested annotations work too
# names must be either a list of strings, or not provided. it cannot be null.
# the strings will be uppercased too.
names: Annotated[
Optional[
List[
Annotated[str, InputExtensions(ToUpperCase())]
]
], InputExtensions(NotNull())
] = UNSET
@strawberry.type
class Mutation3Payload(UserErrorsInterface):
joined_names: Optional[str] = None
@strawberry.type
class Mutation:
# it also works directly on arguments
@strawberry.mutation(extensions=[UserErrorsExtension()])
def mutation1(self, name: Annotated[str, InputExtensions(ToUpperCase())]) -> str:
return f"Hello {name}"
@strawberry.mutation(extensions=[UserErrorsExtension()])
def mutation2(self, input: Mutation2Input) -> Mutation2Payload:
if input.name is not UNSET:
return Mutation2Payload(name=input.name)
return Mutation2Payload(name='Name was not provided')
@strawberry.mutation(extensions=[UserErrorsExtension()])
def mutation3(self, input: Mutation3Input) -> Mutation3Payload:
if input.names is not UNSET:
return Mutation3Payload(joined_names=', '.join(input.names))
return Mutation3Payload(joined_names='No names provided')
@strawberry.type
class Query:
# at least one query is needed to keep strawberry happy :)
@strawberry.field()
def hello(self) -> str:
return "Hello"
schema = strawberry.Schema(query=Query, mutation=Mutation)
app = Flask(__name__)
app.add_url_rule(
"/graphql",
view_func=GraphQLView.as_view("graphql_view", schema=schema),
)
app.run()