Skip to content

Commit 4b1e557

Browse files
authored
[NFC] Use indexes in DeadArgumentElimination (#7868)
Using vectors indexed by function makes this pass 15% faster on large testcases. Previously we were doing a great number of hashtable inserts using function names.
1 parent 7156ba3 commit 4b1e557

File tree

1 file changed

+56
-23
lines changed

1 file changed

+56
-23
lines changed

src/passes/DeadArgumentElimination.cpp

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ struct DAE : public Pass {
189189

190190
bool optimize = false;
191191

192+
Index numFunctions;
193+
194+
// Map of function names to indexes. This lets us use indexes below for speed.
195+
std::unordered_map<Name, Index> indexes;
196+
192197
void run(Module* module) override {
193198
DAEFunctionInfoMap infoMap;
194199
// Ensure all entries exist so the parallel threads don't modify the data
@@ -199,6 +204,12 @@ struct DAE : public Pass {
199204
// The null name represents module-level code (not in a function).
200205
infoMap[Name()];
201206

207+
numFunctions = module->functions.size();
208+
209+
for (Index i = 0; i < numFunctions; i++) {
210+
indexes[module->functions[i]->name] = i;
211+
}
212+
202213
// Iterate to convergence.
203214
while (1) {
204215
if (!iteration(module, infoMap)) {
@@ -234,34 +245,36 @@ struct DAE : public Pass {
234245
Call* call;
235246
Function* func;
236247
};
237-
std::map<Name, std::vector<Call*>> allCalls;
238-
std::unordered_set<Name> tailCallees;
239-
std::unordered_set<Name> hasUnseenCalls;
248+
249+
std::vector<std::vector<Call*>> allCalls(numFunctions);
250+
std::vector<bool> tailCallees(numFunctions);
251+
std::vector<bool> hasUnseenCalls(numFunctions);
252+
240253
// Track the function in which relevant expressions exist. When we modify
241254
// those expressions we will need to mark the function's info as stale.
242255
std::unordered_map<Expression*, Name> expressionFuncs;
243256
for (auto& [func, info] : infoMap) {
244257
for (auto& [name, calls] : info.calls) {
245-
auto& allCallsToName = allCalls[name];
258+
auto& allCallsToName = allCalls[indexes[name]];
246259
allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
247260
for (auto* call : calls) {
248261
expressionFuncs[call] = func;
249262
}
250263
}
251264
for (auto& callee : info.tailCallees) {
252-
tailCallees.insert(callee);
265+
tailCallees[indexes[callee]] = true;
253266
}
254267
for (auto& [call, dropp] : info.droppedCalls) {
255268
allDroppedCalls[call] = dropp;
256269
}
257270
for (auto& name : info.hasUnseenCalls) {
258-
hasUnseenCalls.insert(name);
271+
hasUnseenCalls[indexes[name]] = true;
259272
}
260273
}
261274
// Exports are considered unseen calls.
262275
for (auto& curr : module->exports) {
263276
if (curr->kind == ExternalKind::Function) {
264-
hasUnseenCalls.insert(*curr->getInternalName());
277+
hasUnseenCalls[indexes[*curr->getInternalName()]] = true;
265278
}
266279
}
267280

@@ -300,23 +313,32 @@ struct DAE : public Pass {
300313

301314
// We now have a mapping of all call sites for each function, and can look
302315
// for optimization opportunities.
303-
for (auto& [name, calls] : allCalls) {
316+
for (Index index = 0; index < numFunctions; index++) {
317+
auto* func = module->functions[index].get();
318+
if (func->imported()) {
319+
continue;
320+
}
304321
// We can only optimize if we see all the calls and can modify them.
305-
if (hasUnseenCalls.count(name)) {
322+
if (hasUnseenCalls[index]) {
323+
continue;
324+
}
325+
auto& calls = allCalls[index];
326+
if (calls.empty()) {
327+
// Nothing calls this, so it is not worth optimizing.
306328
continue;
307329
}
308-
auto* func = module->getFunction(name);
309330
// Refine argument types before doing anything else. This does not
310331
// affect whether an argument is used or not, it just refines the type
311332
// where possible.
333+
auto name = func->name;
312334
if (refineArgumentTypes(func, calls, module, infoMap[name])) {
313335
worthOptimizing.insert(func);
314336
markStale(func->name);
315337
}
316338
// Refine return types as well.
317339
if (refineReturnTypes(func, calls, module)) {
318340
refinedReturnTypes = true;
319-
markStale(func->name);
341+
markStale(name);
320342
markCallersStale(calls);
321343
}
322344
auto optimizedIndexes =
@@ -337,21 +359,29 @@ struct DAE : public Pass {
337359
ReFinalize().run(getPassRunner(), module);
338360
}
339361
// We now know which parameters are unused, and can potentially remove them.
340-
for (auto& [name, calls] : allCalls) {
341-
if (hasUnseenCalls.count(name)) {
362+
for (Index index = 0; index < numFunctions; index++) {
363+
auto* func = module->functions[index].get();
364+
if (func->imported()) {
365+
continue;
366+
}
367+
if (hasUnseenCalls[index]) {
342368
continue;
343369
}
344-
auto* func = module->getFunction(name);
345370
auto numParams = func->getNumParams();
346371
if (numParams == 0) {
347372
continue;
348373
}
374+
auto& calls = allCalls[index];
375+
if (calls.empty()) {
376+
continue;
377+
}
378+
auto name = func->name;
349379
auto [removedIndexes, outcome] = ParamUtils::removeParameters(
350380
{func}, infoMap[name].unusedParams, calls, {}, module, getPassRunner());
351381
if (!removedIndexes.empty()) {
352382
// Success!
353383
worthOptimizing.insert(func);
354-
markStale(func->name);
384+
markStale(name);
355385
markCallersStale(calls);
356386
}
357387
if (outcome == ParamUtils::RemovalOutcome::Failure) {
@@ -363,25 +393,28 @@ struct DAE : public Pass {
363393
// modified allCalls (we can't modify a call site twice in one iteration,
364394
// once to remove a param, once to drop the return value).
365395
if (worthOptimizing.empty()) {
366-
for (auto& func : module->functions) {
396+
for (Index index = 0; index < numFunctions; index++) {
397+
auto& func = module->functions[index];
398+
if (func->imported()) {
399+
continue;
400+
}
367401
if (func->getResults() == Type::none) {
368402
continue;
369403
}
370-
auto name = func->name;
371-
if (hasUnseenCalls.count(name)) {
404+
if (hasUnseenCalls[index]) {
372405
continue;
373406
}
407+
auto name = func->name;
374408
if (infoMap[name].hasTailCalls) {
375409
continue;
376410
}
377-
if (tailCallees.count(name)) {
411+
if (tailCallees[index]) {
378412
continue;
379413
}
380-
auto iter = allCalls.find(name);
381-
if (iter == allCalls.end()) {
414+
auto& calls = allCalls[index];
415+
if (calls.empty()) {
382416
continue;
383417
}
384-
auto& calls = iter->second;
385418
bool allDropped =
386419
std::all_of(calls.begin(), calls.end(), [&](Call* call) {
387420
return allDroppedCalls.count(call);
@@ -398,7 +431,7 @@ struct DAE : public Pass {
398431
// TODO Removing a drop may also open optimization opportunities in the
399432
// callers.
400433
worthOptimizing.insert(func.get());
401-
markStale(func->name);
434+
markStale(name);
402435
markCallersStale(calls);
403436
}
404437
}

0 commit comments

Comments
 (0)