Skip to content

Support for input field extensions #3951

@metatick

Description

@metatick

Feature Request Type

  • Core functionality
  • Alteration (enhancement/optimization) of existing feature(s)
  • New behavior

Description

Providing a way for users to declaratively implement non-schema based validation of input data would be fairly useful to reduce boiler-plate logic in mutations. Especially for cases where the same types are re-used in multiple mutations.

In my particular case, I have written a POC annotation based input extension that, when combined with an extension, achieves this reasonably well. Any validation errors are thrown with a custom exception type, which is then caught and returns them in a Shopify-style userErrors field that all my mutations provide in their return type.

It's relatively untested and probably littered with bugs, as I do not want to overcommit to a path that may end up breaking due to how it handles finding metadata, but perhaps it could either start a discussion around the topic, or at least I can get an idea of if the API I'm touching on is stable enough to rely on for further development.

Please let me know your thoughts :)

from typing import Annotated, get_origin, get_args, Union, List, Optional

import strawberry
from flask import Flask
from strawberry import UNSET
from strawberry.extensions import FieldExtension
from strawberry.flask.views import GraphQLView
from strawberry.schema.compat import is_input_type
from strawberry.utils.str_converters import to_camel_case


def convert_string_to_camel_case(value):
    if not isinstance(value, str):
        return value
    return to_camel_case(value)


def convert_field_to_camel_case(field):
    return [convert_string_to_camel_case(str(part)) for part in field]


@strawberry.type
class UserError:
    def __init__(
            self,
            message: str,
            code: Optional[str] = None,
            field: Optional[List[str]] = None,
            camelcase: bool = True
    ):
        self.message = message
        self.code = code
        if camelcase:
            self.field = convert_field_to_camel_case(field)
        else:
            self.field = field

    message: str
    code: Optional[str] = None
    field: Optional[List[str]] = None


class UserErrorException(Exception):
    def __init__(self, message: str, code: Optional[str] = None, field: Optional[list[str]] = None) -> None:
        self.message = message
        self.code = code
        self.field = field

    def to_user_error(self) -> UserError:
        return UserError(
            message=self.message,
            code=self.code,
            field=self.field
        )

    def to_user_errors(self) -> list[UserError]:
        return [self.to_user_error()]


class UserErrorsException(Exception):
    def __init__(self, user_errors: list[UserError]) -> None:
        self.user_errors = user_errors

    def to_user_errors(self) -> list[UserError]:
        return self.user_errors


class InputExtensions:
    def __init__(self, *extensions):
        self.extensions = extensions


def run_extensions(value, path, info, user_errors, extensions):
    if not extensions:
        return value, False

    try:
        new_value = value
        for extension in extensions:
            new_value = extension(new_value, info)
        return new_value, False
    except UserErrorException as e:
        user_errors.append(
            UserError(
                message=e.message,
                code=e.code,
                field=path
            )
        )
    except ValueError as e:
        user_errors.append(
            UserError(
                message=str(e),
                code='validation_error',
                field=path
            )
        )
    return value, True


def get_field_extensions_from_metadata(field):
    if metadata := getattr(field, '__metadata__', None):
        extensions = []
        for m in metadata:
            if isinstance(m, InputExtensions):
                extensions.extend(m.extensions)
        extensions.reverse()
        return extensions
    return None


def create_scalar_resolver(field_type):
    if extensions := get_field_extensions_from_metadata(field_type):
        def resolver(value, path, info, user_errors):
            return run_extensions(value, path, info, user_errors, extensions)

        return resolver
    return None


def create_optional_resolver(field_type, child_resolver):
    if extensions := get_field_extensions_from_metadata(field_type):
        def optional_resolver(value, path, info, user_errors):
            if value is not None and child_resolver:
                value, error = child_resolver(value, path, info, user_errors)
                if error:
                    return value, True
            return run_extensions(value, path, info, user_errors, extensions)

        return optional_resolver
    return child_resolver


def create_list_resolver(field_type, child_resolver):
    extensions = get_field_extensions_from_metadata(field_type)

    if not extensions and not child_resolver:
        return None

    def list_resolver(value, path, info, user_errors):
        if child_resolver:
            has_errors = False
            new_list = []
            for index, item in enumerate(value):
                result, error = child_resolver(item, path + [index], info, user_errors)
                new_list.append(result)
                if error:
                    has_errors = True
            if has_errors:
                return value, True
            value = new_list
        return run_extensions(value, path, info, user_errors, extensions)

    return list_resolver


def get_input_base_type(field_type):
    origin = get_origin(field_type)
    if origin is Annotated:
        return get_args(field_type)[0]

    return field_type


def create_input_resolver(field_type):
    extensions = get_field_extensions_from_metadata(field_type)
    input_base_type = get_input_base_type(field_type)
    annotations = getattr(input_base_type, '__annotations__', {})
    child_resolvers = {}
    has_child_resolvers = False
    for field_name, field in annotations.items():
        resolver = get_field_resolver(field)
        if resolver:
            child_resolvers[field_name] = resolver
            has_child_resolvers = True

    if not extensions and not has_child_resolvers:
        return None

    def input_resolver(value, path, info, user_errors):
        if has_child_resolvers:
            has_errors = False
            new_values = {}
            for field_name, resolver in child_resolvers.items():
                if (field_value := getattr(value, field_name, UNSET)) is not UNSET:
                    new_values[field_name], error = resolver(field_value, path + [field_name], info, user_errors)
                    if error:
                        has_errors = True
            if has_errors:
                return value, True
            for field_name, field_value in new_values.items():
                setattr(value, field_name, field_value)
        return run_extensions(value, path, info, user_errors, extensions)

    return input_resolver


