Skip to content

Commit 3b51f19

Browse files
committed
hotfix: read index as string for csv and tab
1 parent c55707d commit 3b51f19

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

spacebench/env.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Module for defining the SpaceEnvironment class"""
2+
23
import itertools
34
import json
45
import os
@@ -170,7 +171,7 @@ def remove_islands(self) -> "SpaceDataset":
170171
else:
171172
LOGGER.debug(f"Found {sum(islands)} islands. Removing them.")
172173
return self[~islands]
173-
174+
174175
def unmask(self, inplace: bool = False) -> "SpaceDataset":
175176
"""
176177
Returns a SpaceDataset with the masked covariate unmasked.
@@ -295,9 +296,11 @@ def __init__(
295296
# -- full data --
296297
ext = ".".join(glob(os.path.join(tgtdir, "synthetic_data.*"))[0].split(".")[1:])
297298
if ext == "csv":
298-
data = pd.read_csv(os.path.join(tgtdir, "synthetic_data.csv"), index_col=0)
299+
path = os.path.join(tgtdir, "synthetic_data.csv")
300+
data = pd.read_csv(path, index_col=0, dtype={0: str})
299301
elif ext in ("tab", "tsv"):
300-
data = pd.read_csv(os.path.join(tgtdir, "synthetic_data.tab"), sep="\t", index_col=0)
302+
path = os.path.join(tgtdir, "synthetic_data.tab")
303+
data = pd.read_csv(path, sep="\t", index_col=0, dtype={0: str})
301304
elif ext == "parquet":
302305
data = pd.read_parquet(os.path.join(tgtdir, "synthetic_data.parquet"))
303306
else:
@@ -336,7 +339,7 @@ def __init__(
336339
graph = nx.Graph()
337340
graph.add_nodes_from(coords.index)
338341
graph.add_edges_from(edges.values)
339-
342+
340343
node2id = {n: i for i, n in enumerate(data.index)}
341344
self.edge_list = [(node2id[e[0]], node2id[e[1]]) for e in graph.edges]
342345
self.graph = nx.from_edgelist(self.edge_list)

0 commit comments

Comments
 (0)