Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 19 additions & 33 deletions pytest_gee/dictionary_regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Implementation of the ``dictionary_regression`` fixture."""

import os
from contextlib import suppress
from typing import Optional

import ee
Expand Down Expand Up @@ -39,40 +38,27 @@ def check(
)
serialized_name = data_name.with_stem(f"serialized_{data_name.stem}").with_suffix(".yml")

# check the previously registered serialized call from GEE. If it matches the current call,
# we don't need to check the data
with suppress(BaseException):
is_serialized_equal = check_serialized(
object=data_dict,
path=serialized_name,
datadir=self.datadir,
request=self.request,
)

if is_serialized_equal:
# serialized is equal? -> pass test
# TODO: add proper logging
return
else:
data = round_data(data_dict.getInfo(), prescision)

super().check(data, fullpath=data_name)

# if we are here it means that the query result is equal but the serialized is not -> regenerate serialized
serialized_name.unlink(missing_ok=True)
check_serialized(
object=ee.Dictionary(data_dict),
object=data_dict,
path=serialized_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
)
return

# delete the previously created file if wasn't successful
serialized_name.unlink(missing_ok=True)

# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
data = round_data(data_dict.getInfo(), prescision)
try:
super().check(data, fullpath=data_name)

# IF we are here it means the data has been modified so we edit the API call accordingly
# to make sure next run will not be forced to call the API for a response.
with suppress(BaseException):
check_serialized(
object=data_dict,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
force_regen=True,
)

except BaseException as e:
raise e
64 changes: 25 additions & 39 deletions pytest_gee/feature_collection_regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Implementation of the ``feature_collection_regression`` fixture."""

import os
from contextlib import suppress
from typing import Optional

import ee
Expand Down Expand Up @@ -43,49 +42,36 @@ def check(
fullpath=fullpath,
with_test_class_names=self.with_test_class_names,
)

serialized_name = data_name.with_stem(f"serialized_{data_name.stem}").with_suffix(".yml")

# check the previously registered serialized call from GEE. If it matches the current call,
# we don't need to check the data
with suppress(BaseException):
is_serialized_equal = check_serialized(
object=data_fc,
path=serialized_name,
datadir=self.datadir,
request=self.request,
)

if is_serialized_equal:
# serialized is equal? -> pass test
# TODO: add proper logging
return
else:
# round the geometry using geopandas to make sre with use the specific number of decimal places
gdf = gpd.GeoDataFrame.from_features(data_fc.getInfo())
gdf.geometry = gdf.set_precision(grid_size=10 ** (-prescision)).remove_repeated_points()

# round any float value before serving the data to the check function
data = gdf.to_geo_dict()
data = round_data(data, prescision)

super().check(data, fullpath=data_name)

# if we are here it means that the query result is equal but the serialized is not -> regenerate serialized
serialized_name.unlink(missing_ok=True)
check_serialized(
object=data_fc,
path=serialized_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
)
return

# delete the previously created file if wasn't successful
serialized_name.unlink(missing_ok=True)

# round the geometry using geopandas to make sre with use the specific number of decimal places
gdf = gpd.GeoDataFrame.from_features(data_fc.getInfo())
gdf.geometry = gdf.set_precision(grid_size=10 ** (-prescision)).remove_repeated_points()

# round any float value before serving the data to the check function
data = gdf.to_geo_dict()
data = round_data(data, prescision)

# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
try:
super().check(data, fullpath=data_name)

# IF we are here it means the data has been modified so we edit the API call accordingly
# to make sure next run will not be forced to call the API for a response.
with suppress(BaseException):
check_serialized(
object=data_fc,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
force_regen=True,
)

