Skip to content

Commit c5a2e89

Browse files
committed
trying to simplify
1 parent 3ef1337 commit c5a2e89

File tree

6 files changed

+65
-231
lines changed

6 files changed

+65
-231
lines changed

.github/workflows/nightly_build.yml

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,14 @@ jobs:
6060
run: |
6161
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
6262
python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda_support[1] }}
63-
- name: Prepare nightly build
64-
env:
65-
TORCHRL_NIGHTLY: 1
66-
run: |
67-
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
68-
python3 packaging/prepare_nightly_build.py
6963
- name: Build TorchRL Nightly
7064
env:
7165
TORCHRL_NIGHTLY: 1
7266
run: |
7367
rm -r dist || true
7468
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
75-
python3 -mpip install wheel
76-
python3 setup.py bdist_wheel
69+
python3 -mpip install build wheel
70+
./build_nightly.sh
7771
find dist -name '*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \;
7872
# pytorch/pytorch binaries are also manylinux_2_17 compliant but they
7973
# pretend that they're manylinux1 compliant so we do the same.
@@ -233,20 +227,14 @@ jobs:
233227
shell: bash
234228
run: |
235229
python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
236-
- name: Prepare nightly build
237-
env:
238-
TORCHRL_NIGHTLY: 1
239-
shell: bash
240-
run: |
241-
python3 packaging/prepare_nightly_build.py
242230
- name: Build TorchRL nightly
243231
env:
244232
TORCHRL_NIGHTLY: 1
245233
shell: bash
246234
run: |
247235
rm -r dist || true
248-
python3 -mpip install wheel
249-
python3 setup.py bdist_wheel
236+
python3 -mpip install build wheel
237+
./build_nightly.sh
250238
- name: Upload wheel for the test-wheel job
251239
uses: actions/upload-artifact@v4
252240
with:

build_nightly.sh

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/bin/bash
2+
set -e
3+
4+
# Check if we're in a nightly build
5+
if [ "$TORCHRL_NIGHTLY" != "1" ]; then
6+
echo "Not a nightly build, exiting"
7+
exit 0
8+
fi
9+
10+
echo "Starting nightly build process..."
11+
12+
# Create backups of original files
13+
cp pyproject.toml pyproject.toml.backup
14+
cp version.txt version.txt.backup
15+
16+
# Function to restore original files
17+
restore_files() {
18+
echo "Restoring original files..."
19+
mv pyproject.toml.backup pyproject.toml
20+
mv version.txt.backup version.txt
21+
}
22+
23+
# Set up trap to restore files on exit (success or failure)
24+
trap restore_files EXIT
25+
26+
# Modify pyproject.toml for nightly build
27+
echo "Modifying pyproject.toml for nightly build..."
28+
sed -i.bak 's/name = "torchrl"/name = "torchrl-nightly"/' pyproject.toml
29+
30+
# Replace tensordict dependency with tensordict-nightly
31+
echo "Replacing tensordict with tensordict-nightly..."
32+
sed -i.bak 's/"tensordict[^"]*"/"tensordict-nightly"/g' pyproject.toml
33+
34+
# Clean up sed backup files
35+
rm -f pyproject.toml.bak
36+
37+
# Set nightly version (YYYY.MM.DD format)
38+
echo "Setting nightly version..."
39+
echo "$(date +%Y.%m.%d)" > version.txt
40+
41+
# Build the package
42+
echo "Building nightly package..."
43+
python -m build
44+
45+
echo "Nightly build completed successfully!"

packaging/prepare_nightly_build.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
import os
3+
import re
34
from pathlib import Path
45

56

@@ -20,9 +21,11 @@ def prepare_nightly_build():
2021
with open(pyproject_path) as f:
2122
content = f.read()
2223

