From 2c146b6ba213fbed5bcf3da2995c49551117916e Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Tue, 2 Dec 2025 22:24:17 -0800 Subject: [PATCH] [WIP] Rewrite DAE to use a fixed point analysis DAE can be slow because it performs several rounds of interleaved analysis and optimization. On top of this, the analysis it performs is not as precise as it could be because it never removes parameters from referenced functions and it cannot optimize unused parameters or results that are forwarded through recursive cycles. Start improving both the performance and the power of DAE by creating a new pass, called DAE2 for now. DAE2 performs a single parallel walk of the module to collect information with which it performs a fixed point analysis to find unused parameters, then does a single parallel walk of the module to optimize based on this analysis. --- src/passes/CMakeLists.txt | 1 + src/passes/DeadArgumentElimination2.cpp | 342 ++++++++++++++++++++++++ src/passes/pass.cpp | 1 + src/passes/passes.h | 1 + 4 files changed, 345 insertions(+) create mode 100644 src/passes/DeadArgumentElimination2.cpp diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index e9818681a56..d14644cc53c 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -28,6 +28,7 @@ set(passes_SOURCES ConstHoisting.cpp DataFlowOpts.cpp DeadArgumentElimination.cpp + DeadArgumentElimination2.cpp DeadCodeElimination.cpp DeAlign.cpp DebugLocationPropagation.cpp diff --git a/src/passes/DeadArgumentElimination2.cpp b/src/passes/DeadArgumentElimination2.cpp new file mode 100644 index 00000000000..d4f4fc8476e --- /dev/null +++ b/src/passes/DeadArgumentElimination2.cpp @@ -0,0 +1,342 @@ +/* + * Copyright 2025 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// As a POC, only do the backward analyis to find unused parameters, including +// those that appear to be used because they are forwarded on to another call +// but are then unused by that call. +// +// To match and exceed the power of DAE, we will need to extend this backward +// analysis to find unused results as well, and also add a forward analysis that +// propagates constants and types through parameters and results. + +#include +#include +#include + +#include "analysis/lattices/bool.h" +#include "ir/local-graph.h" +#include "pass.h" +#include "support/index.h" +#include "support/utilities.h" +#include "wasm-traversal.h" +#include "wasm.h" + +namespace wasm { + +namespace { + +// Analysis lattice: top/true = used, bot/false = unused. +using Used = analysis::Bool; + +// Function index and parameter index. +using ParamLoc = std::pair; + +// A set of (source, destination) index pairs for parameters of a caller +// function being forwarded as arguments to a called function. +using ForwardedParamSet = std::unordered_set>; + +struct FunctionInfo { + // Analysis results. + // TODO: Fix Bool to wrap its element in a struct so we can store it directly + // in a vector without getting the bool overload. + std::vector> paramUsages; + + // Map callee function names to their forwarded params for direct calls. + std::unordered_map directForwardedParams; + + // Map callee types to their forwarded params for indirect calls. + std::unordered_map indirectForwardedParams; + + // For each parameter of this function, the list of parameters in direct + // callers that will become used if the parameter in this function turns out + // to be used. Computed by reversing the directForwardedParams graph. + std::vector> callerParams; + + // Whether we need to additionally propagate param usage to indirect callers + // of this function's type. Atomic because it can be set when visiting other + // functions in parallel. + std::atomic referenced = false; +}; + +struct GraphBuilder : public WalkerPass> { + // Analysis lattice. + const Used& used; + + // The function info graph is stored as vectors accessed by function index. + // Map function names to their indices. + const std::unordered_map& funcIndices; + + // Vector of analysis info representing the analysis graph we are building. + // This is populated safely in parallel because the visitor for each function + // only modifies the entry for that function. + std::vector& funcInfos; + + // The index of the function we are currently walking. + Index index = -1; + + // A use of a parameter local does not necessarily imply the use of the + // parameter value. We use a local graph to check where parameter values may + // be used. + std::optional localGraph; + + GraphBuilder(const Used& used, + const std::unordered_map& funcIndices, + std::vector& funcInfos) + : used(used), funcIndices(funcIndices), funcInfos(funcInfos) {} + + bool isFunctionParallel() override { return true; } + bool modifiesBinaryenIR() override { return false; } + + std::unique_ptr create() override { + return std::make_unique(used, funcIndices, funcInfos); + } + + void runOnFunction(Module* wasm, Function* func) override { + assert(index == Index(-1)); + index = funcIndices.at(func->name); + assert(index < funcInfos.size()); + if (func->imported()) { + // We must assume imported functions use all their parameters. + auto& usages = funcInfos[index].paramUsages; + assert(usages.empty()); + usages.insert(usages.end(), func->getNumParams(), used.getTop()); + } else { + localGraph.emplace(func); + using Super = WalkerPass>; + Super::runOnFunction(wasm, func); + } + } + + void visitRefFunc(RefFunc* curr) { + funcInfos[funcIndices.at(curr->func)].referenced = true; + } + + Index getArgIndex(const ExpressionList& operands, Expression* arg) { + for (Index i = 0; i < operands.size(); ++i) { + if (operands[i] == arg) { + return i; + } + } + WASM_UNREACHABLE("expected arg"); + } + + void handleDirectForwardedParam(LocalGet* curr, Call* call) { + auto argIndex = getArgIndex(call->operands, curr); + auto& forwarded = funcInfos[index].directForwardedParams[call->target]; + forwarded.insert({curr->index, argIndex}); + } + + void handleIndirectForwardedParam(LocalGet* curr, + const ExpressionList& operands, + HeapType type) { + auto argIndex = getArgIndex(operands, curr); + auto& forwarded = funcInfos[index].indirectForwardedParams[type]; + forwarded.insert({curr->index, argIndex}); + } + + void visitLocalGet(LocalGet* curr) { + if (curr->index >= getFunction()->getNumParams()) { + // Not a parameter. + return; + } + + const auto& sets = localGraph->getSets(curr); + bool usesParam = std::any_of( + sets.begin(), sets.end(), [](LocalSet* set) { return set == nullptr; }); + + if (!usesParam) { + // The original parameter value does not reach here. + return; + } + + auto* parent = getParent(); + if (auto* call = parent->dynCast()) { + handleDirectForwardedParam(curr, call); + } else if (auto* call = parent->dynCast()) { + handleIndirectForwardedParam(curr, call->operands, call->heapType); + } else if (auto* call = parent->dynCast()) { + if (!call->target->type.isSignature()) { + // The call will never happen, so we don't need to consider it. + return; + } + auto heapType = call->target->type.getHeapType(); + handleIndirectForwardedParam(curr, call->operands, heapType); + } else { + // The parameter value is used by something other than a call. We could + // check whether the user is a drop, but for simplicity we assume that + // Vacuum would have already removed such patterns. + funcInfos[index].paramUsages[curr->index] = used.getTop(); + } + } +}; + +struct DAE2 : public Pass { + // Analysis lattice. + Used used; + + // Map function name to index. + std::unordered_map funcIndices; + + // The intermediate and final analysis results by function index. + std::vector funcInfos; + + // For each parameter in each indirectly called type, the set of forwarded + // params in the callers that need to be marked used if a param of a callee of + // that type is used. + std::unordered_map>> + indirectCallerParams; + + Module* wasm = nullptr; + + void run(Module* wasm) override { + this->wasm = wasm; + for (auto& func : wasm->functions) { + funcIndices.insert({func->name, funcIndices.size()}); + } + analyzeModule(wasm); + prepareAnalysis(); + computeFixedPoint(); + optimize(); + } + + void analyzeModule(Module* wasm) { + funcInfos.resize(wasm->functions.size()); + + // Analyze functions to find forwarded and used parameters as well as + // function references. + GraphBuilder builder(used, funcIndices, funcInfos); + builder.run(getPassRunner(), wasm); + + // Find additional function references at the module level. + builder.walkModuleCode(wasm); + + // Mark parameters of exported functions as used. + for (auto& export_ : wasm->exports) { + if (export_->kind == ExternalKind::Function) { + auto name = *export_->getInternalName(); + auto& usages = funcInfos[funcIndices.at(name)].paramUsages; + std::fill(usages.begin(), usages.end(), used.getTop()); + } + } + + // TODO: Find function types that escape the module beyond exported + // functions (or just use all public function types as a conservative + // approximation) and mark parameters of referenced funtions of those types + // as used. + } + + void prepareAnalysis() { + // Compute the reverse graph used by the fixed point analysis from the + // forward graph we have built. + for (Index i = 0; i < funcInfos.size(); ++i) { + funcInfos[i].callerParams.resize(funcInfos[i].paramUsages.size()); + } + for (Index callerIndex = 0; callerIndex < funcInfos.size(); ++callerIndex) { + for (auto& [callee, forwarded] : + funcInfos[callerIndex].directForwardedParams) { + auto& callerParams = funcInfos[funcIndices.at(callee)].callerParams; + for (auto& [srcParam, destParam] : forwarded) { + callerParams[destParam].push_back({callerIndex, srcParam}); + } + } + for (auto& [calleeType, forwarded] : + funcInfos[callerIndex].indirectForwardedParams) { + auto& callerParams = indirectCallerParams[calleeType]; + callerParams.resize(funcInfos[callerIndex].paramUsages.size()); + for (auto& [srcParam, destParam] : forwarded) { + callerParams[destParam].push_back({callerIndex, srcParam}); + } + } + } + } + + bool join(ParamLoc loc, const Used::Element& other) { + auto& elem = std::get<0>(funcInfos[loc.first].paramUsages[loc.second]); + return used.join(elem, other); + } + + void computeFixedPoint() { + // List of params from which we may need to propagate usage information. + // Initialized with all params we have observed to be used in the IR. + std::vector work; + for (Index i = 0; i < funcInfos.size(); ++i) { + for (Index j = 0; j < funcInfos[i].paramUsages.size(); ++j) { + work.push_back({i, j}); + } + } + while (!work.empty()) { + auto [calleeIndex, calleeParamIndex] = work.back(); + work.pop_back(); + + const auto& elem = + std::get<0>(funcInfos[calleeIndex].paramUsages[calleeParamIndex]); + assert(elem && "unexpected unused param"); + + // Propagate back to forwarded params of direct callers. + auto& callerParams = + funcInfos[calleeIndex].callerParams[calleeParamIndex]; + for (auto param : callerParams) { + if (join(param, elem)) { + work.push_back(param); + } + } + + if (!funcInfos[calleeIndex].referenced) { + // Non-referenced functions can only be called directly. + continue; + } + + // Propagate usage back to forwarded params of the indirect callers of all + // supertypes of this function's type. + for (std::optional type = + wasm->functions[calleeIndex]->type.getHeapType(); + type; + type = type->getDeclaredSuperType()) { + auto it = indirectCallerParams.find(*type); + if (it == indirectCallerParams.end()) { + continue; + } + auto& callerParams = it->second[calleeParamIndex]; + for (auto param : callerParams) { + if (join(param, elem)) { + work.push_back(param); + } + } + } + + // TODO: Propagate usage to all functions of any type in the type tree of + // this function's type to keep subtyping valid. + } + } + + void optimize() { + struct Optimizer : public WalkerPass> { + // TODO: Visit functions in parallel, replacing unused parameters with + // locals. Direct calls should look at their target to determine which + // operands to remove (being sure to preserve side effects using + // ChildLocalizer). Indirect calls need to look at the analysis results + // for the target type (TODO: materialize this, possibly just for the root + // type for each type tree) to determine what operands to remove. + }; + Optimizer{}.run(getPassRunner(), wasm); + } +}; + +} // anonymous namespace + +Pass* createDAE2Pass() { return new DAE2(); } + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 56a07ca2d44..16cb93b9c72 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -105,6 +105,7 @@ void PassRegistry::registerPasses() { "removes arguments to calls in an lto-like manner, and " "optimizes where we removed", createDAEOptimizingPass); + registerPass("dae2", "Experimental reimplementation of DAE", createDAE2Pass); registerPass("abstract-type-refining", "refine and merge abstract (never-created) types", createAbstractTypeRefiningPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index e8223e0bac8..d27189ec8d4 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -35,6 +35,7 @@ Pass* createConstantFieldPropagationPass(); Pass* createConstantFieldPropagationRefTestPass(); Pass* createDAEPass(); Pass* createDAEOptimizingPass(); +Pass* createDAE2Pass(); Pass* createDataFlowOptsPass(); Pass* createDeadCodeEliminationPass(); Pass* createDeInstrumentBranchHintsPass();