Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
RowParallelLinear,
)

from ...transformers import linear_utils

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
AllGatherOp,
Expand All @@ -33,15 +35,19 @@
mark_as_sequence_parallel_parameter,
)
except:
pass
AllGatherOp = None
ReduceScatterOp = None
mark_as_sequence_parallel_parameter = None
ColumnSequenceParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowSequenceParallelLinear = linear_utils.RowSequenceParallelLinear


from paddlenlp.transformers.mc2_parallel_linear import (
from ...transformers.mc2_parallel_linear import (
MC2ColumnParallelCoreLinear,
MC2ColumnSeqParallelCoreLinear,
MC2RowParallelCoreLinear,
MC2RowSeqParallelCoreLinear,
)

from .lora_quick_layers import quick_lora


Expand Down
21 changes: 9 additions & 12 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@
PipelineLayer,
RowParallelLinear,
)
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
)

from ...transformers import linear_utils
from ...transformers.conversion_utils import ConversionMixin
from ...transformers.model_utils import (
PretrainedModel,
Expand Down Expand Up @@ -100,7 +97,7 @@ def get_lora_layers():
LoRALinear = lora_layers["LoRALinear"]
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]
AVALIABLE_LAYERS = [
AVAILABLE_LAYERS = [
ColumnParallelLoRALinear,
ColumnSequenceParallelLoRALinear,
LoRAConv2D,
Expand All @@ -120,7 +117,7 @@ def get_lora_layers():
RowParallelQuantizationLoRALinear,
)

AVALIABLE_LAYERS += [
AVAILABLE_LAYERS += [
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
Expand Down Expand Up @@ -510,7 +507,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif isinstance(module, ColumnSequenceParallelLinear):
elif isinstance(module, linear_utils.ColumnSequenceParallelLinear):
# recover the original output_features
output_features = module.weight.shape[1] * module.world_size
lora_module = ColumnSequenceParallelLoRALinear(
Expand All @@ -536,7 +533,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif isinstance(module, RowSequenceParallelLinear):
elif isinstance(module, linear_utils.RowSequenceParallelLinear):
# recover the original output_features
lora_module = RowSequenceParallelLoRALinear(
in_features=module.weight.shape[0] * module.world_size,
Expand Down Expand Up @@ -842,20 +839,20 @@ def save_to_aistudio(

def disable_lora(self):
for _, layer in self.model.named_sublayers():
if any(isinstance(layer, lora_layer) for lora_layer in AVALIABLE_LAYERS):
if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS):
layer.disable_lora = True

def enable_lora(self):
for _, layer in self.model.named_sublayers():
if any(isinstance(layer, lora_layer) for lora_layer in AVALIABLE_LAYERS):
if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS):
layer.disable_lora = False

def merge(self):
for _, layer in self.model.named_sublayers():
if any(isinstance(layer, lora_layer) for lora_layer in AVALIABLE_LAYERS):
if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS):
layer.merge()

def unmerge(self):
for _, layer in self.model.named_sublayers():
if any(isinstance(layer, lora_layer) for lora_layer in AVALIABLE_LAYERS):
if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS):
layer.unmerge()
45 changes: 23 additions & 22 deletions paddlenlp/transformers/gemma/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
GatherOp,
RowSequenceParallelLinear,
ScatterOp,
mark_as_sequence_parallel_parameter,
)
Expand All @@ -54,6 +52,8 @@
)
from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model

from .. import linear_utils
from ..linear_utils import Linear
from ..segment_parallel_utils import ReshardLayer
from .configuration import (
GEMMA_PRETRAINED_INIT_CONFIGURATION,
Expand Down Expand Up @@ -422,11 +422,11 @@ def __init__(self, config):
self.tensor_parallel_degree = config.tensor_parallel_degree

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = mpu.ColumnParallelLinear
RowParallelLinear = mpu.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
self.gate_proj = ColumnParallelLinear(
Expand All @@ -448,9 +448,9 @@ def __init__(self, config):
has_bias=False,
)
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)

def forward(self, x):
# GeGLU
Expand Down Expand Up @@ -509,11 +509,11 @@ def __init__(self, config: GemmaConfig, layerwise_recompute: bool = False):
self.use_fused_rope = False

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = mpu.ColumnParallelLinear
RowParallelLinear = mpu.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
self.q_proj = ColumnParallelLinear(
Expand All @@ -537,29 +537,29 @@ def __init__(self, config: GemmaConfig, layerwise_recompute: bool = False):
gather_output=False,
)
else:
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)

