Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit d372304

Browse files
prigoyalfacebook-github-bot
authored andcommitted
Add geolocalization test to vissl (#510)
Summary: Pull Request resolved: #510 as title Reviewed By: QuentinDuval Differential Revision: D33794851 fbshipit-source-id: b2afc2ea908d21bd532ea02a8e233ca002724440
1 parent 1260872 commit d372304

File tree

4 files changed

+264
-0
lines changed

4 files changed

+264
-0
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ parameterized==0.7.4
99
scikit-learn==0.24.1
1010
submitit==1.3.3
1111
tabulate==0.8.9
12+
pandas

tools/geolocalization_test.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
import os
8+
import sys
9+
from argparse import Namespace
10+
from typing import Any, List
11+
12+
import pandas as pd
13+
import torch
14+
from hydra.experimental import compose, initialize_config_module
15+
from iopath.common.file_io import g_pathmgr
16+
from vissl.config import AttrDict
17+
from vissl.data.dataset_catalog import get_data_files
18+
from vissl.hooks import default_hook_generator
19+
from vissl.utils.checkpoint import get_checkpoint_folder
20+
from vissl.utils.distributed_launcher import launch_distributed
21+
from vissl.utils.env import set_env_vars
22+
from vissl.utils.hydra_config import convert_to_attrdict, is_hydra_available, print_cfg
23+
from vissl.utils.io import load_file, save_file
24+
from vissl.utils.logger import setup_logging, shutdown_logging
25+
26+
27+
PARTITIONIG_MAP = {
28+
"cells_50_5000": "coarse",
29+
"cells_50_2000": "middle",
30+
"cells_50_1000": "fine",
31+
}
32+
33+
34+
# Adapted from
35+
# https://github.com/TIBHannover/GeoEstimation/blob/8dfc2a96741f496587fb598d9627b294058d4c28/classification/s2_utils.py#L20 # NOQA
36+
class Partitioning:
37+
def __init__(
38+
self,
39+
csv_file: str,
40+
skiprows=2,
41+
index_col="class_label",
42+
col_class_label="hex_id",
43+
col_latitute="latitude_mean",
44+
col_longitude="longitude_mean",
45+
):
46+
"""
47+
Required information in CSV:
48+
- class_indexes from 0 to n
49+
- respective class labels i.e. hexid
50+
- latitude and longitude
51+
"""
52+
with g_pathmgr.open(csv_file, "r") as fopen:
53+
self._df = pd.read_csv(fopen, index_col=index_col, skiprows=skiprows)
54+
self._df = self._df.sort_index()
55+
56+
self._nclasses = len(self._df.index)
57+
self._col_class_label = col_class_label
58+
self._col_latitude = col_latitute
59+
self._col_longitude = col_longitude
60+
61+
# map class label (hexid) to index
62+
self._label2index = dict(
63+
zip(self._df[self._col_class_label].tolist(), list(self._df.index))
64+
)
65+
self.name = os.path.splitext(os.path.basename(csv_file))[0]
66+
self.shortname = PARTITIONIG_MAP[self.name]
67+
68+
def __len__(self):
69+
return self._nclasses
70+
71+
def __repr__(self):
72+
return f"{self.name} short: {self.shortname} n: {self._nclasses}"
73+
74+
def get_class_label(self, idx):
75+
return self._df.iloc[idx][self._col_class_label]
76+
77+
def get_lat_lng(self, idx):
78+
x = self._df.iloc[idx]
79+
return float(x[self._col_latitude]), float(x[self._col_longitude])
80+
81+
def contains(self, class_label):
82+
if class_label in self._label2index:
83+
return True
84+
return False
85+
86+
def label2index(self, class_label):
87+
try:
88+
return self._label2index[class_label]
89+
except KeyError:
90+
raise KeyError(f"unknown label {class_label} in {self}")
91+
92+
93+
# Code from:
94+
# https://github.com/TIBHannover/GeoEstimation/blob/8dfc2a96741f496587fb598d9627b294058d4c28/classification/utils_global.py#L66 # NOQA
95+
def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt):
96+
R = 6371
97+
factor_rad = 0.01745329252
98+
longitudes = factor_rad * longitudes
99+
longitudes_gt = factor_rad * longitudes_gt
100+
latitudes = factor_rad * latitudes
101+
latitudes_gt = factor_rad * latitudes_gt
102+
delta_long = longitudes_gt - longitudes
103+
delta_lat = latitudes_gt - latitudes
104+
subterm0 = torch.sin(delta_lat / 2) ** 2
105+
subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt)
106+
subterm2 = torch.sin(delta_long / 2) ** 2
107+
subterm1 = subterm1 * subterm2
108+
a = subterm0 + subterm1
109+
c = 2 * torch.asin(torch.sqrt(a))
110+
gcd = R * c
111+
return gcd
112+
113+
114+
# Code from:
115+
# https://github.com/TIBHannover/GeoEstimation/blob/8dfc2a96741f496587fb598d9627b294058d4c28/classification/utils_global.py#L66 # NOQA
116+
def gcd_threshold_eval(gc_dists, thresholds):
117+
# calculate accuracy for given gcd thresolds
118+
results = {}
119+
for thres in thresholds:
120+
results[thres] = torch.true_divide(
121+
torch.sum(gc_dists <= thres), len(gc_dists)
122+
).item()
123+
return results
124+
125+
126+
def geolocalization_test(cfg: AttrDict, layer_name: str = "heads", topk: int = 1):
127+
output_dir = get_checkpoint_folder(cfg)
128+
logging.info(f"Output dir: {output_dir} ...")
129+
130+
############################################################################
131+
# Step 1: Load the mapping file and partition it
132+
# Also load the test images and targets (latitude/longitude)
133+
# lastly, load the model predictions
134+
logging.info(
135+
f"Loading the label partitioning file: {cfg.GEO_LOCALIZATION.TRAIN_LABEL_MAPPING}"
136+
)
137+
partitioning = Partitioning(cfg.GEO_LOCALIZATION.TRAIN_LABEL_MAPPING)
138+
139+
data_files, label_files = get_data_files("TEST", cfg.DATA)
140+
test_image_paths = load_file(data_files[0])
141+
target_lat_long = load_file(label_files[0])
142+
logging.info(
143+
f"Loaded val image paths: {test_image_paths.shape}, "
144+
f"ground truth latitude/longitude: {target_lat_long.shape}"
145+
)
146+
147+
prediction_image_indices_filepath = f"{output_dir}/rank0_test_{layer_name}_inds.npy"
148+
predictions_filepath = f"{output_dir}/rank0_test_{layer_name}_predictions.npy"
149+
predictions = load_file(predictions_filepath)
150+
predictions_inds = load_file(prediction_image_indices_filepath)
151+
logging.info(
152+
f"Loaded predictions: {predictions.shape}, inds: {predictions_inds.shape}"
153+
)
154+
155+
############################################################################
156+
# Step 2: Convert the predicted classes to latitude/longitude and compute
157+
# accuracy at different km thresholds.
158+
gt_latitudes, gt_longitudes, predicted_lats, predicted_longs = [], [], [], []
159+
output_metadata = {}
160+
num_images = len(test_image_paths)
161+
num_images = min(num_images, len(predictions))
162+
for idx in range(num_images):
163+
img_index = predictions_inds[idx]
164+
inp_img_path = test_image_paths[img_index]
165+
gt_latitude = float(target_lat_long[img_index][0])
166+
gt_longitude = float(target_lat_long[img_index][1])
167+
pred_cls = int(predictions[idx][:topk])
168+
pred_lat, pred_long = partitioning.get_lat_lng(pred_cls)
169+
output_metadata[inp_img_path] = {
170+
"target_lat": gt_latitude,
171+
"target_long": gt_longitude,
172+
"pred_lat": pred_lat,
173+
"pred_long": pred_long,
174+
"pred_cls": pred_cls,
175+
}
176+
gt_latitudes.append(gt_latitude)
177+
gt_longitudes.append(gt_longitude)
178+
predicted_lats.append(pred_lat)
179+
predicted_longs.append(pred_long)
180+
181+
predicted_lats = torch.tensor(predicted_lats, dtype=torch.float)
182+
predicted_longs = torch.tensor(predicted_longs, dtype=torch.float)
183+
gt_latitudes = torch.tensor(gt_latitudes, dtype=torch.float)
184+
gt_longitudes = torch.tensor(gt_longitudes, dtype=torch.float)
185+
distances = vectorized_gc_distance(
186+
predicted_lats,
187+
predicted_longs,
188+
gt_latitudes,
189+
gt_longitudes,
190+
)
191+
192+
# accuracy for all distances (in km)
193+
acc_dict = gcd_threshold_eval(
194+
distances, thresholds=cfg.GEO_LOCALIZATION.ACC_KM_THRESHOLDS
195+
)
196+
gcd_dict = {}
197+
for gcd_thres, acc in acc_dict.items():
198+
gcd_dict[f"{gcd_thres}"] = round(acc * 100.0, 4)
199+
logging.info(f"acc dist in percentage: {gcd_dict}")
200+
save_file(
201+
output_metadata,
202+
f"{output_dir}/output_metadata_predictions.json",
203+
append_to_json=False,
204+
)
205+
save_file(
206+
gcd_dict,
207+
f"{output_dir}/metrics.json",
208+
append_to_json=False,
209+
)
210+
return output_metadata, acc_dict
211+
212+
213+
def main(args: Namespace, config: AttrDict):
214+
# setup logging
215+
setup_logging(__name__)
216+
217+
# print the coniguration used
218+
print_cfg(config)
219+
220+
# setup the environment variables
221+
set_env_vars(local_rank=0, node_id=0, cfg=config)
222+
223+
# extract the label predictions on the test set
224+
launch_distributed(
225+
config,
226+
args.node_id,
227+
engine_name="extract_label_predictions",
228+
hook_generator=default_hook_generator,
229+
)
230+
231+
geolocalization_test(config)
232+
233+
# close the logging streams including the filehandlers
234+
shutdown_logging()
235+
236+
237+
def hydra_main(overrides: List[Any]):
238+
with initialize_config_module(config_module="vissl.config"):
239+
cfg = compose("defaults", overrides=overrides)
240+
args, config = convert_to_attrdict(cfg)
241+
main(args, config)
242+
243+
244+
if __name__ == "__main__":
245+
overrides = sys.argv[1:]
246+
assert is_hydra_available(), "Make sure to install hydra"
247+
hydra_main(overrides=overrides)

