@@ -91,3 +91,137 @@ def forward(self, outputs, targets, length=None):
91
91
predictions = F .log_softmax (predictions , axis = 1 )
92
92
loss = self .criterion (predictions , targets ) / targets .sum ()
93
93
return loss
94
+
95
+
96
+ class NCELoss (nn .Layer ):
97
+ """Noise Contrastive Estimation loss funtion
98
+
99
+ Noise Contrastive Estimation (NCE) is an approximation method that is used to
100
+ work around the huge computational cost of large softmax layer.
101
+ The basic idea is to convert the prediction problem into classification problem
102
+ at training stage. It has been proved that these two criterions converges to
103
+ the same minimal point as long as noise distribution is close enough to real one.
104
+
105
+ NCE bridges the gap between generative models and discriminative models,
106
+ rather than simply speedup the softmax layer.
107
+ With NCE, you can turn almost anything into posterior with less effort (I think).
108
+
109
+ Refs:
110
+ NCE:http://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann10AISTATS.pdf
111
+ Thanks: https://github.com/mingen-pan/easy-to-use-NCE-RNN-for-Pytorch/blob/master/nce.py
112
+
113
+ Examples:
114
+ Q = Q_from_tokens(output_dim)
115
+ NCELoss(Q)
116
+ """
117
+
118
+ def __init__ (self , Q , noise_ratio = 100 , Z_offset = 9.5 ):
119
+ """Noise Contrastive Estimation loss funtion
120
+
121
+ Args:
122
+ Q (tensor): prior model, uniform or guassian
123
+ noise_ratio (int, optional): noise sampling times. Defaults to 100.
124
+ Z_offset (float, optional): scale of post processing the score. Defaults to 9.5.
125
+ """
126
+ super (NCELoss , self ).__init__ ()
127
+ assert type (noise_ratio ) is int
128
+ self .Q = paddle .to_tensor (Q , stop_gradient = False )
129
+ self .N = self .Q .shape [0 ]
130
+ self .K = noise_ratio
131
+ self .Z_offset = Z_offset
132
+
133
+ def forward (self , output , target ):
134
+ """Forward inference
135
+
136
+ Args:
137
+ output (tensor): the model output, which is the input of loss function
138
+ """
139
+ output = paddle .reshape (output , [- 1 , self .N ])
140
+ B = output .shape [0 ]
141
+ noise_idx = self .get_noise (B )
142
+ idx = self .get_combined_idx (target , noise_idx )
143
+ P_target , P_noise = self .get_prob (idx , output , sep_target = True )
144
+ Q_target , Q_noise = self .get_Q (idx )
145
+ loss = self .nce_loss (P_target , P_noise , Q_noise , Q_target )
146
+ return loss .mean ()
147
+
148
+ def get_Q (self , idx , sep_target = True ):
149
+ """Get prior model of batchsize data
150
+ """
151
+ idx_size = idx .size
152
+ prob_model = paddle .to_tensor (
153
+ self .Q .numpy ()[paddle .reshape (idx , [- 1 ]).numpy ()])
154
+ prob_model = paddle .reshape (prob_model , [idx .shape [0 ], idx .shape [1 ]])
155
+ if sep_target :
156
+ return prob_model [:, 0 ], prob_model [:, 1 :]
157
+ else :
158
+ return prob_model
159
+
160
+ def get_prob (self , idx , scores , sep_target = True ):
161
+ """Post processing the score of post model(output of nn) of batchsize data
162
+ """
163
+ scores = self .get_scores (idx , scores )
164
+ scale = paddle .to_tensor ([self .Z_offset ], dtype = 'float32' )
165
+ scores = paddle .add (scores , - scale )
166
+ prob = paddle .exp (scores )
167
+ if sep_target :
168
+ return prob [:, 0 ], prob [:, 1 :]
169
+ else :
170
+ return prob
171
+
172
+ def get_scores (self , idx , scores ):
173
+ """Get the score of post model(output of nn) of batchsize data
174
+ """
175
+ B , N = scores .shape
176
+ K = idx .shape [1 ]
177
+ idx_increment = paddle .to_tensor (
178
+ N * paddle .reshape (paddle .arange (B ), [B , 1 ]) * paddle .ones ([1 , K ]),
179
+ dtype = "int64" ,
180
+ stop_gradient = False )
181
+ new_idx = idx_increment + idx
182
+ new_scores = paddle .index_select (
183
+ paddle .reshape (scores , [- 1 ]), paddle .reshape (new_idx , [- 1 ]))
184
+
185
+ return paddle .reshape (new_scores , [B , K ])
186
+
187
+ def get_noise (self , batch_size , uniform = True ):
188
+ """Select noise sample
189
+ """
190
+ if uniform :
191
+ noise = np .random .randint (self .N , size = self .K * batch_size )
192
+ else :
193
+ noise = np .random .choice (
194
+ self .N , self .K * batch_size , replace = True , p = self .Q .data )
195
+ noise = paddle .to_tensor (noise , dtype = 'int64' , stop_gradient = False )
196
+ noise_idx = paddle .reshape (noise , [batch_size , self .K ])
197
+ return noise_idx
198
+
199
+ def get_combined_idx (self , target_idx , noise_idx ):
200
+ """Combined target and noise
201
+ """
202
+ target_idx = paddle .reshape (target_idx , [- 1 , 1 ])
203
+ return paddle .concat ((target_idx , noise_idx ), 1 )
204
+
205
+ def nce_loss (self , prob_model , prob_noise_in_model , prob_noise ,
206
+ prob_target_in_noise ):
207
+ """Combined the loss of target and noise
208
+ """
209
+
210
+ def safe_log (tensor ):
211
+ """Safe log
212
+ """
213
+ EPSILON = 1e-10
214
+ return paddle .log (EPSILON + tensor )
215
+
216
+ model_loss = safe_log (prob_model /
217
+ (prob_model + self .K * prob_target_in_noise ))
218
+ model_loss = paddle .reshape (model_loss , [- 1 ])
219
+
220
+ noise_loss = paddle .sum (
221
+ safe_log ((self .K * prob_noise ) /
222
+ (prob_noise_in_model + self .K * prob_noise )), - 1 )
223
+ noise_loss = paddle .reshape (noise_loss , [- 1 ])
224
+
225
+ loss = - (model_loss + noise_loss )
226
+
227
+ return loss
0 commit comments