Skip to content
Merged

Dev #3982

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
c7fb5b1
SDNQ fix VAE quant
Disty0 Jun 5, 2025
e25890b
SDNQ INT8 matmul support for Conv2d
Disty0 Jun 5, 2025
ad2a4ad
Cleanup
Disty0 Jun 5, 2025
9b55ffe
SDNQ fix VAE x2
Disty0 Jun 5, 2025
1a00517
SDNQ FP8 matmul support for Conv2d
Disty0 Jun 5, 2025
8c03f78
Fix bias is None
Disty0 Jun 5, 2025
413cf54
Update changelog
Disty0 Jun 5, 2025
778ca04
Update wiki
Disty0 Jun 5, 2025
6bcd335
Update changelog and wiki
Disty0 Jun 5, 2025
976f0ba
Cleanup
Disty0 Jun 5, 2025
06fcc3c
SDNQ add quantized matmul support for Conv1d and Conv3d too
Disty0 Jun 5, 2025
9a54efd
Cleanup
Disty0 Jun 5, 2025
b5d6b57
Update PyTorch to 2.7.1
Disty0 Jun 6, 2025
2ccc76a
Increase medvram mode to 12 GB and update wiki
Disty0 Jun 6, 2025
7679028
Override CPU to use FP32 by default
Disty0 Jun 6, 2025
c039ba9
Fix Meissonic by adding multiple generator support
Disty0 Jun 6, 2025
089e437
Don't set attention processors with models outside of SD 1.5 and SDXL
Disty0 Jun 6, 2025
5624671
Fix PixArt Sigma Small and Large
Disty0 Jun 6, 2025
2f7aff5
Fix TAESD previews with PixArt
Disty0 Jun 6, 2025
8e08ef0
Fix VAE Tiling with non-default tile sizes
Disty0 Jun 6, 2025
92d2379
Relax Python version check with Zluda
Disty0 Jun 9, 2025
bd2d9d1
Python 3.13 support
Disty0 Jun 9, 2025
92dbf39
Update changelog
Disty0 Jun 9, 2025
58b646e
SDNQ add 5-bit and 3-bit quantization support
Disty0 Jun 9, 2025
5eed913
Split SDNQ into multiple files and linting
Disty0 Jun 10, 2025
5bd7a08
don't use inplace ops in quant layer
Disty0 Jun 10, 2025
33fadf9
SDNQ add 7 bit support
Disty0 Jun 10, 2025
f5b575d
Update changelog and wiki
Disty0 Jun 10, 2025
d2ffee1
ROCm don't override user set HSA_OVERRIDE_GFX_VERSION
Disty0 Jun 10, 2025
a6b58ef
ROCm 6.4 support with --use-nightly
Disty0 Jun 10, 2025
7ccd94e
Force upgrade pip when installing Torch
Disty0 Jun 10, 2025
4436a58
Cleanup
Disty0 Jun 10, 2025
78f99ab
SDNQ use group_size / 2 for convs
Disty0 Jun 10, 2025
c81b712
Make VAE options not require model reload
Disty0 Jun 10, 2025
64f49fb
ROCm log HSA_OVERRIDE_GFX_VERSION skip
Disty0 Jun 11, 2025
df6b13e
Don't set gfx override with RX 9000 and above
Disty0 Jun 11, 2025
fd0c5b0
Update changelog
Disty0 Jun 11, 2025
71be3c7
ROCm don't override gfx with gfx1100 and gfx1101 + rocm 6.4
Disty0 Jun 11, 2025
6aa5c08
Cleanup and update changelog
Disty0 Jun 11, 2025
74b6edf
revert gfx1101
Disty0 Jun 11, 2025
5cefa64
SDNQ update accepted dtypes
Disty0 Jun 11, 2025
dd84fb5
Always set sdpa params
Disty0 Jun 11, 2025
26545b6
Add warning for incompatible attention processors
Disty0 Jun 11, 2025
2d05396
SDNQ simplify sym scale formula
Disty0 Jun 11, 2025
5e013fb
SDNQ optimize input quantization and use the word quantize instead of…
Disty0 Jun 12, 2025
41f14df
Fix TAESD and double downloading with Lumina2
Disty0 Jun 12, 2025
c8f9478
IPEX fix Lumina2
Disty0 Jun 12, 2025
cb4684c
SNDQ add separate quant mode option for Text Encoders
Disty0 Jun 13, 2025
e68f927
Disable custom atten processors for non SD 1.5 / SDXL models
Disty0 Jun 13, 2025
1fca565
Cleanup
Disty0 Jun 13, 2025
fb7280c
Flux quanto fix logged dtype
Disty0 Jun 13, 2025
90e76b2
Cleanup
Disty0 Jun 13, 2025
45827a9
IPEX fix torch.cuda.set_device
Disty0 Jun 13, 2025
fb72c6f
Zluda use exact torch version
Disty0 Jun 13, 2025
2ba64ab
Cleanup
Disty0 Jun 13, 2025
8f8e5ce
Cleanup x2
Disty0 Jun 13, 2025
c01802d
SDNQ fix transformers llm
Disty0 Jun 13, 2025
fd58352
Update requirements
Disty0 Jun 14, 2025
2419420
Fix OmniGen
Disty0 Jun 14, 2025
25fc009
SDNQ use quantize_device and return_device args and fix decompress_fp…
Disty0 Jun 14, 2025
d31df8c
SDNQ fuse bias into dequantizer with matmul
Disty0 Jun 14, 2025
223a01d
Update changelog
Disty0 Jun 15, 2025
c307906
Update CHANGELOG.md
Disty0 Jun 15, 2025
c5b233b
Merge pull request #3981 from vladmandic/master
Disty0 Jun 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,46 @@
# Change Log for SD.Next

