Skip to content

Commit d1d47bc

Browse files
authored
Merge branch 'main' into tchow/emr-submission
2 parents 5a41c39 + 675f203 commit d1d47bc

File tree

55 files changed

+2835
-281
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2835
-281
lines changed
-467 Bytes
Binary file not shown.

python/src/ai/chronon/group_by.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,7 @@ def Aggregation(
233233
if isinstance(operation, tuple):
234234
operation, arg_map = operation[0], operation[1]
235235

236-
def normalize(w: Union[common.Window, str]) -> common.Window:
237-
if isinstance(w, str):
238-
return window_utils._from_str(w)
239-
elif isinstance(w, common.Window):
240-
return w
241-
else:
242-
raise Exception("window should be either a string like '7d', '24h', or a Window type")
243-
244-
norm_windows = [normalize(w) for w in windows] if windows else None
236+
norm_windows = [window_utils.normalize_window(w) for w in windows] if windows else None
245237

246238
agg = ttypes.Aggregation(input_column, operation, arg_map, norm_windows, buckets)
247239

python/src/ai/chronon/model.py

Lines changed: 168 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import inspect
22
from dataclasses import dataclass
3-
from typing import Dict, List, Optional
3+
from typing import Dict, List, Optional, Union
44

55
import gen_thrift.api.ttypes as ttypes
6+
import gen_thrift.common.ttypes as common
67

78
from ai.chronon import utils
9+
from ai.chronon import windows as window_utils
810
from ai.chronon.data_types import DataType, FieldsType
911
from ai.chronon.utils import ANY_SOURCE_TYPE, normalize_source
1012

@@ -14,6 +16,17 @@ class ModelBackend:
1416
SAGEMAKER = ttypes.ModelBackend.SageMaker
1517

1618

19+
class DeploymentStrategyType:
20+
# deploys the model in a blue-green fashion (~2x capacity) to another endpoint and gradually ramps traffic
21+
BLUE_GREEN = ttypes.DeploymentStrategyType.BLUE_GREEN
22+
23+
# deploys the model in a rolling manner by gradually scaling down existing instances and scaling up new instances
24+
ROLLING = ttypes.DeploymentStrategyType.ROLLING
25+
26+
# deploys the model immediately to the endpoint without any traffic ramping
27+
IMMEDIATE = ttypes.DeploymentStrategyType.IMMEDIATE
28+
29+
1730
@dataclass
1831
class ResourceConfig:
1932
min_replica_count: Optional[int] = None
@@ -38,20 +51,150 @@ def to_thrift(self):
3851
resource_config_thrift = None
3952
if self.resource_config:
4053
resource_config_thrift = self.resource_config.to_thrift()
41-
54+
4255
return ttypes.InferenceSpec(
4356
modelBackend=self.model_backend,
4457
modelBackendParams=self.model_backend_params,
4558
resourceConfig=resource_config_thrift,
4659
)
4760

4861

62+
@dataclass
63+
class TrainingSpec:
64+
training_data_source: Optional[ANY_SOURCE_TYPE] = None
65+
training_data_window: Optional[Union[common.Window, str]] = None
66+
schedule: Optional[str] = None
67+
image: Optional[str] = None
68+
python_module: Optional[str] = None
69+
resource_config: Optional[ResourceConfig] = None
70+
job_configs: Optional[Dict[str, str]] = None
71+
72+
def to_thrift(self):
73+
resource_config_thrift = None
74+
if self.resource_config:
75+
resource_config_thrift = self.resource_config.to_thrift()
76+
77+
training_data_source_thrift = None
78+
if self.training_data_source:
79+
training_data_source_thrift = normalize_source(self.training_data_source)
80+
81+
# Normalize window - convert string like "30d" or "24h" to common.Window
82+
training_data_window_thrift = None
83+
if self.training_data_window:
84+
training_data_window_thrift = window_utils.normalize_window(self.training_data_window)
85+
86+
return ttypes.TrainingSpec(
87+
trainingDataSource=training_data_source_thrift,
88+
trainingDataWindow=training_data_window_thrift,
89+
schedule=self.schedule,
90+
image=self.image,
91+
pythonModule=self.python_module,
92+
resourceConfig=resource_config_thrift,
93+
jobConfigs=self.job_configs,
94+
)
95+
96+
97+
@dataclass
98+
class ServingContainerConfig:
99+
image: Optional[str] = None
100+
serving_health_route: Optional[str] = None
101+
serving_predict_route: Optional[str] = None
102+
serving_container_env_vars: Optional[Dict[str, str]] = None
103+
104+
def to_thrift(self):
105+
return ttypes.ServingContainerConfig(
106+
image=self.image,
107+
servingHealthRoute=self.serving_health_route,
108+
servingPredictRoute=self.serving_predict_route,
109+
servingContainerEnvVars=self.serving_container_env_vars,
110+
)
111+
112+
113+
@dataclass
114+
class EndpointConfig:
115+
endpoint_name: Optional[str] = None
116+
additional_configs: Optional[Dict[str, str]] = None
117+
118+
def to_thrift(self):
119+
return ttypes.EndpointConfig(
120+
endpointName=self.endpoint_name,
121+
additionalConfigs=self.additional_configs,
122+
)
123+
124+
125+
@dataclass
126+
class Metric:
127+
name: Optional[str] = None
128+
threshold: Optional[float] = None
129+
130+
def to_thrift(self):
131+
return ttypes.Metric(
132+
name=self.name,
133+
threshold=self.threshold,
134+
)
135+
136+
137+
@dataclass
138+
class RolloutStrategy:
139+
rollout_type: Optional[DeploymentStrategyType] = None
140+
validation_traffic_percent_ramps: Optional[List[int]] = None
141+
validation_traffic_duration_mins: Optional[List[int]] = None
142+
rollout_metric_thresholds: Optional[List[Metric]] = None
143+
144+
def to_thrift(self):
145+
rollout_metric_thresholds_thrift = None
146+
if self.rollout_metric_thresholds:
147+
rollout_metric_thresholds_thrift = [metric.to_thrift() for metric in self.rollout_metric_thresholds]
148+
149+
return ttypes.RolloutStrategy(
150+
rolloutType=self.rollout_type,
151+
validationTrafficPercentRamps=self.validation_traffic_percent_ramps,
152+
validationTrafficDurationMins=self.validation_traffic_duration_mins,
153+
rolloutMetricThresholds=rollout_metric_thresholds_thrift,
154+
)
155+
156+
157+
@dataclass
158+
class DeploymentSpec:
159+
container_config: Optional[ServingContainerConfig] = None
160+
endpoint_config: Optional[EndpointConfig] = None
161+
resource_config: Optional[ResourceConfig] = None
162+
rollout_strategy: Optional[RolloutStrategy] = None
163+
164+
def to_thrift(self):
165+
container_config_thrift = None
166+
if self.container_config:
167+
container_config_thrift = self.container_config.to_thrift()
168+
169+
endpoint_config_thrift = None
170+
if self.endpoint_config:
171+
endpoint_config_thrift = self.endpoint_config.to_thrift()
172+
173+
resource_config_thrift = None
174+
if self.resource_config:
175+
resource_config_thrift = self.resource_config.to_thrift()
176+
177+
rollout_strategy_thrift = None
178+
if self.rollout_strategy:
179+
rollout_strategy_thrift = self.rollout_strategy.to_thrift()
180+
181+
return ttypes.DeploymentSpec(
182+
containerConfig=container_config_thrift,
183+
endpointConfig=endpoint_config_thrift,
184+
resourceConfig=resource_config_thrift,
185+
rolloutStrategy=rollout_strategy_thrift,
186+
)
187+
188+
49189
def Model(
50190
version: str,
51191
inference_spec: Optional[InferenceSpec] = None,
52192
input_mapping: Optional[Dict[str, str]] = None,
53193
output_mapping: Optional[Dict[str, str]] = None,
54194
value_fields: Optional[FieldsType] = None,
195+
model_artifact_base_uri: Optional[str] = None,
196+
training_conf: Optional[TrainingSpec] = None,
197+
deployment_conf: Optional[DeploymentSpec] = None,
55198
output_namespace: Optional[str] = None,
56199
table_properties: Optional[Dict[str, str]] = None,
57200
tags: Optional[Dict[str, str]] = None,
@@ -76,6 +219,15 @@ def Model(
76219
If provided, creates a STRUCT schema that will be set as the model's valueSchema.
77220
Example: [('score', DataType.DOUBLE), ('category', DataType.STRING)]
78221
:type value_fields: FieldsType
222+
:param model_artifact_base_uri:
223+
Base URI where trained model artifacts are stored
224+
:type model_artifact_base_uri: str
225+
:param training_conf:
226+
Configs related to orchestrating model training jobs
227+
:type training_conf: TrainingSpec
228+
:param deployment_conf:
229+
Configs related to orchestrating model deployment
230+
:type deployment_conf: DeploymentSpec
79231
:param output_namespace:
80232
Namespace for the model output
81233
:type output_namespace: str
@@ -114,14 +266,27 @@ def Model(
114266
if value_fields:
115267
schema_name = "model_value_schema"
116268
value_schema = DataType.STRUCT(schema_name, *value_fields)
117-
269+
270+
# Convert training_conf to thrift if provided
271+
training_conf_thrift = None
272+
if training_conf:
273+
training_conf_thrift = training_conf.to_thrift()
274+
275+
# Convert deployment_conf to thrift if provided
276+
deployment_conf_thrift = None
277+
if deployment_conf:
278+
deployment_conf_thrift = deployment_conf.to_thrift()
279+
118280
# Create and return the Model object
119281
model = ttypes.Model(
120282
metaData=meta_data,
121283
inferenceSpec=inference_spec_thrift,
122284
inputMapping=input_mapping,
123285
outputMapping=output_mapping,
124286
valueSchema=value_schema,
287+
modelArtifactBaseUri=model_artifact_base_uri,
288+
trainingConf=training_conf_thrift,
289+
deploymentConf=deployment_conf_thrift,
125290
)
126291

127292
return model

python/src/ai/chronon/windows.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
import gen_thrift.common.ttypes as common
24

35

@@ -46,3 +48,30 @@ def _from_str(s: str) -> common.Window:
4648
if "invalid literal for int()" in str(e):
4749
raise ValueError(f"Invalid numeric value in duration: {value}") from e
4850
raise e from None
51+
52+
53+
def normalize_window(w: Union[common.Window, str]) -> common.Window:
54+
"""
55+
Normalizes a window specification to a common.Window object.
56+
57+
Accepts either a Window object directly or a string like "30d" or "24h".
58+
This is used across the codebase (e.g., in GroupBy aggregations and TrainingSpec).
59+
60+
Args:
61+
w: Either a common.Window object or a string like "7d", "24h"
62+
63+
Returns:
64+
common.Window: The normalized window object
65+
66+
Raises:
67+
TypeError: If the input is neither a string nor a Window object
68+
"""
69+
if isinstance(w, str):
70+
return _from_str(w)
71+
elif isinstance(w, common.Window):
72+
return w
73+
else:
74+
raise TypeError(
75+
f"Window should be either a string like '7d', '24h', or a Window type, "
76+
f"got {type(w).__name__}"
77+
)

python/test/canary/compiled/group_bys/gcp/dim_listings.v1__0

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/test/canary/compiled/group_bys/gcp/dim_merchants.v1__0

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/test/canary/compiled/group_bys/gcp/item_event_canary.actions_pubsub_v2__0

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/test/canary/compiled/group_bys/gcp/item_event_canary.actions_v1__0

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/test/canary/compiled/group_bys/gcp/logging_schema.v1__2

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/test/canary/compiled/group_bys/gcp/purchases.v1_dev__0

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)