def get_field_type(field_type):
    origin = get_origin(field_type)
    if origin is Annotated:
        return get_origin(get_args(field_type)[0])
    return origin


def is_optional(field):
    origin = get_origin(field)
    if origin is Union:
        return type(None) in get_args(field)
    if origin is Annotated:
        return is_optional(get_args(field)[0])
    return False


def get_optional_child_type(field):
    origin = get_origin(field)
    if origin is Union:
        return [f for f in get_args(field) if f][0]
    if origin is Annotated:
        return get_optional_child_type(get_args(field)[0])
    return None


def is_list(field):
    origin = get_origin(field)
    if origin is Annotated:
        return is_list(get_args(field)[0])
    return origin in (list, List)


def get_list_child_type(field):
    origin = get_origin(field)
    if origin is Annotated:
        return get_list_child_type(get_args(field)[0])
    return get_args(field)[0]


def get_field_resolver(field):
    if is_optional(field):
        child_type = get_optional_child_type(field)
        child_resolver = get_field_resolver(child_type)
        return create_optional_resolver(field, child_resolver)
    elif is_list(field):
        child_type = get_list_child_type(field)
        child_resolver = get_field_resolver(child_type)
        return create_list_resolver(field, child_resolver)
    elif is_input_type(field):
        return create_input_resolver(field)
    else:
        return create_scalar_resolver(field)


def get_argument_resolver(resolver, key):
    if not hasattr(resolver, '__argument_resolvers__'):
        setattr(resolver, '__argument_resolvers__', {})
    if key not in resolver.__argument_resolvers__:
        resolver.__argument_resolvers__[key] = get_field_resolver(resolver.annotations[key])
    return resolver.__argument_resolvers__[key]


def process_mutation_kwargs(resolver, info, kwargs):
    user_errors = []
    for key, value in kwargs.items():
        if value is not UNSET:
            if resolver := get_argument_resolver(resolver, key):
                kwargs[key], _ = resolver(value, [key], info, user_errors)
    if user_errors:
        raise UserErrorsException(user_errors=user_errors)
    return kwargs


class UserErrorsExtension(FieldExtension):
    def __init__(self, input_extensions=True):
        self.input_extensions = input_extensions

    def resolve(self, next_, root, info, **kwargs):
        try:
            if self.input_extensions:
                kwargs = process_mutation_kwargs(info._field.base_resolver, info, kwargs)
            return next_(root, info, **kwargs)
        except (UserErrorException, UserErrorsException) as e:
            return info.return_type(user_errors=e.to_user_errors())


class ToUpperCase:
    def __call__(self, value, info):
        return value.upper()


class NotNull:
    def __call__(self, value, info):
        if value is None:
            raise ValueError('Value cannot be null')
        return value


@strawberry.interface
class UserErrorsInterface:
    user_errors: Optional[list[UserError]] = None


@strawberry.input
class Mutation2Input:
    # name can either be a string, or not provided. null is invalid
    # it will also be uppercased by the extension
    name: Annotated[Optional[str], InputExtensions(ToUpperCase(), NotNull())] = UNSET


@strawberry.type
class Mutation2Payload(UserErrorsInterface):
    name: Optional[str] = None


@strawberry.input
class Mutation3Input:
    # nested annotations work too
    # names must be either a list of strings, or not provided. it cannot be null.
    # the strings will be uppercased too.
    names: Annotated[
        Optional[
            List[
                Annotated[str, InputExtensions(ToUpperCase())]
            ]
        ], InputExtensions(NotNull())
    ] = UNSET


@strawberry.type
class Mutation3Payload(UserErrorsInterface):
    joined_names: Optional[str] = None


@strawberry.type
class Mutation:
    # it also works directly on arguments
    @strawberry.mutation(extensions=[UserErrorsExtension()])
    def mutation1(self, name: Annotated[str, InputExtensions(ToUpperCase())]) -> str:
        return f"Hello {name}"

    @strawberry.mutation(extensions=[UserErrorsExtension()])
    def mutation2(self, input: Mutation2Input) -> Mutation2Payload:
        if input.name is not UNSET:
            return Mutation2Payload(name=input.name)
        return Mutation2Payload(name='Name was not provided')

    @strawberry.mutation(extensions=[UserErrorsExtension()])
    def mutation3(self, input: Mutation3Input) -> Mutation3Payload:
        if input.names is not UNSET:
            return Mutation3Payload(joined_names=', '.join(input.names))
        return Mutation3Payload(joined_names='No names provided')


@strawberry.type
class Query:
    # at least one query is needed to keep strawberry happy :)
    @strawberry.field()
    def hello(self) -> str:
        return "Hello"


schema = strawberry.Schema(query=Query, mutation=Mutation)

app = Flask(__name__)
app.add_url_rule(
    "/graphql",
    view_func=GraphQLView.as_view("graphql_view", schema=schema),
)

app.run()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions