Skip to content

Commit fca3022

Browse files
authored
Merge pull request #785 from ACEsuit/develop
Develop
2 parents 6dce504 + d5e8a38 commit fca3022

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

mace/modules/radial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
131131

132132
@staticmethod
133133
def calculate_envelope(
134-
x: torch.Tensor, r_max: torch.Tensor, p: int
134+
x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor
135135
) -> torch.Tensor:
136136
r_over_r_max = x / r_max
137137
envelope = (

tests/test_foundations.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn.functional
77
from ase.build import molecule
88
from e3nn import o3
9+
from e3nn.util import jit
910
from scipy.spatial.transform import Rotation as R
1011

1112
from mace import data, modules, tools
@@ -176,6 +177,33 @@ def test_multi_reference():
176177
)
177178

178179

180+
@pytest.mark.parametrize(
181+
"calc",
182+
[
183+
mace_mp(device="cpu", default_dtype="float64"),
184+
mace_mp(model="small", device="cpu", default_dtype="float64"),
185+
mace_mp(model="medium", device="cpu", default_dtype="float64"),
186+
mace_mp(model="large", device="cpu", default_dtype="float64"),
187+
mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64"),
188+
mace_off(model="small", device="cpu", default_dtype="float64"),
189+
mace_off(model="medium", device="cpu", default_dtype="float64"),
190+
mace_off(model="large", device="cpu", default_dtype="float64"),
191+
mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64"),
192+
],
193+
)
194+
def test_compile_foundation(calc):
195+
model = calc.models[0]
196+
atoms = molecule("CH4")
197+
atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1
198+
batch = calc._atoms_to_batch(atoms)
199+
output_1 = model(batch.to_dict())
200+
model_compiled = jit.compile(model)
201+
output = model_compiled(batch.to_dict())
202+
for key in output_1.keys():
203+
if isinstance(output_1[key], torch.Tensor):
204+
assert torch.allclose(output_1[key], output[key], atol=1e-5)
205+
206+
179207
@pytest.mark.parametrize(
180208
"model",
181209
[

0 commit comments

Comments
 (0)