Skip to content

Conversation

rmitsch
Copy link
Collaborator

@rmitsch rmitsch commented Nov 9, 2023

Description

Improve HF device handling:

  • Consider device/device_map conflicts.
  • Always move inputs to model device for models using AutoModelForCausalLM.

Context/motivation: #324 (reply in thread), #324. This PR should help with both.

Corresponding documentation PR

-

Types of change

Checklist

  • I confirm that I have the right to submit this contribution under the project's MIT license.
  • I ran all tests in tests and usage_examples/tests, and all new and existing tests passed. This includes
    • all external tests (i. e. pytest ran with --external)
    • all tests requiring a GPU (i. e. pytest ran with --gpu)
  • My changes don't require a change to the documentation, or if they do, I've added all required information.

@rmitsch rmitsch added bug Something isn't working feat/model Feature: models labels Nov 9, 2023
@rmitsch rmitsch marked this pull request as draft November 9, 2023 16:08
@rmitsch rmitsch added the Test GPU Run GPU tests label Nov 10, 2023
@rmitsch rmitsch marked this pull request as ready for review November 10, 2023 09:28
@adrianeboyd
Copy link
Contributor

I do think it's a good idea to test with accelerate, at least in the future. It looks like that torch bug will be fixed in torch v2.1.1.

@rmitsch
Copy link
Collaborator Author

rmitsch commented Nov 13, 2023

I do think it's a good idea to test with accelerate, at least in the future. It looks like that torch bug will be fixed in torch v2.1.1.

I'll do that in a follow-up PR as I want this to be included in the upcoming v0.6.3 release (release today or tomorrow).

@svlandeg
Copy link
Contributor

There's no documentation update necessary after this?

@rmitsch
Copy link
Collaborator Author

rmitsch commented Nov 13, 2023

There's no documentation update necessary after this?

I wouldn't think so - this is part bugfix and part the ability to process torch_dtype in the config, which should have been there anyway, as we claim all HF args can be set and are passed on to the HF model. I'd say describing that in the release notes should be sufficient, but we can make this explicit in the docs as well.

@rmitsch rmitsch merged commit 7687d44 into main Nov 13, 2023
@svlandeg svlandeg deleted the fix/hf-device-handling branch November 13, 2023 16:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working feat/model Feature: models Test GPU Run GPU tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants