-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Requirements
We want the API to be simple and expressible enough to hand-off inference and training to external model platforms and establish a complete lineage graph in a wide variety of cases.
-
Base Inference patterns : Trained + Pretrained(LLM) :
- Batch: Enrich a feature table with a model output column.
- Online: Fetch a model inference given a set of entity keys - where feature fetching and model calling happens under the hood.
-
Transfer learning inference patterns:
- Apply a parent model on a parent dataset to generate a score or an embedding and then use that embedding as a feature in a downstream model.
- In batch processing mode
- In real-time processing mode
- Apply a parent model on a parent dataset to generate a score or an embedding and then use that embedding as a feature in a downstream model.
-
Base training patterns:
- Labelled / Supervised
- Clustering / Un-supervised
-
Transfer learning training patterns
- Pretrained parent
- Trained parent
Personas
-
Users can do model training and inference out of band, in which case they don't define Model API proposed below. This is how Chronon is typically used today, and we will continue to support that even with the proposal.
-
Users want to do model training out-of-band, but want chronon to pick up the trained model, mount it and use it to respond to
fetcher.fetchInference()method where we use avro to shrink data pass it over the wire to model prediction service and gather an inference and respond to users. This is needed to support embeddings and transfer learning in general. We will be counting on the user to manage model versioning and automatic retraining on their own. -
Users want chronon to trigger model training and serving both. In which case they will define trainer code below, where chronon will handle model versioning, model training, trained artifact serialization, artifact to endpoint mounting, and inference as described in 2. In this case chronon will guarantee data-leak-free training of the entire model chain in transfer learning use-cases.
Proposal
We want to introduce a Model api object that allows for the above patterns. These api objects will live in their own models/ folder similar to joins, group_bys and staging_queries.
Below is Pseudo Code. We will describe the parameters in detail below.
from chronon import Model
risk_model = Model(
platform = ...
id = ...
# persona 2 will specify a join here
# persona 3 will specify a labelled_table here,
# we will walk the graph upstream to find relevant joins for inference
source = Source(...)
# needed to trigger inference
prediction_schema = StructType(...)
key_columns = List[],
feature_columns = List[],
inference_conf = ModelConfig(
inference_image = ...
packages = ...
props_json =
)
# optional (Persona 3), needed to trigger training
training_data_window = Window(x, DAYS)
label_columns = List[],
trainer_class = <PythonModule> # see trainer below
training_conf = ModelConfig(
inference_image = ...
packages = ...
props_json =
)
)-
platformpoints to a impl ofModelPlatformchronon api. We will natively supportvertexandsagemaker, but people can also plugin their own platform implementations and make them available to chronon as a jar on batch/streaming/fetching class path. We will describe theModelPlatformapi below. -
id: refers to the identifier of this particular model in the platform. The platform impl will be passed this id. In gcp this will be of the form,<project_name>.<model_display_name> -
inference_conf|training_conf: params used for inference or training of the underlying platform
// we will implement this for vertex AI and sagemaker to begin with.
abstract class ModelPlatform(model: Model) {
// when the flag `encodeAvro` is set to true:
// chronon will wrap this with avro encoding to significantly shrink the on-wire size,
// the receiver will need to call fetcher.fetchPayloadSchema(model_id: str) to pull the payload schema to unpack
def fetchInference(keys: Map[String, Any], features: Map[String, Any], timestamp: Long, encodeAvro: Boolean = True): Map[String, Any]
// date indicates that the model is trained with data from [date - window, date]
def createInferenceEndpoint(date: String): Future[EndpointCreationResponse]
// optional, only needed if you want chronon to trigger training , Persona 3
// and manage timestamped versions of models
def launchModelTrainingJob(date: String): Future[ModelTrainingResponse]
}
Example of specifying embedding chains
Note: We dropped label-joining api in Joins, and instead substituted it with
recompute_daysflag instaging_query. All labeling now happens with staging_queries. The main driver to do that was the universal difficulty in reasoning about label offsets in the previous API, the new API also can express a wider variety of labelling patterns than the previous ones. We found the staging_query way of computing labels also turned out to be close to 10x faster and cheaper.
During compile we will register upstream joins to the model, and during inference automatically invoke fetchJoin on those. It is also possible to have multiple joins upstream.
For Trained Models
# Persona 3
class MyModelImpl(chronon.BaseModel):
def train(training_data: TablePointer): chronon.BaseTrainedModelIn this case, we will trigger training and associate versions and column lineage to the resulting artifacts.
We will implicitly convert the batch iceberg data into arrow dataframes and pass into the train method.
Online, we will manage the avro conversion of the payload, unpack and pass to the underlying XGBoost, TF or Pytorch models etc.