Skip to content

Commit d785ada

Browse files
authored
Added simple workarounds for gather_mm and segment_mm (#57)
* Added simple workarounds for gather_mm and segment_mm. See #56 * bumping python and pytorch version in CI * enabling black on notebooks in CI * updating github actions to avoid deprecation warning
1 parent 0f92297 commit d785ada

File tree

6 files changed

+237
-8
lines changed

6 files changed

+237
-8
lines changed

.github/workflows/black.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,7 @@ jobs:
66
lint:
77
runs-on: ubuntu-latest
88
steps:
9-
- uses: actions/checkout@v3
9+
- uses: actions/checkout@v4
1010
- uses: psf/black@stable
11+
with:
12+
jupyter: true

.github/workflows/python-package.yml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,28 @@ jobs:
1212
strategy:
1313
fail-fast: false
1414
matrix:
15-
python-version: ["3.8", "3.9", "3.10"]
16-
torch-version: ["1.13.1", "2.0.1"]
15+
python-version: ["3.8", "3.10", "3.12"]
16+
torch-version: ["1.13.1", "2.4.1"]
17+
exclude:
18+
- python-version: "3.12"
19+
torch-version: "1.13.1"
1720

1821
steps:
19-
- uses: actions/checkout@v3
22+
- uses: actions/checkout@v4
2023
- name: Set up Python ${{ matrix.python-version }}
21-
uses: actions/setup-python@v3
24+
uses: actions/setup-python@v5
2225
with:
2326
python-version: ${{ matrix.python-version }}
2427
- name: Install dependencies
2528
run: |
2629
python -m pip install --upgrade pip
2730
pip install torch==${{ matrix.torch-version }}
28-
python -m pip install flake8 black
31+
python -m pip install flake8 black[jupyter]
2932
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
33+
- name: numpy downgrade for pytorch 1.x
34+
if: startsWith(matrix.torch-version, '1.')
35+
run: |
36+
pip install "numpy<2"
3037
- name: Lint check with flake8
3138
run: |
3239
# stop the build if there are Python syntax errors or undefined names

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ def readme():
1818
"Programming Language :: Python :: 3.8",
1919
"Programming Language :: Python :: 3.9",
2020
"Programming Language :: Python :: 3.10",
21+
"Programming Language :: Python :: 3.12",
2122
],
22-
python_requires=">=3.8, <3.11",
23+
python_requires=">=3.8",
2324
keywords="sparse torch utility",
2425
url="https://github.com/cai4cai/torchsparsegradutils",
2526
author="CAI4CAI research group",

torchsparsegradutils/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
from .sparse_matmul import sparse_mm
2+
from .indexed_matmul import gather_mm, segment_mm
23
from .sparse_solve import sparse_triangular_solve, sparse_generic_solve
34
from .sparse_lstsq import sparse_generic_lstsq
45

