11import inspect
22from dataclasses import dataclass
3- from typing import Dict , List , Optional
3+ from typing import Dict , List , Optional , Union
44
55import gen_thrift .api .ttypes as ttypes
6+ import gen_thrift .common .ttypes as common
67
78from ai .chronon import utils
9+ from ai .chronon import windows as window_utils
810from ai .chronon .data_types import DataType , FieldsType
911from ai .chronon .utils import ANY_SOURCE_TYPE , normalize_source
1012
@@ -14,6 +16,17 @@ class ModelBackend:
1416 SAGEMAKER = ttypes .ModelBackend .SageMaker
1517
1618
19+ class DeploymentStrategyType :
20+ # deploys the model in a blue-green fashion (~2x capacity) to another endpoint and gradually ramps traffic
21+ BLUE_GREEN = ttypes .DeploymentStrategyType .BLUE_GREEN
22+
23+ # deploys the model in a rolling manner by gradually scaling down existing instances and scaling up new instances
24+ ROLLING = ttypes .DeploymentStrategyType .ROLLING
25+
26+ # deploys the model immediately to the endpoint without any traffic ramping
27+ IMMEDIATE = ttypes .DeploymentStrategyType .IMMEDIATE
28+
29+
1730@dataclass
1831class ResourceConfig :
1932 min_replica_count : Optional [int ] = None
@@ -38,20 +51,150 @@ def to_thrift(self):
3851 resource_config_thrift = None
3952 if self .resource_config :
4053 resource_config_thrift = self .resource_config .to_thrift ()
41-
54+
4255 return ttypes .InferenceSpec (
4356 modelBackend = self .model_backend ,
4457 modelBackendParams = self .model_backend_params ,
4558 resourceConfig = resource_config_thrift ,
4659 )
4760
4861
62+ @dataclass
63+ class TrainingSpec :
64+ training_data_source : Optional [ANY_SOURCE_TYPE ] = None
65+ training_data_window : Optional [Union [common .Window , str ]] = None
66+ schedule : Optional [str ] = None
67+ image : Optional [str ] = None
68+ python_module : Optional [str ] = None
69+ resource_config : Optional [ResourceConfig ] = None
70+ job_configs : Optional [Dict [str , str ]] = None
71+
72+ def to_thrift (self ):
73+ resource_config_thrift = None
74+ if self .resource_config :
75+ resource_config_thrift = self .resource_config .to_thrift ()
76+
77+ training_data_source_thrift = None
78+ if self .training_data_source :
79+ training_data_source_thrift = normalize_source (self .training_data_source )
80+
81+ # Normalize window - convert string like "30d" or "24h" to common.Window
82+ training_data_window_thrift = None
83+ if self .training_data_window :
84+ training_data_window_thrift = window_utils .normalize_window (self .training_data_window )
85+
86+ return ttypes .TrainingSpec (
87+ trainingDataSource = training_data_source_thrift ,
88+ trainingDataWindow = training_data_window_thrift ,
89+ schedule = self .schedule ,
90+ image = self .image ,
91+ pythonModule = self .python_module ,
92+ resourceConfig = resource_config_thrift ,
93+ jobConfigs = self .job_configs ,
94+ )
95+
96+
97+ @dataclass
98+ class ServingContainerConfig :
99+ image : Optional [str ] = None
100+ serving_health_route : Optional [str ] = None
101+ serving_predict_route : Optional [str ] = None
102+ serving_container_env_vars : Optional [Dict [str , str ]] = None
103+
104+ def to_thrift (self ):
105+ return ttypes .ServingContainerConfig (
106+ image = self .image ,
107+ servingHealthRoute = self .serving_health_route ,
108+ servingPredictRoute = self .serving_predict_route ,
109+ servingContainerEnvVars = self .serving_container_env_vars ,
110+ )
111+
112+
113+ @dataclass
114+ class EndpointConfig :
115+ endpoint_name : Optional [str ] = None
116+ additional_configs : Optional [Dict [str , str ]] = None
117+
118+ def to_thrift (self ):
119+ return ttypes .EndpointConfig (
120+ endpointName = self .endpoint_name ,
121+ additionalConfigs = self .additional_configs ,
122+ )
123+
124+
125+ @dataclass
126+ class Metric :
127+ name : Optional [str ] = None
128+ threshold : Optional [float ] = None
129+
130+ def to_thrift (self ):
131+ return ttypes .Metric (
132+ name = self .name ,
133+ threshold = self .threshold ,
134+ )
135+
136+
137+ @dataclass
138+ class RolloutStrategy :
139+ rollout_type : Optional [DeploymentStrategyType ] = None
140+ validation_traffic_percent_ramps : Optional [List [int ]] = None
141+ validation_traffic_duration_mins : Optional [List [int ]] = None
142+ rollout_metric_thresholds : Optional [List [Metric ]] = None
143+
144+ def to_thrift (self ):
145+ rollout_metric_thresholds_thrift = None
146+ if self .rollout_metric_thresholds :
147+ rollout_metric_thresholds_thrift = [metric .to_thrift () for metric in self .rollout_metric_thresholds ]
148+
149+ return ttypes .RolloutStrategy (
150+ rolloutType = self .rollout_type ,
151+ validationTrafficPercentRamps = self .validation_traffic_percent_ramps ,
152+ validationTrafficDurationMins = self .validation_traffic_duration_mins ,
153+ rolloutMetricThresholds = rollout_metric_thresholds_thrift ,
154+ )
155+
156+
157+ @dataclass
158+ class DeploymentSpec :
159+ container_config : Optional [ServingContainerConfig ] = None
160+ endpoint_config : Optional [EndpointConfig ] = None
161+ resource_config : Optional [ResourceConfig ] = None
162+ rollout_strategy : Optional [RolloutStrategy ] = None
163+
164+ def to_thrift (self ):
165+ container_config_thrift = None
166+ if self .container_config :
167+ container_config_thrift = self .container_config .to_thrift ()
168+
169+ endpoint_config_thrift = None
170+ if self .endpoint_config :
171+ endpoint_config_thrift = self .endpoint_config .to_thrift ()
172+
173+ resource_config_thrift = None
174+ if self .resource_config :
175+ resource_config_thrift = self .resource_config .to_thrift ()
176+
177+ rollout_strategy_thrift = None
178+ if self .rollout_strategy :
179+ rollout_strategy_thrift = self .rollout_strategy .to_thrift ()
180+
181+ return ttypes .DeploymentSpec (
182+ containerConfig = container_config_thrift ,
183+ endpointConfig = endpoint_config_thrift ,
184+ resourceConfig = resource_config_thrift ,
185+ rolloutStrategy = rollout_strategy_thrift ,
186+ )
187+
188+
49189def Model (
50190 version : str ,
51191 inference_spec : Optional [InferenceSpec ] = None ,
52192 input_mapping : Optional [Dict [str , str ]] = None ,
53193 output_mapping : Optional [Dict [str , str ]] = None ,
54194 value_fields : Optional [FieldsType ] = None ,
195+ model_artifact_base_uri : Optional [str ] = None ,
196+ training_conf : Optional [TrainingSpec ] = None ,
197+ deployment_conf : Optional [DeploymentSpec ] = None ,
55198 output_namespace : Optional [str ] = None ,
56199 table_properties : Optional [Dict [str , str ]] = None ,
57200 tags : Optional [Dict [str , str ]] = None ,
@@ -76,6 +219,15 @@ def Model(
76219 If provided, creates a STRUCT schema that will be set as the model's valueSchema.
77220 Example: [('score', DataType.DOUBLE), ('category', DataType.STRING)]
78221 :type value_fields: FieldsType
222+ :param model_artifact_base_uri:
223+ Base URI where trained model artifacts are stored
224+ :type model_artifact_base_uri: str
225+ :param training_conf:
226+ Configs related to orchestrating model training jobs
227+ :type training_conf: TrainingSpec
228+ :param deployment_conf:
229+ Configs related to orchestrating model deployment
230+ :type deployment_conf: DeploymentSpec
79231 :param output_namespace:
80232 Namespace for the model output
81233 :type output_namespace: str
@@ -114,14 +266,27 @@ def Model(
114266 if value_fields :
115267 schema_name = "model_value_schema"
116268 value_schema = DataType .STRUCT (schema_name , * value_fields )
117-
269+
270+ # Convert training_conf to thrift if provided
271+ training_conf_thrift = None
272+ if training_conf :
273+ training_conf_thrift = training_conf .to_thrift ()
274+
275+ # Convert deployment_conf to thrift if provided
276+ deployment_conf_thrift = None
277+ if deployment_conf :
278+ deployment_conf_thrift = deployment_conf .to_thrift ()
279+
118280 # Create and return the Model object
119281 model = ttypes .Model (
120282 metaData = meta_data ,
121283 inferenceSpec = inference_spec_thrift ,
122284 inputMapping = input_mapping ,
123285 outputMapping = output_mapping ,
124286 valueSchema = value_schema ,
287+ modelArtifactBaseUri = model_artifact_base_uri ,
288+ trainingConf = training_conf_thrift ,
289+ deploymentConf = deployment_conf_thrift ,
125290 )
126291
127292 return model
0 commit comments