1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
- """
4
- This is a demo script showing how to use the
5
- PrithviGeospatialMAE model with vLLM
6
- This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
7
-
8
- Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
9
-
10
- The requirements for running this script are:
11
- - Installing [terratorch, albumentations, rasterio] in your python environment
12
- - downloading the model weights in a 'model' folder local to the script
13
- (temporary measure until the proper config.json file is uploaded to HF)
14
- - download an input example image (India_900498_S2Hand.tif) and place it in
15
- the same folder with the script (or specify with the --data_file argument)
16
-
17
- Run the example:
18
- python prithvi_geospatial_mae.py
19
-
20
- """ # noqa: E501
21
-
22
3
import argparse
23
4
import datetime
24
5
import os
6
+ import re
25
7
from typing import Union
26
8
27
9
import albumentations
28
10
import numpy as np
29
11
import rasterio
30
- import regex as re
31
12
import torch
32
13
from einops import rearrange
33
14
from terratorch .datamodules import Sen1Floods11NonGeoDataModule
34
15
35
16
from vllm import LLM
36
17
18
+ torch .set_default_dtype (torch .float16 )
19
+
37
20
NO_DATA = - 9999
38
21
NO_DATA_FLOAT = 0.0001
39
22
OFFSET = 0
40
23
PERCENTILE = 99
41
24
42
- model_config = """{
43
- "architectures": ["PrithviGeoSpatialMAE"],
44
- "num_classes": 0,
45
- "pretrained_cfg": {
46
- "task_args": {
47
- "task": "SemanticSegmentationTask",
48
- "model_factory": "EncoderDecoderFactory",
49
- "loss": "ce",
50
- "ignore_index": -1,
51
- "lr": 0.001,
52
- "freeze_backbone": false,
53
- "freeze_decoder": false,
54
- "plot_on_val": 10,
55
- "optimizer": "AdamW",
56
- "scheduler": "CosineAnnealingLR"
57
- },
58
- "model_args": {
59
- "backbone_pretrained": false,
60
- "backbone": "prithvi_eo_v2_300_tl",
61
- "decoder": "UperNetDecoder",
62
- "decoder_channels": 256,
63
- "decoder_scale_modules": true,
64
- "num_classes": 2,
65
- "rescale": true,
66
- "backbone_bands": [
67
- "BLUE",
68
- "GREEN",
69
- "RED",
70
- "NIR_NARROW",
71
- "SWIR_1",
72
- "SWIR_2"
73
- ],
74
- "head_dropout": 0.1,
75
- "necks": [
76
- {
77
- "name": "SelectIndices",
78
- "indices": [
79
- 5,
80
- 11,
81
- 17,
82
- 23
83
- ]
84
- },
85
- {
86
- "name": "ReshapeTokensToImage"
87
- }
88
- ]
89
- },
90
- "optimizer_params" : {
91
- "lr": 5.0e-05,
92
- "betas": [0.9, 0.999],
93
- "eps": [1.0e-08],
94
- "weight_decay": 0.05,
95
- "amsgrad": false,
96
- "maximize": false,
97
- "capturable": false,
98
- "differentiable": false
99
- },
100
- "scheduler_params" : {
101
- "T_max": 50,
102
- "eta_min": 0,
103
- "last_epoch": -1,
104
- "verbose": "deprecated"
105
- }
106
- },
107
-
108
-
109
- "torch_dtype": "float32"
110
- }
111
- """
112
-
113
- # Temporarily creating the "config.json" for the model.
114
- # This is going to disappear once the correct config.json is available on HF
115
- with open (
116
- os .path .join (os .path .dirname (__file__ ), "./model/config.json" ), "w"
117
- ) as config_file :
118
- config_file .write (model_config )
119
-
120
25
datamodule_config = {
121
26
"bands" : ["BLUE" , "GREEN" , "RED" , "NIR_NARROW" , "SWIR_1" , "SWIR_2" ],
122
27
"batch_size" : 16 ,
138
43
139
44
140
45
class PrithviMAE :
141
- def __init__ (self ):
142
- print ("Initializing PrithviMAE model" )
143
- self .llm = LLM (
144
- model = os .path .join (os .path .dirname (__file__ ), "./model" ),
145
- skip_tokenizer_init = True ,
146
- dtype = "float32" ,
46
+ def __init__ (self , model ):
47
+ self .model = LLM (
48
+ model = model , skip_tokenizer_init = True , dtype = "float16" , enforce_eager = True
147
49
)
148
50
149
51
def run (self , input_data , location_coords ):
150
- print ("################ Running inference on vLLM ##############" )
151
52
# merge the inputs into one data structure
53
+ if input_data is not None and input_data .dtype == torch .float32 :
54
+ input_data = input_data .to (torch .float16 )
55
+ input_data = input_data [0 ]
56
+
152
57
mm_data = {
153
- "pixel_values" : torch .empty (0 ) if input_data is None else input_data ,
154
- "location_coords" : torch .empty (0 )
155
- if location_coords is None
156
- else location_coords ,
58
+ "pixel_values" : input_data ,
59
+ "location_coords" : location_coords ,
157
60
}
158
61
159
62
prompt = {"prompt_token_ids" : [1 ], "multi_modal_data" : mm_data }
160
-
161
- outputs = self .llm .encode (prompt , use_tqdm = False )
162
- print ("################ Inference done (it took seconds) ##############" )
63
+ outputs = self .model .encode (prompt , use_tqdm = False )
163
64
164
65
return outputs [0 ].outputs .data
165
66
@@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
181
82
"""
182
83
Args:
183
84
orig_img: torch.Tensor representing original image (reference)
184
- with shape = (bands, H, W).
85
+ with shape = (bands, H, W).
185
86
channels: list of indices representing RGB channels.
186
87
187
88
Returns:
188
- torch.Tensor with shape (num_channels, height, width) for original image
89
+ torch.Tensor with shape (num_channels, height, width)
90
+ for original image
189
91
"""
190
92
191
93
orig_img = orig_img [channels , ...]
@@ -260,10 +162,10 @@ def load_example(
260
162
261
163
Args:
262
164
file_paths: list of file paths .
263
- mean: list containing mean values for each band in the images
264
- in *file_paths*.
265
- std: list containing std values for each band in the images
266
- in *file_paths*.
165
+ mean: list containing mean values for each band in the
166
+ images in *file_paths*.
167
+ std: list containing std values for each band in the
168
+ images in *file_paths*.
267
169
268
170
Returns:
269
171
np.array containing created example
@@ -308,7 +210,7 @@ def load_example(
308
210
print (f"Could not extract timestamp for { file } ({ e } )" )
309
211
310
212
imgs = np .stack (imgs , axis = 0 ) # num_frames, H, W, C
311
- imgs = np .moveaxis (imgs , - 1 , 0 ).astype ("float32" )
213
+ imgs = np .moveaxis (imgs , - 1 , 0 ).astype ("float32" ) # C, num_frames, H, W
312
214
imgs = np .expand_dims (imgs , axis = 0 ) # add batch di
313
215
314
216
return imgs , temporal_coords , location_coords , metas
@@ -332,8 +234,10 @@ def run_model(
332
234
)
333
235
334
236
# Build sliding window
237
+
335
238
batch_size = 1
336
- batch = torch .tensor (input_data , device = "cpu" )
239
+ # batch = torch.tensor(input_data, device="cpu")
240
+ batch = torch .tensor (input_data )
337
241
windows = batch .unfold (3 , img_size , img_size ).unfold (4 , img_size , img_size )
338
242
h1 , w1 = windows .shape [3 :5 ]
339
243
windows = rearrange (
@@ -344,34 +248,24 @@ def run_model(
344
248
num_batches = windows .shape [0 ] // batch_size if windows .shape [0 ] > batch_size else 1
345
249
windows = torch .tensor_split (windows , num_batches , dim = 0 )
346
250
347
- device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
348
-
349
251
if temporal_coords :
350
- temporal_coords = torch .tensor (temporal_coords , device = device ).unsqueeze (0 )
252
+ temporal_coords = torch .tensor (temporal_coords ).unsqueeze (0 )
351
253
else :
352
254
temporal_coords = None
353
255
if location_coords :
354
- location_coords = torch .tensor (location_coords [0 ], device = device ).unsqueeze (0 )
256
+ location_coords = torch .tensor (location_coords [0 ]).unsqueeze (0 )
355
257
else :
356
258
location_coords = None
357
259
358
- # Run model
260
+ # Run Prithvi-EO-V2-300M-TL-Sen1Floods11
359
261
pred_imgs = []
360
262
for x in windows :
361
263
# Apply standardization
362
264
x = datamodule .test_transform (image = x .squeeze ().numpy ().transpose (1 , 2 , 0 ))
363
265
x = datamodule .aug (x )["image" ]
364
266
365
267
with torch .no_grad ():
366
- x = x .to (device )
367
268
pred = model .run (x , location_coords = location_coords )
368
- if lightning_model :
369
- pred_lightning = lightning_model (
370
- x , temporal_coords = temporal_coords , location_coords = location_coords
371
- )
372
- pred_lightning = pred_lightning .output .detach ().cpu ()
373
- if not torch .equal (pred , pred_lightning ):
374
- print ("Inference output is not equal" )
375
269
y_hat = pred .argmax (dim = 1 )
376
270
377
271
y_hat = torch .nn .functional .interpolate (
@@ -403,52 +297,18 @@ def run_model(
403
297
return pred_imgs
404
298
405
299
406
- def parse_args ():
407
- parser = argparse .ArgumentParser ("MAE run inference" , add_help = False )
408
-
409
- parser .add_argument (
410
- "--data_file" ,
411
- type = str ,
412
- default = "./India_900498_S2Hand.tif" ,
413
- help = "Path to the file." ,
414
- )
415
- parser .add_argument (
416
- "--output_dir" ,
417
- type = str ,
418
- default = "output" ,
419
- help = "Path to the directory where to save outputs." ,
420
- )
421
- parser .add_argument (
422
- "--input_indices" ,
423
- default = [1 , 2 , 3 , 8 , 11 , 12 ],
424
- type = int ,
425
- nargs = "+" ,
426
- help = "0-based indices of the six Prithvi channels to be selected from the "
427
- "input. By default selects [1,2,3,8,11,12] for S2L1C data." ,
428
- )
429
- parser .add_argument (
430
- "--rgb_outputs" ,
431
- action = "store_true" ,
432
- help = "If present, output files will only contain RGB channels. "
433
- "Otherwise, all bands will be saved." ,
434
- )
435
-
436
-
437
300
def main (
438
301
data_file : str ,
302
+ model : str ,
439
303
output_dir : str ,
440
304
rgb_outputs : bool ,
441
305
input_indices : list [int ] = None ,
442
306
):
443
307
os .makedirs (output_dir , exist_ok = True )
444
308
445
- # Load model ---------------------------------------------------------------
446
-
447
- model_obj = PrithviMAE ()
309
+ model_obj = PrithviMAE (model = model )
448
310
datamodule = generate_datamodule ()
449
- img_size = 256 # Size of Sen1Floods11
450
-
451
- # Loading data -------------------------------------------------------------
311
+ img_size = 512 # Size of Sen1Floods11
452
312
453
313
input_data , temporal_coords , location_coords , meta_data = load_example (
454
314
file_paths = [data_file ],
@@ -460,16 +320,13 @@ def main(
460
320
if input_data .mean () > 1 :
461
321
input_data = input_data / 10000 # Convert to range 0-1
462
322
463
- # Running model ------------------------------------------------------------
464
-
465
323
channels = [
466
324
datamodule_config ["bands" ].index (b ) for b in ["RED" , "GREEN" , "BLUE" ]
467
325
] # BGR -> RGB
468
326
469
327
pred = run_model (
470
328
input_data , temporal_coords , location_coords , model_obj , datamodule , img_size
471
329
)
472
-
473
330
# Save pred
474
331
meta_data .update (count = 1 , dtype = "uint8" , compress = "lzw" , nodata = 0 )
475
332
pred_file = os .path .join (
@@ -487,6 +344,7 @@ def main(
487
344
orig_img = torch .Tensor (input_data [0 , :, 0 , ...]),
488
345
channels = channels ,
489
346
)
347
+ rgb_orig = rgb_orig .to (torch .float32 )
490
348
491
349
pred [pred == 0.0 ] = np .nan
492
350
img_pred = rgb_orig * 0.7 + pred * 0.3
@@ -503,9 +361,10 @@ def main(
503
361
504
362
# Save image rgb
505
363
if rgb_outputs :
364
+ name_suffix = os .path .splitext (os .path .basename (data_file ))[0 ]
506
365
rgb_file = os .path .join (
507
366
output_dir ,
508
- f"original_rgb_{ os . path . splitext ( os . path . basename ( data_file ))[ 0 ] } .tiff" ,
367
+ f"original_rgb_{ name_suffix } .tiff" ,
509
368
)
510
369
save_geotiff (
511
370
image = _convert_np_uint8 (rgb_orig ),
@@ -515,6 +374,42 @@ def main(
515
374
516
375
517
376
if __name__ == "__main__" :
518
- args = parse_args ()
377
+ parser = argparse .ArgumentParser ("MAE run inference" , add_help = False )
378
+
379
+ parser .add_argument (
380
+ "--data_file" ,
381
+ type = str ,
382
+ default = "./India_900498_S2Hand.tif" ,
383
+ help = "Path to the file." ,
384
+ )
385
+ parser .add_argument (
386
+ "--model" ,
387
+ type = str ,
388
+ default = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" ,
389
+ help = "Path to a checkpoint file to load from." ,
390
+ )
391
+ parser .add_argument (
392
+ "--output_dir" ,
393
+ type = str ,
394
+ default = "output" ,
395
+ help = "Path to the directory where to save outputs." ,
396
+ )
397
+ parser .add_argument (
398
+ "--input_indices" ,
399
+ default = [1 , 2 , 3 , 8 , 11 , 12 ],
400
+ type = int ,
401
+ nargs = "+" ,
402
+ help = """
403
+ 0-based indices of the six Prithvi channels to be selected from the input.
404
+ By default selects [1,2,3,8,11,12] for S2L1C data.
405
+ """ ,
406
+ )
407
+ parser .add_argument (
408
+ "--rgb_outputs" ,
409
+ action = "store_true" ,
410
+ help = "If present, output files will only contain RGB channels. "
411
+ "Otherwise, all bands will be saved." ,
412
+ )
413
+ args = parser .parse_args ()
519
414
520
415
main (** vars (args ))
0 commit comments