Skip to content

Commit 1e135f3

Browse files
authored
Merge pull request #66 from boutproject/fix-squashoutput-global-attributes
Copy file attributes in squashoutput()
2 parents 0f4913e + 88b4ab0 commit 1e135f3

File tree

5 files changed

+93
-2
lines changed

5 files changed

+93
-2
lines changed

boutdata/data.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,12 @@ def __init__(
11021102
for i in range(self.grid_info["npes"])
11031103
]
11041104

1105+
if self._DataFileCaching or self._parallel:
1106+
# Keep reference to 0'th file, for reading attributes
1107+
self._file0 = DataFile(self._file_list[0])
1108+
else:
1109+
self._file0 = None
1110+
11051111
if self._info:
11061112
print(
11071113
"mxsub = {} mysub = {} mz = {}\n".format(
@@ -1145,6 +1151,8 @@ def __del__(self):
11451151
connection.send(None)
11461152
worker.join()
11471153
connection.close()
1154+
if self._file0 is not None:
1155+
self._file0.close()
11481156

11491157
def _init_caching(self):
11501158
"""
@@ -1224,6 +1232,64 @@ def evolvingVariables(self):
12241232
"""Return a list of names of time-evolving variables"""
12251233
return self.grid_info["evolvingVariableNames"]
12261234

1235+
def get_attribute(self, variable, attrname):
1236+
"""Get an attribute of a variable
1237+
1238+
Parameters
1239+
----------
1240+
variable : str
1241+
Name of variable to get attribute from
1242+
attrname : str
1243+
Name of attribute
1244+
1245+
Returns
1246+
-------
1247+
Value of attribute
1248+
"""
1249+
if self._file0 is None:
1250+
with DataFile(self._file_list[0]) as f:
1251+
return f.attributes(variable)[attrname]
1252+
else:
1253+
return self._file0.attributes(variable)[attrname]
1254+
1255+
def get_file_attribute(self, attrname):
1256+
"""Get an attribute of the output files.
1257+
1258+
Attribute is taken from the rank-0 file. No checking is done that the attribute
1259+
is consistent between all the output files.
1260+
1261+
Parameters
1262+
----------
1263+
attrname : str
1264+
Name of attribute
1265+
1266+
Returns
1267+
-------
1268+
Value of attribute
1269+
"""
1270+
if self._file0 is None:
1271+
with DataFile(self._file_list[0]) as f:
1272+
return f.read_file_attribute(attrname)
1273+
else:
1274+
return self._file0.read_file_attribute(attrname)
1275+
1276+
def list_file_attributes(self):
1277+
"""List all file attributes of output files
1278+
1279+
List is taken from the rank-0 file. No checking is done that the file attributes
1280+
are consistent between all the output files.
1281+
1282+
Returns
1283+
-------
1284+
List of str
1285+
Names of the file attributes
1286+
"""
1287+
if self._file0 is None:
1288+
with DataFile(self._file_list[0]) as f:
1289+
return f.list_file_attributes()
1290+
else:
1291+
return self._file0.list_file_attributes()
1292+
12271293
def redistribute(self, npes, nxpe=None, mxg=2, myg=2, include_restarts=True):
12281294
"""Create a new set of dump files for npes processors.
12291295

boutdata/squashoutput.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ def squashoutput(
226226
var = None
227227
gc.collect()
228228

229+
# Copy file attributes
230+
for attrname in outputs.list_file_attributes():
231+
attrval = outputs.get_file_attribute(attrname)
232+
for f in files:
233+
f.write_file_attribute(attrname, attrval)
234+
229235
for f in files:
230236
f.close()
231237

boutdata/tests/make_test_data.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@
7575
},
7676
}
7777

78+
expected_file_attributes = {
79+
"global_str_attribute": "foobar",
80+
"global_int_attribute": 42,
81+
"global_float_attribute": 7.0,
82+
}
83+
7884

7985
def make_grid_info(
8086
*, mxg=2, myg=2, nxpe=1, nype=1, ixseps1=None, ixseps2=None, xpoints=0
@@ -321,6 +327,9 @@ def createScalar(name, value):
321327
createScalar("PE_YIND", i // nxpe)
322328
createScalar("MYPE", i)
323329

330+
for attrname, attr in expected_file_attributes.items():
331+
setattr(outputfile, attrname, attr)
332+
324333
return result
325334

326335

boutdata/tests/test_collect.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
from boutdata.collect import collect
88
from boutdata.squashoutput import squashoutput
9+
from boututils.datafile import DataFile
910

1011
from boutdata.tests.make_test_data import (
1112
apply_slices,
1213
create_dump_file,
1314
concatenate_data,
1415
expected_attributes,
16+
expected_file_attributes,
1517
make_grid_info,
1618
remove_xboundaries,
1719
remove_yboundaries,
@@ -57,7 +59,7 @@ def check_collected_data(
5759
Arrays should be global (not per-process).
5860
fieldperp_global_yind : int
5961
Global y-index where FieldPerps are expected to be defined.
60-
path : pathlib.Path or str
62+
path : pathlib.Path
6163
Path to collect data from.
6264
squash : bool
6365
If True, call `squashoutput()` and delete the `BOUT.dmp.*.nc` files (so that we
@@ -104,6 +106,14 @@ def check_collected_data(
104106
fieldperp_global_yind,
105107
)
106108

109+
if squash:
110+
filename = path.joinpath("boutdata.nc")
111+
else:
112+
filename = path.joinpath("BOUT.dmp.0.nc")
113+
with DataFile(str(filename)) as f:
114+
for attrname, attr in expected_file_attributes.items():
115+
assert f.read_file_attribute(attrname) == attr
116+
107117

108118
def check_variable(
109119
varname, actual, expected_data, expected_attributes, fieldperp_global_yind

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"numpy",
4646
"matplotlib",
4747
"scipy",
48-
"boututils",
48+
"boututils>=0.1.9",
4949
"importlib-metadata ; python_version<'3.8'",
5050
],
5151
classifiers=[

0 commit comments

Comments
 (0)