Skip to content

Commit 2a1b6a0

Browse files
committed
count size
1 parent fb4c19b commit 2a1b6a0

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

olive/passes/onnx/common.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
logger = logging.getLogger(__name__)
2323

24+
_LARGE_IR_MODEL_THRESHOLD = 1536 * 1024 * 1024 # 1536MB
25+
2426

2527
def get_external_data_config() -> Dict[str, PassConfigParam]:
2628
return {
@@ -201,6 +203,11 @@ def model_proto_to_olive_model(
201203
return olive_model
202204

203205

206+
def _count_initializer_size(graph: ir.Graph) -> int:
207+
"""Count the total size of the initializers in bytes."""
208+
return sum(v.const_value.nbytes for v in graph.initializers.values() if v.const_value is not None)
209+
210+
204211
def ir_model_to_olive_model(
205212
model: ir.Model,
206213
output_model_path: Union[str, Path],
@@ -224,20 +231,28 @@ def ir_model_to_olive_model(
224231
external_data_config = external_data_config.dict()
225232

226233
save_as_external_data = external_data_config.get("save_as_external_data")
234+
# Save as external data if requested or if the model is large
235+
# Since we do not have a true estimate of the model architecture size for IR Model,
236+
# we count the size of all initializers and limit that to 1.5GB.
237+
initializer_size = _count_initializer_size(model.graph)
238+
is_large_model = initializer_size > _LARGE_IR_MODEL_THRESHOLD
239+
if is_large_model:
240+
logger.debug("Model is large (%s), saving as external data", initializer_size)
241+
save_as_external_data = save_as_external_data or is_large_model
227242

228243
if save_as_external_data:
229244
external_data_name = _get_external_data_name(
230245
Path(output_model_path), external_data_config.get("external_data_name")
231246
)
232247
ir.save(model, output_model_path, external_data=external_data_name)
233-
else:
234-
ir.save(model, output_model_path)
235248

236-
if external_data_name:
237249
logger.debug("Model was saved with external data: %s", external_data_name)
238250
model_path = LocalFolder({"path": Path(output_model_path).parent})
239251
onnx_file_name = Path(output_model_path).name
252+
240253
else:
254+
ir.save(model, output_model_path)
255+
241256
logger.debug("Model was not saved with external data")
242257
model_path = LocalFile({"path": output_model_path})
243258
onnx_file_name = None

0 commit comments

Comments
 (0)