Skip to content

Commit 5d0155d

Browse files
christian-pintozixi-qi
authored andcommitted
[Core][Model] PrithviMAE Enablement on vLLM v1 engine (vllm-project#20577)
Signed-off-by: Christian Pinto <[email protected]> Signed-off-by: qizixi <[email protected]>
1 parent 09c7ebb commit 5d0155d

File tree

15 files changed

+704
-238
lines changed

15 files changed

+704
-238
lines changed

examples/offline_inference/prithvi_geospatial_mae.py

Lines changed: 70 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -1,122 +1,27 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
"""
4-
This is a demo script showing how to use the
5-
PrithviGeospatialMAE model with vLLM
6-
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
7-
8-
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
9-
10-
The requirements for running this script are:
11-
- Installing [terratorch, albumentations, rasterio] in your python environment
12-
- downloading the model weights in a 'model' folder local to the script
13-
(temporary measure until the proper config.json file is uploaded to HF)
14-
- download an input example image (India_900498_S2Hand.tif) and place it in
15-
the same folder with the script (or specify with the --data_file argument)
16-
17-
Run the example:
18-
python prithvi_geospatial_mae.py
19-
20-
""" # noqa: E501
21-
223
import argparse
234
import datetime
245
import os
6+
import re
257
from typing import Union
268

279
import albumentations
2810
import numpy as np
2911
import rasterio
30-
import regex as re
3112
import torch
3213
from einops import rearrange
3314
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
3415

3516
from vllm import LLM
3617

18+
torch.set_default_dtype(torch.float16)
19+
3720
NO_DATA = -9999
3821
NO_DATA_FLOAT = 0.0001
3922
OFFSET = 0
4023
PERCENTILE = 99
4124

42-
model_config = """{
43-
"architectures": ["PrithviGeoSpatialMAE"],
44-
"num_classes": 0,
45-
"pretrained_cfg": {
46-
"task_args": {
47-
"task": "SemanticSegmentationTask",
48-
"model_factory": "EncoderDecoderFactory",
49-
"loss": "ce",
50-
"ignore_index": -1,
51-
"lr": 0.001,
52-
"freeze_backbone": false,
53-
"freeze_decoder": false,
54-
"plot_on_val": 10,
55-
"optimizer": "AdamW",
56-
"scheduler": "CosineAnnealingLR"
57-
},
58-
"model_args": {
59-
"backbone_pretrained": false,
60-
"backbone": "prithvi_eo_v2_300_tl",
61-
"decoder": "UperNetDecoder",
62-
"decoder_channels": 256,
63-
"decoder_scale_modules": true,
64-
"num_classes": 2,
65-
"rescale": true,
66-
"backbone_bands": [
67-
"BLUE",
68-
"GREEN",
69-
"RED",
70-
"NIR_NARROW",
71-
"SWIR_1",
72-
"SWIR_2"
73-
],
74-
"head_dropout": 0.1,
75-
"necks": [
76-
{
77-
"name": "SelectIndices",
78-
"indices": [
79-
5,
80-
11,
81-
17,
82-
23
83-
]
84-
},
85-
{
86-
"name": "ReshapeTokensToImage"
87-
}
88-
]
89-
},
90-
"optimizer_params" : {
91-
"lr": 5.0e-05,
92-
"betas": [0.9, 0.999],
93-
"eps": [1.0e-08],
94-
"weight_decay": 0.05,
95-
"amsgrad": false,
96-
"maximize": false,
97-
"capturable": false,
98-
"differentiable": false
99-
},
100-
"scheduler_params" : {
101-
"T_max": 50,
102-
"eta_min": 0,
103-
"last_epoch": -1,
104-
"verbose": "deprecated"
105-
}
106-
},
107-
108-
109-
"torch_dtype": "float32"
110-
}
111-
"""
112-
113-
# Temporarily creating the "config.json" for the model.
114-
# This is going to disappear once the correct config.json is available on HF
115-
with open(
116-
os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
117-
) as config_file:
118-
config_file.write(model_config)
119-
12025
datamodule_config = {
12126
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
12227
"batch_size": 16,
@@ -138,28 +43,24 @@
13843

13944

14045
class PrithviMAE:
141-
def __init__(self):
142-
print("Initializing PrithviMAE model")
143-
self.llm = LLM(
144-
model=os.path.join(os.path.dirname(__file__), "./model"),
145-
skip_tokenizer_init=True,
146-
dtype="float32",
46+
def __init__(self, model):
47+
self.model = LLM(
48+
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
14749
)
14850

14951
def run(self, input_data, location_coords):
150-
print("################ Running inference on vLLM ##############")
15152
# merge the inputs into one data structure
53+
if input_data is not None and input_data.dtype == torch.float32:
54+
input_data = input_data.to(torch.float16)
55+
input_data = input_data[0]
56+
15257
mm_data = {
153-
"pixel_values": torch.empty(0) if input_data is None else input_data,
154-
"location_coords": torch.empty(0)
155-
if location_coords is None
156-
else location_coords,
58+
"pixel_values": input_data,
59+
"location_coords": location_coords,
15760
}
15861

15962
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
160-
161-
outputs = self.llm.encode(prompt, use_tqdm=False)
162-
print("################ Inference done (it took seconds) ##############")
63+
outputs = self.model.encode(prompt, use_tqdm=False)
16364

16465
return outputs[0].outputs.data
16566

@@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
18182
"""
18283
Args:
18384
orig_img: torch.Tensor representing original image (reference)
184-
with shape = (bands, H, W).
85+
with shape = (bands, H, W).
18586
channels: list of indices representing RGB channels.
18687
18788
Returns:
188-
torch.Tensor with shape (num_channels, height, width) for original image
89+
torch.Tensor with shape (num_channels, height, width)
90+
for original image
18991
"""
19092

19193
orig_img = orig_img[channels, ...]
@@ -260,10 +162,10 @@ def load_example(
260162
261163
Args:
262164
file_paths: list of file paths .
263-
mean: list containing mean values for each band in the images
264-
in *file_paths*.
265-
std: list containing std values for each band in the images
266-
in *file_paths*.
165+
mean: list containing mean values for each band in the
166+
images in *file_paths*.
167+
std: list containing std values for each band in the
168+
images in *file_paths*.
267169
268170
Returns:
269171
np.array containing created example
@@ -308,7 +210,7 @@ def load_example(
308210
print(f"Could not extract timestamp for {file} ({e})")
309211

310212
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
311-
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
213+
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
312214
imgs = np.expand_dims(imgs, axis=0) # add batch di
313215

314216
return imgs, temporal_coords, location_coords, metas
@@ -332,8 +234,10 @@ def run_model(
332234
)
333235

334236
# Build sliding window
237+
335238
batch_size = 1
336-
batch = torch.tensor(input_data, device="cpu")
239+
# batch = torch.tensor(input_data, device="cpu")
240+
batch = torch.tensor(input_data)
337241
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
338242
h1, w1 = windows.shape[3:5]
339243
windows = rearrange(
@@ -344,34 +248,24 @@ def run_model(
344248
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
345249
windows = torch.tensor_split(windows, num_batches, dim=0)
346250

347-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
348-
349251
if temporal_coords:
350-
temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
252+
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
351253
else:
352254
temporal_coords = None
353255
if location_coords:
354-
location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
256+
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
355257
else:
356258
location_coords = None
357259

358-
# Run model
260+
# Run Prithvi-EO-V2-300M-TL-Sen1Floods11
359261
pred_imgs = []
360262
for x in windows:
361263
# Apply standardization
362264
x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
363265
x = datamodule.aug(x)["image"]
364266

365267
with torch.no_grad():
366-
x = x.to(device)
367268
pred = model.run(x, location_coords=location_coords)
368-
if lightning_model:
369-
pred_lightning = lightning_model(
370-
x, temporal_coords=temporal_coords, location_coords=location_coords
371-
)
372-
pred_lightning = pred_lightning.output.detach().cpu()
373-
if not torch.equal(pred, pred_lightning):
374-
print("Inference output is not equal")
375269
y_hat = pred.argmax(dim=1)
376270

377271
y_hat = torch.nn.functional.interpolate(
@@ -403,52 +297,18 @@ def run_model(
403297
return pred_imgs
404298

405299

406-
def parse_args():
407-
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
408-
409-
parser.add_argument(
410-
"--data_file",
411-
type=str,
412-
default="./India_900498_S2Hand.tif",
413-
help="Path to the file.",
414-
)
415-
parser.add_argument(
416-
"--output_dir",
417-
type=str,
418-
default="output",
419-
help="Path to the directory where to save outputs.",
420-
)
421-
parser.add_argument(
422-
"--input_indices",
423-
default=[1, 2, 3, 8, 11, 12],
424-
type=int,
425-
nargs="+",
426-
help="0-based indices of the six Prithvi channels to be selected from the "
427-
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
428-
)
429-
parser.add_argument(
430-
"--rgb_outputs",
431-
action="store_true",
432-
help="If present, output files will only contain RGB channels. "
433-
"Otherwise, all bands will be saved.",
434-
)
435-
436-
437300
def main(
438301
data_file: str,
302+
model: str,
439303
output_dir: str,
440304
rgb_outputs: bool,
441305
input_indices: list[int] = None,
442306
):
443307
os.makedirs(output_dir, exist_ok=True)
444308

445-
# Load model ---------------------------------------------------------------
446-
447-
model_obj = PrithviMAE()
309+
model_obj = PrithviMAE(model=model)
448310
datamodule = generate_datamodule()
449-
img_size = 256 # Size of Sen1Floods11
450-
451-
# Loading data -------------------------------------------------------------
311+
img_size = 512 # Size of Sen1Floods11
452312

453313
input_data, temporal_coords, location_coords, meta_data = load_example(
454314
file_paths=[data_file],
@@ -460,16 +320,13 @@ def main(
460320
if input_data.mean() > 1:
461321
input_data = input_data / 10000 # Convert to range 0-1
462322

463-
# Running model ------------------------------------------------------------
464-
465323
channels = [
466324
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
467325
] # BGR -> RGB
468326

469327
pred = run_model(
470328
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
471329
)
472-
473330
# Save pred
474331
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
475332
pred_file = os.path.join(
@@ -487,6 +344,7 @@ def main(
487344
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
488345
channels=channels,
489346
)
347+
rgb_orig = rgb_orig.to(torch.float32)
490348

491349
pred[pred == 0.0] = np.nan
492350
img_pred = rgb_orig * 0.7 + pred * 0.3
@@ -503,9 +361,10 @@ def main(
503361

504362
# Save image rgb
505363
if rgb_outputs:
364+
name_suffix = os.path.splitext(os.path.basename(data_file))[0]
506365
rgb_file = os.path.join(
507366
output_dir,
508-
f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
367+
f"original_rgb_{name_suffix}.tiff",
509368
)
510369
save_geotiff(
511370
image=_convert_np_uint8(rgb_orig),
@@ -515,6 +374,42 @@ def main(
515374

516375

517376
if __name__ == "__main__":
518-
args = parse_args()
377+
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
378+
379+
parser.add_argument(
380+
"--data_file",
381+
type=str,
382+
default="./India_900498_S2Hand.tif",
383+
help="Path to the file.",
384+
)
385+
parser.add_argument(
386+
"--model",
387+
type=str,
388+
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
389+
help="Path to a checkpoint file to load from.",
390+
)
391+
parser.add_argument(
392+
"--output_dir",
393+
type=str,
394+
default="output",
395+
help="Path to the directory where to save outputs.",
396+
)
397+
parser.add_argument(
398+
"--input_indices",
399+
default=[1, 2, 3, 8, 11, 12],
400+
type=int,
401+
nargs="+",
402+
help="""
403+
0-based indices of the six Prithvi channels to be selected from the input.
404+
By default selects [1,2,3,8,11,12] for S2L1C data.
405+
""",
406+
)
407+
parser.add_argument(
408+
"--rgb_outputs",
409+
action="store_true",
410+
help="If present, output files will only contain RGB channels. "
411+
"Otherwise, all bands will be saved.",
412+
)
413+
args = parser.parse_args()
519414

520415
main(**vars(args))

requirements/test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
5454
runai-model-streamer-s3==0.11.0
5555
fastsafetensors>=0.1.10
5656
pydantic>=2.10 # 2.9 leads to error on python 3.10
57+
terratorch==1.1rc2 # required for PrithviMAE test

0 commit comments

Comments
 (0)