|
6 | 6 | import torch.nn.functional
|
7 | 7 | from ase.build import molecule
|
8 | 8 | from e3nn import o3
|
| 9 | +from e3nn.util import jit |
9 | 10 | from scipy.spatial.transform import Rotation as R
|
10 | 11 |
|
11 | 12 | from mace import data, modules, tools
|
@@ -176,6 +177,33 @@ def test_multi_reference():
|
176 | 177 | )
|
177 | 178 |
|
178 | 179 |
|
| 180 | +@pytest.mark.parametrize( |
| 181 | + "calc", |
| 182 | + [ |
| 183 | + mace_mp(device="cpu", default_dtype="float64"), |
| 184 | + mace_mp(model="small", device="cpu", default_dtype="float64"), |
| 185 | + mace_mp(model="medium", device="cpu", default_dtype="float64"), |
| 186 | + mace_mp(model="large", device="cpu", default_dtype="float64"), |
| 187 | + mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64"), |
| 188 | + mace_off(model="small", device="cpu", default_dtype="float64"), |
| 189 | + mace_off(model="medium", device="cpu", default_dtype="float64"), |
| 190 | + mace_off(model="large", device="cpu", default_dtype="float64"), |
| 191 | + mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64"), |
| 192 | + ], |
| 193 | +) |
| 194 | +def test_compile_foundation(calc): |
| 195 | + model = calc.models[0] |
| 196 | + atoms = molecule("CH4") |
| 197 | + atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1 |
| 198 | + batch = calc._atoms_to_batch(atoms) |
| 199 | + output_1 = model(batch.to_dict()) |
| 200 | + model_compiled = jit.compile(model) |
| 201 | + output = model_compiled(batch.to_dict()) |
| 202 | + for key in output_1.keys(): |
| 203 | + if isinstance(output_1[key], torch.Tensor): |
| 204 | + assert torch.allclose(output_1[key], output[key], atol=1e-5) |
| 205 | + |
| 206 | + |
179 | 207 | @pytest.mark.parametrize(
|
180 | 208 | "model",
|
181 | 209 | [
|
|
0 commit comments