Skip to content

Commit 0bcbdb6

Browse files
authored
Merge pull request #786 from ACEsuit/develop
make cueq optional dep and add special test
2 parents fca3022 + 140d250 commit 0bcbdb6

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

.github/workflows/unittest.yaml

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,46 @@ on:
55
branches: [main]
66

77
jobs:
8-
pytest-container:
8+
pytest-general:
99
runs-on: ubuntu-latest
10-
1110
steps:
1211
- uses: actions/checkout@v4
1312
- uses: actions/setup-python@v5
1413
with:
1514
python-version: "3.10"
1615
cache: "pip"
1716

18-
- name: Install requirements
17+
- name: Install requirements (general tests)
1918
run: |
2019
pip install -U pip
2120
pip install .[dev]
2221
23-
- name: Log installed environment
22+
- name: Log installed environment (general tests)
23+
run: |
24+
python3 -m pip freeze
25+
26+
- name: Run general unit tests
27+
run: |
28+
pytest tests --ignore=tests/test_cueq.py
29+
30+
pytest-cueq:
31+
runs-on: ubuntu-latest
32+
steps:
33+
- uses: actions/checkout@v4
34+
- uses: actions/setup-python@v5
35+
with:
36+
python-version: "3.10"
37+
cache: "pip"
38+
39+
- name: Install requirements (with cueq)
40+
run: |
41+
pip install -U pip
42+
pip install ".[dev, cueq]"
43+
44+
- name: Log installed environment (with cueq)
2445
run: |
2546
python3 -m pip freeze
2647
27-
- name: Run unit tests
48+
- name: Run cueq-specific tests
2849
run: |
29-
pytest tests
50+
pytest tests/test_cueq.py tests/test_calculator.py

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ install_requires =
2929
GitPython
3030
pyYAML
3131
tqdm
32-
cuequivariance-torch
3332
# for plotting:
3433
matplotlib
3534
pandas
@@ -60,5 +59,6 @@ dev =
6059
pytest-benchmark
6160
pylint
6261
schedulefree = schedulefree
62+
cueq = cuequivariance-torch
6363
cueq-cuda-11 = cuequivariance-ops-torch-cu11
6464
cueq-cuda-12 = cuequivariance-ops-torch-cu12

tests/test_calculator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model):
586586
assert not np.allclose(desc, desc_rotated, atol=1e-6)
587587

588588

589+
@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed")
589590
def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model_cueq):
590591
at = fitting_configs[2].copy()
591592
at_rotated = fitting_configs[2].copy()

tests/test_cueq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_bidirectional_conversion(
143143
loss_e3nn_back.backward()
144144

145145
# Compare gradients for all conversions
146-
tol = 1e-4 if default_dtype == torch.float32 else 1e-8
146+
tol = 1e-4 if default_dtype == torch.float32 else 1e-7
147147

148148
def print_gradient_diff(name1, p1, name2, p2, conv_type):
149149
if p1.grad is not None and p1.grad.shape == p2.grad.shape:
@@ -152,7 +152,7 @@ def print_gradient_diff(name1, p1, name2, p2, conv_type):
152152
print(
153153
f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}"
154154
)
155-
torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=1e-10)
155+
torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=tol)
156156

157157
# E3nn to CuEq gradients
158158
for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip(

0 commit comments

Comments
 (0)