Skip to content

Very temporary fix to use LoRA with fp8 weight enabled #209

@wangziyao318

Description

@wangziyao318

This is only for those who want to lower vram usage a bit (to be able to use larger batch size) when using tensorRT with sd webui, at a cost of accuracy.

As far as I can tell, the fp8 optimization (currently available in sd webui dev branch, under Settings/Optimizations) would slightly reduce vram usage when using with tensorRT (from 10.9G to 9.6G to train certain SDXL, compared with from 9.7G to 6.8G without tensorRT), because the tensorRT side still stores data in fp16. The vram usage would decrease further if tensorRT has option to store data in fp8 as well.

LoRA can't be converted to tensorrt under fp8 due to dtype cast issue. Here's a very temporarily and dirty fix to get it work. (in dev branch)

In model_helper.py, line 178

wt = wt.cpu().detach().half().numpy().astype(np.float16)

In exporter.py, line 80 and 82

wt_hash = hash(wt.cpu().detach().half().numpy().astype(np.float16).data.tobytes())

delta = wt.half() - torch.tensor(onnx_data_mapping[initializer_name]).to(wt.device)

The idea is to add .half() to convert tensor dtype fp8 to fp16 to do calculation with other fp16 values. Also notice that cache fp16 weight for LoRA in Settings/Optimizations doesn't work in this fix, and therefore you need to apply more weight to the fp8 LoRA you used to achieve the same effect with LoRA in fp16.

By the way, if you check out sd webui dev branch which uses cu121, you can change to 9.0.1.post12.dev4 or the newer 9.2.0.post12.dev5 for cuda 12. (9.1.0.post12.dev4 building wheel failed in my pc, so I don't suggest it) Ensure to modify install.py to update the version number. (tensorRT still work even if you don't change)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions