Skip to content

Commit ef36590

Browse files
authored
[Data] Add category field to embedding to track the type of the embedding. (#111)
1 parent 178b47f commit ef36590

File tree

9 files changed

+201
-9
lines changed

9 files changed

+201
-9
lines changed

eyepop/data/arrow/eyepop/annotations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def table_from_eyepop_annotations(annotations: list[AssetAnnotationResponse], sc
8080
embeddings.append([])
8181
else:
8282
embeddings.append(table_from_eyepop_predicted_embeddings(
83-
e.annotation.embeddings
83+
e.annotation.embeddings,
84+
schema=pa.schema(schema.field(8).type.value_type), # schema for "embeddings" field
8485
).to_struct_array())
8586
if timestamps is not None:
8687
timestamps.append(e.annotation.timestamp)
@@ -125,7 +126,6 @@ def table_from_eyepop_annotations(annotations: list[AssetAnnotationResponse], sc
125126
columns.append(offsets)
126127
if offset_durations is not None:
127128
columns.append(offset_durations)
128-
129129
return pa.Table.from_arrays(columns, schema=schema)
130130

131131

@@ -187,7 +187,7 @@ def eyepop_annotations_from_table(table: pa.Table) -> list[AssetAnnotationRespon
187187
elif len(embeddings[j]) == 0:
188188
child_embeddings = []
189189
else:
190-
child_embeddings = eyepop_predicted_embeddings_from_pylist(texts[j])
190+
child_embeddings = eyepop_predicted_embeddings_from_pylist(embeddings[j])
191191

192192
annotations.append(AssetAnnotationResponse(
193193
type=types[j],

eyepop/data/arrow/eyepop/predictions.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,23 +313,32 @@ def eyepop_predicted_key_points_from_pylist(py_list: list[dict[str, any]],
313313
def table_from_eyepop_predicted_embeddings(predicted_embeddings: list[PredictedEmbedding],
314314
schema: Schema = EMBEDDING_SCHEMA) -> pa.Table:
315315
embeddings = []
316+
categories = []
316317
x_coordinates = []
317318
y_coordinates = []
318319
for predicted_embedding in predicted_embeddings:
319320
embeddings.append(predicted_embedding.embedding)
320321
x_coordinates.append(predicted_embedding.x)
321322
y_coordinates.append(predicted_embedding.y)
322-
return pa.Table.from_arrays([
323+
categories.append(predicted_embedding.category)
324+
325+
columns = [
323326
pa.array(embeddings),
324327
pa.array(x_coordinates),
325328
pa.array(y_coordinates),
326-
], schema=schema)
329+
]
330+
331+
if "category" in schema.names:
332+
columns.append(pa.array(categories).dictionary_encode())
333+
334+
return pa.Table.from_arrays(columns, schema=schema)
327335

328336
def eyepop_predicted_embeddings_from_pylist(py_list: list[dict[str, any]]) -> list[PredictedEmbedding]:
329337
predicted_embeddings: list[PredictedEmbedding | None] = [None] * len(py_list)
330338
for i, predicted_embedding in enumerate(py_list):
331339
predicted_embeddings[i] = PredictedEmbedding(
332340
embedding=predicted_embedding["embedding"],
341+
category=predicted_embedding.get("category", None),
333342
x=_round_float_like(predicted_embedding.get("x", None), COORDINATE_N_DIGITS),
334343
y=_round_float_like(predicted_embedding.get("y", None), COORDINATE_N_DIGITS),
335344
)

eyepop/data/arrow/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from pyarrow._compute import CastOptions
22
import pyarrow as pa
33

4-
from . import schema_1_4 as schema_latest
4+
from . import schema_1_5 as schema_latest
55

66
""" The latest official Arrow schema for the EyePop Dataset API.
77
8-
The latest officially supported schema is: 1.4
8+
The latest officially supported schema is: 1.5
99
1010
These are references to the types and schemas that are currently
1111
supported. For backward compatibility, we keep schemas versioned

eyepop/data/arrow/schema_1_5.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import pyarrow as pa
2+
3+
from eyepop.data.data_types import MIME_TYPE_APACHE_ARROW_FILE
4+
5+
""" Arrow schema for Asset export/import form Data API. """
6+
7+
MIME_TYPE_APACHE_ARROW_FILE_VERSIONED = f"{MIME_TYPE_APACHE_ARROW_FILE};version=1.4"
8+
9+
10+
# BEGIN: Extension since v1.3
11+
_embedding_fields = [
12+
pa.field(name="embedding", type=pa.list_(pa.float16())),
13+
pa.field(name="x", type=pa.float16()),
14+
pa.field(name="y", type=pa.float16()),
15+
# BEGIN: Extension since v1.5
16+
pa.field(name="category", type=pa.dictionary(pa.int32(), pa.string())),
17+
# END: Extension since v1.5
18+
]
19+
EMBEDDING_STRUCT = pa.struct(_embedding_fields)
20+
EMBEDDING_SCHEMA = pa.schema(_embedding_fields)
21+
# END: Extension since v1.3
22+
23+
# BEGIN: Extension since v1.2
24+
_text_fields = [
25+
pa.field(name="confidence", type=pa.float16()),
26+
pa.field(name="text", type=pa.string()),
27+
pa.field(name="category", type=pa.dictionary(pa.int32(), pa.string())),
28+
]
29+
TEXT_STRUCT = pa.struct(_text_fields)
30+
TEXT_SCHEMA = pa.schema(_text_fields)
31+
# END: Extension since v1.2
32+
33+
# BEGIN: Extension since v1.1
34+
_key_point_fields = [
35+
pa.field(name="classLabel", type=pa.dictionary(pa.int32(), pa.string())),
36+
pa.field(name="confidence", type=pa.float16()),
37+
pa.field(name="x", type=pa.float16()),
38+
pa.field(name="y", type=pa.float16()),
39+
# optional z coordinate in pixel coordinate system, null="unknown"
40+
pa.field(name="z", type=pa.float16()),
41+
#optional flag true=visible, false=invisible, null="unknown"
42+
pa.field(name="visible", type=pa.bool_()),
43+
# BEGIN: Extension since v1.2
44+
pa.field(name="category", type=pa.dictionary(pa.int32(), pa.string())),
45+
# END: Extension since v1.2
46+
]
47+
KEY_POINT_STRUCT = pa.struct(_key_point_fields)
48+
KEY_POINT_SCHEMA = pa.schema(_key_point_fields)
49+
_key_points_fields = [
50+
# optional
51+
pa.field(name="type", type=pa.dictionary(pa.int32(), pa.string())),
52+
pa.field(name="points", type=pa.list_(KEY_POINT_STRUCT)),
53+
# BEGIN: Extension since v1.2
54+
pa.field(name="category", type=pa.dictionary(pa.int32(), pa.string())),
55+
# END: Extension since v1.2
56+
]
57+
KEY_POINTS_STRUCT = pa.struct(_key_points_fields)
58+
KEY_POINTS_SCHEMA = pa.schema(_key_points_fields)
59+
# END: Extension since v1.1
60+
61+
_object_fields = [
62+
pa.field(name="classLabel", type=pa.dictionary(pa.int32(), pa.string())),
63+
pa.field(name="confidence", type=pa.float16()),
64+
pa.field(name="x", type=pa.float16()),
65+
pa.field(name="y", type=pa.float16()),
66+
pa.field(name="width", type=pa.float16()),
67+
pa.field(name="height", type=pa.float16()),
68+
# from eyepop.data.data_types import UserReview
69+
pa.field(name="user_review", type=pa.dictionary(pa.int8(), pa.string())),
70+
# BEGIN: Extension since v1.1
71+
pa.field(name="keyPoints", type=pa.list_(KEY_POINTS_STRUCT)),
72+
# END: Extension since v1.1
73+
# BEGIN: Extension since v1.2
74+
pa.field(name="category", type=pa.dictionary(pa.int32(), pa.string())),
75+
pa.field(name="texts", type=pa.list_(TEXT_STRUCT)),
76+
# END: Extension since v1.2
77+
]
78+
79+
OBJECT_STRUCT = pa.struct(_object_fields)
80+
OBJECT_SCHEMA = pa.schema(_object_fields)
81+
82+
_class_fields = [
83+
pa.field(name="classLabel", type=pa.dictionary(pa.int32(), pa.string())),
84+
pa.field(name="confidence", type=pa.float16()),
85+
# from eyepop.data.data_types import UserReview
86+
pa.field(name="user_review", type=pa.dictionary(pa.int8(), pa.string())),
87+
# BEGIN: Extension since v1.2
88+
pa.field(name="category", type=pa.dictionary(pa.int32(), pa.string())),
89+
# END: Extension since v1.2
90+
]
91+
92+
CLASS_STRUCT = pa.struct(_class_fields)
93+
CLASS_SCHEMA = pa.schema(_class_fields)
94+
95+
_annotation_fields = [
96+
# from eyepop.data.data_types import AnnotationType
97+
pa.field(name="type", type=pa.dictionary(pa.int8(), pa.string())),
98+
# from eyepop.data.data_types import AutoAnnotate
99+
pa.field(name="source", type=pa.dictionary(pa.int32(), pa.string())),
100+
# from eyepop.data.data_types import UserReview
101+
pa.field(name="user_review", type=pa.dictionary(pa.int8(), pa.string())),
102+
pa.field(name="objects", type=pa.list_(OBJECT_STRUCT)),
103+
pa.field(name="classes", type=pa.list_(CLASS_STRUCT)),
104+
# read/write, optional, the model that produced this annotation
105+
pa.field(name="source_model_uuid", type=pa.dictionary(pa.int8(), pa.string())),
106+
# BEGIN: Extension since v1.1
107+
pa.field(name="keyPoints", type=pa.list_(KEY_POINTS_STRUCT)),
108+
# END: Extension since v1.1
109+
# BEGIN: Extension since v1.2
110+
pa.field(name="texts", type=pa.list_(TEXT_STRUCT)),
111+
# END: Extension since v1.2
112+
# BEGIN: Extension since v1.3
113+
pa.field(name="embeddings", type=pa.list_(EMBEDDING_STRUCT)),
114+
# END: Extension since v1.3
115+
# BEGIN: Extension since v1.4
116+
pa.field(name="timestamp", type=pa.int64()),
117+
pa.field(name="duration", type=pa.int64()),
118+
pa.field(name="offset", type=pa.int64()),
119+
pa.field(name="offset_duration", type=pa.int64()),
120+
# END: Extension since v1.4
121+
]
122+
123+
ANNOTATION_STRUCT = pa.struct(_annotation_fields)
124+
125+
ANNOTATION_SCHEMA = pa.schema(_annotation_fields)
126+
127+
_asset_fields = [
128+
pa.field(name="uuid", type=pa.string()),
129+
pa.field(name="external_id", type=pa.string()),
130+
pa.field(name="created_at", type=pa.timestamp("ms")),
131+
pa.field(name="updated_at", type=pa.timestamp("ms")),
132+
pa.field(name="asset_url", type=pa.string()),
133+
pa.field(name="original_image_width", type=pa.uint16()),
134+
pa.field(name="original_image_height", type=pa.uint16()),
135+
pa.field(name="partition", type=pa.dictionary(pa.int32(), pa.string())),
136+
pa.field(name="review_priority", type=pa.float16()),
137+
pa.field(name="model_relevance", type=pa.float16()),
138+
pa.field(name="annotations", type=pa.list_(ANNOTATION_STRUCT)),
139+
]
140+
141+
ASSET_STRUCT = pa.struct(_asset_fields)
142+
143+
ASSET_SCHEMA = pa.schema(_asset_fields)

eyepop/data/data_normalize.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from eyepop.data.data_types import AssetAnnotationResponse, Prediction, PredictedObject, PredictedClass
1+
from eyepop.data.data_types import AssetAnnotationResponse, Prediction, PredictedObject, PredictedClass, PredictedEmbedding
22

33
# Confidence and coordinates are represented as float16 in the arrow format but Python lacks support for 2-byte floats.
44
# To avoid "changing" 4-bytes floats when converted back and forth, we will always round to these precisions.
55

66
CONFIDENCE_N_DIGITS = 3
77
COORDINATE_N_DIGITS = 3
8+
EMBEDDING_N_DIGITS = 1
89

910
def normalize_eyepop_annotations(
1011
annotations: list[AssetAnnotationResponse]
@@ -22,6 +23,8 @@ def normalize_eyepop_prediction(
2223
prediction.objects, prediction.source_width, prediction.source_height)
2324
if prediction.classes:
2425
normalize_predicted_classes(prediction.classes)
26+
if prediction.embeddings:
27+
normalize_predicted_embeddings(prediction.embeddings)
2528
prediction.source_width = 1.0
2629
prediction.source_height = 1.0
2730

@@ -50,3 +53,12 @@ def normalize_predicted_classes(
5053
for c in predicted_classes:
5154
if c.confidence is not None:
5255
c.confidence = round(c.confidence, CONFIDENCE_N_DIGITS)
56+
57+
58+
def normalize_predicted_embeddings(
59+
predicted_embeddings: list[PredictedEmbedding]
60+
):
61+
for e in predicted_embeddings:
62+
if e.embedding is not None:
63+
for i in range(len(e.embedding)):
64+
e.embedding[i] = round(e.embedding[i], EMBEDDING_N_DIGITS)

eyepop/data/data_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class PredictedEmbedding(BaseModel):
167167
x: float | None = None
168168
y: float | None = None
169169
embedding: List[float]
170+
category: str | None = None
170171

171172
class PredictedText(BaseModel):
172173
id: int | None = None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
99

1010
[project]
1111
name = "eyepop"
12-
version = "1.17.0"
12+
version = "1.18.0"
1313
description="EyePop.ai Python SDK"
1414
readme = "README.md"
1515
license.file = "./LICENSE"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"source_width": 100,
3+
"source_height": 100,
4+
"embeddings": [
5+
{
6+
"category": "image_patch",
7+
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
8+
"x": 10.0,
9+
"y": 10.0
10+
},
11+
{
12+
"category": "text",
13+
"embedding": [0.6, 0.7, 0.8, 0.9, 1.0]
14+
}
15+
]
16+
}

tests/data/test_arrow_to_from_annotation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from eyepop.data.arrow.schema_1_1 import ASSET_SCHEMA as ASSET_SCHEMA_1_1
1111
from eyepop.data.arrow.schema_1_2 import ASSET_SCHEMA as ASSET_SCHEMA_1_2
1212
from eyepop.data.arrow.schema_1_3 import ASSET_SCHEMA as ASSET_SCHEMA_1_3
13+
from eyepop.data.arrow.schema_1_5 import ASSET_SCHEMA as ASSET_SCHEMA_1_5
1314

1415
from eyepop.data.arrow.eyepop.annotations import table_from_eyepop_annotations, eyepop_annotations_from_table
1516
from eyepop.data.data_normalize import normalize_eyepop_annotations, normalize_eyepop_prediction
@@ -27,6 +28,8 @@ class TestArrowToFromAnnotation:
2728
("prediction_2_keypoints_2_objects.json", 6),
2829
("prediction_11_timestamp.json", 7),
2930
("prediction_12_texts.json", 8),
31+
("prediction_1_embeddings.json", 9),
32+
("prediction_2_embeddings.json", 10),
3033
])
3134
def test_prediction_from_file(self, file_name, n):
3235
test_json = resources.files(files) / file_name
@@ -168,6 +171,14 @@ def test_1_3_to_1_2(self):
168171
if column_name != "annotations":
169172
assert target_table.column(column_name) == source_table.column(column_name)
170173

174+
def test_1_5(self):
175+
""" verify that the new field `category` in `embeddings` in 1.5 are converted """
176+
source_table = create_test_table(schema=ASSET_SCHEMA_1_5, test_file_name="prediction_2_embeddings.json")
177+
target_assets = eyepop_assets_from_table(source_table, schema=ASSET_SCHEMA_1_5)
178+
target_table = table_from_eyepop_assets(target_assets, schema=ASSET_SCHEMA_1_5)
179+
assert target_table.schema == source_table.schema
180+
assert target_table == source_table
181+
171182
def test_assets(self):
172183
""" verify that the new denormalized fields account_uuid and dataset_uuids are being filled """
173184
source_table = create_test_table(schema=ASSET_SCHEMA_1_3, test_file_name="prediction_1_embeddings.json")

0 commit comments

Comments
 (0)