Skip to content

Commit d8e60b8

Browse files
authored
Fix TF saved_model issues (#1659)
Signed-off-by: zehao-intel <[email protected]>
1 parent d07175c commit d8e60b8

File tree

1 file changed

+131
-40
lines changed

1 file changed

+131
-40
lines changed

neural_compressor/model/tensorflow_model.py

Lines changed: 131 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,45 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_
310310
return opt, input_tensor_names, output_tensor_names
311311

312312

313+
def _get_graph_from_saved_model_v3(model, input_tensor_names, output_tensor_names):
314+
"""The version 3 function that get graph from saved_model.
315+
316+
Args:
317+
model (string or tf.keras.Model): model path or tf.keras.Model object.
318+
input_tensor_names (list of string): input tensor names of the model.
319+
output_tensor_names (list of string): output tensor names of the model.
320+
321+
Returns:
322+
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
323+
inputs (list of string): validated input names.
324+
outputs (list of string): validated output names.
325+
"""
326+
from neural_compressor.adaptor.tf_utils.util import parse_saved_model
327+
328+
if isinstance(model, tf.keras.Model):
329+
tmp_dir = cfg.default_workspace + "/saved_model"
330+
model.save(tmp_dir)
331+
model = tmp_dir
332+
graph_def, _, _, _, input_names, output_names = parse_saved_model(
333+
model, True, input_tensor_names, output_tensor_names
334+
)
335+
336+
return graph_def, input_names, output_names
337+
338+
313339
def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names):
340+
"""The version 2 function that get graph from the original keras model.
341+
342+
Args:
343+
saved_model_dir (string): model path of a temporary saved_model.
344+
input_tensor_names (list of string): input tensor names of the model.
345+
output_tensor_names (list of string): output tensor names of the model.
346+
347+
Returns:
348+
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
349+
input_names (list of string): validated input names.
350+
output_names (list of string): validated output names.
351+
"""
314352
from tensorflow.python.saved_model import signature_constants, tag_constants
315353

316354
saved_model_exported_names = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@@ -319,7 +357,17 @@ def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_t
319357
return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names)
320358

321359

