Skip to content

Commit f428d95

Browse files
authored
Merge branch 'master' into fix/openapi31-nullable-choice-fields-regression
2 parents bd68603 + 39cb3d5 commit f428d95

25 files changed

+233
-54
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ assignees: ''
88
---
99

1010
**Describe the bug**
11-
A clear and concise description of what the bug is.
11+
<!--A clear and concise description of what the bug is.-->
1212

1313
**To Reproduce**
14-
It would be most helpful to provide a small snippet to see how the bug was provoked.
14+
<!--It would be most helpful to provide a small snippet to see how the bug was provoked.-->
1515

1616
**Expected behavior**
17-
A clear and concise description of what you expected to happen.
17+
<!--A clear and concise description of what you expected to happen.-->

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Requirements
6363

6464
- Python >= 3.7
6565
- Django (2.2, 3.2, 4.0, 4.1, 4.2, 5.0, 5.1, 5.2)
66-
- Django REST Framework (3.10.3, 3.11, 3.12, 3.13, 3.14, 3.15)
66+
- Django REST Framework (3.10.3, 3.11, 3.12, 3.13, 3.14, 3.15, 3.16)
6767

6868
Installation
6969
------------

docs/customization.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ On rare occasions (e.g. envelope serializers), overriding list detection with ``
136136
'songs': {'top10': True},
137137
'single': {'top10': True}
138138
},
139-
request_only=True, # signal that example only applies to requests
140-
response_only=True, # signal that example only applies to responses
139+
# request_only=True, # signal that example only applies to requests
140+
# response_only=True, # signal that example only applies to responses
141141
),
142142
]
143143
)

drf_spectacular/contrib/django_filters.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ def _get_model_field(self, filter_field, model):
234234
if not filter_field.field_name:
235235
return None
236236
path = filter_field.field_name.split('__')
237+
to_field_name = filter_field.extra.get("to_field_name")
238+
if to_field_name is not None:
239+
path.append(to_field_name)
237240
return follow_field_source(model, path, emit_warnings=False)
238241

239242
def _get_schema_from_model_field(self, auto_schema, filter_field, model):

drf_spectacular/contrib/djangorestframework_camel_case.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import Optional
2+
from typing import MutableMapping, Optional
33

44
from django.utils.module_loading import import_string
55

@@ -20,20 +20,18 @@ def has_middleware_installed():
2020
except ImportError:
2121
return False
2222

23-
for middleware in [import_string(m) for m in settings.MIDDLEWARE]:
24-
try:
25-
if issubclass(middleware, CamelCaseMiddleWare):
26-
return True
27-
except TypeError:
28-
pass
23+
return any(
24+
isinstance(m, type) and issubclass(m, CamelCaseMiddleWare)
25+
for m in map(import_string, settings.MIDDLEWARE)
26+
)
2927

3028
def camelize_str(key: str) -> str:
3129
new_key = re.sub(camelize_re, underscore_to_camel, key) if "_" in key else key
3230
if key in ignore_keys or new_key in ignore_keys:
3331
return key
3432
return new_key
3533

36-
def camelize_component(schema: dict, name: Optional[str] = None) -> dict:
34+
def camelize_component(schema: MutableMapping, name: Optional[str] = None) -> MutableMapping:
3735
if name is not None and (name in ignore_fields or camelize_str(name) in ignore_fields):
3836
return schema
3937
elif schema.get('type') == 'object':
@@ -44,7 +42,7 @@ def camelize_component(schema: dict, name: Optional[str] = None) -> dict:
4442
}
4543
if 'required' in schema:
4644
schema['required'] = [camelize_str(field) for field in schema['required']]
47-
elif schema.get('type') == 'array':
45+
elif schema.get('type') == 'array' and isinstance(schema['items'], MutableMapping):
4846
camelize_component(schema['items'])
4947
return schema
5048

