Skip to content

Commit 706ed63

Browse files
IanHoanggkamat
authored andcommitted
Add Pydantic to SDG for stronger validation, error handling, and extensibility (#931)
Signed-off-by: Ian Hoang <[email protected]>
1 parent 763427a commit 706ed63

File tree

12 files changed

+226
-127
lines changed

12 files changed

+226
-127
lines changed

.pylintrc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,21 @@ disable=print-statement,
166166
too-many-instance-attributes,
167167
too-many-statements,
168168
inconsistent-return-statements,
169-
C0302,
169+
too-many-lines,
170170
C4001,
171-
R0916,
172-
W0201,
173-
W0613,
174-
W0621,
171+
too-many-boolean-expressions,
172+
attribute-defined-outside-init,
173+
unused-argument,
174+
redefined-outer-name,
175175
invalid-docstring-quote,
176-
raise-missing-from
176+
raise-missing-from,
177+
consider-using-with,
178+
duplicate-code,
179+
consider-using-from-import,
180+
bad-option-value,
181+
consider-using-dict-items,
182+
unused-private-member,
183+
use-a-generator
177184

178185

179186
# Enable the message, report, category or checker with the given id(s). You can

osbenchmark/synthetic_data_generator/helpers.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
import yaml
1818

1919
from osbenchmark.utils import console
20-
from osbenchmark.exceptions import SystemSetupError, ExecutorError
20+
from osbenchmark import exceptions
2121
from osbenchmark.synthetic_data_generator.strategies.strategy import DataGenerationStrategy
22-
from osbenchmark.synthetic_data_generator.types import DEFAULT_GENERATION_SETTINGS, SyntheticDataGeneratorMetadata, GB_TO_BYTES
22+
from osbenchmark.synthetic_data_generator.models import SyntheticDataGeneratorMetadata, SDGConfig, GB_TO_BYTES
2323

2424
def load_user_module(file_path):
2525
allowed_extensions = ['.py']
2626
extension = os.path.splitext(file_path)[1]
2727
if extension not in allowed_extensions:
28-
raise SystemSetupError(f"User provided module with file extension [{extension}]. Python modules must have {allowed_extensions} extension.")
28+
raise exceptions.SystemSetupError(f"User provided module with file extension [{extension}]. Python modules must have {allowed_extensions} extension.")
2929

3030
spec = importlib.util.spec_from_file_location("user_module", file_path)
3131
user_module = importlib.util.module_from_spec(spec)
@@ -82,7 +82,7 @@ def remove_existing_files(existing_files_found: list):
8282
for file in existing_files_found:
8383
os.remove(file)
8484
except OSError as e:
85-
raise ExecutorError("OSB could not remove existing files for SDG: ", e)
85+
raise exceptions.ExecutorError("OSB could not remove existing files for SDG: ", e)
8686

8787
def host_has_available_disk_storage(sdg_metadata: SyntheticDataGeneratorMetadata) -> bool:
8888
logger = logging.getLogger(__name__)
@@ -98,18 +98,23 @@ def host_has_available_disk_storage(sdg_metadata: SyntheticDataGeneratorMetadata
9898
logger.error("Error checking disk space.")
9999
return False
100100

101-
def load_config(config_path):
101+
def load_config(config_path: str) -> SDGConfig:
102102
try:
103103
allowed_extensions = ['.yml', '.yaml']
104104

105105
extension = os.path.splitext(config_path)[1]
106106
if extension not in allowed_extensions:
107-
raise SystemSetupError(f"User provided config with extension [{extension}]. Config must have a {allowed_extensions} extension.")
107+
raise exceptions.ConfigError(f"User provided config with extension [{extension}]. Config must have a {allowed_extensions} extension.")
108108
else:
109109
with open(config_path, 'r') as file:
110-
return yaml.safe_load(file)
110+
config_details = yaml.safe_load(file)
111+
112+
return SDGConfig(**config_details) if config_details else SDGConfig()
113+
114+
except yaml.YAMLError as e:
115+
raise exceptions.ConfigError(f"Error when loading config due to YAML error: {e}")
111116
except TypeError:
112-
raise SystemSetupError("Error when loading config. Please ensure that the proper config was provided")
117+
raise exceptions.SystemSetupError("Error when loading config. Please ensure that the proper config was provided")
113118

114119
def write_chunk(data, file_path):
115120
written_bytes = 0
@@ -130,29 +135,6 @@ def calculate_avg_doc_size(strategy: DataGenerationStrategy):
130135

131136
return size
132137

133-
def get_generation_settings(input_config: dict) -> dict:
134-
'''
135-
Grabs the user's config's generation settings and compares it with the default generation settings.
136-
If there are missing fields in the user's config, it populates it with the default values
137-
'''
138-
generation_settings = DEFAULT_GENERATION_SETTINGS
139-
if input_config is None: # if user did not provide a custom config
140-
return generation_settings
141-
142-
user_generation_settings = input_config.get('settings', {})
143-
144-
if not user_generation_settings: # If user provided custom config but did not include settings
145-
return generation_settings
146-
else:
147-
# Traverse and update valid settings that user specified.
148-
for k in generation_settings:
149-
if k in user_generation_settings and user_generation_settings[k] is not None:
150-
generation_settings[k] = user_generation_settings[k]
151-
else:
152-
continue
153-
154-
return generation_settings
155-
156138
def format_size(bytes):
157139
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
158140
if bytes < 1024:

osbenchmark/synthetic_data_generator/input_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010

11-
from osbenchmark.synthetic_data_generator.types import SyntheticDataGeneratorMetadata
11+
from osbenchmark.synthetic_data_generator.models import SyntheticDataGeneratorMetadata
1212
from osbenchmark.exceptions import ConfigError
1313

1414
logger = logging.getLogger(__name__)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
#
3+
# The OpenSearch Contributors require contributions made to
4+
# this file be licensed under the Apache-2.0 license or a
5+
# compatible open source license.
6+
# Modifications Copyright OpenSearch Contributors. See
7+
# GitHub history for details.
8+
9+
import os
10+
from typing import Optional, Dict, List, Any, Union
11+
import re
12+
13+
from pydantic import BaseModel, Field, field_validator
14+
15+
GB_TO_BYTES = 1024 ** 3
16+
17+
class SettingsConfig(BaseModel):
18+
workers: Optional[int] = Field(default_factory=os.cpu_count) # Number of workers recommended to not exceed CPU count
19+
max_file_size_gb: Optional[int] = 40 # Default because some CloudProviders limit the size of files stored
20+
docs_per_chunk: Optional[int] = 10000 # Default based on testing
21+
22+
# pylint: disable = no-self-argument
23+
@field_validator('workers', 'max_file_size_gb', 'docs_per_chunk')
24+
def validate_values_are_positive_integers(cls, v):
25+
if v is not None and v <= 0:
26+
raise ValueError(f"Value '{v}' in Settings portion must be a positive integer.")
27+
28+
return v
29+
30+
class CustomGenerationValuesConfig(BaseModel):
31+
custom_lists: Optional[Dict[str, List[Any]]] = None
32+
custom_providers: Optional[List[Any]] = None
33+
34+
# pylint: disable = no-self-argument
35+
@field_validator('custom_lists')
36+
def validate_custom_lists(cls, v):
37+
if v is not None:
38+
for key, value in v.items():
39+
if not isinstance(key, str):
40+
raise ValueError(f"All keys within custom_lists of CustomGenerationValues section must be strings. '{key}' is not a string")
41+
if not isinstance(value, list):
42+
raise ValueError(f"Value for key '{key}' must be a list.")
43+
return v
44+
45+
class GeneratorParams(BaseModel):
46+
# Integer / Long Params
47+
min: Optional[Union[int, float]] = None
48+
max: Optional[Union[int, float]] = None
49+
50+
# Date Params
51+
start_date: Optional[str] = None
52+
end_date: Optional[str] = None
53+
format: Optional[str] = None
54+
55+
# Text / Keywords Params
56+
must_include: Optional[List[str]] = None
57+
choices: Optional[List[str]] = None
58+
59+
class Config:
60+
extra = 'forbid'
61+
62+
class FieldOverride(BaseModel):
63+
generator: str
64+
params: GeneratorParams
65+
66+
# pylint: disable = no-self-argument
67+
@field_validator('generator')
68+
def validate_generator_name(cls, v):
69+
valid_generators = [
70+
'generate_text',
71+
'generate_keyword',
72+
'generate_integer',
73+
'generate_long',
74+
'generate_short',
75+
'generate_byte',
76+
'generate_float',
77+
'generate_double',
78+
'generate_boolean',
79+
'generate_date',
80+
'generate_ip',
81+
'generate_geopoint',
82+
'generate_object',
83+
'generate_nested'
84+
]
85+
86+
if v not in valid_generators:
87+
raise ValueError(f"Generator '{v}' mentioned in FieldOverrides not among valid generators: {valid_generators}")
88+
return v
89+
90+
class MappingGenerationValuesConfig(BaseModel):
91+
generator_overrides: Optional[Dict[str, GeneratorParams]] = None
92+
field_overrides: Optional[Dict[str, FieldOverride]] = None
93+
94+
# pylint: disable = no-self-argument
95+
@field_validator('generator_overrides')
96+
def validate_generator_types(cls, v):
97+
if v is not None:
98+
valid_generator_types = ['integer', 'long', 'float', 'double', 'date', 'text', 'keyword', 'short', 'byte', 'ip', 'geopoint', 'nested', 'boolean']
99+
100+
for generator_type in v.keys():
101+
if generator_type not in valid_generator_types:
102+
raise ValueError(f"Invalid Generator Type '{generator_type}. Must be one of: {valid_generator_types}'")
103+
104+
return v
105+
106+
# pylint: disable = no-self-argument
107+
@field_validator('field_overrides')
108+
def validate_field_names(cls, v):
109+
if v is not None:
110+
for field_name in v.keys():
111+
if not re.match(r'^[a-zA-Z][a-zA-Z0-9_.]*$', field_name):
112+
raise ValueError(f"Invalid Field Name '{field_name}' in FieldOverrides. Only alphanumeric characters, underscores and periods are allowed.")
113+
114+
return v
115+
116+
class SyntheticDataGeneratorMetadata(BaseModel):
117+
index_name: Optional[str] = None
118+
index_mappings_path: Optional[str] = None
119+
custom_module_path: Optional[str] = None
120+
custom_config_path: Optional[str] = None
121+
output_path: Optional[str] = None
122+
total_size_gb: Optional[int] = None
123+
124+
class Config:
125+
extra = 'forbid'
126+
127+
class SDGConfig(BaseModel):
128+
# If user does not provide YAML fil or provides YAML without all settings fields, it will use default generation settings.
129+
settings: Optional[SettingsConfig] = Field(default_factory=SettingsConfig)
130+
CustomGenerationValues: Optional[CustomGenerationValuesConfig] = None
131+
MappingGenerationValues: Optional[MappingGenerationValuesConfig] = None
132+
133+
class Config:
134+
extra = 'forbid'

osbenchmark/synthetic_data_generator/strategies/custom_module_strategy.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,38 @@
1616
from mimesis.random import Random
1717
from mimesis.providers.base import BaseProvider
1818

19-
from osbenchmark.exceptions import ConfigError
19+
from osbenchmark import exceptions
2020
from osbenchmark.synthetic_data_generator.strategies import DataGenerationStrategy
21-
from osbenchmark.synthetic_data_generator.types import SyntheticDataGeneratorMetadata
21+
from osbenchmark.synthetic_data_generator.models import SyntheticDataGeneratorMetadata, SDGConfig
2222

2323
class CustomModuleStrategy(DataGenerationStrategy):
24-
def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: dict, custom_module: ModuleType) -> None:
24+
def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: SDGConfig, custom_module: ModuleType) -> None:
2525
self.sdg_metadata = sdg_metadata
2626
self.sdg_config = sdg_config
2727
self.custom_module = custom_module
2828
self.logger = logging.getLogger(__name__)
2929

