@@ -91,3 +91,134 @@ def forward(self, outputs, targets, length=None):
91
91
predictions = F .log_softmax (predictions , axis = 1 )
92
92
loss = self .criterion (predictions , targets ) / targets .sum ()
93
93
return loss
94
+
95
+
96
+ class NCELoss (nn .Layer ):
97
+ """Noise Contrastive Estimation loss funtion
98
+
99
+ Noise Contrastive Estimation (NCE) is an approximation method that is used to
100
+ work around the huge computational cost of large softmax layer.
101
+ The basic idea is to convert the prediction problem into classification problem
102
+ at training stage. It has been proved that these two criterions converges to
103
+ the same minimal point as long as noise distribution is close enough to real one.
104
+
105
+ NCE bridges the gap between generative models and discriminative models,
106
+ rather than simply speedup the softmax layer.
107
+ With NCE, you can turn almost anything into posterior with less effort (I think).
108
+
109
+ Refs:
110
+ NCE:http://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann10AISTATS.pdf
111
+ Thanks: https://github.com/mingen-pan/easy-to-use-NCE-RNN-for-Pytorch/blob/master/nce.py
112
+
113
+ Examples:
114
+ Q = Q_from_tokens(output_dim)
115
+ NCELoss(Q)
116
+ """
117
+
118
+ def __init__ (self , Q , noise_ratio = 100 , Z_offset = 9.5 ):
119
+ """Noise Contrastive Estimation loss funtion
120
+
121
+ Args:
122
+ Q (tensor): prior model, uniform or guassian
123
+ noise_ratio (int, optional): noise sampling times. Defaults to 100.
124
+ Z_offset (float, optional): scale of post processing the score. Defaults to 9.5.
125
+ """
126
+ super (NCELoss , self ).__init__ ()
127
+ assert type (noise_ratio ) is int
128
+ self .Q = paddle .to_tensor (Q , stop_gradient = False )
129
+ self .N = self .Q .shape [0 ]
130
+ self .K = noise_ratio
131
+ self .Z_offset = Z_offset
132
+
133
+ def forward (self , output , target ):
134
+ """Forward inference
135
+ """
136
+ output = paddle .reshape (output , [- 1 , self .N ])
137
+ B = output .shape [0 ]
138
+ noise_idx = self .get_noise (B )
139
+ idx = self .get_combined_idx (target , noise_idx )
140
+ P_target , P_noise = self .get_prob (idx , output , sep_target = True )
141
+ Q_target , Q_noise = self .get_Q (idx )
142
+ loss = self .nce_loss (P_target , P_noise , Q_noise , Q_target )
143
+ return loss .mean ()
144
+
145
+ def get_Q (self , idx , sep_target = True ):
146
+ """Get prior model of batchsize data
147
+ """
148
+ idx_size = idx .size
149
+ prob_model = paddle .to_tensor (
150
+ self .Q .numpy ()[paddle .reshape (idx , [- 1 ]).numpy ()])
151
+ prob_model = paddle .reshape (prob_model , [idx .shape [0 ], idx .shape [1 ]])
152
+ if sep_target :
153
+ return prob_model [:, 0 ], prob_model [:, 1 :]
154
+ else :
155
+ return prob_model
156
+
157
+ def get_prob (self , idx , scores , sep_target = True ):
158
+ """Post processing the score of post model(output of nn) of batchsize data
159
+ """
160
+ scores = self .get_scores (idx , scores )
161
+ scale = paddle .to_tensor ([self .Z_offset ], dtype = 'float32' )
162
+ scores = paddle .add (scores , - scale )
163
+ prob = paddle .exp (scores )
164
+ if sep_target :
165
+ return prob [:, 0 ], prob [:, 1 :]
166
+ else :
167
+ return prob
168
+
169
+ def get_scores (self , idx , scores ):
170
+ """Get the score of post model(output of nn) of batchsize data
171
+ """
172
+ B , N = scores .shape
173
+ K = idx .shape [1 ]
174
+ idx_increment = paddle .to_tensor (
175
+ N * paddle .reshape (paddle .arange (B ), [B , 1 ]) * paddle .ones ([1 , K ]),
176
+ dtype = "int64" ,
177
+ stop_gradient = False )
178
+ new_idx = idx_increment + idx
179
+ new_scores = paddle .index_select (
180
+ paddle .reshape (scores , [- 1 ]), paddle .reshape (new_idx , [- 1 ]))
181
+
182
+ return paddle .reshape (new_scores , [B , K ])
183
+
184
+ def get_noise (self , batch_size , uniform = True ):
185
+ """Select noise sample
186
+ """
187
+ if uniform :
188
+ noise = np .random .randint (self .N , size = self .K * batch_size )
189
+ else :
190
+ noise = np .random .choice (
191
+ self .N , self .K * batch_size , replace = True , p = self .Q .data )
192
+ noise = paddle .to_tensor (noise , dtype = 'int64' , stop_gradient = False )
193
+ noise_idx = paddle .reshape (noise , [batch_size , self .K ])
194
+ return noise_idx
195
+
196
+ def get_combined_idx (self , target_idx , noise_idx ):
197
+ """Combined target and noise
198
+ """
199
+ target_idx = paddle .reshape (target_idx , [- 1 , 1 ])
200
+ return paddle .concat ((target_idx , noise_idx ), 1 )
201
+
202
+ def nce_loss (self , prob_model , prob_noise_in_model , prob_noise ,
203
+ prob_target_in_noise ):
204
+ """Combined the loss of target and noise
205
+ """
206
+
207
+ def safe_log (tensor ):
208
+ """Safe log
209
+ """
210
+ EPSILON = 1e-10
211
+ return paddle .log (EPSILON + tensor )
212
+
213
+ model_loss = safe_log (prob_model /
214
+ (prob_model + self .K * prob_target_in_noise ))
215
+ model_loss = paddle .reshape (model_loss , [- 1 ])
216
+
217
+ noise_loss = paddle .sum (
218
+ safe_log ((self .K * prob_noise ) /
219
+ (prob_noise_in_model + self .K * prob_noise )), - 1 )
220
+ noise_loss = paddle .reshape (noise_loss , [- 1 ])
221
+
222
+ loss = - (model_loss + noise_loss )
223
+
224
+ return loss
0 commit comments