drf_spectacular/extensions.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,30 @@ class OpenApiViewExtension(OpenApiGeneratorExtension['OpenApiViewExtension']):
106106
Extension for replacing discovered views with a more schema-appropriate/annotated version.
107107
108108
``view_replacement()`` is expected to return a subclass of ``APIView`` (which includes
109-
``ViewSet`` et al.). The discovered original view instance can be accessed with
110-
``self.target`` and be subclassed if desired.
109+
``ViewSet`` et al.). The discovered original view callback can be accessed with
110+
``self.target_callback``, while the discovered original view class can be accessed
111+
with ``self.target`` and can be subclassed if desired.
111112
"""
112113
_registry: List[Type['OpenApiViewExtension']] = []
113114

115+
def __init__(self, target_callback):
116+
super().__init__(target_callback.cls)
117+
self.target_callback = target_callback
118+
114119
@classmethod
115120
def _load_class(cls):
116121
super()._load_class()
117122
# special case @api_view: view class is nested in the cls attr of the function object
118123
if hasattr(cls.target_class, 'cls'):
119124
cls.target_class = cls.target_class.cls
120125

126+
@classmethod
127+
def get_match(cls, target_callback) -> 'Optional[OpenApiViewExtension]':
128+
for extension in sorted(cls._registry, key=lambda e: e.priority, reverse=True):
129+
if extension._matches(target_callback.cls):
130+
return extension(target_callback)
131+
return None
132+
121133
@abstractmethod
122134
def view_replacement(self) -> 'Type[APIView]':
123135
pass # pragma: no cover

drf_spectacular/generators.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import os
1+
import posixpath
22
import re
3+
import weakref
34

45
from django.urls import URLPattern, URLResolver
56
from rest_framework import views, viewsets
@@ -106,6 +107,7 @@ def __init__(self, *args, **kwargs):
106107
self.registry = ComponentRegistry()
107108
self.api_version = kwargs.pop('api_version', None)
108109
self.inspector = None
110+
self.schemas_storage = []
109111
super().__init__(*args, **kwargs)
110112

111113
def coerce_path(self, path, method, view):
@@ -126,7 +128,7 @@ def create_view(self, callback, method, request=None):
126128
decorating plain views like retrieve, this initialization logic is not running.
127129
Therefore forcefully set the schema if @extend_schema decorator was used.
128130
"""
129-
override_view = OpenApiViewExtension.get_match(callback.cls)
131+
override_view = OpenApiViewExtension.get_match(callback)
130132
if override_view:
131133
original_cls = callback.cls
132134
callback.cls = override_view.view_replacement()
@@ -179,7 +181,7 @@ def create_view(self, callback, method, request=None):
179181
) + view_schema_class.__mro__
180182
action_schema_class = type('ExtendedRearrangedSchema', mro, {})
181183

182-
view.schema = action_schema_class()
184+
self._set_schema_to_view(view, action_schema_class())
183185
return view
184186

185187
def _initialise_endpoints(self):
@@ -210,7 +212,7 @@ def parse(self, input_request, public):
210212
# than one view to prevent emission of erroneous and unnecessary fallback names.
211213
non_trivial_prefix = len(set([view.__class__ for _, _, _, view in endpoints])) > 1
212214
if non_trivial_prefix:
213-
path_prefix = os.path.commonpath([path for path, _, _, _ in endpoints])
215+
path_prefix = posixpath.commonpath([path for path, _, _, _ in endpoints])
214216
path_prefix = re.escape(path_prefix) # guard for RE special chars in path
215217
else:
216218
path_prefix = '/'
@@ -291,3 +293,11 @@ def get_schema(self, request=None, public=False):
291293
result = hook(result=result, generator=self, request=request, public=public)
292294

293295
return sanitize_result_object(normalize_result_object(result))
296+
297+
def _set_schema_to_view(self, view, schema):
298+
# The 'schema' argument is used to store the schema and view instance in the global scope,
299+
# as 'schema' is a descriptor. To facilitate garbage collection of these objects,
300+
# we wrap the schema in a weak reference and store it within the SchemaGenerator instance to keep it alive.
301+
# Thus, the lifetime of both the view and the schema is tied to the lifetime of the SchemaGenerator instance.
302+
view.schema = weakref.proxy(schema)
303+
self.schemas_storage.append(schema)

drf_spectacular/hooks.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
from collections import defaultdict
3+
from collections.abc import MutableMapping
34

45
from inflection import camelize
56
from rest_framework.settings import api_settings
@@ -66,8 +67,8 @@ def extract_hash(schema):
6667
for component_name, props in iter_prop_containers(schemas):
6768
for prop_name, prop_schema in props.items():
6869
if prop_schema.get('type') == 'array':
69-
prop_schema = prop_schema.get('items', {})
70-
if 'enum' not in prop_schema:
70+
prop_schema = prop_schema.get('items')
71+
if not isinstance(prop_schema, MutableMapping) or 'enum' not in prop_schema:
7172
continue
7273

