|  | 
|  | 1 | +# Copyright (c) Facebook, Inc. and its affiliates. | 
|  | 2 | + | 
|  | 3 | +# This source code is licensed under the MIT license found in the | 
|  | 4 | +# LICENSE file in the root directory of this source tree. | 
|  | 5 | + | 
|  | 6 | +import logging | 
|  | 7 | +import os | 
|  | 8 | +import sys | 
|  | 9 | +from argparse import Namespace | 
|  | 10 | +from typing import Any, List | 
|  | 11 | + | 
|  | 12 | +import pandas as pd | 
|  | 13 | +import torch | 
|  | 14 | +from hydra.experimental import compose, initialize_config_module | 
|  | 15 | +from iopath.common.file_io import g_pathmgr | 
|  | 16 | +from vissl.config import AttrDict | 
|  | 17 | +from vissl.data.dataset_catalog import get_data_files | 
|  | 18 | +from vissl.hooks import default_hook_generator | 
|  | 19 | +from vissl.utils.checkpoint import get_checkpoint_folder | 
|  | 20 | +from vissl.utils.distributed_launcher import launch_distributed | 
|  | 21 | +from vissl.utils.env import set_env_vars | 
|  | 22 | +from vissl.utils.hydra_config import convert_to_attrdict, is_hydra_available, print_cfg | 
|  | 23 | +from vissl.utils.io import load_file, save_file | 
|  | 24 | +from vissl.utils.logger import setup_logging, shutdown_logging | 
|  | 25 | + | 
|  | 26 | + | 
|  | 27 | +PARTITIONIG_MAP = { | 
|  | 28 | +    "cells_50_5000": "coarse", | 
|  | 29 | +    "cells_50_2000": "middle", | 
|  | 30 | +    "cells_50_1000": "fine", | 
|  | 31 | +} | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +# Adapted from | 
|  | 35 | +# https://github.com/TIBHannover/GeoEstimation/blob/8dfc2a96741f496587fb598d9627b294058d4c28/classification/s2_utils.py#L20 # NOQA | 
|  | 36 | +class Partitioning: | 
|  | 37 | +    def __init__( | 
|  | 38 | +        self, | 
|  | 39 | +        csv_file: str, | 
|  | 40 | +        skiprows=2, | 
|  | 41 | +        index_col="class_label", | 
|  | 42 | +        col_class_label="hex_id", | 
|  | 43 | +        col_latitute="latitude_mean", | 
|  | 44 | +        col_longitude="longitude_mean", | 
|  | 45 | +    ): | 
|  | 46 | +        """ | 
|  | 47 | +        Required information in CSV: | 
|  | 48 | +            - class_indexes from 0 to n | 
|  | 49 | +            - respective class labels i.e. hexid | 
|  | 50 | +            - latitude and longitude | 
|  | 51 | +        """ | 
|  | 52 | +        with g_pathmgr.open(csv_file, "r") as fopen: | 
|  | 53 | +            self._df = pd.read_csv(fopen, index_col=index_col, skiprows=skiprows) | 
|  | 54 | +        self._df = self._df.sort_index() | 
|  | 55 | + | 
|  | 56 | +        self._nclasses = len(self._df.index) | 
|  | 57 | +        self._col_class_label = col_class_label | 
|  | 58 | +        self._col_latitude = col_latitute | 
|  | 59 | +        self._col_longitude = col_longitude | 
|  | 60 | + | 
|  | 61 | +        # map class label (hexid) to index | 
|  | 62 | +        self._label2index = dict( | 
|  | 63 | +            zip(self._df[self._col_class_label].tolist(), list(self._df.index)) | 
|  | 64 | +        ) | 
|  | 65 | +        self.name = os.path.splitext(os.path.basename(csv_file))[0] | 
|  | 66 | +        self.shortname = PARTITIONIG_MAP[self.name] | 
|  | 67 | + | 
|  | 68 | +    def __len__(self): | 
|  | 69 | +        return self._nclasses | 
|  | 70 | + | 
|  | 71 | +    def __repr__(self): | 
|  | 72 | +        return f"{self.name} short: {self.shortname} n: {self._nclasses}" | 
|  | 73 | + | 
|  | 74 | +    def get_class_label(self, idx): | 
|  | 75 | +        return self._df.iloc[idx][self._col_class_label] | 
|  | 76 | + | 
|  | 77 | +    def get_lat_lng(self, idx): | 
|  | 78 | +        x = self._df.iloc[idx] | 
|  | 79 | +        return float(x[self._col_latitude]), float(x[self._col_longitude]) | 
|  | 80 | + | 
|  | 81 | +    def contains(self, class_label): | 
|  | 82 | +        if class_label in self._label2index: | 
|  | 83 | +            return True | 
|  | 84 | +        return False | 
|  | 85 | + | 
|  | 86 | +    def label2index(self, class_label): | 
|  | 87 | +        try: | 
|  | 88 | +            return self._label2index[class_label] | 
|  | 89 | +        except KeyError: | 
|  | 90 | +            raise KeyError(f"unknown label {class_label} in {self}") | 
|  | 91 | + | 
|  | 92 | + | 
|  | 93 | +# Code from: | 
|  | 94 | +# https://github.com/TIBHannover/GeoEstimation/blob/8dfc2a96741f496587fb598d9627b294058d4c28/classification/utils_global.py#L66 # NOQA | 
|  | 95 | +def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt): | 
|  | 96 | +    R = 6371 | 
|  | 97 | +    factor_rad = 0.01745329252 | 
|  | 98 | +    longitudes = factor_rad * longitudes | 
|  | 99 | +    longitudes_gt = factor_rad * longitudes_gt | 
|  | 100 | +    latitudes = factor_rad * latitudes | 
|  | 101 | +    latitudes_gt = factor_rad * latitudes_gt | 
|  | 102 | +    delta_long = longitudes_gt - longitudes | 
|  | 103 | +    delta_lat = latitudes_gt - latitudes | 
|  | 104 | +    subterm0 = torch.sin(delta_lat / 2) ** 2 | 
|  | 105 | +    subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt) | 
|  | 106 | +    subterm2 = torch.sin(delta_long / 2) ** 2 | 
|  | 107 | +    subterm1 = subterm1 * subterm2 | 
|  | 108 | +    a = subterm0 + subterm1 | 
|  | 109 | +    c = 2 * torch.asin(torch.sqrt(a)) | 
|  | 110 | +    gcd = R * c | 
|  | 111 | +    return gcd | 
|  | 112 | + | 
|  | 113 | + | 
|  | 114 | +# Code from: | 
|  | 115 | +# https://github.com/TIBHannover/GeoEstimation/blob/8dfc2a96741f496587fb598d9627b294058d4c28/classification/utils_global.py#L66 # NOQA | 
|  | 116 | +def gcd_threshold_eval(gc_dists, thresholds): | 
|  | 117 | +    # calculate accuracy for given gcd thresolds | 
|  | 118 | +    results = {} | 
|  | 119 | +    for thres in thresholds: | 
|  | 120 | +        results[thres] = torch.true_divide( | 
|  | 121 | +            torch.sum(gc_dists <= thres), len(gc_dists) | 
|  | 122 | +        ).item() | 
|  | 123 | +    return results | 
|  | 124 | + | 
|  | 125 | + | 
|  | 126 | +def geolocalization_test(cfg: AttrDict, layer_name: str = "heads", topk: int = 1): | 
|  | 127 | +    output_dir = get_checkpoint_folder(cfg) | 
|  | 128 | +    logging.info(f"Output dir: {output_dir} ...") | 
|  | 129 | + | 
|  | 130 | +    ############################################################################ | 
|  | 131 | +    # Step 1: Load the mapping file and partition it | 
|  | 132 | +    # Also load the test images and targets (latitude/longitude) | 
|  | 133 | +    # lastly, load the model predictions | 
|  | 134 | +    logging.info( | 
|  | 135 | +        f"Loading the label partitioning file: {cfg.GEO_LOCALIZATION.TRAIN_LABEL_MAPPING}" | 
|  | 136 | +    ) | 
|  | 137 | +    partitioning = Partitioning(cfg.GEO_LOCALIZATION.TRAIN_LABEL_MAPPING) | 
|  | 138 | + | 
|  | 139 | +    data_files, label_files = get_data_files("TEST", cfg.DATA) | 
|  | 140 | +    test_image_paths = load_file(data_files[0]) | 
|  | 141 | +    target_lat_long = load_file(label_files[0]) | 
|  | 142 | +    logging.info( | 
|  | 143 | +        f"Loaded val image paths: {test_image_paths.shape}, " | 
|  | 144 | +        f"ground truth latitude/longitude: {target_lat_long.shape}" | 
|  | 145 | +    ) | 
|  | 146 | + | 
|  | 147 | +    prediction_image_indices_filepath = f"{output_dir}/rank0_test_{layer_name}_inds.npy" | 
|  | 148 | +    predictions_filepath = f"{output_dir}/rank0_test_{layer_name}_predictions.npy" | 
|  | 149 | +    predictions = load_file(predictions_filepath) | 
|  | 150 | +    predictions_inds = load_file(prediction_image_indices_filepath) | 
|  | 151 | +    logging.info( | 
|  | 152 | +        f"Loaded predictions: {predictions.shape}, inds: {predictions_inds.shape}" | 
|  | 153 | +    ) | 
|  | 154 | + | 
|  | 155 | +    ############################################################################ | 
|  | 156 | +    # Step 2: Convert the predicted classes to latitude/longitude and compute | 
|  | 157 | +    # accuracy at different km thresholds. | 
|  | 158 | +    gt_latitudes, gt_longitudes, predicted_lats, predicted_longs = [], [], [], [] | 
|  | 159 | +    output_metadata = {} | 
|  | 160 | +    num_images = len(test_image_paths) | 
|  | 161 | +    num_images = min(num_images, len(predictions)) | 
|  | 162 | +    for idx in range(num_images): | 
|  | 163 | +        img_index = predictions_inds[idx] | 
|  | 164 | +        inp_img_path = test_image_paths[img_index] | 
|  | 165 | +        gt_latitude = float(target_lat_long[img_index][0]) | 
|  | 166 | +        gt_longitude = float(target_lat_long[img_index][1]) | 
|  | 167 | +        pred_cls = int(predictions[idx][:topk]) | 
|  | 168 | +        pred_lat, pred_long = partitioning.get_lat_lng(pred_cls) | 
|  | 169 | +        output_metadata[inp_img_path] = { | 
|  | 170 | +            "target_lat": gt_latitude, | 
|  | 171 | +            "target_long": gt_longitude, | 
|  | 172 | +            "pred_lat": pred_lat, | 
|  | 173 | +            "pred_long": pred_long, | 
|  | 174 | +            "pred_cls": pred_cls, | 
|  | 175 | +        } | 
|  | 176 | +        gt_latitudes.append(gt_latitude) | 
|  | 177 | +        gt_longitudes.append(gt_longitude) | 
|  | 178 | +        predicted_lats.append(pred_lat) | 
|  | 179 | +        predicted_longs.append(pred_long) | 
|  | 180 | + | 
|  | 181 | +    predicted_lats = torch.tensor(predicted_lats, dtype=torch.float) | 
|  | 182 | +    predicted_longs = torch.tensor(predicted_longs, dtype=torch.float) | 
|  | 183 | +    gt_latitudes = torch.tensor(gt_latitudes, dtype=torch.float) | 
|  | 184 | +    gt_longitudes = torch.tensor(gt_longitudes, dtype=torch.float) | 
|  | 185 | +    distances = vectorized_gc_distance( | 
|  | 186 | +        predicted_lats, | 
|  | 187 | +        predicted_longs, | 
|  | 188 | +        gt_latitudes, | 
|  | 189 | +        gt_longitudes, | 
|  | 190 | +    ) | 
|  | 191 | + | 
|  | 192 | +    # accuracy for all distances (in km) | 
|  | 193 | +    acc_dict = gcd_threshold_eval( | 
|  | 194 | +        distances, thresholds=cfg.GEO_LOCALIZATION.ACC_KM_THRESHOLDS | 
|  | 195 | +    ) | 
|  | 196 | +    gcd_dict = {} | 
|  | 197 | +    for gcd_thres, acc in acc_dict.items(): | 
|  | 198 | +        gcd_dict[f"{gcd_thres}"] = round(acc * 100.0, 4) | 
|  | 199 | +    logging.info(f"acc dist in percentage: {gcd_dict}") | 
|  | 200 | +    save_file( | 
|  | 201 | +        output_metadata, | 
|  | 202 | +        f"{output_dir}/output_metadata_predictions.json", | 
|  | 203 | +        append_to_json=False, | 
|  | 204 | +    ) | 
|  | 205 | +    save_file( | 
|  | 206 | +        gcd_dict, | 
|  | 207 | +        f"{output_dir}/metrics.json", | 
|  | 208 | +        append_to_json=False, | 
|  | 209 | +    ) | 
|  | 210 | +    return output_metadata, acc_dict | 
|  | 211 | + | 
|  | 212 | + | 
|  | 213 | +def main(args: Namespace, config: AttrDict): | 
|  | 214 | +    # setup logging | 
|  | 215 | +    setup_logging(__name__) | 
|  | 216 | + | 
|  | 217 | +    # print the coniguration used | 
|  | 218 | +    print_cfg(config) | 
|  | 219 | + | 
|  | 220 | +    # setup the environment variables | 
|  | 221 | +    set_env_vars(local_rank=0, node_id=0, cfg=config) | 
|  | 222 | + | 
|  | 223 | +    # extract the label predictions on the test set | 
|  | 224 | +    launch_distributed( | 
|  | 225 | +        config, | 
|  | 226 | +        args.node_id, | 
|  | 227 | +        engine_name="extract_label_predictions", | 
|  | 228 | +        hook_generator=default_hook_generator, | 
|  | 229 | +    ) | 
|  | 230 | + | 
|  | 231 | +    geolocalization_test(config) | 
|  | 232 | + | 
|  | 233 | +    # close the logging streams including the filehandlers | 
|  | 234 | +    shutdown_logging() | 
|  | 235 | + | 
|  | 236 | + | 
|  | 237 | +def hydra_main(overrides: List[Any]): | 
|  | 238 | +    with initialize_config_module(config_module="vissl.config"): | 
|  | 239 | +        cfg = compose("defaults", overrides=overrides) | 
|  | 240 | +    args, config = convert_to_attrdict(cfg) | 
|  | 241 | +    main(args, config) | 
|  | 242 | + | 
|  | 243 | + | 
|  | 244 | +if __name__ == "__main__": | 
|  | 245 | +    overrides = sys.argv[1:] | 
|  | 246 | +    assert is_hydra_available(), "Make sure to install hydra" | 
|  | 247 | +    hydra_main(overrides=overrides) | 
0 commit comments