Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c775dbc
init design
yiliu30 Nov 8, 2023
ca6981e
add safe api
Nov 10, 2023
f73de0d
add ut
Nov 10, 2023
559ca24
use safe api
Dec 7, 2023
3994cfb
fixed the save issue
Dec 21, 2023
0b36bc5
Merge branch 'master' into ds_pruner
yiliu30 Dec 21, 2023
d2db2a6
clean code
yiliu30 Dec 21, 2023
02a634d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2023
70a0dc0
clean code
yiliu30 Dec 21, 2023
227e9ef
fixed conflicts
yiliu30 Dec 21, 2023
3fcce31
merge with master
yiliu30 Dec 27, 2023
d4a9486
update the docs
yiliu30 Dec 27, 2023
1593d12
rename folder
yiliu30 Dec 27, 2023
a4b8619
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 27, 2023
2c1265c
add deepspeed check
yiliu30 Dec 27, 2023
268d11f
resolve conflicts
yiliu30 Dec 27, 2023
305a1f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 27, 2023
c24948b
add USE_DEEPSPEED flga
yiliu30 Dec 27, 2023
d50bd23
Merge branch 'ly/ds_p2' of https://github.com/intel/neural-compressor…
yiliu30 Dec 27, 2023
e24bfdb
fixed the ds check
yiliu30 Dec 27, 2023
ebd001a
fixed format
yiliu30 Dec 27, 2023
71aba5c
remove useless code
yiliu30 Dec 27, 2023
957237e
remove the cpu offload
yiliu30 Jan 4, 2024
5938ede
Merge branch 'ly/ds_p2' of https://github.com/intel/neural-compressor…
yiliu30 Jan 4, 2024
22f5bf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2024
031420d
disable coverage check for ds-related some code
yiliu30 Jan 4, 2024
f592fd6
add pragma no cover
yiliu30 Jan 5, 2024
ef399ac
fixed conflicts
yiliu30 Jan 5, 2024
dd27eee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2024
a4aca42
remove unused code
yiliu30 Jan 5, 2024
24c3fcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2024
7920381
Merge branch 'master' into ly/ds_p2
yiliu30 Jan 5, 2024
ed6395b
update the import path
yiliu30 Jan 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Step-by-Step
============

# single GPU
# Single GPU

```
export CUDA_VISIBLE_DEVICES=0
Expand All @@ -15,10 +15,11 @@ bash run.sh \
--pruning_frequency=1000
```

# multi GPU
# Multi GPU

we use `accelerate` and `deepspeed ZeRO Stage-2` to conduct weight magnitude pruning
We use `accelerate` and `deepspeed ZeRO` to conduct weight magnitude, snip pruning. Below are two usage examples: 1) magnitude pruning with ZeRO Stage-2, and 2) snip pruning with ZeRO Stage-3.

## Magnitude pruning with ZeRO Stage-2
### Accelerate DeepSpeed Plugin

On your machine(s) just run:
Expand Down Expand Up @@ -105,3 +106,82 @@ bash run_ds.sh \
--pruning_pattern=4x1 \
--pruning_frequency=1000
```


## SNIP pruning with ZeRO Stage-3

To specify the accelerate use DeepSpeed ZeRO Stage-3. On your machine(s) just run:
``` shell
accelerate config

compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: config/zero_stage3_config.json
zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false
```
with the contents of `config/zero_stage3_config.json` being:

```
{
"train_batch_size": 64,
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 4,
"fp16": {
"enabled": true,
"min_loss_scale": 1,
"opt_level": "O2"
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0.0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto",
"warmup_type": "cosine"
}
}
}
```

### Pruning
> Note: As the ZeRO Stage-3 partitions all three model states(optimizer states, gradients, and parameters), please specify the `pruning_scope` as `local`. Choosing `global` requires gathering all parameters to update the mask, which compromises the benefits of ZeRO Stage-3.


```
# 2 gpu cards example
export CUDA_VISIBLE_DEVICES=0,1 USE_DEEPSPEED=1
bash run_ds_z3.sh \
--model_name_or_path=facebook/opt-125m \
--dataset_name=NeelNanda/pile-10k \
--block_size=128 \
--output_dir=./test-clm \
--pruning_type=snip_momentum \
--pruning_scope=local \
--pruning_pattern=4x1 \
--pruning_frequency=1000
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"train_batch_size": 64,
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 4,
"fp16": {
"enabled": true,
"min_loss_scale": 1,
"opt_level": "O2"
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0.0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto",
"warmup_type": "cosine"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ def parse_args():
help="pruning criteria to use.",
choices=["magnitude", "snip", "snip_momentum"],
)
parser.add_argument(
"--pruning_scope",
type=str,
default="global",
help="determine layers' scores should be gather together to sort.",
choices=["local", "global"],
)
parser.add_argument(
"--warm_epochs",
type=int,
Expand Down Expand Up @@ -688,7 +695,7 @@ def group_texts(examples):
pruning_configs=[
{
"pruning_type": args.pruning_type,
"pruning_scope": "global",
"pruning_scope": args.pruning_scope,
"sparsity_decay_type": "exp",
"excluded_op_names": ["pooler"],
"pruning_op_types": ["Linear"],
Expand Down Expand Up @@ -800,7 +807,8 @@ def group_texts(examples):

if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
# fetch the ds model from inc model
unwrapped_model = accelerator.unwrap_model(model.model)
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/bin/bash
set -x

function main {

init_params "$@"
run_pruning

}

# init params
function init_params {
dataset_name="NeelNanda/pile-10k"
model_name_or_path="facebook/opt-125m"
output_dir="./test-clm"
per_device_train_batch_size=8
block_size=128
gradient_accumulation_steps=4
num_train_epochs=3
target_sparsity=0.8
pruning_type="snip_momentum"
pruning_scope="local"
pruning_pattern="4x1"
pruning_frequency=1000
for var in "$@"
do
case $var in
--dataset_name=*)
dataset_name=$(echo $var |cut -f2 -d=)
;;
--model_name_or_path=*)
model_name_or_path=$(echo $var |cut -f2 -d=)
;;
--output_dir=*)
output_dir=$(echo $var |cut -f2 -d=)
;;
--per_device_train_batch_size=*)
per_device_train_batch_size=$(echo $var |cut -f2 -d=)
;;
--block_size=*)
block_size=$(echo $var |cut -f2 -d=)
;;
--gradient_accumulation_steps=*)
gradient_accumulation_steps=$(echo $var |cut -f2 -d=)
;;
--num_train_epochs=*)
num_train_epochs=$(echo $var |cut -f2 -d=)
;;
--target_sparsity=*)
target_sparsity=$(echo $var |cut -f2 -d=)
;;
--pruning_type=*)
pruning_type=$(echo $var |cut -f2 -d=)
;;
--pruning_scope=*)
pruning_scope=$(echo $var |cut -f2 -d=)
;;
--pruning_pattern=*)
pruning_pattern=$(echo $var |cut -f2 -d=)
;;
--pruning_frequency=*)
pruning_frequency=$(echo $var |cut -f2 -d=)
;;
*)
echo "Error: No such parameter: ${var}"
exit 1
;;
esac
done

}

# run_tuning
function run_pruning {
accelerate launch --deepspeed_config_file config/ds_config.json --mixed_precision fp16 \
run_clm_no_trainer_deepspeed.py \
--dataset_name $dataset_name \
--model_name_or_path $model_name_or_path \
--block_size $block_size \
--per_device_train_batch_size $per_device_train_batch_size \
--gradient_accumulation_steps $gradient_accumulation_steps \
--output_dir $output_dir \
--do_prune \
--num_train_epochs $num_train_epochs \
--target_sparsity $target_sparsity \
--pruning_type $pruning_type \
--pruning_scope $pruning_scope \
--pruning_pattern $pruning_pattern \
--pruning_frequency $pruning_frequency

}

main "$@"

31 changes: 21 additions & 10 deletions neural_compressor/compression/pruner/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import torch
from .utils import safe_get_data, safe_get_grad, safe_get_shape, torch

CRITERIA = {}

