Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 76 additions & 95 deletions paddlenlp/transformers/qwen/modeling_3D_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@

import paddle
import paddle.distributed as dist
import paddle.distributed.fleet.meta_parallel as mpu
import paddle.nn.functional as F
from paddle import Tensor, nn
from paddle import nn
from paddle.distributed import fleet
from paddle.distributed.fleet.utils import recompute
from paddle.utils import try_import
Expand Down Expand Up @@ -55,6 +54,15 @@
except:
fused_rotary_position_embedding = None

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:

Check warning on line 59 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L59

Added line #L59 was not covered by tests

def swiglu(x, y=None):
if y is None:
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y

Check warning on line 64 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L61-L64

Added lines #L61 - L64 were not covered by tests


def get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
Expand All @@ -63,31 +71,6 @@
return mesh


def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
is_fleet_init = True
tensor_parallel_degree = 1
try:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
tensor_parallel_degree = hcg.get_model_parallel_world_size()
except:
is_fleet_init = False

if is_fleet_init and tensor_parallel_degree > 1 and y.is_distributed:
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
logits = paddle.matmul(input_parallel, y, transpose_y=False)

if tensor_parallel_output:
return logits

return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)

else:
logits = paddle.matmul(x, y, transpose_y=False)
return logits


def get_triangle_upper_mask(x, mask=None):
if mask is not None:
return mask
Expand Down Expand Up @@ -148,6 +131,13 @@
global attention_cnt
self.attention_cnt = attention_cnt
attention_cnt += 1
self.c_attn.weight = dist.shard_tensor(

Check warning on line 134 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L134

Added line #L134 was not covered by tests
self.c_attn.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)]
)
self.c_attn.bias = dist.shard_tensor(self.c_attn.bias, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)])
self.c_proj.weight = dist.shard_tensor(

Check warning on line 138 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L137-L138

Added lines #L137 - L138 were not covered by tests
self.c_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)]
)

def _attn(self, query, key, value, attention_mask=None):
# Support the flash attention and normal attention
Expand Down Expand Up @@ -230,12 +220,15 @@
# # [bz, sql, hid] ==> [bz, sql, 3*hid]
mixed_x_layer = self.c_attn(hidden_states)
# [bz, sql, 3*hid] ==> [bz, sql, hid]
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]

Check warning on line 223 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L223

Added line #L223 was not covered by tests

mixed_x_layer = paddle.reshape_(mixed_x_layer, target_shape)

Check warning on line 225 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L225

Added line #L225 was not covered by tests
query, key, value = paddle.split(mixed_x_layer, num_or_sections=3, axis=-1)

# [bz, sql, hid] ==> [bz, sql, nh, hdim]
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)

kv_seq_len = hidden_states.shape[1]
if layer_past:
Expand Down Expand Up @@ -310,18 +303,28 @@
def __init__(self, config, ipp=None):
super().__init__()
ff_dim_in = config.intermediate_size // 2
self.fuse_attention_ffn = config.fuse_attention_ffn

Check warning on line 306 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L306

Added line #L306 was not covered by tests
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias_attr=not config.no_bias)
self.ipp = ipp
self.w1.weight = dist.shard_tensor(self.w1.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])
self.w2.weight = dist.shard_tensor(self.w2.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])
self.c_proj.weight = dist.shard_tensor(

Check warning on line 313 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L311-L313

Added lines #L311 - L313 were not covered by tests
self.c_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)]
)

def forward(self, hidden_states):
# up
a1 = self.w1(hidden_states)
# gate
a2 = self.w2(hidden_states)
intermediate_parallel = a1 * F.silu(a2)
# # up
# a1 = self.w1(hidden_states)
# # gate
# a2 = self.w2(hidden_states)
# intermediate_parallel = a1 * F.silu(a2)
# down
if self.fuse_attention_ffn:
intermediate_parallel = swiglu(self.gate_up_fused_proj(hidden_states))

Check warning on line 325 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L324-L325

Added lines #L324 - L325 were not covered by tests
else:
intermediate_parallel = swiglu(self.w2(hidden_states), self.w1(hidden_states))

Check warning on line 327 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L327

