@@ -189,6 +189,11 @@ struct DAE : public Pass {
189
189
190
190
bool optimize = false ;
191
191
192
+ Index numFunctions;
193
+
194
+ // Map of function names to indexes. This lets us use indexes below for speed.
195
+ std::unordered_map<Name, Index> indexes;
196
+
192
197
void run (Module* module ) override {
193
198
DAEFunctionInfoMap infoMap;
194
199
// Ensure all entries exist so the parallel threads don't modify the data
@@ -199,6 +204,12 @@ struct DAE : public Pass {
199
204
// The null name represents module-level code (not in a function).
200
205
infoMap[Name ()];
201
206
207
+ numFunctions = module ->functions .size ();
208
+
209
+ for (Index i = 0 ; i < numFunctions; i++) {
210
+ indexes[module ->functions [i]->name ] = i;
211
+ }
212
+
202
213
// Iterate to convergence.
203
214
while (1 ) {
204
215
if (!iteration (module , infoMap)) {
@@ -234,34 +245,36 @@ struct DAE : public Pass {
234
245
Call* call;
235
246
Function* func;
236
247
};
237
- std::map<Name, std::vector<Call*>> allCalls;
238
- std::unordered_set<Name> tailCallees;
239
- std::unordered_set<Name> hasUnseenCalls;
248
+
249
+ std::vector<std::vector<Call*>> allCalls (numFunctions);
250
+ std::vector<bool > tailCallees (numFunctions);
251
+ std::vector<bool > hasUnseenCalls (numFunctions);
252
+
240
253
// Track the function in which relevant expressions exist. When we modify
241
254
// those expressions we will need to mark the function's info as stale.
242
255
std::unordered_map<Expression*, Name> expressionFuncs;
243
256
for (auto & [func, info] : infoMap) {
244
257
for (auto & [name, calls] : info.calls ) {
245
- auto & allCallsToName = allCalls[name];
258
+ auto & allCallsToName = allCalls[indexes[ name] ];
246
259
allCallsToName.insert (allCallsToName.end (), calls.begin (), calls.end ());
247
260
for (auto * call : calls) {
248
261
expressionFuncs[call] = func;
249
262
}
250
263
}
251
264
for (auto & callee : info.tailCallees ) {
252
- tailCallees. insert ( callee) ;
265
+ tailCallees[indexes[ callee]] = true ;
253
266
}
254
267
for (auto & [call, dropp] : info.droppedCalls ) {
255
268
allDroppedCalls[call] = dropp;
256
269
}
257
270
for (auto & name : info.hasUnseenCalls ) {
258
- hasUnseenCalls. insert ( name) ;
271
+ hasUnseenCalls[indexes[ name]] = true ;
259
272
}
260
273
}
261
274
// Exports are considered unseen calls.
262
275
for (auto & curr : module ->exports ) {
263
276
if (curr->kind == ExternalKind::Function) {
264
- hasUnseenCalls. insert ( *curr->getInternalName ()) ;
277
+ hasUnseenCalls[indexes[ *curr->getInternalName ()]] = true ;
265
278
}
266
279
}
267
280
@@ -300,23 +313,32 @@ struct DAE : public Pass {
300
313
301
314
// We now have a mapping of all call sites for each function, and can look
302
315
// for optimization opportunities.
303
- for (auto & [name, calls] : allCalls) {
316
+ for (Index index = 0 ; index < numFunctions; index++) {
317
+ auto * func = module ->functions [index].get ();
318
+ if (func->imported ()) {
319
+ continue ;
320
+ }
304
321
// We can only optimize if we see all the calls and can modify them.
305
- if (hasUnseenCalls.count (name)) {
322
+ if (hasUnseenCalls[index]) {
323
+ continue ;
324
+ }
325
+ auto & calls = allCalls[index];
326
+ if (calls.empty ()) {
327
+ // Nothing calls this, so it is not worth optimizing.
306
328
continue ;
307
329
}
308
- auto * func = module ->getFunction (name);
309
330
// Refine argument types before doing anything else. This does not
310
331
// affect whether an argument is used or not, it just refines the type
311
332
// where possible.
333
+ auto name = func->name ;
312
334
if (refineArgumentTypes (func, calls, module , infoMap[name])) {
313
335
worthOptimizing.insert (func);
314
336
markStale (func->name );
315
337
}
316
338
// Refine return types as well.
317
339
if (refineReturnTypes (func, calls, module )) {
318
340
refinedReturnTypes = true ;
319
- markStale (func-> name );
341
+ markStale (name);
320
342
markCallersStale (calls);
321
343
}
322
344
auto optimizedIndexes =
@@ -337,21 +359,29 @@ struct DAE : public Pass {
337
359
ReFinalize ().run (getPassRunner (), module );
338
360
}
339
361
// We now know which parameters are unused, and can potentially remove them.
340
- for (auto & [name, calls] : allCalls) {
341
- if (hasUnseenCalls.count (name)) {
362
+ for (Index index = 0 ; index < numFunctions; index++) {
363
+ auto * func = module ->functions [index].get ();
364
+ if (func->imported ()) {
365
+ continue ;
366
+ }
367
+ if (hasUnseenCalls[index]) {
342
368
continue ;
343
369
}
344
- auto * func = module ->getFunction (name);
345
370
auto numParams = func->getNumParams ();
346
371
if (numParams == 0 ) {
347
372
continue ;
348
373
}
374
+ auto & calls = allCalls[index];
375
+ if (calls.empty ()) {
376
+ continue ;
377
+ }
378
+ auto name = func->name ;
349
379
auto [removedIndexes, outcome] = ParamUtils::removeParameters (
350
380
{func}, infoMap[name].unusedParams , calls, {}, module , getPassRunner ());
351
381
if (!removedIndexes.empty ()) {
352
382
// Success!
353
383
worthOptimizing.insert (func);
354
- markStale (func-> name );
384
+ markStale (name);
355
385
markCallersStale (calls);
356
386
}
357
387
if (outcome == ParamUtils::RemovalOutcome::Failure) {
@@ -363,25 +393,28 @@ struct DAE : public Pass {
363
393
// modified allCalls (we can't modify a call site twice in one iteration,
364
394
// once to remove a param, once to drop the return value).
365
395
if (worthOptimizing.empty ()) {
366
- for (auto & func : module ->functions ) {
396
+ for (Index index = 0 ; index < numFunctions; index++) {
397
+ auto & func = module ->functions [index];
398
+ if (func->imported ()) {
399
+ continue ;
400
+ }
367
401
if (func->getResults () == Type::none) {
368
402
continue ;
369
403
}
370
- auto name = func->name ;
371
- if (hasUnseenCalls.count (name)) {
404
+ if (hasUnseenCalls[index]) {
372
405
continue ;
373
406
}
407
+ auto name = func->name ;
374
408
if (infoMap[name].hasTailCalls ) {
375
409
continue ;
376
410
}
377
- if (tailCallees. count (name) ) {
411
+ if (tailCallees[index] ) {
378
412
continue ;
379
413
}
380
- auto iter = allCalls. find (name) ;
381
- if (iter == allCalls. end ()) {
414
+ auto & calls = allCalls[index] ;
415
+ if (calls. empty ()) {
382
416
continue ;
383
417
}
384
- auto & calls = iter->second ;
385
418
bool allDropped =
386
419
std::all_of (calls.begin (), calls.end (), [&](Call* call) {
387
420
return allDroppedCalls.count (call);
@@ -398,7 +431,7 @@ struct DAE : public Pass {
398
431
// TODO Removing a drop may also open optimization opportunities in the
399
432
// callers.
400
433
worthOptimizing.insert (func.get ());
401
- markStale (func-> name );
434
+ markStale (name);
402
435
markCallersStale (calls);
403
436
}
404
437
}
0 commit comments