322-
def _get_graph_from_original_keras_v2(model, output_dir):
360+
def _get_graph_from_original_keras_v2(model):
361+
"""The version 2 function that get graph from the original keras model.
362+
363+
Args:
364+
model (string or tf.keras.Model): model path or tf.keras.Model object.
365+
366+
Returns:
367+
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
368+
input_names (list of string): validated input names.
369+
output_names (list of string): validated output names.
370+
"""
323371
from tensorflow.lite.python.convert import OpsSet
324372
from tensorflow.lite.python.util import (
325373
get_grappler_config,
@@ -364,6 +412,17 @@ def _get_graph_from_original_keras_v2(model, output_dir):
364412

365413

366414
def _check_keras_format(model, saved_model_dir):
415+
"""Decide which method will be used to get graph from the saved_model .
416+
417+
Args:
418+
model (string or tf.keras.Model): model path or tf.keras.Model object.
419+
saved_model_dir (string): the path to save a temporary saved_model.
420+
421+
Returns:
422+
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
423+
inputs (list of string): validated input names.
424+
outputs (list of string): validated output names.
425+
"""
367426
from tensorflow.python import saved_model
368427
from tensorflow.python.saved_model import save_options
369428
from tensorflow.python.saved_model.load import load
@@ -384,6 +443,16 @@ def _check_keras_format(model, saved_model_dir):
384443

385444

386445
def _get_graph_from_saved_model_v1(model):
446+
"""The version 1 function that get graph from saved_model.
447+
448+
Args:
449+
model (string or tf.keras.Model): model path or tf.keras.Model object.
450+
451+
Returns:
452+
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
453+
inputs (list of string): validated input names.
454+
outputs (list of string): validated output names.
455+
"""
387456
from tensorflow.lite.python.convert_saved_model import get_inputs_outputs, get_meta_graph_def, get_signature_def
388457
from tensorflow.python.client import session
389458
from tensorflow.python.framework import ops
@@ -424,6 +493,51 @@ def _get_graph_from_saved_model_v1(model):
424493
return graph_def, inputs, outputs
425494

426495

496+
def try_loading_keras(model, input_tensor_names, output_tensor_names):
497+
"""Try different ways of loading keras models.
498+
499+
Args:
500+
model (string or tf.keras.Model): model path or tf.keras.Model object.
501+
input_tensor_names (list of string): input tensor names of the model.
502+
output_tensor_names (list of string): output tensor names of the model.
503+
504+
Returns:
505+
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
506+
input_names (list of string): validated input names.
507+
output_names (list of string): validated output names.
508+
"""
509+
temp_dir = tempfile.mkdtemp()
510+
if not isinstance(model, tf.keras.Model):
511+
model = tf.keras.models.load_model(model)
512+
keras_format = _check_keras_format(model, temp_dir)
513+
514+
if keras_format == "saved_model_v2":
515+
try:
516+
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
517+
temp_dir, input_tensor_names, output_tensor_names
518+
)
519+
if "_FusedBatchNormEx" in [node.op for node in graph_def.node]:
520+
keras_format = "trackable_object"
521+
except:
522+
keras_format = "trackable_object"
523+
524+
if keras_format == "trackable_object":
525+
try:
526+
graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model)
527+
except:
528+
keras_format = "saved_model_v1"
529+
530+
if keras_format == "saved_model_v1": # pragma: no cover
531+
try:
532+
tf.keras.backend.set_learning_phase(0)
533+
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
534+
except:
535+
raise ValueError("Not supported keras model type...")
536+
537+
shutil.rmtree(temp_dir, True)
538+
return graph_def, input_names, output_names
539+
540+
427541
def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
428542
"""Build session with keras model.
429543
@@ -434,49 +548,19 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
434548
435549
Returns:
436550
sess (tf.compat.v1.Session): tf.compat.v1.Session object.
437-
input_tensor_names (list of string): validated input_tensor_names.
438-
output_tensor_names (list of string): validated output_tensor_names.
439551
"""
440-
temp_dir = tempfile.mkdtemp()
441552
if tf.version.VERSION > "2.1.0":
442-
if not isinstance(model, tf.keras.Model):
443-
model = tf.keras.models.load_model(model)
444-
keras_format = _check_keras_format(model, temp_dir)
445-
if keras_format == "saved_model_v2":
446-
try:
447-
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
448-
temp_dir, input_tensor_names, output_tensor_names
449-
)
450-
if "_FusedBatchNormEx" in [node.op for node in graph_def.node]:
451-
keras_format = "trackable_object"
452-
except:
453-
keras_format = "trackable_object"
454-
if keras_format == "trackable_object":
455-
try:
456-
graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model, temp_dir)
457-
except:
458-
keras_format = "saved_model_v1"
459-
if keras_format == "saved_model_v1": # pragma: no cover
460-
try:
461-
tf.keras.backend.set_learning_phase(0)
462-
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
463-
except:
464-
keras_format = "saved_model_general"
465-
if keras_format == "saved_model_general": # pargma: no cover
466-
try:
467-
from neural_compressor.adaptor.tf_utils.util import parse_saved_model
468-
469-
graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model(
470-
temp_dir, True, input_tensor_names, output_tensor_names
471-
)
472-
except:
473-
raise ValueError("Not supported keras model type...")
474-
553+
try:
554+
graph_def, input_names, output_names = _get_graph_from_saved_model_v3(
555+
model, input_tensor_names, output_tensor_names
556+
)
557+
except:
558+
graph_def, input_names, output_names = try_loading_keras(model, input_tensor_names, output_tensor_names)
475559
# tensorflow 1.x use v1 convert method
476560
else:
477561
tf.keras.backend.set_learning_phase(0)
478562
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
479-
shutil.rmtree(temp_dir, True)
563+
480564
return graph_def_session(graph_def, input_names, output_names, **kwargs)
481565

482566

@@ -645,12 +729,19 @@ def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs
645729
output_tensor_names (list of string): validated output_tensor_names.
646730
"""
647731
try:
648-
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
732+
graph_def, input_names, output_names = _get_graph_from_saved_model_v3(
649733
model, input_tensor_names, output_tensor_names
650734
)
651735
except:
652-
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
736+
try:
737+
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
738+
model, input_tensor_names, output_tensor_names
739+
)
740+
except:
741+
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
742+
653743
assert graph_def is not None, "Can not parse the saved model..."
744+
654745
return graph_def_session(graph_def, input_names, output_names, **kwargs)
655746

656747

0 commit comments

Comments
 (0)