Skip to content

Commit 1975529

Browse files
authored
cleanup: replace namedtuples with dataclasses in event_accumulator (#6121)
This PR reattempts the changes reverted in #6013 and fixes #5725. Googlers, see prerequisite internal changes at cl/b:234007753 and global internal test at cl/499266633.
1 parent 20cd56b commit 1975529

File tree

1 file changed

+130
-38
lines changed

1 file changed

+130
-38
lines changed

tensorboard/backend/event_processing/event_accumulator.py

Lines changed: 130 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
"""Takes a generator of values, and accumulates them for a frontend."""
1616

1717
import collections
18+
import dataclasses
1819
import threading
1920

20-
from typing import Optional
21+
from typing import Optional, Sequence, Tuple
2122

2223
from tensorboard.backend.event_processing import directory_watcher
2324
from tensorboard.backend.event_processing import event_file_loader
@@ -30,47 +31,137 @@
3031
from tensorboard.compat.proto import event_pb2
3132
from tensorboard.compat.proto import graph_pb2
3233
from tensorboard.compat.proto import meta_graph_pb2
34+
from tensorboard.compat.proto import tensor_pb2
3335
from tensorboard.plugins.distribution import compressor
3436
from tensorboard.util import tb_logging
3537

3638

3739
logger = tb_logging.get_logger()
3840

39-
namedtuple = collections.namedtuple
40-
ScalarEvent = namedtuple("ScalarEvent", ["wall_time", "step", "value"])
41-
42-
CompressedHistogramEvent = namedtuple(
43-
"CompressedHistogramEvent",
44-
["wall_time", "step", "compressed_histogram_values"],
45-
)
46-
47-
HistogramEvent = namedtuple(
48-
"HistogramEvent", ["wall_time", "step", "histogram_value"]
49-
)
50-
51-
HistogramValue = namedtuple(
52-
"HistogramValue",
53-
["min", "max", "num", "sum", "sum_squares", "bucket_limit", "bucket"],
54-
)
55-
56-
ImageEvent = namedtuple(
57-
"ImageEvent",
58-
["wall_time", "step", "encoded_image_string", "width", "height"],
59-
)
60-
61-
AudioEvent = namedtuple(
62-
"AudioEvent",
63-
[
64-
"wall_time",
65-
"step",
66-
"encoded_audio_string",
67-
"content_type",
68-
"sample_rate",
69-
"length_frames",
70-
],
71-
)
72-
73-
TensorEvent = namedtuple("TensorEvent", ["wall_time", "step", "tensor_proto"])
41+
42+
@dataclasses.dataclass(frozen=True)
43+
class ScalarEvent:
44+
"""Contains information of a scalar event.
45+
46+
Attributes:
47+
wall_time: Timestamp of the event in seconds.
48+
step: Global step of the event.
49+
value: A float or int value of the scalar.
50+
"""
51+
52+
wall_time: float
53+
step: int
54+
value: float
55+
56+
57+
@dataclasses.dataclass(frozen=True)
58+
class CompressedHistogramEvent:
59+
"""Contains information of a compressed histogram event.
60+
61+
Attributes:
62+
wall_time: Timestamp of the event in seconds.
63+
step: Global step of the event.
64+
compressed_histogram_values: A sequence of tuples of basis points and
65+
associated values in a compressed histogram.
66+
"""
67+
68+
wall_time: float
69+
step: int
70+
compressed_histogram_values: Sequence[Tuple[float, float]]
71+
72+
73+
@dataclasses.dataclass(frozen=True)
74+
class HistogramValue:
75+
"""Holds the information of the histogram values.
76+
77+
Attributes:
78+
min: A float or int min value.
79+
max: A float or int max value.
80+
num: Total number of values.
81+
sum: Sum of all values.
82+
sum_squares: Sum of squares for all values.
83+
bucket_limit: Upper values per bucket.
84+
bucket: Numbers of values per bucket.
85+
"""
86+
87+
min: float
88+
max: float
89+
num: int
90+
sum: float
91+
sum_squares: float
92+
bucket_limit: Sequence[float]
93+
bucket: Sequence[int]
94+
95+
96+
@dataclasses.dataclass(frozen=True)
97+
class HistogramEvent:
98+
"""Contains information of a histogram event.
99+
100+
Attributes:
101+
wall_time: Timestamp of the event in seconds.
102+
step: Global step of the event.
103+
histogram_value: Information of the histogram values.
104+
"""
105+
106+
wall_time: float
107+
step: int
108+
histogram_value: HistogramValue
109+
110+
111+
@dataclasses.dataclass(frozen=True)
112+
class ImageEvent:
113+
"""Contains information of an image event.
114+
115+
Attributes:
116+
wall_time: Timestamp of the event in seconds.
117+
step: Global step of the event.
118+
encoded_image_string: Image content encoded in bytes.
119+
width: Width of the image.
120+
height: Height of the image.
121+
"""
122+
123+
wall_time: float
124+
step: int
125+
encoded_image_string: bytes
126+
width: int
127+
height: int
128+
129+
130+
@dataclasses.dataclass(frozen=True)
131+
class AudioEvent:
132+
"""Contains information of an audio event.
133+
134+
Attributes:
135+
wall_time: Timestamp of the event in seconds.
136+
step: Global step of the event.
137+
encoded_audio_string: Audio content encoded in bytes.
138+
content_type: A string describes the type of the audio content.
139+
sample_rate: Sample rate of the audio in Hz. Must be positive.
140+
length_frames: Length of the audio in frames (samples per channel).
141+
"""
142+
143+
wall_time: float
144+
step: int
145+
encoded_audio_string: bytes
146+
content_type: str
147+
sample_rate: float
148+
length_frames: int
149+
150+
151+
@dataclasses.dataclass(frozen=True)
152+
class TensorEvent:
153+
"""A tensor event.
154+
155+
Attributes:
156+
wall_time: Timestamp of the event in seconds.
157+
step: Global step of the event.
158+
tensor_proto: A `TensorProto`.
159+
"""
160+
161+
wall_time: float
162+
step: int
163+
tensor_proto: tensor_pb2.TensorProto
164+
74165

75166
## Different types of summary events handled by the event_accumulator
76167
SUMMARY_TYPES = {
@@ -697,7 +788,8 @@ def _CheckForOutOfOrderStepAndMaybePurge(self, event):
697788
self.most_recent_step = event.step
698789
self.most_recent_wall_time = event.wall_time
699790

700-
def _ConvertHistogramProtoToTuple(self, histo):
791+
def _ConvertHistogramProtoToPopo(self, histo):
792+
"""Converts histogram proto to Python object."""
701793
return HistogramValue(
702794
min=histo.min,
703795
max=histo.max,
@@ -710,7 +802,7 @@ def _ConvertHistogramProtoToTuple(self, histo):
710802

711803
def _ProcessHistogram(self, tag, wall_time, step, histo):
712804
"""Processes a proto histogram by adding it to accumulated state."""
713-
histo = self._ConvertHistogramProtoToTuple(histo)
805+
histo = self._ConvertHistogramProtoToPopo(histo)
714806
histo_ev = HistogramEvent(wall_time, step, histo)
715807
self.histograms.AddItem(tag, histo_ev)
716808
self.compressed_histograms.AddItem(

0 commit comments

Comments
 (0)