except BaseException as e:
raise e
79 changes: 33 additions & 46 deletions pytest_gee/image_regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""implementation of the ``image_regression`` fixture."""

import os
from contextlib import suppress
from typing import Optional

import ee
Expand Down Expand Up @@ -56,54 +55,42 @@ def check(
)
serialized_name = data_name.with_stem(f"serialized_{data_name.stem}").with_suffix(".yml")

# check the previously registered serialized call from GEE. If it matches the current call,
# we don't need to check the data
with suppress(BaseException):
is_serialized_equal = check_serialized(
object=data_image,
path=serialized_name,
datadir=self.datadir,
request=self.request,
)

if is_serialized_equal:
# serialized is equal? -> pass test
# TODO: add proper logging
return
else:
# extract min and max for visualization
minMax = data_image.reduceRegion(ee.Reducer.minMax(), geometry, scale)

# create visualization parameters based on the computed minMax values
if viz_params is None:
nbBands = ee.Algorithms.If(data_image.bandNames().size().gte(3), 3, 1)
bands = data_image.bandNames().slice(0, ee.Number(nbBands))
min = bands.map(lambda b: minMax.get(ee.String(b).cat("_min")))
max = bands.map(lambda b: minMax.get(ee.String(b).cat("_max")))
viz_params = ee.Dictionary({"bands": bands, "min": min, "max": max}).getInfo()

# get the thumbnail image
thumb_url = data_image.getThumbURL(params=viz_params)
byte_data = requests.get(thumb_url).content

# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
super().check(byte_data, diff_threshold, expect_equal, fullpath=data_name)

# if we are here it means that the query result is equal but the serialized is not -> regenerate serialized
serialized_name.unlink(missing_ok=True)
check_serialized(
object=data_image,
path=serialized_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
)
return

# delete the previously created file if wasn't successful
serialized_name.unlink(missing_ok=True)

# extract min and max for visualization
minMax = data_image.reduceRegion(ee.Reducer.minMax(), geometry, scale)

# create visualization parameters based on the computed minMax values
if viz_params is None:
nbBands = ee.Algorithms.If(data_image.bandNames().size().gte(3), 3, 1)
bands = data_image.bandNames().slice(0, ee.Number(nbBands))
min = bands.map(lambda b: minMax.get(ee.String(b).cat("_min")))
max = bands.map(lambda b: minMax.get(ee.String(b).cat("_max")))
viz_params = ee.Dictionary({"bands": bands, "min": min, "max": max}).getInfo()

# get the thumbnail image
thumb_url = data_image.getThumbURL(params=viz_params)
byte_data = requests.get(thumb_url).content

# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
try:
super().check(byte_data, diff_threshold, expect_equal, fullpath=data_name)

# IF we are here it means the data has been modified so we edit the API call accordingly
# to make sure next run will not be forced to call the API for a response.
with suppress(BaseException):
check_serialized(
object=data_image,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
force_regen=True,
)

except BaseException as e:
raise e
53 changes: 21 additions & 32 deletions pytest_gee/list_regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Implementation of the ``list_regression`` fixture."""

import os
from contextlib import suppress
from typing import Optional

import ee
Expand Down Expand Up @@ -39,40 +38,30 @@ def check(
)
serialized_name = data_name.with_stem(f"serialized_{data_name.stem}").with_suffix(".yml")

# check the previously registered serialized call from GEE. If it matches the current call,
# we don't need to check the data
with suppress(BaseException):
is_serialized_equal = check_serialized(
object=data_list,
path=serialized_name,
datadir=self.datadir,
request=self.request,
)

if is_serialized_equal:
# serialized is equal? -> pass test
# TODO: add proper logging
return
else:
# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
data = round_data(data_list.getInfo(), prescision)

# check query result
super().check(data, fullpath=data_name)

# if we are here it means that the query result is equal but the serialized is not -> regenerate serialized
serialized_name.unlink(missing_ok=True)
check_serialized(
object=data_list,
path=serialized_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
)
return

# delete the previously created file if wasn't successful
serialized_name.unlink(missing_ok=True)

# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
data = round_data(data_list.getInfo(), prescision)
try:
super().check(data, fullpath=data_name)

# IF we are here it means the data has been modified so we edit the API call accordingly
# to make sure next run will not be forced to call the API for a response.
with suppress(BaseException):
check_serialized(
object=data_list,
path=data_name,
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
with_test_class_names=self.with_test_class_names,
force_regen=True,
)

except BaseException as e:
raise e
Loading
Loading