Skip to content

[Core][Model] PrithviMAE Enablement on vLLM v1 engine #20577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c623c60
Better support for skip_tokenizer_init=True
christian-pinto Jun 5, 2025
146b3b2
Support for attention free models in V1
christian-pinto Jun 5, 2025
7392c45
Last few changes after rebasing to latest branch version
christian-pinto Jun 6, 2025
9a06b55
Support passing raw multimodal data to model
christian-pinto Jun 5, 2025
9aa5533
latest changes to align with the original branch
christian-pinto Jun 6, 2025
5ac66e7
Latest changes to aadpt to upstream master
christian-pinto Jun 24, 2025
8f27e28
Some reformatting to make the pre-commit hooks succeed
christian-pinto Jun 25, 2025
5fe55fd
Few more changes to solve some other pre-commit hooks failures
christian-pinto Jun 25, 2025
c992ea3
Some style changes
christian-pinto Jun 26, 2025
137ec29
Simple code refactoring
christian-pinto Jun 27, 2025
e59d7dc
Rebased to master
christian-pinto Jul 16, 2025
9afe699
Ensure pre-commit checks succeed
christian-pinto Jul 16, 2025
7b3081b
Few more style edits
christian-pinto Jul 16, 2025
c0f3907
Skip tokenizer init in async LLM engine
christian-pinto Jul 16, 2025
670fe89
Merge branch 'main' into prithvi_v1_embeddings_zero_kv_cache_group
christian-pinto Jul 18, 2025
91f7b24
Changes after review
christian-pinto Jul 18, 2025
1e751d0
Further review round
christian-pinto Jul 18, 2025
cebfc59
Further edits
christian-pinto Jul 18, 2025
59a8248
One last if condition to be flipped
christian-pinto Jul 18, 2025
0f9bd48
pre-commit hooks failing
christian-pinto Jul 18, 2025
c66c92e
Added tokenizer check in AsyncLLM.get_tokenizer()
christian-pinto Jul 18, 2025
d8b6566
Merge branch 'main' into prithvi_v1_embeddings_zero_kv_cache_group
christian-pinto Jul 21, 2025
aaf247c
Added PrithviMAE test requirements
christian-pinto Jul 22, 2025
0bc48a3
Testing test deadlock fix
christian-pinto Jul 22, 2025
482f639
Merge branch 'main' into prithvi_v1_embeddings_zero_kv_cache_group
christian-pinto Jul 22, 2025
2cdd666
Pre-commit of course
christian-pinto Jul 22, 2025
31d6de7
Fix GPUModelRunner._may_reorder_batch to support Mamba models
christian-pinto Jul 23, 2025
93e1cd1
Updated iprithvi geospatial MAE test to avoid OOM during warmup
christian-pinto Jul 23, 2025
7a87907
Updated requirements for PrithviGeospatialMAE tests
christian-pinto Jul 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 70 additions & 175 deletions examples/offline_inference/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This is a demo script showing how to use the
PrithviGeospatialMAE model with vLLM
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa

Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa

The requirements for running this script are:
- Installing [terratorch, albumentations, rasterio] in your python environment
- downloading the model weights in a 'model' folder local to the script
(temporary measure until the proper config.json file is uploaded to HF)
- download an input example image (India_900498_S2Hand.tif) and place it in
the same folder with the script (or specify with the --data_file argument)

Run the example:
python prithvi_geospatial_mae.py

""" # noqa: E501

import argparse
import datetime
import os
import re
from typing import Union

import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule

from vllm import LLM

torch.set_default_dtype(torch.float16)

NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
OFFSET = 0
PERCENTILE = 99

model_config = """{
"architectures": ["PrithviGeoSpatialMAE"],
"num_classes": 0,
"pretrained_cfg": {
"task_args": {
"task": "SemanticSegmentationTask",
"model_factory": "EncoderDecoderFactory",
"loss": "ce",
"ignore_index": -1,
"lr": 0.001,
"freeze_backbone": false,
"freeze_decoder": false,
"plot_on_val": 10,
"optimizer": "AdamW",
"scheduler": "CosineAnnealingLR"
},
"model_args": {
"backbone_pretrained": false,
"backbone": "prithvi_eo_v2_300_tl",
"decoder": "UperNetDecoder",
"decoder_channels": 256,
"decoder_scale_modules": true,
"num_classes": 2,
"rescale": true,
"backbone_bands": [
"BLUE",
"GREEN",
"RED",
"NIR_NARROW",
"SWIR_1",
"SWIR_2"
],
"head_dropout": 0.1,
"necks": [
{
"name": "SelectIndices",
"indices": [
5,
11,
17,
23
]
},
{
"name": "ReshapeTokensToImage"
}
]
},
"optimizer_params" : {
"lr": 5.0e-05,
"betas": [0.9, 0.999],
"eps": [1.0e-08],
"weight_decay": 0.05,
"amsgrad": false,
"maximize": false,
"capturable": false,
"differentiable": false
},
"scheduler_params" : {
"T_max": 50,
"eta_min": 0,
"last_epoch": -1,
"verbose": "deprecated"
}
},


"torch_dtype": "float32"
}
"""

# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
with open(
os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
) as config_file:
config_file.write(model_config)

datamodule_config = {
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
"batch_size": 16,
Expand All @@ -138,28 +43,24 @@


class PrithviMAE:
def __init__(self):
print("Initializing PrithviMAE model")
self.llm = LLM(
model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True,
dtype="float32",
def __init__(self, model):
self.model = LLM(
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
)

def run(self, input_data, location_coords):
print("################ Running inference on vLLM ##############")
# merge the inputs into one data structure
if input_data is not None and input_data.dtype == torch.float32:
input_data = input_data.to(torch.float16)
input_data = input_data[0]

mm_data = {
"pixel_values": torch.empty(0) if input_data is None else input_data,
"location_coords": torch.empty(0)
if location_coords is None
else location_coords,
"pixel_values": input_data,
"location_coords": location_coords,
}

prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}

outputs = self.llm.encode(prompt, use_tqdm=False)
print("################ Inference done (it took seconds) ##############")
outputs = self.model.encode(prompt, use_tqdm=False)

return outputs[0].outputs.data

Expand All @@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
"""
Args:
orig_img: torch.Tensor representing original image (reference)
with shape = (bands, H, W).
with shape = (bands, H, W).
channels: list of indices representing RGB channels.