5-
__all__ = ["sparse_mm", "sparse_triangular_solve", "sparse_generic_solve", "sparse_generic_lstsq"]
6+
__all__ = [
7+
"sparse_mm",
8+
"gather_mm",
9+
"segment_mm",
10+
"sparse_triangular_solve",
11+
"sparse_generic_solve",
12+
"sparse_generic_lstsq",
13+
]
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import torch
2+
3+
try:
4+
import dgl.ops as dglops
5+
6+
dgl_installed = True
7+
except ImportError:
8+
dgl_installed = False
9+
10+
11+
def segment_mm(a, b, seglen_a):
12+
"""
13+
Performs matrix multiplication according to segments.
14+
See https://docs.dgl.ai/generated/dgl.ops.segment_mm.html
15+
16+
Suppose ``seglen_a == [10, 5, 0, 3]``, the operator will perform
17+
four matrix multiplications::
18+
19+
a[0:10] @ b[0], a[10:15] @ b[1],
20+
a[15:15] @ b[2], a[15:18] @ b[3]
21+
22+
Args:
23+
a (torch.Tensor): The left operand, 2-D tensor of shape ``(N, D1)``
24+
b (torch.Tensor): The right operand, 3-D tensor of shape ``(R, D1, D2)``
25+
seglen_a (torch.Tensor): An integer tensor of shape ``(R,)``. Each element is the length of segments of input ``a``. The summation of all elements must be equal to ``N``.
26+
27+
Returns:
28+
torch.Tensor: The output dense matrix of shape ``(N, D2)``
29+
"""
30+
if torch.__version__ < (2, 4):
31+
raise NotImplementedError("PyTorch version is too old for nested tesors")
32+
33+
if dgl_installed:
34+
# DGL is probably more computationally efficient
35+
# See https://github.com/pytorch/pytorch/issues/136747
36+
return dglops.segment_mm(a, b, seglen_a)
37+
38+
if not a.dim() == 2 or not b.dim() == 3 or not seglen_a.dim() == 1:
39+
raise ValueError("Input tensors have unexpected dimensions")
40+
41+
N, _ = a.shape
42+
R, D1, D2 = b.shape
43+
44+
# Sanity check sizes
45+
if not a.shape[1] == D1 or not seglen_a.shape[0] == R:
46+
raise ValueError("Incompatible size for inputs")
47+
48+
segidx_a = torch.cumsum(seglen_a[:-1], dim=0)
49+
50+
# Ideally the conversions below to nested tensor would be handled natively
51+
nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))
52+
nested_b = torch.nested.as_nested_tensor(list(map(torch.squeeze, torch.split(b, 1, dim=0))))
53+
54+
# The actual gather matmul computation
55+
nested_ab = torch.matmul(nested_a, nested_b)
56+
57+
# Convert back to tensors, again ideally this would be handled natively
58+
ab = torch.cat(nested_ab.unbind(), dim=0)
59+
return ab
60+
61+
62+
def gather_mm(a, b, idx_b):
63+
"""
64+
Gather data according to the given indices and perform matrix multiplication.
65+
See https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
66+
67+
Let the result tensor be ``c``, the operator conducts the following computation:
68+
69+
c[i] = a[i] @ b[idx_b[i]]
70+
, where len(c) == len(idx_b)
71+
72+
Args:
73+
a (torch.Tensor): A 2-D tensor of shape ``(N, D1)``
74+
b (torch.Tensor): A 3-D tensor of shape ``(R, D1, D2)``
75+
idx_b (torch.Tensor): An 1-D integer tensor of shape ``(N,)``.
76+
77+
Returns:
78+
torch.Tensor: The output dense matrix of shape ``(N, D2)``
79+
"""
80+
if torch.__version__ < (2, 4):
81+
raise NotImplementedError("PyTorch version is too old for nested tesors")
82+
83+
if dgl_installed:
84+
# DGL is more computationally efficient
85+
# See https://github.com/pytorch/pytorch/issues/136747
86+
return dglops.gather_mm(a, b, idx_b)
87+
88+
# Dependency free fallback
89+
if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor) or not isinstance(idx_b, torch.Tensor):
90+
raise ValueError("Inputs should be instances of torch.Tensor")
91+
92+
if not a.dim() == 2 or not b.dim() == 3 or not idx_b.dim() == 1:
93+
raise ValueError("Input tensors have unexpected dimensions")
94+
95+
N = idx_b.shape[0]
96+
R, D1, D2 = b.shape
97+
98+
# Sanity check sizes
99+
if not a.shape[0] == N or not a.shape[1] == D1:
100+
raise ValueError("Incompatible size for inputs")
101+
102+
torchdevice = a.device
103+
src_idx = torch.arange(N, device=torchdevice)
104+
105+
# Ideally the conversions below to nested tensor would be handled without for looops and without copy
106+
nested_a = torch.nested.as_nested_tensor([a[idx_b == i, :] for i in range(R)])
107+
src_idx_reshuffled = torch.cat([src_idx[idx_b == i] for i in range(R)])
108+
nested_b = torch.nested.as_nested_tensor([b[i, :, :].squeeze() for i in range(R)])
109+
110+
# The actual gather matmul computation
111+
nested_ab = torch.matmul(nested_a, nested_b)
112+
113+
# Convert back to tensors, again, ideally this would be handled natively with no copy
114+
ab_segmented = torch.cat(nested_ab.unbind(), dim=0)
115+
ab = torch.empty((N, D2), device=torchdevice)
116+
ab[src_idx_reshuffled] = ab_segmented
117+
return ab
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
import pytest
3+
4+
if torch.__version__ < (2, 4):
5+
pytest.skip(
6+
"Skipping test based on nested tensors since an old version of pytorch is used", allow_module_level=True
7+
)
8+
9+
from torchsparsegradutils import gather_mm, segment_mm
10+
11+
# Identify Testing Parameters
12+
DEVICES = [torch.device("cpu")]
13+
if torch.cuda.is_available():
14+
DEVICES.append(torch.device("cuda"))
15+
16+
TEST_DATA = [
17+
# name N, R, D1, D2
18+
("small", 100, 32, 7, 10),
19+
]
20+
21+
INDEX_DTYPES = [torch.int32, torch.int64]
22+
VALUE_DTYPES = [torch.float32, torch.float64]
23+
24+
ATOL = 1e-6 # relaxed tolerance to allow for float32
25+
RTOL = 1e-4
26+
27+
28+
# Define Test Names:
29+
def data_id(shapes):
30+
return shapes[0]
31+
32+
33+
def device_id(device):
34+
return str(device)
35+
36+
37+
def dtype_id(dtype):
38+
return str(dtype).split(".")[-1]
39+
40+
41+
# Define Fixtures
42+
43+
44+
@pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA])
45+
def shapes(request):
46+
return request.param
47+
48+
49+
@pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES])
50+
def value_dtype(request):
51+
return request.param
52+
53+
54+
@pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES])
55+
def index_dtype(request):
56+
return request.param
57+
58+
59+
@pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES])
60+
def device(request):
61+
return request.param
62+
63+
64+
# Define Tests
65+
66+
67+
def test_segment_mm(device, value_dtype, index_dtype, shapes):
68+
_, N, R, D1, D2 = shapes
69+
70+
a = torch.randn((N, D1), device=device)
71+
b = torch.randn((R, D1, D2), device=device)
72+
seglen_a = torch.randint(low=1, high=int(N / R), size=(R,), device=device)
73+
seglen_a[-1] = N - seglen_a[:-1].sum()
74+
75+
ab = segment_mm(a, b, seglen_a)
76+
77+
k = 0
78+
for i in range(R):
79+
for j in range(seglen_a[i]):
80+
assert torch.allclose(ab[k, :].squeeze(), a[k, :].squeeze() @ b[i, :, :].squeeze(), atol=ATOL, rtol=RTOL)
81+
k += 1
82+
83+
84+
def test_gather_mm(device, value_dtype, index_dtype, shapes):
85+
_, N, R, D1, D2 = shapes
86+
87+
a = torch.randn((N, D1), device=device)
88+
b = torch.randn((R, D1, D2), device=device)
89+
idx_b = torch.randint(low=0, high=R, size=(N,), device=device)
90+
91+
ab = gather_mm(a, b, idx_b)
92+
93+
for i in range(N):
94+
assert torch.allclose(ab[i, :].squeeze(), a[i, :].squeeze() @ b[idx_b[i], :, :].squeeze(), atol=ATOL, rtol=RTOL)

0 commit comments

Comments
 (0)