Skip to content

Conversation

maxreciprocate
Copy link
Collaborator

@maxreciprocate maxreciprocate commented Feb 6, 2023

@cat-state
Copy link
Collaborator

cat-state commented Feb 6, 2023

Thanks @reciprocated and @ZHAOTING . Does this increase memory usage? Or are the model weights still being cast to fp16 somewhere else? I see in the report that the bf16 and fp16 runs have about the same but theres a run which seems to have more memory use than the baseline:

2023-02-06-142143_738x134_scrot

@maxreciprocate
Copy link
Collaborator Author

No, it doesn't increase the memory usage. For the baseline I took by accident a prior bf16 run, while on the fix is a run without accelerate, so the difference here is between bf16 and fp32 (it's not possible to make a fp16 run on the main). Just in case I made a few more runs with fp32, zero-fp32, zero-fp16, zero-bf16:
https://wandb.ai/sorry/trlx/reports/Set-deepspeed-s-fp16-auto_cast-to-false-279--VmlldzozNDk0OTMz

@cat-state
Copy link
Collaborator

Ah okay, got it. This LGTM!

@cat-state cat-state merged commit f5a7d78 into main Feb 6, 2023
@cat-state cat-state deleted the fix-ppo-fp16 branch February 6, 2023 17:13
@cat-state cat-state restored the fix-ppo-fp16 branch February 6, 2023 17:13
@Jiaxin-Wen
Copy link

Jiaxin-Wen commented Feb 8, 2023

it's not possible to make a fp16 run on the main

Hi, I have some related problems.
Specifically, I am running the example of summarize_rlhf. I observe that "mixed_preicision" is set to "no" in summarize_rlhf/configs/default_accelerate_config.yaml, while "fp16" is enabled in summarize_rlhf/configs/ds_config_trlx_gptj_summarize.json.

First, without any modification, the code will fail with the following error:

ValueError: When using `deepspeed_config_file`, the following accelerate config variables will be ignored: ['gradient_accumulation_steps', 'gradient_clipping', 'zero_stage', 'offload_optimizer_device', 'offload_param_device', 
'zero3_save_16bit_model', 'mixed_precision'].

And after I delete "mixed_precision" in summarize_rlhf/configs/default_accelerate_config.yaml, the code will again with the following error:

│ /data/wenjiaxin/home/trlx/trlx/trainer/nn/ppo_models.py:220 in loss                              │                                                                                                                               
│                                                                                                  │                                                                                                                               
│    217 │   │   │   │   value_loss=vf_loss.item(),                                                │                                                                                                                               
│    218 │   │   │   ),                                                                            │                                                                                                                               
│    219 │   │   │   values=dict(                                                                  │                                                                                                                               
│ ❱  220 │   │   │   │   get_tensor_stats(values, mask, n),                                        │                                                                                                                               
│    221 │   │   │   │   values_error=torch.sum(((values - returns) * mask) ** 2) / n,             │                                                                                                                               
│    222 │   │   │   │   clipfrac=vf_clipfrac,                                                     │                                                                                                                               
│    223 │   │   │   ),                                                                            │                                                                                                                               
│                                                                                                  │                                                                                                                               
│ /data/wenjiaxin/home/trlx/trlx/utils/modeling.py:242 in get_tensor_stats                         │                                                                                                                               
│                                                                                                  │                                                                                                                               
│   239 │   mean = (xs * mask).sum() / n                                                           │                                                                                                                               
│   240 │   return dict(                                                                           │                                                                                                                               
│   241 │   │   mean=mean,                                                                         │                                                                                                                               
│ ❱ 242 │   │   min=torch.where(mask.bool(), xs, np.inf).min(),                                    │                                                                                                                               
│   243 │   │   max=torch.where(mask.bool(), xs, -np.inf).max(),                                   │                                                                                                                               
│   244 │   │   std=torch.sqrt(((xs - mean) * mask).pow(2).sum() / n),                             │                                                                                                                               
│   245 │   )                                                                                      │                                                                                                                               
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯                                                                                                                               
RuntimeError: expected scalar type c10::Half but found double   

@maxreciprocate
Copy link
Collaborator Author

Hi! @XWwwwww, I'm able to reproduce the first part, it's related to the newest accelerate release which started to raise errors in this case, however I'm not getting the runtime error you're having. Maybe it's related to some particular versioning (at least that's what a quick google search suggests). Are you having the same problem for other examples as well?

@Jiaxin-Wen
Copy link

Hi! @reciprocated
For the second problem, I fix it as follows:

tmp = np.inf * torch.ones_like(xs)
min=torch.where(mask.bool(), xs, tmp).min(),
max=torch.where(mask.bool(), xs, -tmp).max(),

and my torch version is 1.10.1

@PhungVanDuy
Copy link
Collaborator

Can you try to upgrade your pytorch version? ref: #264 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants