Skip to content

Commit 6b9a4e1

Browse files
committed
Recursive case somewhat working, correct math primitive in ForwardADSignalTransform .
1 parent 2ad1ed0 commit 6b9a4e1

File tree

2 files changed

+111
-95
lines changed

2 files changed

+111
-95
lines changed

compiler/transform/forwardADSignalTransform.cpp

Lines changed: 98 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,41 @@
2121

2222
#include "forwardADSignalTransform.hh"
2323
#include <string> // For std::string and std::string::find
24+
#include "description.hh"
2425
#include "global.hh"
2526
#include "list.hh"
2627
#include "ppsig.hh"
2728
#include "signalVisitor.hh" // For DependencyVisitor
2829
#include "xtended.hh"
2930

3031
/**
31-
* @brief Parses the label of a UI control to check for the [autodiff:false] tag.
32+
* @brief Parses the label of a UI control to check for the [autodiff:xxx] tag.
3233
*
3334
* Faust UI element labels can contain metadata in the form of `[key:value]`.
3435
* This function checks for the specific tag that disables automatic differentiation
3536
* for a given parameter.
3637
*
37-
* @param label_path_tree The 'path' Tree associated with a UI element like hslider.
38+
* @param label The 'path' Tree associated with a UI element like hslider.
3839
* This is expected to be a list, where the head is the local label.
39-
* @return true if the "[autodiff:false]" tag is found in the label string, false otherwise.
40+
* @return false if the "[autodiff:false]" tag is found in the label string, true otherwise.
4041
*/
41-
static bool hasAutodiffFalseTag(Tree label_path_tree)
42+
static bool hasAutodiff(Tree label)
4243
{
43-
if (!label_path_tree || isNil(label_path_tree)) {
44-
return false;
45-
}
46-
try {
47-
// The label path is a list of symbols. The local label is the first element.
48-
// We convert its symbol node to a string to inspect it.
49-
const char* label_str = tree2str(hd(label_path_tree));
50-
if (label_str) {
51-
// Perform a simple substring search for the metadata tag.
52-
return std::string(label_str).find("[autodiff:false]") != std::string::npos;
44+
std::map<std::string, std::set<std::string>> metadata;
45+
std::string simplifiedLabel;
46+
47+
extractMetadata(tree2str(hd(label)), simplifiedLabel, metadata);
48+
49+
// Look for [autodiff:false]
50+
for (const auto& i : metadata) {
51+
if (i.first == "autodiff") {
52+
const std::set<std::string>& values = i.second;
53+
for (const auto& j : values) {
54+
return (j == "true");
55+
}
5356
}
54-
} catch (const faustexception& e) {
55-
// tree2str can throw an exception if the tree is not a symbol,
56-
// which can happen with complex/computed labels. In that case,
57-
// we assume no tag is present.
58-
return false;
5957
}
60-
return false;
58+
return true;
6159
}
6260

6361
/**
@@ -82,13 +80,10 @@ struct ADDependencyVisitor : public SignalVisitor {
8280
if (isSigVSlider(sig, path, c, x, y, z) || isSigHSlider(sig, path, c, x, y, z) ||
8381
isSigNumEntry(sig, path, c, x, y, z)) {
8482
// Check the label for the [autodiff:false] tag.
85-
if (!hasAutodiffFalseTag(path)) {
86-
// If the tag is not present, add this control to the set of
87-
// parameters that will be differentiated.
83+
if (hasAutodiff(path)) {
8884
fControls.insert(sig);
8985
}
9086
// If the tag is present, we do nothing, effectively excluding it.
91-
9287
} else {
9388
// For all other signal types, continue the traversal to visit children.
9489
SignalVisitor::visit(sig);
@@ -100,7 +95,7 @@ struct ADDependencyVisitor : public SignalVisitor {
10095
* @brief Extracts the first element (head) of each pair in a list of pairs.
10196
* * Given a list of dual numbers `((p1, t1), (p2, t2), ...)` this function
10297
* returns a new list containing only the primal components: `(p1, p2, ...)`.
103-
* * @param list_of_pairs A Tree representing a list of pairs.
98+
* @param list_of_pairs A Tree representing a list of pairs.
10499
* @return A new Tree representing the list of first elements.
105100
*/
106101
static Tree mapHd(Tree list_of_pairs)
@@ -122,7 +117,7 @@ static Tree mapHd(Tree list_of_pairs)
122117
* @brief Extracts the second element (head of the tail) of each pair in a list of pairs.
123118
* * Given a list of dual numbers `((p1, t1), (p2, t2), ...)` this function
124119
* returns a new list containing only the tangent components: `(t1, t2, ...)`.
125-
* * @param list_of_pairs A Tree representing a list of pairs.
120+
* @param list_of_pairs A Tree representing a list of pairs.
126121
* @return A new Tree representing the list of second elements.
127122
*/
128123
static Tree mapHdTl(Tree list_of_pairs)
@@ -166,13 +161,15 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
166161
// Handle constants: The derivative of any constant is zero.
167162
if (isSigReal(sig, &r_val)) {
168163
// Return the dual number (constant_real, 0.0)
169-
return cons(sig, cons(sigReal(0.0), gGlobal->nil));
170-
} else if (isSigInt(sig, &i_val) || isSigInt64(sig, &i64_val)) {
164+
return dual(sig, sigReal(0.0));
165+
}
166+
167+
if (isSigInt(sig, &i_val) || isSigInt64(sig, &i64_val)) {
171168
// Return the dual number (constant_int, 0)
172169
// Note: While derivatives are floats, we create a typed integer zero
173170
// which will be cast to float automatically by arithmetic operations.
174171
// A more robust implementation might cast to float immediately.
175-
return cons(sig, cons(sigInt(0), gGlobal->nil));
172+
return dual(sig, sigInt(0));
176173
}
177174

178175
// Handle differentiable UI controls (sliders, nentries).
@@ -182,13 +179,13 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
182179
// its derivative is 1. Otherwise, it's treated as a constant, so its derivative is 0.
183180
Tree tangent = (sig == fDiffControl) ? sigReal(1.0) : sigReal(0.0);
184181
// Return the dual number (control, 1.0 or 0.0)
185-
return cons(sig, cons(tangent, gGlobal->nil));
182+
return dual(sig, tangent);
186183
}
187184

188185
// Handle non-differentiable UI elements like buttons. Their derivative is 0.
189186
if (isSigButton(sig, label) || isSigCheckbox(sig, label)) {
190187
// Return the dual number (button, 0.0)
191-
return cons(sig, cons(sigReal(0.0), gGlobal->nil));
188+
return dual(sig, sigReal(0.0));
192189
}
193190

194191
// Math primitives
@@ -202,33 +199,34 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
202199
if (ext == gGlobal->gPowPrim || ext == gGlobal->gFmodPrim ||
203200
ext == gGlobal->gRemainderPrim || ext == gGlobal->gMaxPrim ||
204201
ext == gGlobal->gMinPrim) {
205-
206-
// Derivative of these primitives require f, g, f' and g'.
207202
Tree dual_x = self(sig->branch(0)); // This will be (primal_x, tangent_x)
208203
Tree dual_y = self(sig->branch(1)); // This will be (primal_y, tangent_y)
209204

210-
Tree primal_x = hd(dual_x);
211-
Tree tangent_x = hd(tl(dual_x));
205+
Tree primal_x = primal(dual_x);
206+
Tree tangent_x = tangent(dual_x);
212207

213-
Tree primal_y = hd(dual_y);
214-
Tree tangent_y = hd(tl(dual_y));
208+
Tree primal_y = primal(dual_y);
209+
Tree tangent_y = tangent(dual_y);
215210

216211
std::vector<Tree> primal_args;
217212
primal_args.push_back(primal_x);
218213
primal_args.push_back(primal_y);
219214
Tree new_primal = ext->computeSigOutput(primal_args);
220215

216+
// Derivative of these primitives require f, g, f' and g'.
221217
std::vector<Tree> tangent_args;
218+
tangent_args.push_back(primal_x);
219+
tangent_args.push_back(primal_y);
222220
tangent_args.push_back(tangent_x);
223221
tangent_args.push_back(tangent_y);
224222
Tree new_tangent = ext->diff(tangent_args);
225223

226-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
224+
return dual(new_primal, new_tangent);
227225
} else {
228226
// chain rule for unary function: f(g(x))' = f'(g(x)) * g'(x)
229227
Tree dual_x = self(sig->branch(0));
230-
Tree primal_x = hd(dual_x);
231-
Tree tangent_x = hd(tl(dual_x));
228+
Tree primal_x = primal(dual_x);
229+
Tree tangent_x = tangent(dual_x);
232230

233231
std::vector<Tree> primal_args;
234232
primal_args.push_back(primal_x);
@@ -238,7 +236,7 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
238236
tangent_args.push_back(primal_x);
239237
Tree new_tangent = sigMul(ext->diff(tangent_args), tangent_x);
240238

241-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
239+
return dual(new_primal, new_tangent);
242240
}
243241
}
244242

