Skip to content

Commit d147a23

Browse files
fix bug of gfile compatibility (#536)
* fix bug of gfile compatibility --------- Co-authored-by: yanzhen1233 <[email protected]>
1 parent 0dca402 commit d147a23

File tree

11 files changed

+63
-42
lines changed

11 files changed

+63
-42
lines changed

easy_rec/python/core/sampler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import math
99
import os
1010
import sys
11-
# import re
1211
import threading
1312

1413
import numpy as np
@@ -20,6 +19,11 @@
2019
from easy_rec.python.utils.config_util import process_multi_file_input_path
2120
from easy_rec.python.utils.tf_utils import get_tf_type
2221

22+
if tf.__version__.startswith('1.'):
23+
from tensorflow.python.platform import gfile
24+
else:
25+
import tensorflow.io.gfile as gfile
26+
2327

2428
# patch graph-learn string_attrs for utf-8
2529
@property
@@ -395,7 +399,7 @@ def _load_data(self, data_path, attr_delimiter):
395399
item_id_col = 0
396400
fea_id_col = 2
397401
print('NegativeSamplerInMemory: load sample feature from %s' % data_path)
398-
with tf.gfile.GFile(data_path, 'r') as fin:
402+
with gfile.GFile(data_path, 'r') as fin:
399403
for line_id, line_str in enumerate(fin):
400404
line_str = line_str.strip()
401405
cols = line_str.split('\t')

easy_rec/python/export.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55

66
import tensorflow as tf
77
from tensorflow.python.lib.io import file_io
8-
from tensorflow.python.platform import gfile
98

109
from easy_rec.python.main import export
1110
from easy_rec.python.protos.train_pb2 import DistributionStrategy
1211
from easy_rec.python.utils import config_util
1312
from easy_rec.python.utils import estimator_utils
1413

14+
if tf.__version__.startswith('1.'):
15+
from tensorflow.python.platform import gfile
16+
else:
17+
import tensorflow.io.gfile as gfile
18+
1519
if tf.__version__ >= '2.0':
1620
tf = tf.compat.v1
1721

easy_rec/python/input/criteo_input.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44

55
import tensorflow as tf
6-
from tensorflow.python.platform import gfile
76

87
from easy_rec.python.input.criteo_binary_reader import BinaryDataset
98
from easy_rec.python.input.input import Input
@@ -38,9 +37,9 @@ def __init__(self,
3837
for label_path, dense_path, category_path in zip(
3938
input_path.label_path, input_path.dense_path,
4039
input_path.category_path):
41-
label_paths = gfile.Glob(input_path.label_path)
42-
dense_paths = gfile.Glob(input_path.dense_path)
43-
category_paths = gfile.Glob(input_path.category_path)
40+
label_paths = tf.gfile.Glob(input_path.label_path)
41+
dense_paths = tf.gfile.Glob(input_path.dense_path)
42+
category_paths = tf.gfile.Glob(input_path.category_path)
4443
assert len(label_paths) == len(dense_paths) and len(label_paths) == \
4544
len(category_paths), 'label_path(%s) dense_path(%s) category_path(%s) ' + \
4645
'matched different number of files(%d %d %d)' % (

easy_rec/python/input/datahub_input.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66

77
import tensorflow as tf
88
from tensorflow.python.framework import dtypes
9-
from tensorflow.python.platform import gfile
109

1110
from easy_rec.python.input.input import Input
1211
from easy_rec.python.utils import odps_util
1312
from easy_rec.python.utils.config_util import parse_time
1413

14+
if tf.__version__.startswith('1.'):
15+
from tensorflow.python.platform import gfile
16+
else:
17+
import tensorflow.io.gfile as gfile
18+
1519
try:
1620
import common_io
1721
except Exception:

easy_rec/python/input/kafka_input.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66

77
import six
88
import tensorflow as tf
9-
from tensorflow.python.platform import gfile
109

1110
from easy_rec.python.input.input import Input
1211
from easy_rec.python.input.kafka_dataset import KafkaDataset
1312
from easy_rec.python.utils.config_util import parse_time
1413

14+
if tf.__version__.startswith('1.'):
15+
from tensorflow.python.platform import gfile
16+
else:
17+
import tensorflow.io.gfile as gfile
18+
1519
try:
1620
from kafka import KafkaConsumer, TopicPartition
1721
except ImportError:

easy_rec/python/input/odps_rtp_input_v2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
from easy_rec.python.input.odps_rtp_input import OdpsRTPInput
99

10+
if tf.__version__.startswith('1.'):
11+
from tensorflow.python.platform import gfile
12+
else:
13+
import tensorflow.io.gfile as gfile
1014
try:
1115
import pai
1216
import rtp_fg
@@ -45,7 +49,7 @@ def __init__(self,
4549
logging.info('fg config path: {}'.format(self._fg_config_path))
4650
if self._fg_config_path is None:
4751
raise ValueError('fg_json_path is not set')
48-
with tf.gfile.GFile(self._fg_config_path, 'r') as f:
52+
with gfile.GFile(self._fg_config_path, 'r') as f:
4953
self._fg_config = json.load(f)
5054

5155
def _parse_table(self, *fields):

easy_rec/python/input/parquet_input.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,18 @@
33
import logging
44
import multiprocessing
55
import queue
6-
# import threading
76
import time
87

9-
# import numpy as np
10-
# import pandas as pd
118
import tensorflow as tf
12-
# from tensorflow.python.framework import ops
139
from tensorflow.python.ops import array_ops
14-
# from tensorflow.python.ops import logging_ops
15-
# from tensorflow.python.ops import math_ops
16-
from tensorflow.python.platform import gfile
1710

1811
from easy_rec.python.compat import queues
1912
from easy_rec.python.input import load_parquet
2013
from easy_rec.python.input.input import Input
2114

15+
if tf.__version__ >= '2.0':
16+
tf = tf.compat.v1
17+
2218

2319
class ParquetInput(Input):
2420

@@ -40,7 +36,7 @@ def __init__(self,
4036

4137
self._input_files = []
4238
for sub_path in input_path.strip().split(','):
43-
self._input_files.extend(gfile.Glob(sub_path))
39+
self._input_files.extend(tf.gfile.Glob(sub_path))
4440
logging.info('parquet input_path=%s file_num=%d' %
4541
(input_path, len(self._input_files)))
4642
mp_ctxt = multiprocessing.get_context('spawn')

easy_rec/python/input/parquet_input_v3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44

55
import tensorflow as tf
6-
from tensorflow.python.platform import gfile
76

87
from easy_rec.python.input.input import Input
98
from easy_rec.python.utils.input_utils import get_type_defaults
@@ -19,6 +18,9 @@
1918
_has_deep_rec = False
2019
pass
2120

21+
if tf.__version__ >= '2.0':
22+
tf = tf.compat.v1
23+
2224

2325
class ParquetInputV3(Input):
2426

@@ -114,7 +116,7 @@ def _parse_dataframe(self, df):
114116
def _build(self, mode, params):
115117
input_files = []
116118
for sub_path in self._input_path.strip().split(','):
117-
input_files.extend(gfile.Glob(sub_path))
119+
input_files.extend(tf.gfile.Glob(sub_path))
118120
file_num = len(input_files)
119121
logging.info('[task_index=%d] total_file_num=%d task_num=%d' %
120122
(self._task_index, file_num, self._task_num))

easy_rec/python/main.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import six
1515
import tensorflow as tf
1616
from tensorflow.core.protobuf import saved_model_pb2
17-
from tensorflow.python.platform import gfile
1817

1918
import easy_rec
2019
from easy_rec.python.builders import strategy_builder
@@ -240,27 +239,27 @@ def _metric_cmp_fn(best_eval_result, current_eval_result):
240239

241240
def _check_model_dir(model_dir, continue_train):
242241
if not continue_train:
243-
if not gfile.IsDirectory(model_dir):
244-
gfile.MakeDirs(model_dir)
242+
if not tf.gfile.IsDirectory(model_dir):
243+
tf.gfile.MakeDirs(model_dir)
245244
else:
246-
assert len(gfile.Glob(model_dir + '/model.ckpt-*.meta')) == 0, \
245+
assert len(tf.gfile.Glob(model_dir + '/model.ckpt-*.meta')) == 0, \
247246
'model_dir[=%s] already exists and not empty(if you ' \
248247
'want to continue train on current model_dir please ' \
249248
'delete dir %s or specify --continue_train[internal use only])' % (
250249
model_dir, model_dir)
251250
else:
252-
if not gfile.IsDirectory(model_dir):
251+
if not tf.gfile.IsDirectory(model_dir):
253252
logging.info('%s does not exists, create it automatically' % model_dir)
254-
gfile.MakeDirs(model_dir)
253+
tf.gfile.MakeDirs(model_dir)
255254

256255

257256
def _get_ckpt_path(pipeline_config, checkpoint_path):
258257
if checkpoint_path != '' and checkpoint_path is not None:
259-
if gfile.IsDirectory(checkpoint_path):
258+
if tf.gfile.IsDirectory(checkpoint_path):
260259
ckpt_path = estimator_utils.latest_checkpoint(checkpoint_path)
261260
else:
262261
ckpt_path = checkpoint_path
263-
elif gfile.IsDirectory(pipeline_config.model_dir):
262+
elif tf.gfile.IsDirectory(pipeline_config.model_dir):
264263
ckpt_path = estimator_utils.latest_checkpoint(pipeline_config.model_dir)
265264
logging.info('checkpoint_path is not specified, '
266265
'will use latest checkpoint %s from %s' %
@@ -284,7 +283,8 @@ def train_and_evaluate(pipeline_config_path, continue_train=False):
284283
Returns:
285284
None, the model will be saved into pipeline_config.model_dir
286285
"""
287-
assert gfile.Exists(pipeline_config_path), 'pipeline_config_path not exists'
286+
assert tf.gfile.Exists(
287+
pipeline_config_path), 'pipeline_config_path not exists'
288288
pipeline_config = config_util.get_configs_from_pipeline_file(
289289
pipeline_config_path)
290290

@@ -323,7 +323,7 @@ def _train_and_evaluate_impl(pipeline_config,
323323
if estimator_utils.is_chief():
324324
_check_model_dir(pipeline_config.model_dir, continue_train)
325325
config_util.save_pipeline_config(pipeline_config, pipeline_config.model_dir)
326-
with gfile.GFile(version_file, 'w') as f:
326+
with tf.gfile.GFile(version_file, 'w') as f:
327327
f.write(easy_rec.__version__ + '\n')
328328

329329
train_steps = None
@@ -509,7 +509,7 @@ def evaluate(pipeline_config,
509509
model_dir = pipeline_config.model_dir
510510
eval_result_file = os.path.join(model_dir, eval_result_filename)
511511
logging.info('save eval result to file %s' % eval_result_file)
512-
with gfile.GFile(eval_result_file, 'w') as ofile:
512+
with tf.gfile.GFile(eval_result_file, 'w') as ofile:
513513
result_to_write = {}
514514
for key in sorted(eval_result):
515515
# skip logging binary data
@@ -562,10 +562,10 @@ def distribute_evaluate(pipeline_config,
562562
return eval_result
563563
model_dir = get_model_dir_path(pipeline_config)
564564
eval_tmp_results_dir = os.path.join(model_dir, 'distribute_eval_tmp_results')
565-
if not gfile.IsDirectory(eval_tmp_results_dir):
565+
if not tf.gfile.IsDirectory(eval_tmp_results_dir):
566566
logging.info('create eval tmp results dir {}'.format(eval_tmp_results_dir))
567-
gfile.MakeDirs(eval_tmp_results_dir)
568-
assert gfile.IsDirectory(
567+
tf.gfile.MakeDirs(eval_tmp_results_dir)
568+
assert tf.gfile.IsDirectory(
569569
eval_tmp_results_dir), 'tmp results dir not create success.'
570570
os.environ['eval_tmp_results_dir'] = eval_tmp_results_dir
571571

@@ -679,7 +679,7 @@ def distribute_evaluate(pipeline_config,
679679
if cur_job_name == 'master':
680680
print('eval_result = ', eval_result)
681681
logging.info('eval_result = {0}'.format(eval_result))
682-
with gfile.GFile(eval_result_file, 'w') as ofile:
682+
with tf.gfile.GFile(eval_result_file, 'w') as ofile:
683683
result_to_write = {'eval_method': 'distribute'}
684684
for key in sorted(eval_result):
685685
# skip logging binary data
@@ -766,8 +766,8 @@ def export(export_dir,
766766
AssertionError, if:
767767
* pipeline_config_path does not exist
768768
"""
769-
if not gfile.Exists(export_dir):
770-
gfile.MakeDirs(export_dir)
769+
if not tf.gfile.Exists(export_dir):
770+
tf.gfile.MakeDirs(export_dir)
771771

772772
pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
773773
if pipeline_config.fg_json_path:
@@ -830,10 +830,10 @@ def export(export_dir,
830830
]
831831
export_ts = export_ts[-1]
832832
saved_pb_path = os.path.join(final_export_dir, 'saved_model.pb')
833-
with gfile.GFile(saved_pb_path, 'rb') as fin:
833+
with tf.gfile.GFile(saved_pb_path, 'rb') as fin:
834834
saved_model.ParseFromString(fin.read())
835835
saved_model.meta_graphs[0].meta_info_def.meta_graph_version = export_ts
836-
with gfile.GFile(saved_pb_path, 'wb') as fout:
836+
with tf.gfile.GFile(saved_pb_path, 'wb') as fout:
837837
fout.write(saved_model.SerializeToString())
838838

839839
logging.info('model has been exported to %s successfully' % final_export_dir)

easy_rec/python/test/hpo_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from easy_rec.python.utils import test_utils
1616

1717
if tf.__version__ >= '2.0':
18-
gfile = tf.compat.v1.gfile
18+
import tensorflow.io.gfile as gfile
1919
from tensorflow.core.protobuf import config_pb2
2020

2121
ConfigProto = config_pb2.ConfigProto
2222
GPUOptions = config_pb2.GPUOptions
2323
else:
24-
gfile = tf.gfile
24+
from tensorflow.python.platform import gfile
2525
GPUOptions = tf.GPUOptions
2626
ConfigProto = tf.ConfigProto
2727

0 commit comments

Comments
 (0)