Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions ninja/errors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import traceback
from functools import partial
from typing import TYPE_CHECKING, Generic, List, Optional, TypeVar
from typing import TYPE_CHECKING, Generic, List, Optional

import pydantic
from django.conf import settings
from django.http import Http404, HttpRequest, HttpResponse
from pydantic import BaseModel, GetCoreSchemaHandler
from typing_extensions import TypeVar

from ninja.types import DictStrAny

Expand Down Expand Up @@ -58,7 +60,27 @@ def __init__(self, errors: List[DictStrAny]) -> None:
super().__init__(errors)


class HttpError(Exception):
class HttpErrorResponse(BaseModel):
detail: str


T = TypeVar("T", bound=HttpErrorResponse)


class HttpError(Exception, Generic[T]):
__response_schema__ = HttpErrorResponse

def __init_subclass__(cls, **kwargs):
"""
Sets the __response_schema__ for generic subclasses of HttpError, enabling
custom Pydantic response models via generic type parameters (e.g., HttpError[MySchema]).
"""
super().__init_subclass__(**kwargs)
if hasattr(cls, "__orig_bases__"):
for base in cls.__orig_bases__:
if hasattr(base, "__args__") and base.__origin__ is HttpError:
cls.__response_schema__ = base.__args__[0]

def __init__(self, status_code: int, message: str) -> None:
self.status_code = status_code
self.message = message
Expand All @@ -67,6 +89,16 @@ def __init__(self, status_code: int, message: str) -> None:
def __str__(self) -> str:
return self.message

@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
"""
Returns the Pydantic core schema for the error response, allowing HttpError
subclasses to be used as valid response types in OpenAPI schema generation.
"""
return cls.__response_schema__.__get_pydantic_core_schema__(
source_type, handler
)


class AuthenticationError(HttpError):
def __init__(self, status_code: int = 401, message: str = "Unauthorized") -> None:
Expand Down
160 changes: 158 additions & 2 deletions tests/test_response_multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,35 @@
from pydantic import ValidationError

from ninja import NinjaAPI, Schema
from ninja.errors import ConfigError
from ninja.errors import ConfigError, HttpError
from ninja.responses import codes_2xx, codes_3xx
from ninja.testing import TestClient

api = NinjaAPI()


class RequestErrorResponse(Schema):
detail: str
reason: list[str]


class ServerErrorResponse(Schema):
detail: str


class BasicHttpError(HttpError): ...


class RequestError(HttpError[RequestErrorResponse]):
status_code = 400
message = "Request Error"


class ServerError(HttpError[ServerErrorResponse]):
status_code = 500
message = "Server Error"


@api.get("/check_int", response={200: int})
def check_int(request):
return 200, "1"
Expand Down Expand Up @@ -46,6 +68,35 @@ def check_multiple_codes(request, code: int):
return code, "1"


@api.get(
"/excs_in_responses",
response={codes_2xx: str, 400: RequestError, 404: BasicHttpError, 500: ServerError},
)
def excs_in_responses(request, code: int):
if code == 400:
raise RequestError
if code == 500:
raise ServerError
return code, "1"


@api.get(
"/errors_in_docstring",
# response={400: RequestError, 500: ServerError}, # <-- not needed
)
def errors_in_docstring(request, code: int):
"""
Raises:
RequestError: you made a mistake in the request
ServerError: unexpected error on our server
"""
if code == 400:
raise RequestError
if code == 500:
raise ServerError
return code, "1"


class User:
def __init__(self, id, name, password):
self.id = id
Expand Down Expand Up @@ -87,6 +138,11 @@ def check_union(request, q: int):
client = TestClient(api)


@pytest.fixture
def schema():
return api.get_openapi_schema()


@pytest.mark.parametrize(
"path,expected_status,expected_response",
[
Expand All @@ -113,7 +169,7 @@ def test_responses(path, expected_status, expected_response):
assert response.json() == expected_response


def test_schema():
def test_schema(schema):
checks = [
("/api/check_int", {200}),
("/api/check_int2", {200}),
Expand Down Expand Up @@ -154,6 +210,106 @@ def test_schema():
}


def test_excs_as_responses(schema):
responses = schema["paths"]["/api/excs_in_responses"]["get"]["responses"]
schemas = schema["components"]["schemas"]

assert responses == {
200: {
"content": {
"application/json": {"schema": {"title": "Response", "type": "string"}}
},
"description": "OK",
},
201: {
"content": {
"application/json": {"schema": {"title": "Response", "type": "string"}}
},
"description": "Created",
},
202: {
"content": {
"application/json": {"schema": {"title": "Response", "type": "string"}}
},
"description": "Accepted",
},
203: {
"content": {
"application/json": {"schema": {"title": "Response", "type": "string"}}
},
"description": "Non-Authoritative Information",
},
204: {
"content": {
"application/json": {"schema": {"title": "Response", "type": "string"}}
},
"description": "No Content",
},
205: {
"content": {
"application/json": {"schema": {"title": "Response", "type": "string"}}
},
"description": "Reset Content",
},
206: {
"content": {
"application/json": {"schema": {"title": "Response", "type": "string"}}
},
"description": "Partial Content",
},
400: {
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/RequestErrorResponse"}
}
},
"description": "Bad Request",
},
404: {
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/HttpErrorResponse"}
}
},
"description": "Not Found",
},
500: {
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/ServerErrorResponse"}
}
},
"description": "Internal Server Error",
},
}

assert "RequestErrorResponse" in schemas
assert "HttpErrorResponse" in schemas
assert "ServerErrorResponse" in schemas

assert schemas["RequestErrorResponse"] == {
"properties": {
"detail": {"title": "Detail", "type": "string"},
"reason": {"items": {"type": "string"}, "title": "Reason", "type": "array"},
},
"required": ["detail", "reason"],
"title": "RequestErrorResponse",
"type": "object",
}
assert schemas["HttpErrorResponse"] == {
"properties": {"detail": {"title": "Detail", "type": "string"}},
"required": ["detail"],
"title": "HttpErrorResponse",
"type": "object",
}
assert schemas["ServerErrorResponse"] == {
"properties": {"detail": {"title": "Detail", "type": "string"}},
"required": ["detail"],
"title": "ServerErrorResponse",
"type": "object",
}


def test_no_content():
response = client.get("/check_no_content?return_code=1")
assert response.status_code == 204
Expand Down