3030
if not hasattr(self.custom_module, 'generate_synthetic_document'):
3131
msg = f"Custom module at [{self.sdg_metadata.custom_module_path}] does not define a function called generate_synthetic_document(). Ensure that this method is defined."
32-
raise ConfigError(msg)
32+
raise exceptions.ConfigError(msg)
3333

3434
# Fetch settings and custom module components from sdg-config.yml
35-
custom_module_values = self.sdg_config.get('CustomGenerationValues', {})
36-
try:
37-
self.custom_lists = custom_module_values.get('custom_lists', {})
38-
self.custom_providers = {name: getattr(self.custom_module, name) for name in custom_module_values.get('custom_providers', [])}
39-
except TypeError:
40-
msg = "Synthetic Data Generator Config has custom_lists and custom_providers pointing to null values. Either populate or remove."
41-
raise ConfigError(msg)
35+
if self.sdg_config.CustomGenerationValues is None:
36+
self.custom_lists = {}
37+
self.custom_providers = {}
38+
else:
39+
try:
40+
self.custom_lists = self.sdg_config.CustomGenerationValues.custom_lists or {}
41+
provider_names = self.sdg_config.CustomGenerationValues.custom_providers or []
42+
self.custom_providers = {
43+
name: getattr(self.custom_module, name) for name in provider_names
44+
}
45+
except AttributeError as e:
46+
msg = f"Error when setting up custom lists and custom providers: {e}"
47+
raise exceptions.ConfigError(msg)
48+
except TypeError:
49+
msg = "Synthetic Data Generator Config has custom_lists and custom_providers pointing to null values. Either populate or remove."
50+
4251