@@ -249,10 +247,10 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
249247
Tree dual_y = self(y); // This will be (primal_y, tangent_y)
250248

251249
// Deconstruct the dual number lists to get their components.
252-
Tree primal_x = hd(dual_x);
253-
Tree tangent_x = hd(tl(dual_x));
254-
Tree primal_y = hd(dual_y);
255-
Tree tangent_y = hd(tl(dual_y));
250+
Tree primal_x = primal(dual_x);
251+
Tree tangent_x = tangent(dual_x);
252+
Tree primal_y = primal(dual_y);
253+
Tree tangent_y = tangent(dual_y);
256254

257255
// The new primal is the same operation applied to the children's primals.
258256
Tree new_primal = sigBinOp(opt_op, primal_x, primal_y);
@@ -285,32 +283,37 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
285283
break;
286284
}
287285
// Return the newly constructed dual number signal.
288-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
286+
return dual(new_primal, new_tangent);
289287
}
290288

291289
// Handle one-sample delay. The derivative of a delay is the delay of the derivative.
292290
if (isSigDelay1(sig, u_tree)) {
293291
Tree dual_u = self(u_tree);
294-
Tree primal_u = hd(dual_u);
295-
Tree tangent_u = hd(tl(dual_u));
292+
Tree primal_u = primal(dual_u);
293+
Tree tangent_u = tangent(dual_u);
296294
// Apply delay to both the primal and tangent signals.
297295
Tree new_primal = sigDelay1(primal_u);
298296
Tree new_tangent = sigDelay1(tangent_u);
299-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
297+
return dual(new_primal, new_tangent);
300298
}
301299