7374
prop_enum_cleaned_hash = extract_hash(prop_schema)
@@ -117,9 +118,9 @@ def extract_hash(schema):
117118
for prop_name, prop_schema in props.items():
118119
is_array = prop_schema.get('type') == 'array'
119120
if is_array:
120-
prop_schema = prop_schema.get('items', {})
121+
prop_schema = prop_schema.get('items')
121122

122-
if 'enum' not in prop_schema:
123+
if not isinstance(prop_schema, MutableMapping) or 'enum' not in prop_schema:
123124
continue
124125

125126
prop_enum_original_list = prop_schema['enum']

drf_spectacular/openapi.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,8 @@ def _map_serializer_field(self, field, direction, bypass_extensions=False):
689689
schema = self._map_serializer_field(field.child_relation, direction)
690690
# remove hand-over initkwargs applying only to outer scope
691691
schema.pop('readOnly', None)
692+
# similarly from outer scope, default value is invalid inside 'items'
693+
schema.pop('default', None)
692694
if meta.get('description') == schema.get('description'):
693695
schema.pop('description', None)
694696
return append_meta(build_array_type(schema), meta)
@@ -713,7 +715,7 @@ def _map_serializer_field(self, field, direction, bypass_extensions=False):
713715
if isinstance(field.parent, serializers.ManyRelatedField):
714716
model = field.parent.parent.Meta.model
715717
source = field.parent.source.split('.')
716-
elif hasattr(field.parent, 'Meta'):
718+
elif hasattr(field.parent, 'Meta') and hasattr(field.parent.Meta, 'model'):
717719
model = field.parent.Meta.model
718720
source = field.source.split('.')
719721
else:
@@ -815,14 +817,14 @@ def _map_serializer_field(self, field, direction, bypass_extensions=False):
815817
if isinstance(field, serializers.DecimalField):
816818
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
817819
content = {**build_basic_type(OpenApiTypes.STR), 'format': 'decimal'}
818-
if field.max_whole_digits:
820+
if field.max_whole_digits is not None:
819821
content['pattern'] = (
820-
fr'^-?\d{{0,{field.max_whole_digits}}}'
821-
fr'(?:\.\d{{0,{field.decimal_places}}})?$'
822+
r'^-?0?' if field.max_whole_digits == 0 else fr'^-?\d{{0,{field.max_whole_digits}}}'
822823
)
824+
content['pattern'] += fr'(?:\.\d{{0,{field.decimal_places}}})?$'
823825
else:
824826
content = build_basic_type(OpenApiTypes.DECIMAL)
825-
if field.max_whole_digits:
827+
if field.max_whole_digits is not None:
826828
value = 10 ** field.max_whole_digits
827829
content.update({
828830
'maximum': value,
@@ -1076,7 +1078,7 @@ def _map_basic_serializer(self, serializer, direction):
10761078
return build_object_type(
10771079
properties=properties,
10781080
required=required,
1079-
description=get_doc(serializer.__class__),
1081+
description=get_override(serializer, 'description', get_doc(serializer.__class__)),
10801082
)
10811083

10821084
def _insert_field_validators(self, field, schema):

drf_spectacular/plumbing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def get_lib_doc_excludes():
200200
return [
201201
object,
202202
dict,
203+
Generic,
203204
views.APIView,
204205
*[getattr(serializers, c) for c in dir(serializers) if c.endswith('Serializer')],
205206
*[getattr(viewsets, c) for c in dir(viewsets) if c.endswith('ViewSet')],
@@ -939,7 +940,7 @@ def _load_enum_name_overrides(language: str):
939940

940941

941942
def list_hash(lst: Any) -> str:
942-
return hashlib.sha256(json.dumps(list(lst), sort_keys=True, cls=JSONEncoder).encode()).hexdigest()[:16]
943+
return hashlib.sha256(json.dumps(sorted(lst), sort_keys=True, cls=JSONEncoder).encode()).hexdigest()[:16]
943944

944945

945946
def anchor_pattern(pattern: str) -> str:

0 commit comments

Comments
 (0)