1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- import glob
6
- import logging
7
1
import os
8
- import shutil
9
2
import sys
3
+ import glob
4
+ import logging
10
5
from pathlib import Path
11
-
12
- from setuptools import Command , setup
6
+ from setuptools import setup
13
7
from torch .utils .cpp_extension import BuildExtension , CppExtension
14
8
15
- cwd = os .path .dirname (os .path .abspath (__file__ ))
16
- ROOT_DIR = Path (__file__ ).parent .resolve ()
17
-
18
-
19
- def write_version_file (version ):
20
- version_path = os .path .join (cwd , "torchrl" , "version.py" )
21
- logging .info (f"Writing version file to: { version_path } " )
22
- logging .info (f"Version to write: { version } " )
23
-
24
- # Get PyTorch version during build
25
- try :
26
- import torch
27
-
28
- pytorch_version = torch .__version__
29
- except ImportError :
30
- pytorch_version = "unknown"
31
-
32
- # Get git sha
33
- try :
34
- import subprocess
35
-
36
- sha = (
37
- subprocess .check_output (["git" , "rev-parse" , "HEAD" ], cwd = cwd )
38
- .decode ("ascii" )
39
- .strip ()
40
- )
41
- except Exception :
42
- sha = "Unknown"
43
-
44
- with open (version_path , "w" ) as f :
45
- f .write (f"__version__ = '{ version } '\n " )
46
- f .write (f"git_version = { repr (sha )} \n " )
47
- f .write (f"pytorch_version = '{ pytorch_version } '\n " )
48
-
49
- logging .info ("Version file written successfully" )
50
-
51
-
52
- class clean (Command ):
53
- user_options = []
54
-
55
- def initialize_options (self ):
56
- pass
57
-
58
- def finalize_options (self ):
59
- pass
60
-
61
- def run (self ):
62
- # Remove torchrl extension
63
- for path in (ROOT_DIR / "torchrl" ).glob ("**/*.so" ):
64
- logging .info (f"removing '{ path } '" )
65
- path .unlink ()
66
- # Remove build directory
67
- build_dirs = [
68
- ROOT_DIR / "build" ,
69
- ]
70
- for path in build_dirs :
71
- if path .exists ():
72
- logging .info (f"removing '{ path } ' (and everything under it)" )
73
- shutil .rmtree (str (path ), ignore_errors = True )
74
-
75
-
76
9
def get_extensions ():
77
10
extension = CppExtension
78
11
@@ -104,7 +37,7 @@ def get_extensions():
104
37
cpp_files = glob .glob (os .path .join (extensions_dir , "*.cpp" ))
105
38
sources = [os .path .relpath (f ) for f in cpp_files ]
106
39
107
- include_dirs = ["." ]
40
+ include_dirs = ["." , "torchrl/csrc" ]
108
41
python_include_dir = os .getenv ("PYTHON_INCLUDE_DIR" )
109
42
if python_include_dir is not None :
110
43
include_dirs .append (python_include_dir )
@@ -120,127 +53,13 @@ def get_extensions():
120
53
121
54
return ext_modules
122
55
123
-
124
- def _main ():
125
- # Always use "torchrl" as the project name for GitHub discovery
126
- # The version will be read from pyproject.toml
127
-
128
- # Handle nightly builds
129
- is_nightly = (
130
- any ("nightly" in arg for arg in sys .argv ) or os .getenv ("TORCHRL_NIGHTLY" ) == "1"
131
- )
132
- logging .info (f"is_nightly: { is_nightly } " )
133
-
134
- # Read version from version.txt
135
- version_txt = os .path .join (cwd , "version.txt" )
136
- with open (version_txt ) as f :
137
- base_version = f .readline ().strip ()
138
-
139
- if os .getenv ("TORCHRL_BUILD_VERSION" ):
140
- version = os .getenv ("TORCHRL_BUILD_VERSION" )
141
- elif is_nightly :
142
- from datetime import date
143
-
144
- today = date .today ()
145
- version = f"{ today .year } .{ today .month } .{ today .day } "
146
- logging .info (f"Using nightly version: { version } " )
147
- # Update version.txt for nightly builds
148
- with open (version_txt , "w" ) as f :
149
- f .write (f"{ version } \n " )
150
- else :
151
- # For regular builds, append git hash for development versions
152
- try :
153
- import subprocess
154
-
155
- git_sha = (
156
- subprocess .check_output (["git" , "rev-parse" , "HEAD" ], cwd = cwd )
157
- .decode ("ascii" )
158
- .strip ()[:7 ]
159
- )
160
- version = f"{ base_version } +{ git_sha } "
161
- logging .info (f"Using development version: { version } " )
162
- except Exception :
163
- version = base_version
164
- logging .info (f"Using base version: { version } " )
165
-
166
- # Always write the version file to ensure it's up to date
167
- write_version_file (version )
168
- logging .info (f"Building torchrl-{ version } " )
169
-
170
- # Verify the version file was written correctly
171
- try :
172
- with open (os .path .join (cwd , "torchrl" , "version.py" )) as f :
173
- content = f .read ()
174
- if f"__version__ = '{ version } '" in content :
175
- logging .info (f"Version file correctly contains: { version } " )
176
- else :
177
- logging .error (
178
- f"Version file does not contain expected version: { version } "
179
- )
180
- except Exception as e :
181
- logging .error (f"Failed to verify version file: { e } " )
182
-
183
- # Handle package name for nightly builds
184
- if is_nightly :
185
- package_name = "torchrl-nightly" # Use torchrl-nightly for PyPI uploads
186
- else :
187
- package_name = "torchrl" # Use torchrl for regular builds and GitHub discovery
188
-
56
+ def main ():
189
57
setup_kwargs = {
190
- "name" : package_name ,
191
- # Only C++ extension configuration
192
58
"ext_modules" : get_extensions (),
193
- "cmdclass" : {
194
- "build_ext" : BuildExtension .with_options (),
195
- "clean" : clean ,
196
- },
197
- "zip_safe" : False ,
198
- "package_data" : {
199
- "torchrl" : ["version.py" ],
200
- },
201
- "include_package_data" : True ,
202
- "packages" : ["torchrl" ],
59
+ "cmdclass" : {"build_ext" : BuildExtension .with_options ()},
203
60
}
204
-
205
- # Handle nightly tensordict dependency override
206
- if is_nightly :
207
- setup_kwargs ["install_requires" ] = [
208
- "torch>=2.1.0" ,
209
- "numpy" ,
210
- "packaging" ,
211
- "cloudpickle" ,
212
- "tensordict-nightly" ,
213
- ]
214
-
215
- # Override pyproject.toml settings for nightly builds
216
- if is_nightly :
217
- # Add all the metadata from pyproject.toml but override the name
218
- setup_kwargs .update (
219
- {
220
- "description" : "A modular, primitive-first, python-first PyTorch library for Reinforcement Learning" ,
221
- "long_description" : (Path (__file__ ).parent / "README.md" ).read_text (
222
- encoding = "utf8"
223
- ),
224
- "long_description_content_type" : "text/markdown" ,
225
- "author" : "torchrl contributors" ,
226
- "author_email" :
"[email protected] " ,
227
- "url" : "https://github.com/pytorch/rl" ,
228
- "classifiers" : [
229
- "Programming Language :: Python :: 3.9" ,
230
- "Programming Language :: Python :: 3.10" ,
231
- "Programming Language :: Python :: 3.11" ,
232
- "Programming Language :: Python :: 3.12" ,
233
- "Operating System :: OS Independent" ,
234
- "Development Status :: 4 - Beta" ,
235
- "Intended Audience :: Developers" ,
236
- "Intended Audience :: Science/Research" ,
237
- "Topic :: Scientific/Engineering :: Artificial Intelligence" ,
238
- ],
239
- }
240
- )
241
-
61
+
242
62
setup (** setup_kwargs )
243
63
244
-
245
64
if __name__ == "__main__" :
246
- _main ()
65
+ main ()
0 commit comments