Skip to content

Commit d326844

Browse files
authored
Extend Api + VertexPlatform to support submitting training jobs + deployments (#1331)
## Summary Added a first cut implementation of adding the rails to submit training jobs & deployments with the Vertex platform. The implementation is fairly basic atm (e.g 0 to 1 traffic rollouts, support for additional deploy types like b/g). We can flesh these out in follow ups - the current implementation allows us to start building a basic end to end scaffolding and MVP. Pulled in the Vertex SDK to trigger these. We needed the HTTP client for the Vertex predict calls (to support custom & published models - this isn't as easy to configure in the SDK) but using the SDK for training and deploy is a lot more ergonomic. ## Testing Tested using the local integration test that submits a training run, creates and endpoint, uploads a model and deploys it to an endpoint. Was able to validate that the model is up and can be hit for serving. ``` INFO a.c.i.cloud_gcp.VertexOrchestration - Submitting training job for test_ctr_model-1.0; Python pkg: gs://zipline-warehouse-models/builds/test_ctr_model-1.0.tar.gz; Model output dir: gs://zipline-warehouse-models/training_output/test_ctr_model-1.0/2025-12-08 INFO a.c.i.cloud_gcp.VertexOrchestration - Training job submitted successfully: projects/703996152583/locations/us-central1/customJobs/5518405205760147456 INFO a.c.i.c.VertexOrchestrationTest - Training job name: projects/703996152583/locations/us-central1/customJobs/5518405205760147456 INFO a.c.i.c.VertexOrchestrationTest - Waiting for training job. Current state: UNKNOWN INFO a.c.i.c.VertexOrchestrationTest - Waiting for training job. Current state: PENDING ... INFO a.c.i.c.VertexOrchestrationTest - Waiting for training job. Current state: RUNNING ... INFO a.c.i.c.VertexOrchestrationTest - Step 2: Creating endpoint (if absent)... INFO a.c.i.c.VertexOrchestrationTest - ================================================================================ INFO a.c.i.cloud_gcp.VertexOrchestration - Found existing endpoint: projects/703996152583/locations/us-central1/endpoints/8483716271198699520 INFO a.c.i.c.VertexOrchestrationTest - Endpoint resource name: projects/703996152583/locations/us-central1/endpoints/8483716271198699520 INFO a.c.i.c.VertexOrchestrationTest - ================================================================================ INFO a.c.i.c.VertexOrchestrationTest - Step 2: Deploying model (endpoint creation + model upload + deployment)... INFO a.c.i.c.VertexOrchestrationTest - ================================================================================ INFO a.c.i.cloud_gcp.VertexOrchestration - Found existing endpoint: projects/703996152583/locations/us-central1/endpoints/8483716271198699520 INFO a.c.i.cloud_gcp.VertexOrchestration - Uploading model: test_ctr_model-1.0; Artifact URI: gs://zipline-warehouse-models/training_output/test_ctr_model-1.0/2025-12-08/model; Container Image: us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-1:latest INFO a.c.i.cloud_gcp.VertexOrchestration - Model upload completed: projects/703996152583/locations/us-central1/models/7906534239267454976 INFO a.c.i.cloud_gcp.VertexOrchestration - Deploying model to endpoint. Details: Model: projects/703996152583/locations/us-central1/models/7906534239267454976; Endpoint: projects/703996152583/locations/us-central1/endpoints/8483716271198699520; Replicas: 3-10 Machine type: n1-standard-4 INFO a.c.i.cloud_gcp.VertexOrchestration - Model deployment initiated. Operation: projects/703996152583/locations/us-central1/endpoints/8483716271198699520/operations/5139381969750065152 INFO a.c.i.c.VertexOrchestrationTest - Deployment ID: projects/703996152583/locations/us-central1/endpoints/8483716271198699520/operations/5139381969750065152 INFO a.c.i.c.VertexOrchestrationTest - Waiting for deployment. Current state: UNKNOWN INFO a.c.i.c.VertexOrchestrationTest - Waiting for deployment. Current state: RUNNING ... ``` ## Checklist - [X] Added Unit Tests - [ ] Covered by existing CI - [X] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Vertex AI orchestration: submit/monitor training jobs, create/manage endpoints, deploy models, and query operation/job status. * **New Utilities** * Centralized HTTP client and prediction request/response helpers for Vertex interactions. * **API** * Model platform API extended with training/deploy requests, operations, and job status types. * **Tests** * Added comprehensive tests covering orchestration, model construction, deployment, and status mapping. * **Chores** * Added GCP AI Platform client dependency. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 5a12a09 commit d326844

File tree

10 files changed

+917
-150
lines changed

10 files changed

+917
-150
lines changed

cloud_gcp/package.mill

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ trait CloudGcpModule extends Cross.Module[String] with build.BaseModule {
4747
mvn"org.apache.iceberg::iceberg-spark-runtime-3.5:1.10.0",
4848
mvn"com.google.cloud:google-cloud-bigquery:2.54.1",
4949
mvn"com.google.cloud:google-cloud-bigtable:2.57.1",
50+
mvn"com.google.cloud:google-cloud-aiplatform:3.79.0",
5051
mvn"io.vertx:vertx-web-client:4.5.22",
5152
mvn"com.google.auth:google-auth-library-oauth2-http:1.40.0",
5253
).map(excludeJackson) ++ Seq(
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package ai.chronon.integrations.cloud_gcp
2+
3+
import com.google.auth.oauth2.GoogleCredentials
4+
import io.vertx.core.buffer.Buffer
5+
import io.vertx.core.json.{JsonArray, JsonObject}
6+
import io.vertx.ext.web.client.{HttpResponse, WebClient}
7+
8+
import scala.jdk.CollectionConverters._
9+
10+
sealed trait HttpMethod
11+
case object GetMethod extends HttpMethod
12+
case object PostMethod extends HttpMethod
13+
case object PutMethod extends HttpMethod
14+
case object DeleteMethod extends HttpMethod
15+
16+
class VertexHttpClient(client: WebClient) extends Serializable {
17+
18+
// Init google creds - we need this to set the Auth header
19+
private lazy val googleCredentials: GoogleCredentials =
20+
GoogleCredentials
21+
.getApplicationDefault()
22+
.createScoped(List("https://www.googleapis.com/auth/aiplatform").asJava)
23+
24+
private def getAccessToken: String = {
25+
googleCredentials.refreshIfExpired()
26+
googleCredentials.getAccessToken.getTokenValue
27+
}
28+
29+
def makeHttpRequest[T](url: String, method: HttpMethod, requestBody: Option[JsonObject] = None)(
30+
handler: HttpResponse[Buffer] => T): Unit = {
31+
val request = method match {
32+
case GetMethod => client.getAbs(url)
33+
case PostMethod => client.postAbs(url)
34+
case PutMethod => client.putAbs(url)
35+
case DeleteMethod => client.deleteAbs(url)
36+
case _ => throw new IllegalArgumentException(s"Currently unsupported HTTP method: $method")
37+
}
38+
39+
// Add common headers
40+
val requestWithHeaders = request
41+
.putHeader("Authorization", s"Bearer $getAccessToken")
42+
val finalRequest = if (requestBody.isDefined) {
43+
requestWithHeaders
44+
.putHeader("Content-Type", "application/json")
45+
.putHeader("Accept", "application/json")
46+
} else {
47+
requestWithHeaders
48+
}
49+
50+
// Send the request
51+
val responseFuture = requestBody match {
52+
case Some(body) => finalRequest.sendJsonObject(body)
53+
case None => finalRequest.send()
54+
}
55+
56+
responseFuture.onComplete { asyncResult =>
57+
if (asyncResult.succeeded()) {
58+
handler(asyncResult.result())
59+
} else {
60+
handler(null)
61+
}
62+
}
63+
}
64+
}
65+
66+
object VertexHttpUtils {
67+
68+
def convertToVertxJson(obj: Any): Any = {
69+
obj match {
70+
case map: Map[_, _] =>
71+
val jsonObject = new JsonObject()
72+
map.foreach { case (key, value) =>
73+
jsonObject.put(key.toString, convertToVertxJson(value))
74+
}
75+
jsonObject
76+
case seq: Seq[_] =>
77+
val jsonArray = new JsonArray()
78+
seq.foreach { item =>
79+
jsonArray.add(convertToVertxJson(item))
80+
}
81+
jsonArray
82+
case other => other
83+
}
84+
}
85+
86+
/** Build the request body for Vertex AI prediction requests.
87+
* Format is the same for both publisher and custom models:
88+
* { "instances": [ { ..req 1..}, { ... } ], "parameters": { ... } }
89+
*
90+
* Publisher: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings
91+
* Custom: https://docs.cloud.google.com/vertex-ai/docs/predictions/get-online-predictions
92+
*/
93+
def createPredictionRequestBody(inputRequests: Seq[Map[String, AnyRef]],
94+
modelParams: Map[String, String]): JsonObject = {
95+
val instancesArray = new JsonArray()
96+
97+
inputRequests.foreach { inputRequest =>
98+
val instance = inputRequest("instance")
99+
val jsonInstance = convertToVertxJson(instance)
100+
instancesArray.add(jsonInstance)
101+
}
102+
103+
val requestBody = new JsonObject()
104+
requestBody.put("instances", instancesArray)
105+
106+
// Add parameters if present (exclude model_name and model_type)
107+
val additionalParams = modelParams.filterKeys(k => k != "model_name" && k != "model_type")
108+
if (additionalParams.nonEmpty) {
109+
val parametersObj = new JsonObject()
110+
additionalParams.foreach { case (key, value) =>
111+
parametersObj.put(key, value)
112+
}
113+
requestBody.put("parameters", parametersObj)
114+
}
115+
116+
requestBody
117+
}
118+
119+
/** Extract prediction results from Vertex AI prediction response.
120+
* Response is a JsonObject with "predictions": [ {...}, {...} ]
121+
*/
122+
def extractPredictionResults(responseBody: JsonObject): Seq[Map[String, AnyRef]] = {
123+
val predictions = responseBody.getJsonArray("predictions")
124+
125+
if (predictions == null) {
126+
throw new RuntimeException("No 'predictions' array found in response")
127+
}
128+
129+
(0 until predictions.size()).map { index =>
130+
val predictionJsonObject = predictions.getJsonObject(index)
131+
predictionJsonObject.getMap.asScala.toMap
132+
}
133+
}
134+
}

0 commit comments

Comments
 (0)