@@ -310,7 +310,45 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_
310310 return opt , input_tensor_names , output_tensor_names
311311
312312
313+ def _get_graph_from_saved_model_v3 (model , input_tensor_names , output_tensor_names ):
314+ """The version 3 function that get graph from saved_model.
315+
316+ Args:
317+ model (string or tf.keras.Model): model path or tf.keras.Model object.
318+ input_tensor_names (list of string): input tensor names of the model.
319+ output_tensor_names (list of string): output tensor names of the model.
320+
321+ Returns:
322+ graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
323+ inputs (list of string): validated input names.
324+ outputs (list of string): validated output names.
325+ """
326+ from neural_compressor .adaptor .tf_utils .util import parse_saved_model
327+
328+ if isinstance (model , tf .keras .Model ):
329+ tmp_dir = cfg .default_workspace + "/saved_model"
330+ model .save (tmp_dir )
331+ model = tmp_dir
332+ graph_def , _ , _ , _ , input_names , output_names = parse_saved_model (
333+ model , True , input_tensor_names , output_tensor_names
334+ )
335+
336+ return graph_def , input_names , output_names
337+
338+
313339def _get_graph_from_saved_model_v2 (saved_model_dir , input_tensor_names , output_tensor_names ):
340+ """The version 2 function that get graph from the original keras model.
341+
342+ Args:
343+ saved_model_dir (string): model path of a temporary saved_model.
344+ input_tensor_names (list of string): input tensor names of the model.
345+ output_tensor_names (list of string): output tensor names of the model.
346+
347+ Returns:
348+ graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
349+ input_names (list of string): validated input names.
350+ output_names (list of string): validated output names.
351+ """
314352 from tensorflow .python .saved_model import signature_constants , tag_constants
315353
316354 saved_model_exported_names = [signature_constants .DEFAULT_SERVING_SIGNATURE_DEF_KEY ]
@@ -319,7 +357,17 @@ def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_t
319357 return load_saved_model (saved_model_dir , saved_model_tags , input_tensor_names , output_tensor_names )
320358
321359
322- def _get_graph_from_original_keras_v2 (model , output_dir ):
360+ def _get_graph_from_original_keras_v2 (model ):
361+ """The version 2 function that get graph from the original keras model.
362+
363+ Args:
364+ model (string or tf.keras.Model): model path or tf.keras.Model object.
365+
366+ Returns:
367+ graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
368+ input_names (list of string): validated input names.
369+ output_names (list of string): validated output names.
370+ """
323371 from tensorflow .lite .python .convert import OpsSet
324372 from tensorflow .lite .python .util import (
325373 get_grappler_config ,
@@ -364,6 +412,17 @@ def _get_graph_from_original_keras_v2(model, output_dir):
364412
365413
366414def _check_keras_format (model , saved_model_dir ):
415+ """Decide which method will be used to get graph from the saved_model .
416+
417+ Args:
418+ model (string or tf.keras.Model): model path or tf.keras.Model object.
419+ saved_model_dir (string): the path to save a temporary saved_model.
420+
421+ Returns:
422+ graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
423+ inputs (list of string): validated input names.
424+ outputs (list of string): validated output names.
425+ """
367426 from tensorflow .python import saved_model
368427 from tensorflow .python .saved_model import save_options
369428 from tensorflow .python .saved_model .load import load
@@ -384,6 +443,16 @@ def _check_keras_format(model, saved_model_dir):
384443
385444
386445def _get_graph_from_saved_model_v1 (model ):
446+ """The version 1 function that get graph from saved_model.
447+
448+ Args:
449+ model (string or tf.keras.Model): model path or tf.keras.Model object.
450+
451+ Returns:
452+ graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
453+ inputs (list of string): validated input names.
454+ outputs (list of string): validated output names.
455+ """
387456 from tensorflow .lite .python .convert_saved_model import get_inputs_outputs , get_meta_graph_def , get_signature_def
388457 from tensorflow .python .client import session
389458 from tensorflow .python .framework import ops
@@ -424,6 +493,51 @@ def _get_graph_from_saved_model_v1(model):
424493 return graph_def , inputs , outputs
425494
426495
496+ def try_loading_keras (model , input_tensor_names , output_tensor_names ):
497+ """Try different ways of loading keras models.
498+
499+ Args:
500+ model (string or tf.keras.Model): model path or tf.keras.Model object.
501+ input_tensor_names (list of string): input tensor names of the model.
502+ output_tensor_names (list of string): output tensor names of the model.
503+
504+ Returns:
505+ graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
506+ input_names (list of string): validated input names.
507+ output_names (list of string): validated output names.
508+ """
509+ temp_dir = tempfile .mkdtemp ()
510+ if not isinstance (model , tf .keras .Model ):
511+ model = tf .keras .models .load_model (model )
512+ keras_format = _check_keras_format (model , temp_dir )
513+
514+ if keras_format == "saved_model_v2" :
515+ try :
516+ graph_def , input_names , output_names = _get_graph_from_saved_model_v2 (
517+ temp_dir , input_tensor_names , output_tensor_names
518+ )
519+ if "_FusedBatchNormEx" in [node .op for node in graph_def .node ]:
520+ keras_format = "trackable_object"
521+ except :
522+ keras_format = "trackable_object"
523+
524+ if keras_format == "trackable_object" :
525+ try :
526+ graph_def , input_names , output_names = _get_graph_from_original_keras_v2 (model )
527+ except :
528+ keras_format = "saved_model_v1"
529+
530+ if keras_format == "saved_model_v1" : # pragma: no cover
531+ try :
532+ tf .keras .backend .set_learning_phase (0 )
533+ graph_def , input_names , output_names = _get_graph_from_saved_model_v1 (model )
534+ except :
535+ raise ValueError ("Not supported keras model type..." )
536+
537+ shutil .rmtree (temp_dir , True )
538+ return graph_def , input_names , output_names
539+
540+
427541def keras_session (model , input_tensor_names , output_tensor_names , ** kwargs ):
428542 """Build session with keras model.
429543
@@ -434,49 +548,19 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
434548
435549 Returns:
436550 sess (tf.compat.v1.Session): tf.compat.v1.Session object.
437- input_tensor_names (list of string): validated input_tensor_names.
438- output_tensor_names (list of string): validated output_tensor_names.
439551 """
440- temp_dir = tempfile .mkdtemp ()
441552 if tf .version .VERSION > "2.1.0" :
442- if not isinstance (model , tf .keras .Model ):
443- model = tf .keras .models .load_model (model )
444- keras_format = _check_keras_format (model , temp_dir )
445- if keras_format == "saved_model_v2" :
446- try :
447- graph_def , input_names , output_names = _get_graph_from_saved_model_v2 (
448- temp_dir , input_tensor_names , output_tensor_names
449- )
450- if "_FusedBatchNormEx" in [node .op for node in graph_def .node ]:
451- keras_format = "trackable_object"
452- except :
453- keras_format = "trackable_object"
454- if keras_format == "trackable_object" :
455- try :
456- graph_def , input_names , output_names = _get_graph_from_original_keras_v2 (model , temp_dir )
457- except :
458- keras_format = "saved_model_v1"
459- if keras_format == "saved_model_v1" : # pragma: no cover
460- try :
461- tf .keras .backend .set_learning_phase (0 )
462- graph_def , input_names , output_names = _get_graph_from_saved_model_v1 (model )
463- except :
464- keras_format = "saved_model_general"
465- if keras_format == "saved_model_general" : # pargma: no cover
466- try :
467- from neural_compressor .adaptor .tf_utils .util import parse_saved_model
468-
469- graph_def , _saved_model , _ , _ , input_names , output_names = parse_saved_model (
470- temp_dir , True , input_tensor_names , output_tensor_names
471- )
472- except :
473- raise ValueError ("Not supported keras model type..." )
474-
553+ try :
554+ graph_def , input_names , output_names = _get_graph_from_saved_model_v3 (
555+ model , input_tensor_names , output_tensor_names
556+ )
557+ except :
558+ graph_def , input_names , output_names = try_loading_keras (model , input_tensor_names , output_tensor_names )
475559 # tensorflow 1.x use v1 convert method
476560 else :
477561 tf .keras .backend .set_learning_phase (0 )
478562 graph_def , input_names , output_names = _get_graph_from_saved_model_v1 (model )
479- shutil . rmtree ( temp_dir , True )
563+
480564 return graph_def_session (graph_def , input_names , output_names , ** kwargs )
481565
482566
@@ -645,12 +729,19 @@ def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs
645729 output_tensor_names (list of string): validated output_tensor_names.
646730 """
647731 try :
648- graph_def , input_names , output_names = _get_graph_from_saved_model_v2 (
732+ graph_def , input_names , output_names = _get_graph_from_saved_model_v3 (
649733 model , input_tensor_names , output_tensor_names
650734 )
651735 except :
652- graph_def , input_names , output_names = _get_graph_from_saved_model_v1 (model )
736+ try :
737+ graph_def , input_names , output_names = _get_graph_from_saved_model_v2 (
738+ model , input_tensor_names , output_tensor_names
739+ )
740+ except :
741+ graph_def , input_names , output_names = _get_graph_from_saved_model_v1 (model )
742+
653743 assert graph_def is not None , "Can not parse the saved model..."
744+
654745 return graph_def_session (graph_def , input_names , output_names , ** kwargs )
655746
656747
0 commit comments