vissl/config/defaults.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,3 +1464,15 @@ config:
14641464
FEATURES:
14651465
# If empty, will run the extract features, if not, use the path to find the features
14661466
PATH: ""
1467+
1468+
# ----------------------------------------------------------------------------------- #
1469+
# Geo Localization (benchmark)
1470+
# ----------------------------------------------------------------------------------- #
1471+
GEO_LOCALIZATION:
1472+
# Benchmark Details:
1473+
# Step1: Take a model and extract the model label predictions on test data.
1474+
# Step2: Find the corresponding latitude/longitude predictions using the json mapping for train set.
1475+
# Step3: find the ground truth latitute/longitude and compute the metric following the code.
1476+
TRAIN_LABEL_MAPPING: "/path/to/.json"
1477+
# [1, 25, 200, 750, 2500] -> [street, city, region, country, continent]
1478+
ACC_KM_THRESHOLDS: [1, 25, 200, 750, 2500]

vissl/utils/io.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from urllib.parse import urlparse
1313

1414
import numpy as np
15+
import pandas as pd
1516
import yaml
1617
from iopath.common.download import download
1718
from iopath.common.file_io import g_pathmgr, file_lock
@@ -130,6 +131,9 @@ def load_file(filename, mmap_mode=None):
130131
elif file_ext == ".yaml":
131132
with g_pathmgr.open(filename, "r") as fopen:
132133
data = yaml.load(fopen, Loader=yaml.FullLoader)
134+
elif file_ext == ".csv":
135+
with g_pathmgr.open(filename, "r") as fopen:
136+
data = pd.read_csv(fopen)
133137
else:
134138
raise Exception(f"Reading from {file_ext} is not supported yet")
135139
return data

0 commit comments

Comments
 (0)