Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 85 additions & 56 deletions rfcs/APIs/20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

|API名称 | paddle.nn.AdaptiveLogSoftmaxWithLoss |
|---|------------------------------------|
|提交作者<input type="checkbox" class="rowselector hidden"> | PeachML |
|提交时间<input type="checkbox" class="rowselector hidden"> | 2022-03-22 |
|提交作者<input type="checkbox" class="rowselector hidden"> | liethann |
|提交时间<input type="checkbox" class="rowselector hidden"> | 2023-10-20 |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件名的时间也改一下吧

|版本号 | V1.0 |
|依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | develop |
|文件名 | 20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md<br> |
Expand All @@ -17,6 +17,8 @@ Paddle需要扩充API,新增 AdaptiveLogSoftmaxWithLoss API,
调用路径为:`paddle.nn.AdaptiveLogSoftmaxWithLoss` 和 `paddle.nn.functional.adaptive_log_softmax_with_loss`。
实现Softmax快速近似计算的功能。

## 2、功能目标
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

功能目标写上吧,写上这个API具体会做哪些事情,并补充计算公式


## 3、意义
在自然语言处理中,当字典维度过大时,embedding 将占据模型大部分参数量。
例如机器翻译任务中,词表维度大约是2^17,embedding维度取1024,那么就会产生将近1亿参数量,
Expand Down Expand Up @@ -144,113 +146,115 @@ Efficient softmax approximation as described in

2. 训练
```python
def forward(self, input, target):
# input的shape为[batch_size * bptt, hidden_size]
# target的shape为[batch_size * bptt, 1]
if input.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
# 用来统计多个cluster计算的batch,然后求和,保证最终等于batch_size
def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
targ_dim = target_.dim()

if targ_dim == 1:
if input_.size(0) != target_.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
if input_.dim() != 2:
raise RuntimeError('1D target tensor expects 2D input tensors, '
'but found inputs with size', input_.size())
elif targ_dim == 0:
if input_.dim() != 1:
raise RuntimeError('0D target tensor expects 1D input tensors, '
'but found inputs with size', input_.size())
else:
raise RuntimeError('0D or 1D target tensor expected, '
'multi-target not supported')

is_batched = targ_dim > 0
input = input_ if is_batched else input_.unsqueeze(0)
target = target_ if is_batched else target_.unsqueeze(0)

used_rows = 0
batch_size = target.size(0)
# 用来记录在target位置的 logprob

output = input.new_zeros(batch_size)
# 用来记录batch样本在第一层对应的类别
gather_inds = target.new_empty(batch_size)

cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):

low_idx = cutoff_values[i]
high_idx = cutoff_values[i + 1]
# 找到当前cluster的样本对应的index

target_mask = (target >= low_idx) & (target < high_idx)
row_indices = target_mask.nonzero().squeeze()
# 如果当前cluster没有样本,则没有loss

if row_indices.numel() == 0:
continue
# target对应高频词,这里只用来记录batch对应的target,高频词的预测在后面 self.head

if i == 0:
gather_inds.index_copy_(0, row_indices, target[target_mask])
# target对应低频词

else:
# 获取低频cluster对应的target的相对位置
relative_target = target[target_mask] - low_idx
# 获取对应cluster的input
input_subset = input.index_select(0, row_indices)
# 经过线性变换 得到 [batch_size_i, target_i]

cluster_output = self.tail[i - 1](input_subset)
# 当前cluster对应第一层权重元素的类别
cluster_index = self.shortlist_size + i - 1
# 记录对应第一层的类别

gather_inds.index_fill_(0, row_indices, cluster_index)
# 计算当前cluster的log_prob
cluster_logprob = log_softmax(cluster_output, dim=1)
# 获取对应target位置的log_prob
local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
# 将结果记录到对应的batch中
output.index_copy_(0, row_indices, local_logprob.squeeze(1))

used_rows += row_indices.numel()

if used_rows != batch_size:
raise RuntimeError("Target values should be in [0, {}], "
"but values in range [{}, {}] "
"were found. ".format(self.n_classes - 1,
target.min().item(),
target.max().item()))
# 第一层的线性变换,因为无论高频和低频词都需要计算第一层,所以放到了这里统一计算
raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], "
f"but values in range [{target.min().item()}, {target.max().item()}] "
"were found. ")

head_output = self.head(input)
# 取log_prob
head_logprob = log_softmax(head_output, dim=1)
# 这里是第一层的log_prob和第二层的log_prob加起来作为最后的输出
# tips: 对于属于第一层的样本,只需要计算第一层的log_prob就好
# 对于属于第二层的样本,需要将第一层计算得到的cluster对应类别的log_prob和
第二层cluster内计算得到的log_prob加起来,所以是output +=
output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
loss = (-output).mean()
# 返回一个nametuple

if not is_batched:
output = output.squeeze(0)

return _ASMoutput(output, loss)
```


