@@ -52,7 +52,7 @@ class E(B): pass
5252from sphinx .util .docutils import SphinxDirective
5353
5454if TYPE_CHECKING :
55- from collections .abc import Iterable , Sequence
55+ from collections .abc import Collection , Iterable , Iterator , Sequence , Set
5656 from typing import Any , ClassVar , Final
5757
5858 from docutils .nodes import Node
@@ -106,7 +106,7 @@ def try_import(objname: str) -> Any:
106106 return None
107107
108108
109- def import_classes (name : str , currmodule : str ) -> Any :
109+ def import_classes (name : str , currmodule : str ) -> list [ type [ Any ]] :
110110 """Import a class using its fully-qualified *name*."""
111111 target = None
112112
@@ -156,37 +156,45 @@ def __init__(
156156 private_bases : bool = False ,
157157 parts : int = 0 ,
158158 aliases : dict [str , str ] | None = None ,
159- top_classes : Sequence [Any ] = (),
159+ top_classes : Set [str ] = frozenset (),
160+ include_subclasses : bool = False ,
160161 ) -> None :
161162 """*class_names* is a list of child classes to show bases from.
162163
163164 If *show_builtins* is True, then Python builtins will be shown
164165 in the graph.
165166 """
166167 self .class_names = class_names
167- classes = self ._import_classes (class_names , currmodule )
168+ classes : Collection [type [Any ]] = self ._import_classes (class_names , currmodule )
169+ if include_subclasses :
170+ classes_set = {* classes }
171+ for cls in tuple (classes_set ):
172+ classes_set .update (_subclasses (cls ))
173+ classes = classes_set
168174 self .class_info = self ._class_info (
169175 classes , show_builtins , private_bases , parts , aliases , top_classes
170176 )
171177 if not self .class_info :
172178 msg = 'No classes found for inheritance diagram'
173179 raise InheritanceException (msg )
174180
175- def _import_classes (self , class_names : list [str ], currmodule : str ) -> list [Any ]:
181+ def _import_classes (
182+ self , class_names : list [str ], currmodule : str
183+ ) -> Sequence [type [Any ]]:
176184 """Import a list of classes."""
177- classes : list [Any ] = []
185+ classes : list [type [ Any ] ] = []
178186 for name in class_names :
179187 classes .extend (import_classes (name , currmodule ))
180188 return classes
181189
182190 def _class_info (
183191 self ,
184- classes : list [ Any ],
192+ classes : Collection [ type [ Any ] ],
185193 show_builtins : bool ,
186194 private_bases : bool ,
187195 parts : int ,
188196 aliases : dict [str , str ] | None ,
189- top_classes : Sequence [ Any ],
197+ top_classes : Set [ str ],
190198 ) -> list [tuple [str , str , Sequence [str ], str | None ]]:
191199 """Return name and bases for all classes that are ancestors of
192200 *classes*.
@@ -205,7 +213,7 @@ def _class_info(
205213 """
206214 all_classes = {}
207215
208- def recurse (cls : Any ) -> None :
216+ def recurse (cls : type [ Any ] ) -> None :
209217 if not show_builtins and cls in PY_BUILTINS :
210218 return
211219 if not private_bases and cls .__name__ .startswith ('_' ):
@@ -248,7 +256,7 @@ def recurse(cls: Any) -> None:
248256 ]
249257
250258 def class_name (
251- self , cls : Any , parts : int = 0 , aliases : dict [str , str ] | None = None
259+ self , cls : type [ Any ] , parts : int = 0 , aliases : dict [str , str ] | None = None
252260 ) -> str :
253261 """Given a class object, return a fully-qualified name.
254262
@@ -377,6 +385,7 @@ class InheritanceDiagram(SphinxDirective):
377385 'private-bases' : directives .flag ,
378386 'caption' : directives .unchanged ,
379387 'top-classes' : directives .unchanged_required ,
388+ 'include-subclasses' : directives .flag ,
380389 }
381390
382391 def run (self ) -> list [Node ]:
@@ -387,11 +396,11 @@ def run(self) -> list[Node]:
387396 # Store the original content for use as a hash
388397 node ['parts' ] = self .options .get ('parts' , 0 )
389398 node ['content' ] = ', ' .join (class_names )
390- node ['top-classes' ] = []
391- for cls in self . options . get ( 'top-classes' , '' ). split ( ',' ):
392- cls = cls . strip ( )
393- if cls :
394- node [ 'top-classes' ]. append ( cls )
399+ node ['top-classes' ] = frozenset ({
400+ cls_stripped
401+ for cls in self . options . get ( 'top-classes' , '' ). split ( ',' )
402+ if ( cls_stripped := cls . strip ())
403+ } )
395404
396405 # Create a graph starting with the list of classes
397406 try :
@@ -402,6 +411,7 @@ def run(self) -> list[Node]:
402411 private_bases = 'private-bases' in self .options ,
403412 aliases = self .config .inheritance_alias ,
404413 top_classes = node ['top-classes' ],
414+ include_subclasses = 'include-subclasses' in self .options ,
405415 )
406416 except InheritanceException as err :
407417 return [node .document .reporter .warning (err , line = self .lineno )]
@@ -428,6 +438,12 @@ def run(self) -> list[Node]:
428438 return [figure ]
429439
430440
441+ def _subclasses (cls : type [Any ]) -> Iterator [type [Any ]]:
442+ yield cls
443+ for sub_cls in cls .__subclasses__ ():
444+ yield from _subclasses (sub_cls )
445+
446+
431447def get_graph_hash (node : inheritance_diagram ) -> str :
432448 encoded = (node ['content' ] + str (node ['parts' ])).encode ()
433449 return hashlib .md5 (encoded , usedforsecurity = False ).hexdigest ()[- 10 :]
0 commit comments