302300
// Handle variable-length delay.
303301
// For this to be mathematically meaningful, the underlying implementation of the
304302
// variable delay must be interpolating and differentiable.
305303
if (isSigDelay(sig, u_tree, d_tree)) {
304+
Node n = d_tree->node();
305+
if (isZero(n)) {
306+
return self(u_tree);
307+
}
308+
306309
// Recursively get dual signals for the input signal 'u' and delay time 'd'.
307310
Tree dual_u = self(u_tree);
308311
Tree dual_d = self(d_tree);
309312

310-
Tree primal_u = hd(dual_u);
311-
Tree tangent_u = hd(tl(dual_u)); // This is u'
312-
Tree primal_d = hd(dual_d);
313-
Tree tangent_d = hd(tl(dual_d)); // This is d'
313+
Tree primal_u = primal(dual_u);
314+
Tree tangent_u = tangent(dual_u); // This is u'
315+
Tree primal_d = primal(dual_d);
316+
Tree tangent_d = tangent(dual_d); // This is d'
314317

315318
// The new primal is the variable delay applied to the primal signal.
316319
Tree new_primal = sigDelay(primal_u, primal_d);
@@ -335,79 +338,79 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
335338
// The final tangent is the sum of the two terms.
336339
Tree new_tangent = sigAdd(term1, term2);
337340

338-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
341+
return dual(new_primal, new_tangent);
339342
}
340343

341344
// Handle conditional selection (multiplexer).
342345
Tree sel, tx, ty;
343346
if (isSigSelect2(sig, sel, tx, ty)) {
344-
Tree dual_sel = self(sel);
345-
Tree dual_x = self(tx);
346-
Tree dual_y = self(ty);
347-
348-
Tree primal_sel = hd(dual_sel);
349-
Tree new_primal = sigSelect2(primal_sel, hd(dual_x), hd(dual_y));
350-
Tree new_tangent = sigSelect2(primal_sel, hd(tl(dual_x)), hd(tl(dual_y)));
351-
352-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
347+
Tree dual_sel = self(sel);
348+
Tree dual_x = self(tx);
349+
Tree dual_y = self(ty);
350+
Tree primal_sel = primal(dual_sel);
351+
Tree new_primal = sigSelect2(primal_sel, primal(dual_x), primal(dual_y));
352+
Tree new_tangent = sigSelect2(primal_sel, tangent(dual_x), tangent(dual_y));
353+
return dual(new_primal, new_tangent);
353354
}
354355