Added line #L327 was not covered by tests
output = self.c_proj(intermediate_parallel)
return output

Expand All @@ -330,9 +333,10 @@
def __init__(self, config, ipp=None, idx=None):
super().__init__()
self.config = config
self.ln_1 = QWenRMSNormAuto(config)
self.ipp = ipp
self.ln_1 = QWenRMSNormAuto(config, self.ipp)

Check warning on line 337 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L336-L337

Added lines #L336 - L337 were not covered by tests
self.attn = QWenAttentionAuto(config, ipp)
self.ln_2 = QWenRMSNormAuto(config)
self.ln_2 = QWenRMSNormAuto(config, self.ipp)

Check warning on line 339 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L339

Added line #L339 was not covered by tests
self.mlp = QWenMLPAuto(config, ipp)
self.ipp = ipp
self.idx = idx
Expand All @@ -349,7 +353,11 @@
output_attentions=False,
):
layernorm_output = self.ln_1(hidden_states)

attention_mask = (

Check warning on line 356 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L356

Added line #L356 was not covered by tests
dist.reshard(attention_mask, get_mesh(self.ipp), [dist.Shard(0), dist.Replicate()])
if attention_mask is not None
else attention_mask
)
attn_outputs = self.attn(
layernorm_output,
layer_past=layer_past,
Expand Down Expand Up @@ -386,9 +394,6 @@
config_class = QWenConfig
base_model_prefix = "qwen"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):

Expand Down Expand Up @@ -497,35 +502,6 @@
init_name_mappings(mappings)
return [StateDictNameMapping(*mapping) for mapping in mappings]

def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(
module,
(
nn.Linear,
nn.Embedding,
mpu.ColumnParallelLinear,
mpu.RowParallelLinear,
mpu.VocabParallelEmbedding,
QWenLMHeadAuto,
),
):
module.weight.set_value(
paddle.tensor.normal(mean=0.0, std=self.config.initializer_range, shape=module.weight.shape)
)
if getattr(module, "bias", None) is not None:
module.weight.set_value(paddle.zeros(shape=module.weight.shape, dtype=paddle.get_default_dtype()))

for name, p in module.named_parameters():
if name == "c_proj.weight":
p.set_value(
paddle.tensor.normal(
mean=0.0,
std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers),
shape=p.shape,
)
)


class QWenModelAuto(QWenPretrainedModelAuto):
def __init__(self, config):
Expand All @@ -538,29 +514,32 @@
self.recompute_granularity = config.recompute_granularity

self.wte = nn.Embedding(self.vocab_size, self.embed_dim)

self.wte.weight = dist.shard_tensor(self.wte.weight, get_mesh(), [dist.Replicate(), dist.Shard(0)])

Check warning on line 517 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L517

Added line #L517 was not covered by tests
self.drop = nn.Dropout(config.emb_dropout_prob)

def get_layer_ipp(layer_index):
mesh = fleet.auto.get_mesh()
if "pp" not in mesh.dim_names:
return None
else:
pp_degree = mesh.get_dim_size("pp")
layer_per_stage = math.ceil(config.num_hidden_layers / pp_degree)
return layer_index // layer_per_stage