else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
self.hidden_size,
self.config.num_attention_heads * self.head_dim,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand All @@ -573,7 +573,7 @@ def __init__(self, config: GemmaConfig, layerwise_recompute: bool = False):
input_is_parallel=True,
)
else:
self.o_proj = nn.Linear(
self.o_proj = Linear(
self.config.num_attention_heads * self.head_dim,
self.hidden_size,
bias_attr=False,
Expand Down Expand Up @@ -992,10 +992,11 @@ def _init_weights(self, layer):
nn.Linear,
nn.Embedding,
mpu.VocabParallelEmbedding,
mpu.ColumnParallelLinear,
mpu.RowParallelLinear,
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
mpu.ColumnParallelLinear,
linear_utils.RowSequenceParallelLinear,
linear_utils.ColumnSequenceParallelLinear,
GemmaLMHead,
),
):
# In the dygraph mode, use the `set_value` to reset the parameter directly,
Expand Down
52 changes: 26 additions & 26 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import paddle
import paddle.distributed.fleet.meta_parallel as mpu
import paddle.incubate as incubate
import paddle.nn as nn
import paddle.nn.functional as F
Expand All @@ -32,9 +33,7 @@

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
GatherOp,
RowSequenceParallelLinear,
ScatterOp,
mark_as_sequence_parallel_parameter,
)
Expand All @@ -45,7 +44,8 @@

from ...utils.converter import StateDictNameMapping
from ...utils.log import logger
from .. import PretrainedModel, register_base_model
from .. import PretrainedModel, linear_utils, register_base_model
from ..linear_utils import Linear
from ..model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -210,11 +210,11 @@ def __init__(
self.num_attention_heads = config.num_attention_heads # default, without tensor parallel

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
assert config.num_attention_heads % config.tensor_parallel_degree == 0
Expand Down Expand Up @@ -262,13 +262,13 @@ def __init__(
)
else:
if self.config.fuse_attention_qkv:
self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias_attr=True)
self.qkv_proj = Linear(config.hidden_size, 3 * config.hidden_size, bias_attr=True)
else:
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.q_proj = Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.k_proj = Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.v_proj = Linear(config.hidden_size, config.hidden_size, bias_attr=True)

self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.out_proj = Linear(config.hidden_size, config.hidden_size, bias_attr=True)

def _fuse_prepare_qkv(self, query, use_cache=False, past_key_value=None):
if self.config.sequence_parallel:
Expand Down Expand Up @@ -583,11 +583,11 @@ def __init__(self, config: GPTConfig):
self.self_attn = MultiHeadAttention(config=config)

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

# TODO:config.fuse_attention_ffn @DrownFish19
if config.tensor_parallel_degree > 1:
Expand All @@ -607,8 +607,8 @@ def __init__(self, config: GPTConfig):
fuse_matmul_bias=self.config.use_fused_linear,
)
else:
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias_attr=True)
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True)
self.linear1 = Linear(config.hidden_size, config.intermediate_size, bias_attr=True)
self.linear2 = Linear(config.intermediate_size, config.hidden_size, bias_attr=True)

self.norm1 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
self.norm2 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
Expand Down Expand Up @@ -980,11 +980,11 @@ def _init_weights(self, layer):
(
nn.Linear,
nn.Embedding,
fleet.meta_parallel.VocabParallelEmbedding,
fleet.meta_parallel.ColumnParallelLinear,
fleet.meta_parallel.RowParallelLinear,
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
mpu.VocabParallelEmbedding,
mpu.RowParallelLinear,
mpu.ColumnParallelLinear,
linear_utils.RowSequenceParallelLinear,
linear_utils.ColumnSequenceParallelLinear,
),
):
# In the dygraph mode, use the `set_value` to reset the parameter directly,
Expand Down Expand Up @@ -1295,7 +1295,7 @@ def __init__(self, config):
super(GPTPretrainingCriterion, self).__init__()
self.config = config
if config.tensor_parallel_degree > 1 and config.tensor_parallel_output:
self.loss_func = fleet.meta_parallel.ParallelCrossEntropy(ignore_index=config.ignore_index)
self.loss_func = mpu.ParallelCrossEntropy(ignore_index=config.ignore_index)
else:
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=config.ignore_index)

Expand Down Expand Up @@ -1660,7 +1660,7 @@ def __init__(self, config: GPTConfig):
self.gpt = GPTModel(config) # allow gpt to be config
dropout_p = config.hidden_dropout_prob if config.classifier_dropout is None else config.classifier_dropout
self.dropout = nn.Dropout(dropout_p)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.classifier = Linear(config.hidden_size, config.num_labels)

def forward(
self,
Expand Down Expand Up @@ -1774,7 +1774,7 @@ def __init__(self, config: GPTConfig):
super(GPTForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.gpt = GPTModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias_attr=False)
self.score = Linear(config.hidden_size, config.num_labels, bias_attr=False)

def forward(
self,
Expand Down
21 changes: 19 additions & 2 deletions paddlenlp/transformers/linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,25 @@
ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
except:
ColumnSequenceParallelLinear = None
RowSequenceParallelLinear = None

class ColumnSequenceParallelLinearPass(object):
"""
A dummy class for ColumnSequenceParallelLinear, used when the actual class
cannot be imported from sequence_parallel_utils.
"""

pass

class RowSequenceParallelLinearPass(object):
"""
A dummy class for RowSequenceParallelLinear, used when the actual class
cannot be imported from sequence_parallel_utils.
"""

pass

ColumnSequenceParallelLinear = ColumnSequenceParallelLinearPass
RowSequenceParallelLinear = RowSequenceParallelLinearPass

if get_env_device() == "npu":
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
Expand Down
Loading