4352
# pylint: disable=arguments-differ
4453
def generate_data_chunks_across_workers(self, dask_client: Client, docs_per_chunk: int, seeds: list ) -> list:
@@ -51,7 +60,7 @@ def generate_data_chunks_across_workers(self, dask_client: Client, docs_per_chun
5160
self.generate_data_chunk_from_worker, self.custom_module.generate_synthetic_document,
5261
docs_per_chunk, seed) for seed in seeds]
5362

54-
63+
# pylint: disable=arguments-renamed
5564
def generate_data_chunk_from_worker(self, generate_synthetic_document: Callable, docs_per_chunk: int, seed: Optional[int]) -> list:
5665
"""
5766
This method is submitted to Dask worker and can be thought of as the worker performing a job, which is calling the
@@ -83,7 +92,7 @@ def generate_test_document(self):
8392
msg = "Encountered AttributeError when setting up custom_providers and custom_lists. " + \
8493
"It seems that your module might be using custom_lists and custom_providers." + \
8594
f"Please ensure you have provided a custom config with custom_providers and custom_lists: {e}"
86-
raise ConfigError(msg)
95+
raise exceptions.ConfigError(msg)
8796
return document
8897

8998
def _instantiate_all_providers(self, custom_providers):

osbenchmark/synthetic_data_generator/strategies/mapping_strategy.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
from osbenchmark.exceptions import ConfigError, MappingsError
2121
from osbenchmark.synthetic_data_generator.strategies import DataGenerationStrategy
22-
from osbenchmark.synthetic_data_generator.types import SyntheticDataGeneratorMetadata
22+
from osbenchmark.synthetic_data_generator.models import SyntheticDataGeneratorMetadata, SDGConfig, MappingGenerationValuesConfig
2323

2424
class MappingStrategy(DataGenerationStrategy):
25-
def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: dict, index_mapping: dict) -> None:
25+
def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: SDGConfig, index_mapping: dict) -> None:
2626
self.sdg_metadata = sdg_metadata
2727
self.sdg_config = sdg_config # Optional YAML-based config for value constraints
2828
self.index_mapping = index_mapping # OpenSearch Mapping
29-
self.mapping_generation_values = self.sdg_config.get("MappingGenerationValues", {}) if self.sdg_config else {}
29+
self.mapping_generation_values = (self.sdg_config.MappingGenerationValues or {}) if self.sdg_config else {}
3030

3131
self.logger = logging.getLogger(__name__)
3232

@@ -224,8 +224,12 @@ def transform_mapping_to_generators(self, mapping_dict: Dict[str, Any], field_pa
224224
transformed_mapping = {}
225225

226226
# Extract configuration settings (both default generators and field overrides) from config user provided
227-
# TODO: Set self.mapping_config to automatically point to MappingConverter
228-
config = self.mapping_config
227+
# Convert the sdg_config's MappingGenerationValues section into a dictionary and access the overrides
228+
if isinstance(self.mapping_config, MappingGenerationValuesConfig):
229+
config = self.mapping_config.model_dump()
230+
else:
231+
config = self.mapping_config
232+
229233
generator_overrides = config.get("generator_overrides", {})
230234
field_overrides = config.get("field_overrides", {})
231235

0 commit comments

Comments
 (0)