|
1 | 1 | """Module for defining the SpaceEnvironment class""" |
| 2 | + |
2 | 3 | import itertools |
3 | 4 | import json |
4 | 5 | import os |
@@ -170,7 +171,7 @@ def remove_islands(self) -> "SpaceDataset": |
170 | 171 | else: |
171 | 172 | LOGGER.debug(f"Found {sum(islands)} islands. Removing them.") |
172 | 173 | return self[~islands] |
173 | | - |
| 174 | + |
174 | 175 | def unmask(self, inplace: bool = False) -> "SpaceDataset": |
175 | 176 | """ |
176 | 177 | Returns a SpaceDataset with the masked covariate unmasked. |
@@ -295,9 +296,11 @@ def __init__( |
295 | 296 | # -- full data -- |
296 | 297 | ext = ".".join(glob(os.path.join(tgtdir, "synthetic_data.*"))[0].split(".")[1:]) |
297 | 298 | if ext == "csv": |
298 | | - data = pd.read_csv(os.path.join(tgtdir, "synthetic_data.csv"), index_col=0) |
| 299 | + path = os.path.join(tgtdir, "synthetic_data.csv") |
| 300 | + data = pd.read_csv(path, index_col=0, dtype={0: str}) |
299 | 301 | elif ext in ("tab", "tsv"): |
300 | | - data = pd.read_csv(os.path.join(tgtdir, "synthetic_data.tab"), sep="\t", index_col=0) |
| 302 | + path = os.path.join(tgtdir, "synthetic_data.tab") |
| 303 | + data = pd.read_csv(path, sep="\t", index_col=0, dtype={0: str}) |
301 | 304 | elif ext == "parquet": |
302 | 305 | data = pd.read_parquet(os.path.join(tgtdir, "synthetic_data.parquet")) |
303 | 306 | else: |
@@ -336,7 +339,7 @@ def __init__( |
336 | 339 | graph = nx.Graph() |
337 | 340 | graph.add_nodes_from(coords.index) |
338 | 341 | graph.add_edges_from(edges.values) |
339 | | - |
| 342 | + |
340 | 343 | node2id = {n: i for i, n in enumerate(data.index)} |
341 | 344 | self.edge_list = [(node2id[e[0]], node2id[e[1]]) for e in graph.edges] |
342 | 345 | self.graph = nx.from_edgelist(self.edge_list) |
|
0 commit comments