self.h = nn.LayerList(
[
QWenBlockAuto(
config,
get_layer_ipp(i),
self.get_layer_ipp(i),
i,
)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = QWenRMSNormAuto(config)
self.ln_f = QWenRMSNormAuto(config, self.get_last_layer_ipp())

Check warning on line 530 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L530

Added line #L530 was not covered by tests

def get_layer_ipp(self, layer_index):
mesh = fleet.auto.get_mesh()
if "pp" not in mesh.dim_names:
return None

Check warning on line 535 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L533-L535

Added lines #L533 - L535 were not covered by tests
else:
pp_degree = mesh.get_dim_size("pp")
layer_per_stage = math.ceil(self.config.num_hidden_layers / pp_degree)
return layer_index // layer_per_stage

Check warning on line 539 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L537-L539

Added lines #L537 - L539 were not covered by tests

def get_last_layer_ipp(self):
return self.get_layer_ipp(self.config.num_hidden_layers - 1)

Check warning on line 542 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L542

Added line #L542 was not covered by tests

def get_input_embeddings(self):
return self.wte
Expand Down Expand Up @@ -668,7 +647,8 @@

encoder_attention_mask = None
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
with paddle.amp.auto_cast(False):
inputs_embeds = self.wte(input_ids)

Check warning on line 651 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L650-L651

Added lines #L650 - L651 were not covered by tests

hidden_states = inputs_embeds

Expand All @@ -681,7 +661,7 @@
neg_inf = paddle.full_like(attention_mask, paddle.finfo(paddle.bfloat16).min, dtype=paddle.bfloat16)
# dtype 4D mask
attention_mask = paddle.where(attention_mask, zero, neg_inf)

attention_mask = dist.shard_tensor(attention_mask, get_mesh(), [dist.Replicate(), dist.Replicate()])

Check warning on line 664 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L664

Added line #L664 was not covered by tests
hidden_states = self.drop(hidden_states)
hidden_states = dist.reshard(hidden_states, get_mesh(), [dist.Shard(0), dist.Replicate()])
output_shape = input_shape + [
Expand Down Expand Up @@ -718,7 +698,7 @@
attention_mask = dist.reshard(
attention_mask,
get_mesh(block.ipp),
[dist.Shard(0), dist.Replicate()],
[dist.Replicate(), dist.Replicate()],
)
if self.enable_recompute and self.training and has_gradient and self.recompute_granularity == "full":
outputs = self.recompute_training(
Expand Down Expand Up @@ -774,15 +754,16 @@


class QWenLMHeadAuto(nn.Layer):
def __init__(self, config: QWenConfig):
def __init__(self, config: QWenConfig, ipp=None):
super(QWenLMHeadAuto, self).__init__()
self.config = config
vocab_size = config.vocab_size

self.ipp = ipp

Check warning on line 761 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L761

Added line #L761 was not covered by tests
self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
)
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])

Check warning on line 766 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L766

Added line #L766 was not covered by tests

def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
Expand Down Expand Up @@ -835,7 +816,7 @@
def __init__(self, config):
super().__init__(config)
self.qwen = QWenModelAuto(config)
self.lm_head = QWenLMHeadAuto(config)
self.lm_head = QWenLMHeadAuto(config, self.qwen.get_last_layer_ipp())

Check warning on line 819 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L819

Added line #L819 was not covered by tests

def forward(
self,
Expand Down Expand Up @@ -898,7 +879,7 @@
self._ntk_alpha_cached = ntk_alpha
seq = paddle.arange(self._seq_len_cached)
with paddle.amp.auto_cast(enable=False):
freqs = paddle.outer(seq.astype(self.inv_freq.dtype), self.inv_freq)
freqs = paddle.outer(seq.astype(paddle.float32), self.inv_freq.astype(paddle.float32))

Check warning on line 882 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L882

Added line #L882 was not covered by tests
emb = paddle.concat([freqs, freqs], axis=-1)
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]
Expand Down Expand Up @@ -940,7 +921,7 @@


class QWenRMSNormAuto(nn.Layer):
def __init__(self, config):
def __init__(self, config, ipp):
super().__init__()
self.config = config
self.eps = config.layer_norm_epsilon
Expand All @@ -949,14 +930,14 @@
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(1.0),
)
self.weight = dist.shard_tensor(self.weight, get_mesh(ipp), [dist.Replicate(), dist.Replicate()])

Check warning on line 933 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L933

Added line #L933 was not covered by tests

def _norm(self, x):
return x * paddle.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

Check warning on line 936 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L936

Added line #L936 was not covered by tests

def forward(self, x):
if self.config.use_fused_rms_norm:
return rms_norm_fused(x, self.weight, self.eps)
with paddle.amp.auto_cast(False):
variance = x.astype("float32").pow(2).mean(-1, keepdim=True)
output = paddle.rsqrt(variance + self.eps) * x

if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
output = paddle.cast(output, self.weight.dtype)
output = self._norm(x.astype(paddle.float32)).astype(x.dtype)

Check warning on line 942 in paddlenlp/transformers/qwen/modeling_3D_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_3D_auto.py#L942

Added line #L942 was not covered by tests
return output * self.weight
Loading