22
22
# https://github.com/kubeflow/katib/blob/master/examples/v1beta1/kubeflow-training-operator/tfjob-mnist-with-summaries.yaml#L16-L22
23
23
24
24
import tensorflow as tf
25
- from tensorboard .backend .event_processing .event_accumulator import EventAccumulator
25
+ from tensorboard .backend .event_processing .event_accumulator import EventAccumulator , TensorEvent
26
+ from tensorboard .backend .event_processing .tag_types import TENSORS
27
+ from tensorboard .compat .proto import tensor_pb2
26
28
import os
27
- from datetime import datetime
28
29
import rfc3339
30
+ from datetime import datetime
31
+ from dataclasses import is_dataclass
29
32
import api_pb2
30
33
from logging import getLogger , StreamHandler , INFO
31
34
from pkg .metricscollector .v1beta1 .common import const
@@ -41,27 +44,38 @@ def find_all_files(directory):
41
44
for f in files :
42
45
yield os .path .join (root , f )
43
46
47
+ @staticmethod
48
+ def new_metric_log (metric_name : str , wall_time : float , tensor : tensor_pb2 .TensorProto ) -> api_pb2 .MetricLog :
49
+ return api_pb2 .MetricLog (
50
+ time_stamp = rfc3339 .rfc3339 (datetime .fromtimestamp (wall_time )),
51
+ metric = api_pb2 .Metric (
52
+ name = metric_name ,
53
+ value = str (tf .make_ndarray (tensor ))
54
+ )
55
+ )
56
+
44
57
def parse_summary (self , tfefile ):
45
58
metric_logs = []
46
- event_accumulator = EventAccumulator (tfefile , size_guidance = {'tensors' : 0 })
59
+ event_accumulator = EventAccumulator (tfefile , size_guidance = {TENSORS : 0 })
47
60
event_accumulator .Reload ()
48
- for tag in event_accumulator .Tags ()['tensors' ]:
61
+ for tag in event_accumulator .Tags ()[TENSORS ]:
49
62
for m in self .metric_names :
50
63
51
64
tfefile_parent_dir = os .path .dirname (m ) if len (m .split ("/" )) >= 2 else os .path .dirname (tfefile )
52
65
basedir_name = os .path .dirname (tfefile )
53
66
if not tag .startswith (m .split ("/" )[- 1 ]) or not basedir_name .endswith (tfefile_parent_dir ):
54
67
continue
55
68
56
- for wall_time , step , tensor in event_accumulator .Tensors (tag ):
57
- ml = api_pb2 .MetricLog (
58
- time_stamp = rfc3339 .rfc3339 (datetime .fromtimestamp (wall_time )),
59
- metric = api_pb2 .Metric (
60
- name = m ,
61
- value = str (tf .make_ndarray (tensor ))
62
- )
63
- )
64
- metric_logs .append (ml )
69
+ # Since Tensorboard v2.12.0, the 'TensorEvent' typed was changed from namedtuple to dataclass.
70
+ # REF: https://github.com/tensorflow/tensorboard/commit/1975529d953eff55d279ee036240e4db4cb0d57c
71
+ if is_dataclass (TensorEvent ):
72
+ # Tensorboard >= v2.12.0
73
+ for tensors in event_accumulator .Tensors (tag ):
74
+ metric_logs .append (self .new_metric_log (m , tensors .wall_time , tensors .tensor_proto ))
75
+ else :
76
+ # Tensorboard < v2.12.0
77
+ for wall_time , step , tensor in event_accumulator .Tensors (tag ):
78
+ metric_logs .append (self .new_metric_log (m , wall_time , tensor ))
65
79
66
80
return metric_logs
67
81
0 commit comments