Skip to content

[WIP] RFC: Model API design #1181

@nikhil-zlai

Description

@nikhil-zlai

Requirements

We want the API to be simple and expressible enough to hand-off inference and training to external model platforms and establish a complete lineage graph in a wide variety of cases.

  1. Base Inference patterns : Trained + Pretrained(LLM) :

    1. Batch: Enrich a feature table with a model output column.
    2. Online: Fetch a model inference given a set of entity keys - where feature fetching and model calling happens under the hood.
  2. Transfer learning inference patterns:

    1. Apply a parent model on a parent dataset to generate a score or an embedding and then use that embedding as a feature in a downstream model.
      1. In batch processing mode
      2. In real-time processing mode
  3. Base training patterns:

    1. Labelled / Supervised
    2. Clustering / Un-supervised
  4. Transfer learning training patterns

    1. Pretrained parent
    2. Trained parent

Personas

  1. Users can do model training and inference out of band, in which case they don't define Model API proposed below. This is how Chronon is typically used today, and we will continue to support that even with the proposal.

  2. Users want to do model training out-of-band, but want chronon to pick up the trained model, mount it and use it to respond to fetcher.fetchInference() method where we use avro to shrink data pass it over the wire to model prediction service and gather an inference and respond to users. This is needed to support embeddings and transfer learning in general. We will be counting on the user to manage model versioning and automatic retraining on their own.

  3. Users want chronon to trigger model training and serving both. In which case they will define trainer code below, where chronon will handle model versioning, model training, trained artifact serialization, artifact to endpoint mounting, and inference as described in 2. In this case chronon will guarantee data-leak-free training of the entire model chain in transfer learning use-cases.

Proposal

We want to introduce a Model api object that allows for the above patterns. These api objects will live in their own models/ folder similar to joins, group_bys and staging_queries.

Below is Pseudo Code. We will describe the parameters in detail below.

from chronon import Model

risk_model = Model(
    platform = ... 
    id = ...
    # persona 2 will specify a join here
    # persona 3 will specify a labelled_table here, 
    #                    we will walk the graph upstream to find relevant joins for inference
    source = Source(...)

    # needed to trigger inference
    prediction_schema = StructType(...)    
    key_columns = List[],
    feature_columns = List[],
    inference_conf = ModelConfig(
        inference_image = ...
        packages = ...
        props_json = 
    )

    # optional (Persona 3), needed to trigger training
    training_data_window = Window(x, DAYS)
    label_columns = List[],
    trainer_class = <PythonModule>  # see trainer below
    training_conf = ModelConfig(
        inference_image = ...
        packages = ...
        props_json = 
    ) 
)
  • platform points to a impl of ModelPlatform chronon api. We will natively support vertex and sagemaker, but people can also plugin their own platform implementations and make them available to chronon as a jar on batch/streaming/fetching class path. We will describe the ModelPlatform api below.

  • id: refers to the identifier of this particular model in the platform. The platform impl will be passed this id. In gcp this will be of the form, <project_name>.<model_display_name>

  • inference_conf | training_conf: params used for inference or training of the underlying platform

// we will implement this for vertex AI and sagemaker to begin with.
abstract class ModelPlatform(model: Model) {

  // when the flag `encodeAvro` is set to true:
  //         chronon will wrap this with avro encoding to significantly shrink the on-wire size, 
  //        the receiver will need to call fetcher.fetchPayloadSchema(model_id: str) to pull the payload schema to unpack
  def fetchInference(keys: Map[String, Any], features: Map[String, Any], timestamp: Long, encodeAvro: Boolean = True): Map[String, Any]

  // date indicates that the model is trained with data from [date - window, date] 
  def createInferenceEndpoint(date: String): Future[EndpointCreationResponse]

  // optional, only needed if you want chronon to trigger training , Persona 3
  // and manage timestamped versions of models 
  def launchModelTrainingJob(date: String): Future[ModelTrainingResponse]
}

Example of specifying embedding chains

Note: We dropped label-joining api in Joins, and instead substituted it with recompute_days flag in staging_query. All labeling now happens with staging_queries. The main driver to do that was the universal difficulty in reasoning about label offsets in the previous API, the new API also can express a wider variety of labelling patterns than the previous ones. We found the staging_query way of computing labels also turned out to be close to 10x faster and cheaper.

Image

During compile we will register upstream joins to the model, and during inference automatically invoke fetchJoin on those. It is also possible to have multiple joins upstream.

For Trained Models

# Persona 3
class MyModelImpl(chronon.BaseModel):
  def train(training_data: TablePointer): chronon.BaseTrainedModel

In this case, we will trigger training and associate versions and column lineage to the resulting artifacts.
We will implicitly convert the batch iceberg data into arrow dataframes and pass into the train method.
Online, we will manage the avro conversion of the payload, unpack and pass to the underlying XGBoost, TF or Pytorch models etc.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions