@@ -93,38 +93,10 @@ def calculate_size(
9393 Returns the corrected shape of `inputs` and the size of
9494 a single element in bytes.
9595 """
96-
97- def nested_list_size (
98- inputs : Sequence [Any ] | torch .Tensor ,
99- ) -> tuple [list [int ], int ]:
100- """Flattens nested list size."""
101-
102- if hasattr (inputs , "tensors" ):
103- size , elem_bytes = nested_list_size (inputs .tensors )
104- elif isinstance (inputs , torch .Tensor ):
105- size , elem_bytes = list (inputs .size ()), inputs .element_size ()
106- elif not hasattr (inputs , "__getitem__" ) or not inputs :
107- size , elem_bytes = [], 0
108- elif isinstance (inputs , dict ):
109- size , elem_bytes = nested_list_size (list (inputs .values ()))
110- elif (
111- hasattr (inputs , "size" )
112- and callable (inputs .size )
113- and hasattr (inputs , "element_size" )
114- and callable (inputs .element_size )
115- ):
116- size , elem_bytes = list (inputs .size ()), inputs .element_size ()
117- elif isinstance (inputs , (list , tuple )):
118- size , elem_bytes = nested_list_size (inputs [0 ])
119- else :
120- size , elem_bytes = [], 0
121-
122- return size , elem_bytes
123-
12496 if inputs is None :
12597 size , elem_bytes = [], 0
12698
127- # pack_padded_seq and pad_packed_seq store feature into data attribute
99+ # pack_padded_seq and pad_packed_seq store feature into data attribute
128100 elif (
129101 isinstance (inputs , (list , tuple )) and inputs and hasattr (inputs [0 ], "data" )
130102 ):
@@ -337,6 +309,31 @@ def leftover_trainable_params(self) -> int:
337309 )
338310
339311
312+ def nested_list_size (inputs : Sequence [Any ] | torch .Tensor ) -> tuple [list [int ], int ]:
313+ """Flattens nested list size."""
314+ if hasattr (inputs , "tensors" ):
315+ size , elem_bytes = nested_list_size (inputs .tensors )
316+ elif isinstance (inputs , torch .Tensor ):
317+ size , elem_bytes = list (inputs .size ()), inputs .element_size ()
318+ elif not hasattr (inputs , "__getitem__" ) or not inputs :
319+ size , elem_bytes = [], 0
320+ elif isinstance (inputs , dict ):
321+ size , elem_bytes = nested_list_size (list (inputs .values ()))
322+ elif (
323+ hasattr (inputs , "size" )
324+ and callable (inputs .size )
325+ and hasattr (inputs , "element_size" )
326+ and callable (inputs .element_size )
327+ ):
328+ size , elem_bytes = list (inputs .size ()), inputs .element_size ()
329+ elif isinstance (inputs , (list , tuple )):
330+ size , elem_bytes = nested_list_size (inputs [0 ])
331+ else :
332+ size , elem_bytes = [], 0
333+
334+ return size , elem_bytes
335+
336+
340337def prod (num_list : Iterable [int ] | torch .Size ) -> int :
341338 result = 1
342339 if isinstance (num_list , Iterable ):
0 commit comments