23-
# Replace tensordict dependency with tensordict-nightly
24-
if "tensordict>=0.9.0,<0.10.0" in content:
25-
content = content.replace("tensordict>=0.9.0,<0.10.0", "tensordict-nightly")
24+
# Replace tensordict dependency with tensordict-nightly using regex
25+
# This pattern matches "tensordict" followed by any version constraints
26+
tensordict_pattern = r'tensordict[^,\]]*'
27+
if re.search(tensordict_pattern, content):
28+
content = re.sub(tensordict_pattern, "tensordict-nightly", content)
2629
print("Replaced tensordict with tensordict-nightly in pyproject.toml")
2730
else:
2831
print("tensordict dependency not found in pyproject.toml")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ classifiers = [
2020
"Programming Language :: Python :: 3.10",
2121
"Programming Language :: Python :: 3.11",
2222
"Programming Language :: Python :: 3.12",
23+
"Programming Language :: Python :: 3.13",
2324
"Operating System :: OS Independent",
2425
"Development Status :: 4 - Beta",
2526
"Intended Audience :: Developers",

setup.py

Lines changed: 8 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,11 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
#
3-
# This source code is licensed under the MIT license found in the
4-
# LICENSE file in the root directory of this source tree.
5-
import glob
6-
import logging
71
import os
8-
import shutil
92
import sys
3+
import glob
4+
import logging
105
from pathlib import Path
11-
12-
from setuptools import Command, setup
6+
from setuptools import setup
137
from torch.utils.cpp_extension import BuildExtension, CppExtension
148

15-
cwd = os.path.dirname(os.path.abspath(__file__))
16-
ROOT_DIR = Path(__file__).parent.resolve()
17-
18-
19-
def write_version_file(version):
20-
version_path = os.path.join(cwd, "torchrl", "version.py")
21-
logging.info(f"Writing version file to: {version_path}")
22-
logging.info(f"Version to write: {version}")
23-
24-
# Get PyTorch version during build
25-
try:
26-
import torch
27-
28-
pytorch_version = torch.__version__
29-
except ImportError:
30-
pytorch_version = "unknown"
31-
32-
# Get git sha
33-
try:
34-
import subprocess
35-
36-
sha = (
37-
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd)
38-
.decode("ascii")
39-
.strip()
40-
)
41-
except Exception:
42-
sha = "Unknown"
43-
44-
with open(version_path, "w") as f:
45-
f.write(f"__version__ = '{version}'\n")
46-
f.write(f"git_version = {repr(sha)}\n")
47-
f.write(f"pytorch_version = '{pytorch_version}'\n")
48-
49-
logging.info("Version file written successfully")
50-
51-
52-
class clean(Command):
53-
user_options = []
54-
55-
def initialize_options(self):
56-
pass
57-
58-
def finalize_options(self):
59-
pass
60-
61-
def run(self):
62-
# Remove torchrl extension
63-
for path in (ROOT_DIR / "torchrl").glob("**/*.so"):
64-
logging.info(f"removing '{path}'")
65-
path.unlink()
66-
# Remove build directory
67-
build_dirs = [
68-
ROOT_DIR / "build",
69-
]
70-
for path in build_dirs:
71-
if path.exists():
72-
logging.info(f"removing '{path}' (and everything under it)")
73-
shutil.rmtree(str(path), ignore_errors=True)
74-
75-
769
def get_extensions():
7710
extension = CppExtension
7811

@@ -104,7 +37,7 @@ def get_extensions():
10437
cpp_files = glob.glob(os.path.join(extensions_dir, "*.cpp"))
10538
sources = [os.path.relpath(f) for f in cpp_files]
10639

107-
include_dirs = ["."]
40+
include_dirs = [".", "torchrl/csrc"]
10841
python_include_dir = os.getenv("PYTHON_INCLUDE_DIR")
10942
if python_include_dir is not None:
11043
include_dirs.append(python_include_dir)
@@ -120,127 +53,13 @@ def get_extensions():
12053

12154
return ext_modules
12255