3. 预测
```python
def predict(self, input):
"""
def predict(self, input: Tensor) -> Tensor:
r""" This is equivalent to `self.log_prob(input).argmax(dim=1)`,
but is more efficient in some cases.

Args:
input (Tensor): a minibatch of examples

Returns:
output (Tensor): a class with the highest probability for each example

Shape:
- Input: :math:`(N, in\_features)`
- Input: :math:`(N, \texttt{in\_features})`
- Output: :math:`(N)`
"""
# 第一层的线性转化

head_output = self.head(input)
# 记录预测target的位置
output = torch.argmax(head_output, dim=1)
# 判断预测的位置是否都是低频词
not_in_shortlist = (output >= self.shortlist_size)
# 获取预测高频词的样本index
all_in_shortlist = not (not_in_shortlist.any())
# 如果预测的结果都为高频词,则直接返回结果

if all_in_shortlist:
return output
# 如果预测的结果都为低频词

elif not_in_shortlist.all():
# 计算低频词对应cluster中target对应的log_prob
log_prob = self._get_full_log_prob(input, head_output)
return torch.argmax(log_prob, dim=1)
# 如果预测的结果既有高频词,也有低频词

else:
# 只对低频词进行对应cluser的预测
log_prob = self._get_full_log_prob(input[not_in_shortlist],
head_output[not_in_shortlist])
output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
return output

# 计算低频词对应cluster中target对应的log_prob
def _get_full_log_prob(self, input, head_output):
""" Given input tensor, and output of `self.head`,
compute the log of the full distribution """
Expand All @@ -268,25 +272,50 @@ def _get_full_log_prob(self, input, head_output):
out[:, start_idx:stop_idx] = output_logprob

return out

def log_prob(self, input: Tensor) -> Tensor:
r""" Computes log probabilities for all :math:`\texttt{n\_classes}`

Args:
input (Tensor): a minibatch of examples

Returns:
log-probabilities of for each class :math:`c`
in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.

Shape:
- Input: :math:`(N, \texttt{in\_features})`
- Output: :math:`(N, \texttt{n\_classes})`

"""

head_output = self.head(input)
return self._get_full_log_prob(input, head_output)
```


# 四、对比分析
无其它框架实现

# 五、方案设计
# 五、设计思路与实现方案
## 命名与参数设计
API设计为`paddle.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`及
`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label,
in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`, 返回为`NamedTuple` 包含 `output` 和 `loss`字段

- function API:`paddle.nn.functional.adaptive_log_softmax_with_loss(input, target, head_weight, head_bias, tail_weights, cutoffs, shortlist_size)` 用于训练计算
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充每个参数的含义


- layer层类API:`paddle.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,包含两个主要方法:
- forward(self, input, target),用于训练,返回为`output` 和 `loss`
- predict(self, input),用于预测

## 底层OP设计
使用已有API组合实现,不再单独设计OP。

## API实现方案
主要参考pytorch实现,替换掉部分paddle没有的api
计算逻辑参考pytorch实现,并基于paddle API进行重组与封装:
- function API:`paddle.nn.functional.adaptive_log_softmax_with_loss(input, target, head_weight, head_bias, tail_weights, cutoffs, shortlist_size)`,使用已有api进行组合实现,

- layer层类API:`paddle.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,包含两个主要方法:
- forward(self, input, target),用于训练,返回为`output` 和 `loss`
- predict(self, input),用于预测,其计算与forward共享权重但是计算逻辑存在差异,故使用已有API组合实现的方式单独实现

# 六、测试和验收的考量
测试考虑的case如下:
Expand All @@ -297,12 +326,12 @@ in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,


# 七、可行性分析及规划排期
方案主要依赖paddle现有api组合而成
方案主要依赖paddle现有api组合而成。依赖的API中index_fill代码尚未合入,故参考index_fill的pr临时实现index_fill方法,待pr合入后可替换;paddle.gather与torch.gather存在差异,使用paddle.take_along_axis替换实现。开发已完成。

# 八、影响面
为独立新增API,对其他模块没有影响

# 名词解释
# 附件及参考资料