|
1 | 1 | # setting of data generation |
2 | 2 |
|
| 3 | +import os |
3 | 4 | import pickle as pkl |
4 | 5 | import random |
5 | 6 | import sys |
|
10 | 11 | import networkx as nx |
11 | 12 | import numpy as np |
12 | 13 | import pandas as pd |
| 14 | +import requests |
13 | 15 | import scipy.sparse as sp |
14 | 16 | import torch |
15 | 17 | import torch_geometric |
@@ -178,6 +180,31 @@ def NC_parse_index_file(filename: str) -> list: |
178 | 180 | return index |
179 | 181 |
|
180 | 182 |
|
| 183 | +def download_file_from_github(url: str, save_path: str): |
| 184 | + """ |
| 185 | + Downloads a file from a GitHub URL and saves it to a specified local path. |
| 186 | + Note |
| 187 | + ---- |
| 188 | + - The function downloads files in chunks to handle large files efficiently. |
| 189 | + - If the file already exists at `save_path`, it will not be downloaded again. |
| 190 | + """ |
| 191 | + if not os.path.exists(save_path): |
| 192 | + print(f"Downloading {url} to {save_path}...") |
| 193 | + response = requests.get(url, stream=True) |
| 194 | + if response.status_code == 200: |
| 195 | + with open(save_path, "wb") as f: |
| 196 | + for chunk in response.iter_content(chunk_size=1024): |
| 197 | + if chunk: |
| 198 | + f.write(chunk) |
| 199 | + print(f"Downloaded {save_path}") |
| 200 | + else: |
| 201 | + raise Exception( |
| 202 | + f"Failed to download {url}. HTTP Status Code: {response.status_code}" |
| 203 | + ) |
| 204 | + else: |
| 205 | + print(f"File already exists: {save_path}") |
| 206 | + |
| 207 | + |
181 | 208 | def NC_load_data(dataset_str: str) -> tuple: |
182 | 209 | """ |
183 | 210 | Loads input data from 'gcn/data' directory and processes these datasets into a format |
@@ -217,21 +244,38 @@ def NC_load_data(dataset_str: str) -> tuple: |
217 | 244 | """ |
218 | 245 | if dataset_str in ["cora", "citeseer", "pubmed"]: |
219 | 246 | # download dataset from torch_geometric |
220 | | - dataset = torch_geometric.datasets.Planetoid("./data", dataset_str) |
221 | | - names = ["x", "y", "tx", "ty", "allx", "ally", "graph"] |
| 247 | + BASE_URL = "https://github.com/kimiyoung/planetoid/raw/master/data" |
| 248 | + DATA_DIR = f"./data/{dataset_str}/raw/" |
| 249 | + os.makedirs(DATA_DIR, exist_ok=True) |
| 250 | + |
| 251 | + filenames = [ |
| 252 | + f"ind.{dataset_str}.x", |
| 253 | + f"ind.{dataset_str}.tx", |
| 254 | + f"ind.{dataset_str}.allx", |
| 255 | + f"ind.{dataset_str}.y", |
| 256 | + f"ind.{dataset_str}.ty", |
| 257 | + f"ind.{dataset_str}.ally", |
| 258 | + f"ind.{dataset_str}.graph", |
| 259 | + f"ind.{dataset_str}.test.index", |
| 260 | + ] |
| 261 | + |
| 262 | + for filename in filenames: |
| 263 | + file_url = f"{BASE_URL}/{filename}" |
| 264 | + save_path = os.path.join(DATA_DIR, filename) |
| 265 | + download_file_from_github(file_url, save_path) |
| 266 | + |
222 | 267 | objects = [] |
223 | | - for i in range(len(names)): |
224 | | - with open( |
225 | | - "data/{}/raw/ind.{}.{}".format(dataset_str, dataset_str, names[i]), "rb" |
226 | | - ) as f: |
| 268 | + for name in ["x", "y", "tx", "ty", "allx", "ally", "graph"]: |
| 269 | + file_path = os.path.join(DATA_DIR, f"ind.{dataset_str}.{name}") |
| 270 | + with open(file_path, "rb") as f: |
227 | 271 | if sys.version_info > (3, 0): |
228 | 272 | objects.append(pkl.load(f, encoding="latin1")) |
229 | 273 | else: |
230 | 274 | objects.append(pkl.load(f)) |
231 | 275 |
|
232 | 276 | x, y, tx, ty, allx, ally, graph = tuple(objects) |
233 | 277 | test_idx_reorder = NC_parse_index_file( |
234 | | - "data/{}/raw/ind.{}.test.index".format(dataset_str, dataset_str) |
| 278 | + os.path.join(DATA_DIR, f"ind.{dataset_str}.test.index") |
235 | 279 | ) |
236 | 280 | test_idx_range = np.sort(test_idx_reorder) |
237 | 281 |
|
|
0 commit comments