Returns:
torch.Tensor with shape (num_channels, height, width) for original image
torch.Tensor with shape (num_channels, height, width)
for original image
"""

orig_img = orig_img[channels, ...]
Expand Down Expand Up @@ -260,10 +162,10 @@ def load_example(

Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the images
in *file_paths*.
std: list containing std values for each band in the images
in *file_paths*.
mean: list containing mean values for each band in the
images in *file_paths*.
std: list containing std values for each band in the
images in *file_paths*.

Returns:
np.array containing created example
Expand Down Expand Up @@ -308,7 +210,7 @@ def load_example(
print(f"Could not extract timestamp for {file} ({e})")

imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
imgs = np.expand_dims(imgs, axis=0) # add batch di

return imgs, temporal_coords, location_coords, metas
Expand All @@ -332,8 +234,10 @@ def run_model(
)

# Build sliding window

batch_size = 1
batch = torch.tensor(input_data, device="cpu")
# batch = torch.tensor(input_data, device="cpu")
batch = torch.tensor(input_data)
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
h1, w1 = windows.shape[3:5]
windows = rearrange(
Expand All @@ -344,34 +248,24 @@ def run_model(
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
windows = torch.tensor_split(windows, num_batches, dim=0)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

if temporal_coords:
temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
else:
temporal_coords = None
if location_coords:
location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
else:
location_coords = None

# Run model
# Run Prithvi-EO-V2-300M-TL-Sen1Floods11
pred_imgs = []
for x in windows:
# Apply standardization
x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
x = datamodule.aug(x)["image"]

with torch.no_grad():
x = x.to(device)
pred = model.run(x, location_coords=location_coords)
if lightning_model:
pred_lightning = lightning_model(
x, temporal_coords=temporal_coords, location_coords=location_coords
)
pred_lightning = pred_lightning.output.detach().cpu()
if not torch.equal(pred, pred_lightning):
print("Inference output is not equal")
y_hat = pred.argmax(dim=1)

y_hat = torch.nn.functional.interpolate(
Expand Down Expand Up @@ -403,52 +297,18 @@ def run_model(
return pred_imgs


def parse_args():
parser = argparse.ArgumentParser("MAE run inference", add_help=False)

parser.add_argument(
"--data_file",
type=str,
default="./India_900498_S2Hand.tif",
help="Path to the file.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Path to the directory where to save outputs.",
)
parser.add_argument(
"--input_indices",
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
help="0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
)
parser.add_argument(
"--rgb_outputs",
action="store_true",
help="If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved.",
)


def main(
data_file: str,
model: str,
output_dir: str,
rgb_outputs: bool,
input_indices: list[int] = None,
):
os.makedirs(output_dir, exist_ok=True)

# Load model ---------------------------------------------------------------

model_obj = PrithviMAE()
model_obj = PrithviMAE(model=model)
datamodule = generate_datamodule()
img_size = 256 # Size of Sen1Floods11

# Loading data -------------------------------------------------------------
img_size = 512 # Size of Sen1Floods11

input_data, temporal_coords, location_coords, meta_data = load_example(
file_paths=[data_file],
Expand All @@ -460,16 +320,13 @@ def main(
if input_data.mean() > 1:
input_data = input_data / 10000 # Convert to range 0-1

# Running model ------------------------------------------------------------

channels = [
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
] # BGR -> RGB

pred = run_model(
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
)

# Save pred
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
pred_file = os.path.join(
Expand All @@ -487,6 +344,7 @@ def main(
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
channels=channels,
)
rgb_orig = rgb_orig.to(torch.float32)

pred[pred == 0.0] = np.nan
img_pred = rgb_orig * 0.7 + pred * 0.3
Expand All @@ -503,9 +361,10 @@ def main(

# Save image rgb
if rgb_outputs:
name_suffix = os.path.splitext(os.path.basename(data_file))[0]
rgb_file = os.path.join(
output_dir,
f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
f"original_rgb_{name_suffix}.tiff",
)
save_geotiff(
image=_convert_np_uint8(rgb_orig),
Expand All @@ -515,6 +374,42 @@ def main(


if __name__ == "__main__":
args = parse_args()
parser = argparse.ArgumentParser("MAE run inference", add_help=False)

parser.add_argument(
"--data_file",
type=str,
default="./India_900498_S2Hand.tif",
help="Path to the file.",
)
parser.add_argument(
"--model",
type=str,
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
help="Path to a checkpoint file to load from.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Path to the directory where to save outputs.",
)
parser.add_argument(
"--input_indices",
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
help="""
0-based indices of the six Prithvi channels to be selected from the input.
By default selects [1,2,3,8,11,12] for S2L1C data.
""",
)
parser.add_argument(
"--rgb_outputs",
action="store_true",
help="If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved.",
)
args = parser.parse_args()

main(**vars(args))
1 change: 1 addition & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
Loading