21
21
22
22
#include " forwardADSignalTransform.hh"
23
23
#include < string> // For std::string and std::string::find
24
+ #include " description.hh"
24
25
#include " global.hh"
25
26
#include " list.hh"
26
27
#include " ppsig.hh"
27
28
#include " signalVisitor.hh" // For DependencyVisitor
28
29
#include " xtended.hh"
29
30
30
31
/* *
31
- * @brief Parses the label of a UI control to check for the [autodiff:false ] tag.
32
+ * @brief Parses the label of a UI control to check for the [autodiff:xxx ] tag.
32
33
*
33
34
* Faust UI element labels can contain metadata in the form of `[key:value]`.
34
35
* This function checks for the specific tag that disables automatic differentiation
35
36
* for a given parameter.
36
37
*
37
- * @param label_path_tree The 'path' Tree associated with a UI element like hslider.
38
+ * @param label The 'path' Tree associated with a UI element like hslider.
38
39
* This is expected to be a list, where the head is the local label.
39
- * @return true if the "[autodiff:false]" tag is found in the label string, false otherwise.
40
+ * @return false if the "[autodiff:false]" tag is found in the label string, true otherwise.
40
41
*/
41
- static bool hasAutodiffFalseTag (Tree label_path_tree )
42
+ static bool hasAutodiff (Tree label )
42
43
{
43
- if (!label_path_tree || isNil (label_path_tree)) {
44
- return false ;
45
- }
46
- try {
47
- // The label path is a list of symbols. The local label is the first element.
48
- // We convert its symbol node to a string to inspect it.
49
- const char * label_str = tree2str (hd (label_path_tree));
50
- if (label_str) {
51
- // Perform a simple substring search for the metadata tag.
52
- return std::string (label_str).find (" [autodiff:false]" ) != std::string::npos;
44
+ std::map<std::string, std::set<std::string>> metadata;
45
+ std::string simplifiedLabel;
46
+
47
+ extractMetadata (tree2str (hd (label)), simplifiedLabel, metadata);
48
+
49
+ // Look for [autodiff:false]
50
+ for (const auto & i : metadata) {
51
+ if (i.first == " autodiff" ) {
52
+ const std::set<std::string>& values = i.second ;
53
+ for (const auto & j : values) {
54
+ return (j == " true" );
55
+ }
53
56
}
54
- } catch (const faustexception& e) {
55
- // tree2str can throw an exception if the tree is not a symbol,
56
- // which can happen with complex/computed labels. In that case,
57
- // we assume no tag is present.
58
- return false ;
59
57
}
60
- return false ;
58
+ return true ;
61
59
}
62
60
63
61
/* *
@@ -82,13 +80,10 @@ struct ADDependencyVisitor : public SignalVisitor {
82
80
if (isSigVSlider (sig, path, c, x, y, z) || isSigHSlider (sig, path, c, x, y, z) ||
83
81
isSigNumEntry (sig, path, c, x, y, z)) {
84
82
// Check the label for the [autodiff:false] tag.
85
- if (!hasAutodiffFalseTag (path)) {
86
- // If the tag is not present, add this control to the set of
87
- // parameters that will be differentiated.
83
+ if (hasAutodiff (path)) {
88
84
fControls .insert (sig);
89
85
}
90
86
// If the tag is present, we do nothing, effectively excluding it.
91
-
92
87
} else {
93
88
// For all other signal types, continue the traversal to visit children.
94
89
SignalVisitor::visit (sig);
@@ -100,7 +95,7 @@ struct ADDependencyVisitor : public SignalVisitor {
100
95
* @brief Extracts the first element (head) of each pair in a list of pairs.
101
96
* * Given a list of dual numbers `((p1, t1), (p2, t2), ...)` this function
102
97
* returns a new list containing only the primal components: `(p1, p2, ...)`.
103
- * * @param list_of_pairs A Tree representing a list of pairs.
98
+ * @param list_of_pairs A Tree representing a list of pairs.
104
99
* @return A new Tree representing the list of first elements.
105
100
*/
106
101
static Tree mapHd (Tree list_of_pairs)
@@ -122,7 +117,7 @@ static Tree mapHd(Tree list_of_pairs)
122
117
* @brief Extracts the second element (head of the tail) of each pair in a list of pairs.
123
118
* * Given a list of dual numbers `((p1, t1), (p2, t2), ...)` this function
124
119
* returns a new list containing only the tangent components: `(t1, t2, ...)`.
125
- * * @param list_of_pairs A Tree representing a list of pairs.
120
+ * @param list_of_pairs A Tree representing a list of pairs.
126
121
* @return A new Tree representing the list of second elements.
127
122
*/
128
123
static Tree mapHdTl (Tree list_of_pairs)
@@ -166,13 +161,15 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
166
161
// Handle constants: The derivative of any constant is zero.
167
162
if (isSigReal (sig, &r_val)) {
168
163
// Return the dual number (constant_real, 0.0)
169
- return cons (sig, cons (sigReal (0.0 ), gGlobal ->nil ));
170
- } else if (isSigInt (sig, &i_val) || isSigInt64 (sig, &i64_val)) {
164
+ return dual (sig, sigReal (0.0 ));
165
+ }
166
+
167
+ if (isSigInt (sig, &i_val) || isSigInt64 (sig, &i64_val)) {
171
168
// Return the dual number (constant_int, 0)
172
169
// Note: While derivatives are floats, we create a typed integer zero
173
170
// which will be cast to float automatically by arithmetic operations.
174
171
// A more robust implementation might cast to float immediately.
175
- return cons (sig, cons ( sigInt (0 ), gGlobal -> nil ));
172
+ return dual (sig, sigInt (0 ));
176
173
}
177
174
178
175
// Handle differentiable UI controls (sliders, nentries).
@@ -182,13 +179,13 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
182
179
// its derivative is 1. Otherwise, it's treated as a constant, so its derivative is 0.
183
180
Tree tangent = (sig == fDiffControl ) ? sigReal (1.0 ) : sigReal (0.0 );
184
181
// Return the dual number (control, 1.0 or 0.0)
185
- return cons (sig, cons ( tangent, gGlobal -> nil ) );
182
+ return dual (sig, tangent);
186
183
}
187
184
188
185
// Handle non-differentiable UI elements like buttons. Their derivative is 0.
189
186
if (isSigButton (sig, label) || isSigCheckbox (sig, label)) {
190
187
// Return the dual number (button, 0.0)
191
- return cons (sig, cons ( sigReal (0.0 ), gGlobal -> nil ));
188
+ return dual (sig, sigReal (0.0 ));
192
189
}
193
190
194
191
// Math primitives
@@ -202,33 +199,34 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
202
199
if (ext == gGlobal ->gPowPrim || ext == gGlobal ->gFmodPrim ||
203
200
ext == gGlobal ->gRemainderPrim || ext == gGlobal ->gMaxPrim ||
204
201
ext == gGlobal ->gMinPrim ) {
205
-
206
- // Derivative of these primitives require f, g, f' and g'.
207
202
Tree dual_x = self (sig->branch (0 )); // This will be (primal_x, tangent_x)
208
203
Tree dual_y = self (sig->branch (1 )); // This will be (primal_y, tangent_y)
209
204
210
- Tree primal_x = hd (dual_x);
211
- Tree tangent_x = hd ( tl ( dual_x) );
205
+ Tree primal_x = primal (dual_x);
206
+ Tree tangent_x = tangent ( dual_x);
212
207
213
- Tree primal_y = hd (dual_y);
214
- Tree tangent_y = hd ( tl ( dual_y) );
208
+ Tree primal_y = primal (dual_y);
209
+ Tree tangent_y = tangent ( dual_y);
215
210
216
211
std::vector<Tree> primal_args;
217
212
primal_args.push_back (primal_x);
218
213
primal_args.push_back (primal_y);
219
214
Tree new_primal = ext->computeSigOutput (primal_args);
220
215
216
+ // Derivative of these primitives require f, g, f' and g'.
221
217
std::vector<Tree> tangent_args;
218
+ tangent_args.push_back (primal_x);
219
+ tangent_args.push_back (primal_y);
222
220
tangent_args.push_back (tangent_x);
223
221
tangent_args.push_back (tangent_y);
224
222
Tree new_tangent = ext->diff (tangent_args);
225
223
226
- return cons (new_primal, cons ( new_tangent, gGlobal -> nil ) );
224
+ return dual (new_primal, new_tangent);
227
225
} else {
228
226
// chain rule for unary function: f(g(x))' = f'(g(x)) * g'(x)
229
227
Tree dual_x = self (sig->branch (0 ));
230
- Tree primal_x = hd (dual_x);
231
- Tree tangent_x = hd ( tl ( dual_x) );
228
+ Tree primal_x = primal (dual_x);
229
+ Tree tangent_x = tangent ( dual_x);
232
230
233
231
std::vector<Tree> primal_args;
234
232
primal_args.push_back (primal_x);
@@ -238,7 +236,7 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
238
236
tangent_args.push_back (primal_x);
239
237
Tree new_tangent = sigMul (ext->diff (tangent_args), tangent_x);
240
238
241
- return cons (new_primal, cons ( new_tangent, gGlobal -> nil ) );
239
+ return dual (new_primal, new_tangent);
242
240
}
243
241
}
244
242
@@ -249,10 +247,10 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
249
247
Tree dual_y = self (y); // This will be (primal_y, tangent_y)
250
248
251
249
// Deconstruct the dual number lists to get their components.
252
- Tree primal_x = hd (dual_x);
253
- Tree tangent_x = hd ( tl ( dual_x) );
254
- Tree primal_y = hd (dual_y);
255
- Tree tangent_y = hd ( tl ( dual_y) );
250
+ Tree primal_x = primal (dual_x);
251
+ Tree tangent_x = tangent ( dual_x);
252
+ Tree primal_y = primal (dual_y);
253
+ Tree tangent_y = tangent ( dual_y);
256
254
257
255
// The new primal is the same operation applied to the children's primals.
258
256
Tree new_primal = sigBinOp (opt_op, primal_x, primal_y);
@@ -285,32 +283,37 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
285
283
break ;
286
284
}
287
285
// Return the newly constructed dual number signal.
288
- return cons (new_primal, cons ( new_tangent, gGlobal -> nil ) );
286
+ return dual (new_primal, new_tangent);
289
287
}
290
288
291
289
// Handle one-sample delay. The derivative of a delay is the delay of the derivative.
292
290
if (isSigDelay1 (sig, u_tree)) {
293
291
Tree dual_u = self (u_tree);
294
- Tree primal_u = hd (dual_u);
295
- Tree tangent_u = hd ( tl ( dual_u) );
292
+ Tree primal_u = primal (dual_u);
293
+ Tree tangent_u = tangent ( dual_u);
296
294
// Apply delay to both the primal and tangent signals.
297
295
Tree new_primal = sigDelay1 (primal_u);
298
296
Tree new_tangent = sigDelay1 (tangent_u);
299
- return cons (new_primal, cons ( new_tangent, gGlobal -> nil ) );
297
+ return dual (new_primal, new_tangent);
300
298
}
301
299
302
300
// Handle variable-length delay.
303
301
// For this to be mathematically meaningful, the underlying implementation of the
304
302
// variable delay must be interpolating and differentiable.
305
303
if (isSigDelay (sig, u_tree, d_tree)) {
304
+ Node n = d_tree->node ();
305
+ if (isZero (n)) {
306
+ return self (u_tree);
307
+ }
308
+
306
309
// Recursively get dual signals for the input signal 'u' and delay time 'd'.
307
310
Tree dual_u = self (u_tree);
308
311
Tree dual_d = self (d_tree);
309
312
310
- Tree primal_u = hd (dual_u);
311
- Tree tangent_u = hd ( tl ( dual_u) ); // This is u'
312
- Tree primal_d = hd (dual_d);
313
- Tree tangent_d = hd ( tl ( dual_d) ); // This is d'
313
+ Tree primal_u = primal (dual_u);
314
+ Tree tangent_u = tangent ( dual_u); // This is u'
315
+ Tree primal_d = primal (dual_d);
316
+ Tree tangent_d = tangent ( dual_d); // This is d'
314
317
315
318
// The new primal is the variable delay applied to the primal signal.
316
319
Tree new_primal = sigDelay (primal_u, primal_d);
@@ -335,79 +338,79 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
335
338
// The final tangent is the sum of the two terms.
336
339
Tree new_tangent = sigAdd (term1, term2);
337
340
338
- return cons (new_primal, cons ( new_tangent, gGlobal -> nil ) );
341
+ return dual (new_primal, new_tangent);
339
342
}
340
343
341
344
// Handle conditional selection (multiplexer).
342
345
Tree sel, tx, ty;
343
346
if (isSigSelect2 (sig, sel, tx, ty)) {
344
- Tree dual_sel = self (sel);
345
- Tree dual_x = self (tx);
346
- Tree dual_y = self (ty);
347
-
348
- Tree primal_sel = hd (dual_sel);
349
- Tree new_primal = sigSelect2 (primal_sel, hd (dual_x), hd (dual_y));
350
- Tree new_tangent = sigSelect2 (primal_sel, hd (tl (dual_x)), hd (tl (dual_y)));
351
-
352
- return cons (new_primal, cons (new_tangent, gGlobal ->nil ));
347
+ Tree dual_sel = self (sel);
348
+ Tree dual_x = self (tx);
349
+ Tree dual_y = self (ty);
350
+ Tree primal_sel = primal (dual_sel);
351
+ Tree new_primal = sigSelect2 (primal_sel, primal (dual_x), primal (dual_y));
352
+ Tree new_tangent = sigSelect2 (primal_sel, tangent (dual_x), tangent (dual_y));
353
+ return dual (new_primal, new_tangent);
353
354
}
354
355
355
356
// Handle prefix operator (one-sample initialization).
356
357
if (isSigPrefix (sig, x, y)) {
357
- Tree dual_x = self (x);
358
- Tree dual_y = self (y);
359
-
360
- Tree new_primal = sigPrefix (hd (dual_x), hd (dual_y));
361
- Tree new_tangent = sigPrefix (hd (tl (dual_x)), hd (tl (dual_y)));
362
-
363
- return cons (new_primal, cons (new_tangent, gGlobal ->nil ));
358
+ Tree dual_x = self (x);
359
+ Tree dual_y = self (y);
360
+ Tree new_primal = sigPrefix (primal (dual_x), primal (dual_y));
361
+ Tree new_tangent = sigPrefix (tangent (dual_x), tangent (dual_y));
362
+ return dual (new_primal, new_tangent);
364
363
}
365
364
366
365
// Handle type casting.
367
366
if (isSigFloatCast (sig, x)) {
368
367
Tree dual_x = self (x);
369
- Tree new_primal = sigFloatCast (hd (dual_x));
370
- Tree new_tangent = sigFloatCast (hd ( tl ( dual_x) ));
371
- return cons (new_primal, cons ( new_tangent, gGlobal -> nil ) );
368
+ Tree new_primal = sigFloatCast (primal (dual_x));
369
+ Tree new_tangent = sigFloatCast (tangent ( dual_x));
370
+ return dual (new_primal, new_tangent);
372
371
}
373
372
374
373
if (isSigIntCast (sig, x)) {
375
374
Tree dual_x = self (x);
376
- Tree new_primal = sigIntCast (hd (dual_x));
375
+ Tree new_primal = sigIntCast (primal (dual_x));
377
376
Tree new_tangent = sigReal (0.0 );
378
- return cons (new_primal, cons ( new_tangent, gGlobal -> nil ) );
377
+ return dual (new_primal, new_tangent);
379
378
}
380
379
381
380
// Handle recursion (rec/proj).
382
381
// Based on the rule (rec(u))' = rec(u'), we differentiate 'through' the recursion.
383
- Tree rec_block;
384
- if (isRec (sig, x, y)) {
385
- /*
386
- if (getUserData(sig)) {
387
- return getUserData(sig);
388
- }
389
- */
382
+ Tree rec_block, var, body;
383
+ if (isRec (sig, var, body)) {
384
+ if (isNil (body)) {
385
+ // we are already visiting this recursive group
386
+ return dual (sig, sig);
387
+ } else {
388
+ // first visit
389
+ rec (var, gGlobal ->nil ); // to avoid infinite recursions
390
390
391
- Tree dual_list = mapself (y);
392
- Tree primal_expressions = mapHd (dual_list);
393
- Tree tangent_expressions = mapHdTl (dual_list);
391
+ // compute the recursion
392
+ Tree dual_list = mapselfRec (body);
393
+ Tree primal_expressions = mapHd (dual_list);
394
+ Tree tangent_expressions = mapHdTl (dual_list);
394
395
395
- Tree primal_rec = rec (x, primal_expressions);
396
- Tree tangent_rec = rec (x, tangent_expressions );
396
+ // new symbol for differentiated rec
397
+ Tree dvar ( tree ( unique ( " DW " )) );
397
398
398
- Tree result = cons (primal_rec, cons (tangent_rec, gGlobal ->nil ));
399
- // setUserData(sig, result);
400
- return result;
399
+ Tree primal_rec = rec (var, primal_expressions);
400
+ Tree tangent_rec = rec (dvar, tangent_expressions);
401
+ return dual (primal_rec, tangent_rec);
402
+ }
401
403
}
402
404
403
405
int rec_idx;
404
406
if (isProj (sig, &rec_idx, rec_block)) {
405
407
Tree dual_rec_block = self (rec_block);
406
- Tree primal_rec = hd (dual_rec_block);
407
- Tree tangent_rec = hd ( tl ( dual_rec_block) );
408
+ Tree primal_rec = primal (dual_rec_block);
409
+ Tree tangent_rec = tangent ( dual_rec_block);
408
410
Tree new_primal = sigProj (rec_idx, primal_rec);
409
411
Tree new_tangent = sigProj (rec_idx, tangent_rec);
410
- return cons (new_primal, cons (new_tangent, gGlobal ->nil ));
412
+
413
+ return dual (new_primal, new_tangent);
411
414
}
412
415
413
416
// Handle output nodes.
@@ -416,7 +419,7 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
416
419
}
417
420
418
421
// Fallback for any unhandled signal types. Treat them as constants.
419
- return cons (sig, cons ( sigReal (0.0 ), gGlobal -> nil ));
422
+ return dual (sig, sigReal (0.0 ));
420
423
}
421
424
422
425
/* *
@@ -452,7 +455,7 @@ siglist generateADSignals(const siglist& outputs_list_aux)
452
455
Tree dual_signal = ad_transform.self (out_sig);
453
456
454
457
// The result is a dual number (primal, tangent). We only need the tangent part.
455
- Tree derivative_signal = hd ( tl ( dual_signal) );
458
+ Tree derivative_signal = tangent ( dual_signal);
456
459
457
460
// Add the new derivative signal to our master list of signals.
458
461
all_signals = cons (derivative_signal, all_signals);
0 commit comments