## Update for 2025-06-15

- **Feature**
- Support for Python 3.13

- **Changes**
- Increase the medvram mode threshold from 8GB to 12GB
- Set CPU backend to use FP32 by default
- Relax Python version checks for Zluda
- Make VAE options not require model reload
- Add warning about incompatible attention processors

- **Torch**
- Set default to `torch==2.7.1`
- Force upgrade pip when installing Torch

- **ROCm**
- Support ROCm 6.4 with `--use-nightly`
- Don't override user set gfx version
- Don't override gfx version with RX 9000
- Fix flash-atten repo

- **SDNQ Quantization**
- Add group size support for convolutional layers
- Add quantized matmul support for for convolutional layers
- Add 7-bit, 5-bit and 3-bit quantization support
- Add separate quant mode option for Text Encoders
- Fix forced FP32 with tensorwise FP8 matmul
- Fix PyTorch <= 2.4 compatibility with FP8 matmul
- Fix VAE with conv quant
- Don't ignore the Quantize with GPU option with offload mode `none` and `model`

- **Fixes**
- Meissonic with multiple generators
- OmniGen with new transformers
- Invalid attention processors
- PixArt Sigma Small and Large loading
- TAESD previews with PixArt and Lumina 2
- VAE Tiling with non-default tile sizes
- Lumina 2 with IPEX

## Update for 2025-06-02