Expand Down Expand Up @@ -96,7 +96,8 @@ def on_step_begin(self):
"""Calculate and store the pruning scores based on a magnitude criterion."""
with torch.no_grad():
for key in self.modules.keys():
p = self.modules[key].weight.data
param = self.modules[key].weight
p = safe_get_data(param)
if hasattr(self.pattern, "reduce_score"):
self.scores[key] = self.pattern.reduce_score(torch.abs(p), key)
else:
Expand Down Expand Up @@ -161,12 +162,15 @@ def on_before_optimizer_step(self):
"""Calculate and store the pruning scores based on snip criterion."""
with torch.no_grad():
for key in self.modules.keys():
p = self.modules[key].weight
# p = self.modules[key].weight
param = self.modules[key].weight
data = safe_get_data(param)
grad = safe_get_grad(param)
# self.scores[key] = torch.abs(p * p.grad)
if hasattr(self.pattern, "reduce_score"):
self.scores[key] = self.pattern.reduce_score(torch.abs(p * p.grad), key)
self.scores[key] = self.pattern.reduce_score(torch.abs(data * grad), key)
else:
self.scores[key] = torch.abs(p * p.grad)
self.scores[key] = torch.abs(data * grad)


@register_criterion("snip_momentum")
Expand All @@ -191,15 +195,19 @@ def __init__(self, modules, config, pattern):
super(SnipMomentumCriterion, self).__init__(modules, config, pattern)
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
for key in modules.keys():
p = modules[key].weight
param = modules[key].weight
# p = modules[key].weight
param_shape = safe_get_shape(param)
dtype = torch.float32
if self.low_memory_usage:
dtype = torch.bfloat16 if p.device.type == "cpu" else torch.float16
dtype = torch.bfloat16 if param.device.type == "cpu" else torch.float16
# self.scores[key] = torch.zeros(p.shape, dtype=dtype).to(p.device)
if hasattr(self.pattern, "reduce_score"):
self.scores[key] = self.pattern.reduce_score(torch.zeros(p.shape, dtype=dtype).to(p.device), key)
self.scores[key] = self.pattern.reduce_score(
torch.zeros(param_shape, dtype=dtype).to(param.device), key
)
else:
self.scores[key] = torch.zeros(p.shape, dtype=dtype).to(p.device)
self.scores[key] = torch.zeros(param_shape, dtype=dtype).to(param.device)

self.alpha = 0.9
self.beta = 1.0
Expand All @@ -209,8 +217,11 @@ def on_before_optimizer_step(self):
with torch.no_grad():
for key in self.modules.keys():
p = self.modules[key].weight
param = self.modules[key].weight
data = safe_get_data(param)
grad = safe_get_grad(param)
self.scores[key] *= self.alpha
tmp = torch.abs(p * p.grad)
tmp = torch.abs(data * grad)
if hasattr(self.pattern, "reduce_score"):
tmp = self.pattern.reduce_score(tmp, key, force=True)
if self.low_memory_usage:
Expand Down
24 changes: 17 additions & 7 deletions neural_compressor/compression/pruner/patterns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy as np

from ..utils import tf, torch
from ..utils import safe_get_data, safe_get_grad, safe_get_shape, tf, torch

PATTERNS = {}

Expand Down Expand Up @@ -75,12 +75,18 @@ def _reshape_2dims_to_orig(data, orig_shape):
Returns:
Reshaped data.
"""
if len(orig_shape) == 4:
if len(orig_shape) == 2:
return data
elif len(orig_shape) == 4:
data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1])
data = data.permute(0, 3, 1, 2)
if len(orig_shape) == 3:
elif len(orig_shape) == 3:
data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[1])
data = data.permute(0, 2, 1)
elif len(orig_shape) == 1:
data = data.reshape(orig_shape)
else:
raise NotImplementedError(f"not support {data.shape}")
return data

# some util functions which can be used.
Expand Down Expand Up @@ -601,12 +607,16 @@ def get_pattern_lock_masks(self, modules):
"""
pattern_lock_masks = {}
for key in modules.keys():
weight = modules[key].weight
shape = weight.shape
# weight = modules[key].weight
# shape = weight.shape
param = modules[key].weight
data = safe_get_data(param)
shape = safe_get_shape(param)
mask = torch.ones(shape)
mask[weight == 0] = 0.0
# mask[weight == 0] = 0.0
mask[data == 0] = 0.0
mask = mask.bool()
pattern_lock_masks[key] = mask.to(weight.device)
pattern_lock_masks[key] = mask.to(param.device)

return pattern_lock_masks

Expand Down
Loading