@@ -26,6 +26,11 @@ def predict_compressed_model(model_dir,
2626 Returns:
2727 latency_dict(dict): The latency latency of the model under various compression strategies.
2828 """
29+ local_rank = paddle .distributed .get_rank ()
30+ quant_model_path = f'quant_model/rank_{ local_rank } '
31+ prune_model_path = f'prune_model/rank_{ local_rank } '
32+ sparse_model_path = f'sparse_model/rank_{ local_rank } '
33+
2934 latency_dict = {}
3035
3136 model_file = os .path .join (model_dir , model_filename )
@@ -43,13 +48,13 @@ def predict_compressed_model(model_dir,
4348 model_dir = model_dir ,
4449 model_filename = model_filename ,
4550 params_filename = params_filename ,
46- save_model_path = 'quant_model' ,
51+ save_model_path = quant_model_path ,
4752 quantizable_op_type = ["conv2d" , "depthwise_conv2d" , "mul" ],
4853 is_full_quantize = False ,
4954 activation_bits = 8 ,
5055 weight_bits = 8 )
51- quant_model_file = os .path .join ('quant_model' , model_filename )
52- quant_param_file = os .path .join ('quant_model' , params_filename )
56+ quant_model_file = os .path .join (quant_model_path , model_filename )
57+ quant_param_file = os .path .join (quant_model_path , params_filename )
5358
5459 latency = predictor .predict (
5560 model_file = quant_model_file ,
@@ -62,9 +67,9 @@ def predict_compressed_model(model_dir,
6267 model_file = model_file ,
6368 param_file = param_file ,
6469 ratio = prune_ratio ,
65- save_path = 'prune_model' )
66- prune_model_file = os .path .join ('prune_model' , model_filename )
67- prune_param_file = os .path .join ('prune_model' , params_filename )
70+ save_path = prune_model_path )
71+ prune_model_file = os .path .join (prune_model_path , model_filename )
72+ prune_param_file = os .path .join (prune_model_path , params_filename )
6873
6974 latency = predictor .predict (
7075 model_file = prune_model_file ,
@@ -74,16 +79,16 @@ def predict_compressed_model(model_dir,
7479
7580 post_quant_fake (
7681 exe ,
77- model_dir = 'prune_model' ,
82+ model_dir = prune_model_path ,
7883 model_filename = model_filename ,
7984 params_filename = params_filename ,
80- save_model_path = 'quant_model' ,
85+ save_model_path = quant_model_path ,
8186 quantizable_op_type = ["conv2d" , "depthwise_conv2d" , "mul" ],
8287 is_full_quantize = False ,
8388 activation_bits = 8 ,
8489 weight_bits = 8 )
85- quant_model_file = os .path .join ('quant_model' , model_filename )
86- quant_param_file = os .path .join ('quant_model' , params_filename )
90+ quant_model_file = os .path .join (quant_model_path , model_filename )
91+ quant_param_file = os .path .join (quant_model_path , params_filename )
8792
8893 latency = predictor .predict (
8994 model_file = quant_model_file ,
@@ -96,9 +101,9 @@ def predict_compressed_model(model_dir,
96101 model_file = model_file ,
97102 param_file = param_file ,
98103 ratio = sparse_ratio ,
99- save_path = 'sparse_model' )
100- sparse_model_file = os .path .join ('sparse_model' , model_filename )
101- sparse_param_file = os .path .join ('sparse_model' , params_filename )
104+ save_path = sparse_model_path )
105+ sparse_model_file = os .path .join (sparse_model_path , model_filename )
106+ sparse_param_file = os .path .join (sparse_model_path , params_filename )
102107
103108 latency = predictor .predict (
104109 model_file = sparse_model_file ,
@@ -108,25 +113,28 @@ def predict_compressed_model(model_dir,
108113
109114 post_quant_fake (
110115 exe ,
111- model_dir = 'sparse_model' ,
116+ model_dir = sparse_model_path ,
112117 model_filename = model_filename ,
113118 params_filename = params_filename ,
114119 save_model_path = 'quant_model' ,
115120 quantizable_op_type = ["conv2d" , "depthwise_conv2d" , "mul" ],
116121 is_full_quantize = False ,
117122 activation_bits = 8 ,
118123 weight_bits = 8 )
119- quant_model_file = os .path .join ('quant_model' , model_filename )
120- quant_param_file = os .path .join ('quant_model' , params_filename )
124+ quant_model_file = os .path .join (quant_model_path , model_filename )
125+ quant_param_file = os .path .join (quant_model_path , params_filename )
121126
122127 latency = predictor .predict (
123128 model_file = quant_model_file ,
124129 param_file = quant_param_file ,
125130 data_type = 'int8' )
126131 latency_dict .update ({f'sparse_{ sparse_ratio } _int8' : latency })
127132
128- # Delete temporary model files
129- shutil .rmtree ('./quant_model' )
130- shutil .rmtree ('./prune_model' )
131- shutil .rmtree ('./sparse_model' )
133+ # NOTE: Delete temporary model files
134+ if os .path .exists ('quant_model' ):
135+ shutil .rmtree ('quant_model' , ignore_errors = True )
136+ if os .path .exists ('prune_model' ):
137+ shutil .rmtree ('prune_model' , ignore_errors = True )
138+ if os .path .exists ('sparse_model' ):
139+ shutil .rmtree ('sparse_model' , ignore_errors = True )
132140 return latency_dict
0 commit comments