### Highlights for 2025-06-02
Expand Down Expand Up @@ -31,7 +72,6 @@ Take a look at [Docs](https://github.com/vladmandic/sdnext/wiki/Docs), [Hints](h
- `INT4` -> `uint4`
- Add `float8_e4m3fn`, `float8_e5m2`, `float8_e4m3fnuz`, `float8_e5m2fnuz`, `int6`, `uint6`, `int2`, `uint2` and `uint1` support
- Add quantized matmul support for `float8_e4m3fn` and `float8_e5m2`
- Add group size support for convolutional layers
- Set the default quant mode to `pre`
- Use per token input quant with int8 and fp8 quantized matmul
- Implement better layer hijacks
Expand Down
64 changes: 38 additions & 26 deletions installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def get_platform():
def check_python(supported_minors=[], experimental_minors=[], reason=None):
if supported_minors is None or len(supported_minors) == 0:
supported_minors = [9, 10, 11, 12]
experimental_minors = []
experimental_minors = [13]
t_start = time.time()
if args.quick:
return
Expand Down Expand Up @@ -546,7 +546,7 @@ def check_diffusers():
t_start = time.time()
if args.skip_all or args.skip_git or args.experimental:
return
sha = '6508da6f06a0da1054ae6a808d0025c04b70f0e8' # diffusers commit hash
sha = '8adc6003ba4dbf5b61bb4f1ce571e9e55e145a99' # diffusers commit hash
pkg = pkg_resources.working_set.by_key.get('diffusers', None)
minor = int(pkg.version.split('.')[1] if pkg is not None else 0)
cur = opts.get('diffusers_version', '') if minor > 0 else ''
Expand Down Expand Up @@ -582,7 +582,7 @@ def install_cuda():
cmd = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --extra-index-url https://download.pytorch.org/whl/nightly/cu126')
else:
# cmd = os.environ.get('TORCH_COMMAND', 'torch==2.6.0+cu126 torchvision==0.21.0+cu126 --index-url https://download.pytorch.org/whl/cu126')
cmd = os.environ.get('TORCH_COMMAND', 'torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128')
cmd = os.environ.get('TORCH_COMMAND', 'torch==2.7.1+cu128 torchvision==0.22.1+cu128 --index-url https://download.pytorch.org/whl/cu128')
return cmd


Expand Down Expand Up @@ -633,7 +633,7 @@ def install_rocm_zluda():
log.info(msg)

if sys.platform == "win32": # TODO install: enable ROCm for windows when available
check_python(supported_minors=[10, 11], reason='ZLUDA backend requires Python 3.10 or 3.11')
#check_python(supported_minors=[9, 10, 11, 12, 13], reason='ZLUDA backend requires a Python version between 3.9 and 3.13')

if args.device_id is not None:
if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None:
Expand All @@ -655,30 +655,32 @@ def install_rocm_zluda():
if error is None:
try:
zluda_installer.load()
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.1+cu118 torchvision==0.22.1+cu118 --index-url https://download.pytorch.org/whl/cu118')
except Exception as e:
error = e
log.warning(f'Failed to load ZLUDA: {e}')
if error is not None:
log.info('Using CPU-only torch')
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision')
else:
check_python(supported_minors=[9, 10, 11, 12], reason='ROCm backend requires a Python version between 3.9 and 3.12')
#check_python(supported_minors=[9, 10, 11, 12, 13], reason='ROCm backend requires a Python version between 3.9 and 3.13')

if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", None) is None:
os.environ.setdefault('TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL', '1')

if args.use_nightly:
if rocm.version is None or float(rocm.version) >= 6.3: # assume the latest if version check fails
if rocm.version is None or float(rocm.version) >= 6.4: # assume the latest if version check fails
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm6.4')
elif rocm.version == "6.3":
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm6.3')
else: # oldest rocm version on nightly is 6.2.4
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4')
else:
if rocm.version is None or float(rocm.version) >= 6.3: # assume the latest if version check fails
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.0+rocm6.3 torchvision==0.22.0+rocm6.3 --index-url https://download.pytorch.org/whl/rocm6.3')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.1+rocm6.3 torchvision==0.22.1+rocm6.3 --index-url https://download.pytorch.org/whl/rocm6.3')
elif rocm.version == "6.2":
# use rocm 6.2.4 instead of 6.2 as torch==2.7.0+rocm6.2 doesn't exists
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.0+rocm6.2.4 torchvision==0.22.0+rocm6.2.4 --index-url https://download.pytorch.org/whl/rocm6.2.4')
# use rocm 6.2.4 instead of 6.2 as torch==2.7.1+rocm6.2 doesn't exists
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.1+rocm6.2.4 torchvision==0.22.1+rocm6.2.4 --index-url https://download.pytorch.org/whl/rocm6.2.4')
elif rocm.version == "6.1":
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0+rocm6.1 torchvision==0.21.0+rocm6.1 --index-url https://download.pytorch.org/whl/rocm6.1')
elif rocm.version == "6.0":
Expand All @@ -696,22 +698,21 @@ def install_rocm_zluda():
log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}')
rocm.set_blaslt_enabled(device.blaslt_supported)

if device is None:
log.debug('ROCm: HSA_OVERRIDE_GFX_VERSION auto config skipped')
if device is None or os.environ.get("HSA_OVERRIDE_GFX_VERSION", None) is not None:
log.info(f'ROCm: HSA_OVERRIDE_GFX_VERSION auto config skipped: device={device.name if device is not None else None} version={os.environ.get("HSA_OVERRIDE_GFX_VERSION", None)}')
else:
gfx_ver = device.get_gfx_version()
if gfx_ver is not None:
os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', gfx_ver)
else:
log.warning(f'ROCm: device={device.name} could not auto-detect HSA version')
log.info(f'ROCm: HSA_OVERRIDE_GFX_VERSION config overridden: device={device.name} version={os.environ.get("HSA_OVERRIDE_GFX_VERSION", None)}')