355356
// Handle prefix operator (one-sample initialization).
356357
if (isSigPrefix(sig, x, y)) {
357-
Tree dual_x = self(x);
358-
Tree dual_y = self(y);
359-
360-
Tree new_primal = sigPrefix(hd(dual_x), hd(dual_y));
361-
Tree new_tangent = sigPrefix(hd(tl(dual_x)), hd(tl(dual_y)));
362-
363-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
358+
Tree dual_x = self(x);
359+
Tree dual_y = self(y);
360+
Tree new_primal = sigPrefix(primal(dual_x), primal(dual_y));
361+
Tree new_tangent = sigPrefix(tangent(dual_x), tangent(dual_y));
362+
return dual(new_primal, new_tangent);
364363
}
365364

366365
// Handle type casting.
367366
if (isSigFloatCast(sig, x)) {
368367
Tree dual_x = self(x);
369-
Tree new_primal = sigFloatCast(hd(dual_x));
370-
Tree new_tangent = sigFloatCast(hd(tl(dual_x)));
371-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
368+
Tree new_primal = sigFloatCast(primal(dual_x));
369+
Tree new_tangent = sigFloatCast(tangent(dual_x));
370+
return dual(new_primal, new_tangent);
372371
}
373372

374373
if (isSigIntCast(sig, x)) {
375374
Tree dual_x = self(x);
376-
Tree new_primal = sigIntCast(hd(dual_x));
375+
Tree new_primal = sigIntCast(primal(dual_x));
377376
Tree new_tangent = sigReal(0.0);
378-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
377+
return dual(new_primal, new_tangent);
379378
}
380379

381380
// Handle recursion (rec/proj).
382381
// Based on the rule (rec(u))' = rec(u'), we differentiate 'through' the recursion.
383-
Tree rec_block;
384-
if (isRec(sig, x, y)) {
385-
/*
386-
if (getUserData(sig)) {
387-
return getUserData(sig);
388-
}
389-
*/
382+
Tree rec_block, var, body;
383+
if (isRec(sig, var, body)) {
384+
if (isNil(body)) {
385+
// we are already visiting this recursive group
386+
return dual(sig, sig);
387+
} else {
388+
// first visit
389+
rec(var, gGlobal->nil); // to avoid infinite recursions
390390

391-
Tree dual_list = mapself(y);
392-
Tree primal_expressions = mapHd(dual_list);
393-
Tree tangent_expressions = mapHdTl(dual_list);
391+
// compute the recursion
392+
Tree dual_list = mapselfRec(body);
393+
Tree primal_expressions = mapHd(dual_list);
394+
Tree tangent_expressions = mapHdTl(dual_list);
394395

395-
Tree primal_rec = rec(x, primal_expressions);
396-
Tree tangent_rec = rec(x, tangent_expressions);
396+
// new symbol for differentiated rec
397+
Tree dvar(tree(unique("DW")));
397398

398-
Tree result = cons(primal_rec, cons(tangent_rec, gGlobal->nil));
399-
// setUserData(sig, result);
400-
return result;
399+
Tree primal_rec = rec(var, primal_expressions);
400+
Tree tangent_rec = rec(dvar, tangent_expressions);
401+
return dual(primal_rec, tangent_rec);
402+
}
401403
}
402404

403405
int rec_idx;
404406
if (isProj(sig, &rec_idx, rec_block)) {
405407
Tree dual_rec_block = self(rec_block);
406-
Tree primal_rec = hd(dual_rec_block);
407-
Tree tangent_rec = hd(tl(dual_rec_block));
408+
Tree primal_rec = primal(dual_rec_block);
409+
Tree tangent_rec = tangent(dual_rec_block);
408410
Tree new_primal = sigProj(rec_idx, primal_rec);
409411
Tree new_tangent = sigProj(rec_idx, tangent_rec);
410-
return cons(new_primal, cons(new_tangent, gGlobal->nil));
412+
413+
return dual(new_primal, new_tangent);
411414
}
412415

413416
// Handle output nodes.
@@ -416,7 +419,7 @@ Tree ForwardADSignalTransform::transformation(Tree sig)
416419
}
417420

418421
// Fallback for any unhandled signal types. Treat them as constants.
419-
return cons(sig, cons(sigReal(0.0), gGlobal->nil));
422+
return dual(sig, sigReal(0.0));
420423
}
421424

422425
/**
@@ -452,7 +455,7 @@ siglist generateADSignals(const siglist& outputs_list_aux)
452455
Tree dual_signal = ad_transform.self(out_sig);
453456

454457
// The result is a dual number (primal, tangent). We only need the tangent part.
455-
Tree derivative_signal = hd(tl(dual_signal));
458+
Tree derivative_signal = tangent(dual_signal);
456459

457460
// Add the new derivative signal to our master list of signals.
458461
all_signals = cons(derivative_signal, all_signals);

0 commit comments

Comments
 (0)