Skip to content

Commit ba06291

Browse files
authored
Added .devcontainer setup and updated versions (#59)
* dev container env * devcontainer working * include py311 and bump minor version * update torch to 2.5.0 * add cupy dependency * jax[cuda] working * tensor_split expects tensor_indices_or_sections to be on cpu * exclude py38 with torch25 Signed-off-by: theo-barfoot <[email protected]> --------- Signed-off-by: theo-barfoot <[email protected]>
1 parent 10a0ace commit ba06291

File tree

9 files changed

+129
-8
lines changed

9 files changed

+129
-8
lines changed

.devcontainer/Dockerfile

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Stage 1: NVIDIA CUDA Image
2+
ARG CUDA_VERSION=12.5.0
3+
FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu22.04 AS cuda-base
4+
5+
# Stage 2: Miniconda setup from configuration
6+
FROM continuumio/miniconda3 AS miniconda-stage
7+
8+
# Stage 3: Final image combining CUDA and Miniconda
9+
FROM mcr.microsoft.com/devcontainers/base:ubuntu-22.04
10+
11+
# Copy from CUDA base
12+
COPY --from=cuda-base /usr/local/cuda /usr/local/cuda
13+
14+
# Copy Miniconda from the Miniconda stage
15+
COPY --from=miniconda-stage /opt/conda /opt/conda
16+
17+
# Set environment variables for Miniconda
18+
ENV PATH /opt/conda/bin:$PATH
19+
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
20+
21+
# Install Python 3.10
22+
ARG PYTHON_VERSION=3.12
23+
RUN conda install python=${PYTHON_VERSION}
24+
25+
# Arguments for PyTorch and CUDA Toolkit versions
26+
ARG PYTORCH_VERSION=2.5.0
27+
ARG CUDATOOLKIT_VERSION=12.4
28+
29+
# Install PyTorch and other dependencies
30+
RUN conda install pytorch=${PYTORCH_VERSION} pytorch-cuda=${CUDATOOLKIT_VERSION} -c pytorch -c nvidia
31+
32+
# Handle environment.yml if it exists
33+
RUN echo env_change_20241021_2
34+
COPY environment.yml* noop.txt /tmp/conda-tmp/
35+
RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then \
36+
/opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; \
37+
fi \
38+
&& rm -rf /tmp/conda-tmp
39+
40+
# Append Miniconda to PATH in .bashrc for interactive shells
41+
RUN echo ". /opt/conda/etc/profile.d/conda.sh" >> /root/.bashrc \
42+
&& echo "conda activate base" >> /root/.bashrc
43+
44+
# Final CMD or ENTRYPOINT
45+
CMD ["bash"]

.devcontainer/devcontainer.json

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
{
2+
"name": "torchsparsegradutils Dev Container",
3+
"build": {
4+
"dockerfile": "./Dockerfile",
5+
"context": ".",
6+
"args": {
7+
"CUDA_VERSION": "12.4.0",
8+
"PYTORCH_VERSION": "2.5.0",
9+
"CUDATOOLKIT_VERSION": "12.4",
10+
"PYTHON_VERSION": "3.12"
11+
}
12+
},
13+
"runArgs": [
14+
"--gpus",
15+
"all"
16+
],
17+
"remoteEnv": {
18+
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock"
19+
},
20+
"customizations": {
21+
"vscode": {
22+
"settings": {
23+
"python.defaultInterpreterPath": "/opt/conda/bin/python",
24+
"terminal.integrated.shell.linux": "/bin/bash",
25+
"terminal.integrated.env.linux": {
26+
"CONDA_DEFAULT_ENV": "base",
27+
"CONDA_PREFIX": "/opt/conda",
28+
"CONDA_PYTHON_EXE": "/opt/conda/bin/python",
29+
"PATH": "/opt/conda/bin:${env:PATH}"
30+
},
31+
"python.testing.pytestArgs": [
32+
"torchsparsegradutils/tests"
33+
],
34+
"python.testing.unittestEnabled": false,
35+
"python.testing.pytestEnabled": true
36+
},
37+
"extensions": [
38+
"dbaeumer.vscode-eslint",
39+
"ms-python.vscode-pylance",
40+
"ms-python.python",
41+
"github.copilot",
42+
"GitHub.vscode-pull-request-github",
43+
"GitHub.vscode-github-actions",
44+
"mhutchie.git-graph",
45+
"waderyan.gitblame"
46+
]
47+
}
48+
},
49+
"remoteUser": "vscode",
50+
"postCreateCommand": "echo 'Container is ready!'"
51+
}

.devcontainer/environment.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
name: base
2+
channels:
3+
- conda-forge
4+
- defaults
5+
dependencies:
6+
- numpy
7+
- cupy
8+
- scipy
9+
- pre-commit==3.7.1
10+
- black==24.4.2
11+
- flake8==7.1.0
12+
- parameterized==0.9.0
13+
- pytest==8.2.2
14+
- pytest-rerunfailures==14.0
15+
- pyyaml==6.0.1
16+
- conda-libmamba-solver
17+
- libmamba
18+
- libmambapy
19+
- libarchive
20+
- pip
21+
- pip:
22+
- "jax[cuda12]"

.devcontainer/noop.txt

Whitespace-only changes.

.github/workflows/python-package.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ jobs:
1313
fail-fast: false
1414
matrix:
1515
python-version: ["3.8", "3.10", "3.12"]
16-
torch-version: ["1.13.1", "2.4.1"]
16+
torch-version: ["1.13.1", "2.5.0"]
1717
exclude:
1818
- python-version: "3.12"
1919
torch-version: "1.13.1"
20+
- python-version: "3.8"
21+
torch-version: "2.5.0"
2022

2123
steps:
2224
- uses: actions/checkout@v4

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: 23.3.0
3+
rev: 24.4.2
44
hooks:
55
- id: black
66
language_version: python3.10
77

88
- repo: https://github.com/pycqa/flake8
9-
rev: 6.0.0
9+
rev: 7.1.0
1010
hooks:
1111
- id: flake8
1212

1313
- repo: https://github.com/pre-commit/pre-commit-hooks
14-
rev: v4.4.0
14+
rev: v4.6.0
1515
hooks:
1616
- id: trailing-whitespace

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.black]
22
line-length = 120
3-
target-version = ['py37', 'py38', 'py39', 'py310']
3+
target-version = ['py38', 'py39', 'py310', 'py311', 'py312']
44
include = '\.pyi?$'
55
exclude = '''
66
(
@@ -21,4 +21,4 @@ exclude = '''
2121
# also separately exclude other files if needed
2222
#| some_file
2323
)
24-
'''
24+
'''

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def readme():
88

99
setuptools.setup(
1010
name="torchsparsegradutils",
11-
version="0.1.2",
11+
version="0.1.3",
1212
description="A collection of utility functions to work with PyTorch sparse tensors",
1313
long_description=readme(),
1414
long_description_content_type="text/markdown",
@@ -18,6 +18,7 @@ def readme():
1818
"Programming Language :: Python :: 3.8",
1919
"Programming Language :: Python :: 3.9",
2020
"Programming Language :: Python :: 3.10",
21+
"Programming Language :: Python :: 3.11",
2122
"Programming Language :: Python :: 3.12",
2223
],
2324
python_requires=">=3.8",

torchsparsegradutils/indexed_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def segment_mm(a, b, seglen_a):
4545
if not a.shape[1] == D1 or not seglen_a.shape[0] == R:
4646
raise ValueError("Incompatible size for inputs")
4747

48-
segidx_a = torch.cumsum(seglen_a[:-1], dim=0)
48+
segidx_a = torch.cumsum(seglen_a[:-1], dim=0).cpu()
4949

5050
# Ideally the conversions below to nested tensor would be handled natively
5151
nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))

0 commit comments

Comments
 (0)