ts('amd', t_start)
return torch_command


def install_ipex(torch_command):
def install_ipex():
t_start = time.time()
check_python(supported_minors=[9, 10, 11, 12], reason='IPEX backend requires a Python version between 3.9 and 3.12')
#check_python(supported_minors=[9, 10, 11, 12, 13], reason='IPEX backend requires a Python version between 3.9 and 3.13')
args.use_ipex = True # pylint: disable=attribute-defined-outside-init
log.info('IPEX: Intel OneAPI toolkit detected')

Expand All @@ -736,20 +737,20 @@ def install_ipex(torch_command):
if args.use_nightly:
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu')
else:
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.0+xpu torchvision==0.22.0+xpu --index-url https://download.pytorch.org/whl/xpu')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.1+xpu torchvision==0.22.1+xpu --index-url https://download.pytorch.org/whl/xpu')

ts('ipex', t_start)
return torch_command


def install_openvino(torch_command):
def install_openvino():
t_start = time.time()
check_python(supported_minors=[9, 10, 11, 12], reason='OpenVINO backend requires a Python version between 3.9 and 3.12')
#check_python(supported_minors=[9, 10, 11, 12, 13], reason='OpenVINO backend requires a Python version between 3.9 and 3.13')
log.info('OpenVINO: selected')
if sys.platform == 'darwin':
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.0 torchvision==0.22.0')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.1 torchvision==0.22.1')
else:
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.0+cpu torchvision==0.22.0+cpu --index-url https://download.pytorch.org/whl/cpu')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.1+cpu torchvision==0.22.1+cpu --index-url https://download.pytorch.org/whl/cpu')

install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.1.0'), 'openvino')
install(os.environ.get('NNCF_COMMAND', 'nncf==2.16.0'), 'nncf')
Expand Down Expand Up @@ -840,15 +841,15 @@ def check_torch():
elif is_rocm_available and (args.use_rocm or args.use_zluda): # prioritize rocm
torch_command = install_rocm_zluda()
elif allow_ipex and args.use_ipex: # prioritize ipex
torch_command = install_ipex(torch_command)
torch_command = install_ipex()
elif allow_openvino and args.use_openvino: # prioritize openvino
torch_command = install_openvino(torch_command)
torch_command = install_openvino()
elif is_cuda_available:
torch_command = install_cuda()
elif is_rocm_available:
torch_command = install_rocm_zluda()
elif is_ipex_available:
torch_command = install_ipex(torch_command)
torch_command = install_ipex()
else:
machine = platform.machine()
if sys.platform == 'darwin':
Expand All @@ -867,6 +868,7 @@ def check_torch():
if 'torch' in torch_command and not args.version:
if not installed('torch'):
log.info(f'Torch: download and install in progress... cmd="{torch_command}"')
install('--upgrade pip', 'pip', reinstall=True) # pytorch rocm is too large for older pip
install(torch_command, 'torch torchvision', quiet=True)
else:
try:
Expand Down Expand Up @@ -1155,8 +1157,8 @@ def update_setuptools():
def install_optional():
t_start = time.time()
log.info('Installing optional requirements...')
install('basicsr')
install('gfpgan')
install('git+https://github.com/Disty0/BasicSR@2b6a12c28e0c81bfb13b7e984144f0b0f5461484', 'basicsr')
install('git+https://github.com/Disty0/GFPGAN@09b1190eabbc77e5f15c61fa7c38a2064b403e20', 'gfpgan')
install('clean-fid')
install('pillow-jxl-plugin==1.3.3', ignore=True)
install('optimum-quanto==0.2.7', ignore=True)
Expand Down Expand Up @@ -1188,6 +1190,16 @@ def install_requirements():
pr.enable()
if args.skip_requirements and not args.requirements:
return
if int(sys.version_info.minor) >= 13:
install('audioop-lts')
# gcc 15 patch
backup_cmake_policy = os.environ.get('CMAKE_POLICY_VERSION_MINIMUM', None)
backup_cxxflags = os.environ.get('CXXFLAGS', None)
os.environ.setdefault('CMAKE_POLICY_VERSION_MINIMUM', '3.5')
os.environ.setdefault('CXXFLAGS', '-include cstdint')
install('git+https://github.com/google/sentencepiece#subdirectory=python', 'sentencepiece')
os.environ.setdefault('CMAKE_POLICY_VERSION_MINIMUM', backup_cmake_policy)
os.environ.setdefault('CXXFLAGS', backup_cxxflags)
if not installed('diffusers', quiet=True): # diffusers are not installed, so run initial installation
global quick_allowed # pylint: disable=global-statement
quick_allowed = False
Expand Down
6 changes: 2 additions & 4 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def test_fp16():
if fp16_ok is not None:
return fp16_ok
if opts.cuda_dtype != 'FP16': # don't override if the user sets it
if sys.platform == "darwin" or backend == 'openvino': # override
if sys.platform == "darwin" or backend in {'openvino', 'cpu'}: # override
fp16_ok = False
return fp16_ok
elif backend == 'rocm':
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_bf16():
if bf16_ok is not None:
return bf16_ok
if opts.cuda_dtype != 'BF16': # don't override if the user sets it
if sys.platform == "darwin" or backend == 'openvino' or backend == 'directml': # override
if sys.platform == "darwin" or backend in {'openvino', 'directml', 'cpu'}: # override
bf16_ok = False
return bf16_ok
elif backend == 'rocm' or backend == 'zluda':
Expand Down Expand Up @@ -426,8 +426,6 @@ def override_ipex_math():