123-
124-
def _main():
125-
# Always use "torchrl" as the project name for GitHub discovery
126-
# The version will be read from pyproject.toml
127-
128-
# Handle nightly builds
129-
is_nightly = (
130-
any("nightly" in arg for arg in sys.argv) or os.getenv("TORCHRL_NIGHTLY") == "1"
131-
)
132-
logging.info(f"is_nightly: {is_nightly}")
133-
134-
# Read version from version.txt
135-
version_txt = os.path.join(cwd, "version.txt")
136-
with open(version_txt) as f:
137-
base_version = f.readline().strip()
138-
139-
if os.getenv("TORCHRL_BUILD_VERSION"):
140-
version = os.getenv("TORCHRL_BUILD_VERSION")
141-
elif is_nightly:
142-
from datetime import date
143-
144-
today = date.today()
145-
version = f"{today.year}.{today.month}.{today.day}"
146-
logging.info(f"Using nightly version: {version}")
147-
# Update version.txt for nightly builds
148-
with open(version_txt, "w") as f:
149-
f.write(f"{version}\n")
150-
else:
151-
# For regular builds, append git hash for development versions
152-
try:
153-
import subprocess
154-
155-
git_sha = (
156-
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd)
157-
.decode("ascii")
158-
.strip()[:7]
159-
)
160-
version = f"{base_version}+{git_sha}"
161-
logging.info(f"Using development version: {version}")
162-
except Exception:
163-
version = base_version
164-
logging.info(f"Using base version: {version}")
165-
166-
# Always write the version file to ensure it's up to date
167-
write_version_file(version)
168-
logging.info(f"Building torchrl-{version}")
169-
170-
# Verify the version file was written correctly
171-
try:
172-
with open(os.path.join(cwd, "torchrl", "version.py")) as f:
173-
content = f.read()
174-
if f"__version__ = '{version}'" in content:
175-
logging.info(f"Version file correctly contains: {version}")
176-
else:
177-
logging.error(
178-
f"Version file does not contain expected version: {version}"
179-
)
180-
except Exception as e:
181-
logging.error(f"Failed to verify version file: {e}")
182-
183-
# Handle package name for nightly builds
184-
if is_nightly:
185-
package_name = "torchrl-nightly" # Use torchrl-nightly for PyPI uploads
186-
else:
187-
package_name = "torchrl" # Use torchrl for regular builds and GitHub discovery
188-
56+
def main():
18957
setup_kwargs = {
190-
"name": package_name,
191-
# Only C++ extension configuration
19258
"ext_modules": get_extensions(),
193-
"cmdclass": {
194-
"build_ext": BuildExtension.with_options(),
195-
"clean": clean,
196-
},
197-
"zip_safe": False,
198-
"package_data": {
199-
"torchrl": ["version.py"],
200-
},
201-
"include_package_data": True,
202-
"packages": ["torchrl"],
59+
"cmdclass": {"build_ext": BuildExtension.with_options()},
20360
}
204-
205-
# Handle nightly tensordict dependency override
206-
if is_nightly:
207-
setup_kwargs["install_requires"] = [
208-
"torch>=2.1.0",
209-
"numpy",
210-
"packaging",
211-
"cloudpickle",
212-
"tensordict-nightly",
213-
]
214-
215-
# Override pyproject.toml settings for nightly builds
216-
if is_nightly:
217-
# Add all the metadata from pyproject.toml but override the name
218-
setup_kwargs.update(
219-
{
220-
"description": "A modular, primitive-first, python-first PyTorch library for Reinforcement Learning",
221-
"long_description": (Path(__file__).parent / "README.md").read_text(
222-
encoding="utf8"
223-
),
224-
"long_description_content_type": "text/markdown",
225-
"author": "torchrl contributors",
226-
"author_email": "[email protected]",
227-
"url": "https://github.com/pytorch/rl",
228-
"classifiers": [
229-
"Programming Language :: Python :: 3.9",
230-
"Programming Language :: Python :: 3.10",
231-
"Programming Language :: Python :: 3.11",
232-
"Programming Language :: Python :: 3.12",
233-
"Operating System :: OS Independent",
234-
"Development Status :: 4 - Beta",
235-
"Intended Audience :: Developers",
236-
"Intended Audience :: Science/Research",
237-
"Topic :: Scientific/Engineering :: Artificial Intelligence",
238-
],
239-
}
240-
)
241-
61+
24262
setup(**setup_kwargs)
24363

244-
24564
if __name__ == "__main__":
246-
_main()
65+
main()

torchrl/__init__.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,7 @@
2525
try:
2626
from .version import __version__
2727
except ImportError:
28-
# Fallback: try to read version from version.txt and append git hash
29-
try:
30-
import os
31-
import subprocess
32-
33-
version_file = os.path.join(os.path.dirname(__file__), "..", "version.txt")
34-
with open(version_file) as f:
35-
base_version = f.read().strip()
36-
37-
# Try to get git hash
38-
try:
39-
git_sha = (
40-
subprocess.check_output(
41-
["git", "rev-parse", "HEAD"], cwd=os.path.dirname(version_file)
42-
)
43-
.decode("ascii")
44-
.strip()[:7]
45-
)
46-
__version__ = f"{base_version}+{git_sha}"
47-
except Exception:
48-
__version__ = base_version
49-
except Exception:
50-
__version__ = "unknown"
28+
__version__ = "0.0.0+unknown"
5129

5230
try:
5331
from torch.compiler import is_dynamo_compiling

0 commit comments

Comments
 (0)