def set_sdpa_params():
try:
if opts.cross_attention_optimization != "Scaled-Dot-Product":
return
try:
global sdpa_original # pylint: disable=global-statement
if sdpa_original is not None:
Expand Down
1 change: 0 additions & 1 deletion modules/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
Expand Down
32 changes: 32 additions & 0 deletions modules/intel/ipex/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,46 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
return emb


def apply_rotary_emb(x, freqs_cis, use_real: bool = True, use_real_unbind_dim: int = -1):
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)

if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")

out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
# force cpu with Alchemist
x_rotated = torch.view_as_complex(x.to("cpu").float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.to("cpu").unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x).to(x.device)


def ipex_diffusers(device_supports_fp64=False):
diffusers.utils.torch_utils.fourier_filter = fourier_filter
if not device_supports_fp64:
# get around lazy imports
from diffusers.models import embeddings as diffusers_embeddings # pylint: disable=import-error, unused-import # noqa: F401
from diffusers.models import transformers as diffusers_transformers # pylint: disable=import-error, unused-import # noqa: F401
from diffusers.models import controlnets as diffusers_controlnets # pylint: disable=import-error, unused-import # noqa: F401
diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid = get_1d_sincos_pos_embed_from_grid
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
diffusers.models.embeddings.apply_rotary_emb = apply_rotary_emb
diffusers.models.transformers.transformer_flux.FluxPosEmbed = FluxPosEmbed
diffusers.models.transformers.transformer_lumina2.apply_rotary_emb = apply_rotary_emb
diffusers.models.controlnets.controlnet_flux.FluxPosEmbed = FluxPosEmbed
diffusers.models.transformers.transformer_hidream_image.rope = hidream_rope
7 changes: 7 additions & 0 deletions modules/intel/ipex/hijacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@ def torch_cuda_device(device):
else:
return torch.xpu.device(device)

@wraps(torch.cuda.set_device)
def torch_cuda_set_device(device):
if check_cuda(device):
torch.xpu.set_device(return_xpu(device))
else:
torch.xpu.set_device(device)

# torch.Generator has to be a class for isinstance checks
original_torch_Generator = torch.Generator
Expand Down Expand Up @@ -412,6 +418,7 @@ def ipex_hijacks():
torch.load = torch_load
torch.cuda.synchronize = torch_cuda_synchronize
torch.cuda.device = torch_cuda_device
torch.cuda.set_device = torch_cuda_set_device

torch.Generator = torch_Generator
torch._C.Generator = torch_Generator
Expand Down
Loading