|
| 1 | +.. Licensed to the Apache Software Foundation (ASF) under one |
| 2 | + or more contributor license agreements. See the NOTICE file |
| 3 | + distributed with this work for additional information |
| 4 | + regarding copyright ownership. The ASF licenses this file |
| 5 | + to you under the Apache License, Version 2.0 (the |
| 6 | + "License"); you may not use this file except in compliance |
| 7 | + with the License. You may obtain a copy of the License at |
| 8 | +
|
| 9 | +.. http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +
|
| 11 | +.. Unless required by applicable law or agreed to in writing, |
| 12 | + software distributed under the License is distributed on an |
| 13 | + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | + KIND, either express or implied. See the License for the |
| 15 | + specific language governing permissions and limitations |
| 16 | + under the License. |
| 17 | +
|
| 18 | +.. _fusion-arch: |
| 19 | + |
| 20 | +Operator Fusion |
| 21 | +=============== |
| 22 | + |
| 23 | +Operator fusion is one of the most impactful optimizations in TVM. Instead of launching one kernel |
| 24 | +per operator (e.g., conv2d, bias_add, relu), fusion merges multiple operators into a single kernel, |
| 25 | +eliminating intermediate memory allocations and kernel launch overhead. |
| 26 | + |
| 27 | +TVM provides two complementary fusion mechanisms: |
| 28 | + |
| 29 | +- **Automatic fusion** (``FuseOps`` + ``FuseTIR``): groups operators based on their computational |
| 30 | + patterns using a post-dominator analysis algorithm. |
| 31 | +- **Pattern-based fusion** (``FuseOpsByPattern``): groups operators that match user-defined |
| 32 | + dataflow patterns, typically for offloading to external backends (cuBLAS, CUTLASS, DNNL, etc.). |
| 33 | + |
| 34 | +Both produce the same output: Relax functions marked with ``Primitive=True`` that are later |
| 35 | +lowered to fused TIR kernels or dispatched to external libraries. |
| 36 | + |
| 37 | +Overview |
| 38 | +-------- |
| 39 | + |
| 40 | +Fusion involves three passes: |
| 41 | + |
| 42 | +.. code-block:: text |
| 43 | +
|
| 44 | + IRModule (after LegalizeOps) |
| 45 | + │ |
| 46 | + ▼ AnnotateTIROpPattern ← label each op (elementwise, reduce, etc.) |
| 47 | + IRModule (annotated) |
| 48 | + │ |
| 49 | + ▼ FuseOps ← group ops into fused Relax functions |
| 50 | + IRModule (with fused functions marked Primitive=True) |
| 51 | + │ |
| 52 | + ▼ FuseTIR ← merge TIR PrimFuncs inside each group |
| 53 | + IRModule (fused TIR kernels) |
| 54 | +
|
| 55 | +In the compilation pipeline, these passes appear in the backend-specific ``legalize_passes`` |
| 56 | +phase. For example, the CUDA pipeline (``python/tvm/relax/backend/cuda/pipeline.py``) runs: |
| 57 | + |
| 58 | +.. code-block:: python |
| 59 | +
|
| 60 | + LegalizeOps() # lower Relax ops to call_tir |
| 61 | + AnnotateTIROpPattern() # annotate pattern kinds |
| 62 | + FoldConstant() |
| 63 | + FuseOps() # group ops |
| 64 | + FuseTIR() # merge TIR functions |
| 65 | +
|
| 66 | +
|
| 67 | +Operator Pattern Classification |
| 68 | +------------------------------- |
| 69 | + |
| 70 | +Before fusion, ``AnnotateTIROpPattern`` analyzes each TIR function in the module and assigns |
| 71 | +an ``OpPatternKind``. The fusion algorithm uses these pattern kinds to decide which operators |
| 72 | +can be fused together. |
| 73 | + |
| 74 | +.. list-table:: |
| 75 | + :header-rows: 1 |
| 76 | + :widths: 20 10 70 |
| 77 | + |
| 78 | + * - Pattern Kind |
| 79 | + - Value |
| 80 | + - Description |
| 81 | + * - ``kElemWise`` |
| 82 | + - 0 |
| 83 | + - Elementwise: one-to-one input/output mapping (e.g., ``add``, ``relu``, ``exp``). |
| 84 | + * - ``kBroadcast`` |
| 85 | + - 1 |
| 86 | + - Broadcasting: output axes map to input axes in order, but some input axes may be |
| 87 | + broadcast (e.g., ``bias_add``). Note: ``transpose`` is **not** broadcast because axes |
| 88 | + are reordered. |
| 89 | + * - ``kInjective`` |
| 90 | + - 2 |
| 91 | + - Injective: each output element depends on a single input element, but the mapping may |
| 92 | + be non-trivial (e.g., ``reshape``, ``concatenate``, ``transpose``). |
| 93 | + * - ``kCommReduce`` |
| 94 | + - 3 |
| 95 | + - Communicative reduction: output elements aggregate over input elements |
| 96 | + (e.g., ``sum``, ``max``, ``mean``). |
| 97 | + * - ``kOutEWiseFusable`` |
| 98 | + - 4 |
| 99 | + - Complex operation whose output can accept elementwise followers, but cannot chain |
| 100 | + with another complex op (e.g., ``conv2d``, ``matmul``, ``dense``). |
| 101 | + * - ``kTuple`` |
| 102 | + - 7 |
| 103 | + - Tuple node. Can fuse into subsequent injective ops but is treated specially. |
| 104 | + * - ``kOpaque`` |
| 105 | + - 8 |
| 106 | + - Opaque: cannot be fused (e.g., external function calls, operations with side effects). |
| 107 | + |
| 108 | +These kinds form an ordering: lower values are "simpler" and more fusable. The fusion algorithm |
| 109 | +uses ``CombinePattern(lhs, rhs) = max(lhs, rhs)`` when merging patterns along a path. |
| 110 | + |
| 111 | + |
| 112 | +FuseOps: Automatic Fusion |
| 113 | +------------------------- |
| 114 | + |
| 115 | +``FuseOps`` (``src/relax/transform/fuse_ops.cc``) groups bindings in a dataflow block into |
| 116 | +new Relax functions. It operates only within ``DataflowBlock``\ s — if your module doesn't have |
| 117 | +any, run ``ConvertToDataflow`` first. |
| 118 | + |
| 119 | +Algorithm |
| 120 | +~~~~~~~~~ |
| 121 | + |
| 122 | +The fusion algorithm addresses diamond-shaped dataflow branches, where a single producer |
| 123 | +(e.g., conv2d) has multiple consumers that eventually reconverge: |
| 124 | + |
| 125 | +.. code-block:: text |
| 126 | +
|
| 127 | + conv2d |
| 128 | + / | \ |
| 129 | + / | \ |
| 130 | + op op op |
| 131 | + \ | / |
| 132 | + \ | / |
| 133 | + elemwise add |
| 134 | +
|
| 135 | +At the point of ``conv2d``, we don't know if all future paths will merge. The algorithm uses |
| 136 | +**post-dominator analysis** to resolve this: |
| 137 | + |
| 138 | +1. **Build forward graph**: construct an ``IndexedForwardGraph`` from the dataflow block. |
| 139 | + Each node has an ``OpPatternKind`` and a list of forward edges. |
| 140 | + |
| 141 | +2. **Build post-dominator tree**: compute the immediate post-dominator of each node using |
| 142 | + Least Common Ancestor (LCA) on the DAG. The post-dominator of a node is the closest |
| 143 | + downstream node where **all** future paths converge. |
| 144 | + |
| 145 | +3. **Fuse groups**: for each node in topological order, check if it can be fused with its |
| 146 | + immediate post-dominator: |
| 147 | + |
| 148 | + - **CheckPath**: verify that all paths from the node to its post-dominator satisfy the |
| 149 | + fusion conditions (pattern compatibility, depth limits, argument limits). |
| 150 | + - **CommitFuse**: mark all intermediate nodes as belonging to the same group using a |
| 151 | + Union-Find data structure. |
| 152 | + |
| 153 | +4. **Create grouped functions**: extract each group into a new ``relax.Function`` with the |
| 154 | + attribute ``Primitive=True``. Replace the original bindings with a call to the grouped |
| 155 | + function. |
| 156 | + |
| 157 | +Fusion rules |
| 158 | +~~~~~~~~~~~~ |
| 159 | + |
| 160 | +The key fusion decisions depend on the ``OpPatternKind`` of the source, the path, and the |
| 161 | +post-dominator. The algorithm runs in three phases (via ``GraphPartitioner::RunFuse``) so that |
| 162 | +higher-complexity ops get a chance to fuse first: |
| 163 | + |
| 164 | +- **Phase 0**: ``kOutEWiseFusable`` ops (e.g., ``conv2d``) can fuse with their elementwise |
| 165 | + post-dominator if all intermediate ops are broadcast or simpler. This enables patterns like |
| 166 | + conv2d + bias_add + relu. Two ``kOutEWiseFusable`` ops cannot fuse together. |
| 167 | +- **Phase 1**: ``kInjective`` and ``kTuple`` ops can fuse only when all paths to the |
| 168 | + post-dominator are injective or simpler. This is deferred to phase 1 so that |
| 169 | + ``kOutEWiseFusable`` groups are finalized first. |
| 170 | +- **Phase 2**: fuse injective ops into intermediate tuple nodes that have already been absorbed |
| 171 | + by subsequent injective groups. |
| 172 | + |
| 173 | +``kElemWise`` / ``kBroadcast`` ops are processed in **every** phase (not restricted to one): |
| 174 | +they can fuse into a post-dominator that is injective or reduction. The sink (final node) may |
| 175 | +also be a ``kOutEWiseFusable`` group that was formed in phase 0 — this is how elementwise |
| 176 | +producers merge into an existing conv2d fusion group. |
| 177 | + |
| 178 | +Additional constraints: |
| 179 | + |
| 180 | +- **Reduction** (``kCommReduce``) ops never initiate fusion — they act as sinks only. Elementwise |
| 181 | + and broadcast producers can fuse *into* a reduction, but a reduction cannot fuse forward. |
| 182 | +- **Opaque** ops are fusion barriers. |
| 183 | +- A group cannot exceed ``kMaxFusedOps`` (256) nodes or the maximum function argument count. |
| 184 | + |
| 185 | +Example |
| 186 | +~~~~~~~ |
| 187 | + |
| 188 | +Given two elementwise ops (``add``, ``exp``) and one injective op (``squeeze``). |
| 189 | +The examples below are simplified pseudocode — real TVMScript would reference TIR functions |
| 190 | +via ``cls.func_name``: |
| 191 | + |
| 192 | +.. code-block:: python |
| 193 | +
|
| 194 | + # Before FuseOps (simplified) |
| 195 | + @R.function |
| 196 | + def main(x: R.Tensor((10, 20), "float32")): |
| 197 | + with R.dataflow(): |
| 198 | + lv0 = R.call_tir(add, (x, const_1), out_sinfo=R.Tensor((10, 20), "float32")) |
| 199 | + lv1 = R.call_tir(exp, (lv0,), out_sinfo=R.Tensor((10, 20), "float32")) |
| 200 | + gv = R.call_tir(squeeze, (lv1,), out_sinfo=R.Tensor((10, 20), "float32")) |
| 201 | + R.output(gv) |
| 202 | + return gv |
| 203 | +
|
| 204 | +After ``FuseOps``, all three are grouped into a single function: |
| 205 | + |
| 206 | +.. code-block:: python |
| 207 | +
|
| 208 | + # After FuseOps |
| 209 | + @R.function(private=True) |
| 210 | + def fused_add_exp_squeeze(x, p0): |
| 211 | + R.func_attr({"Primitive": True}) |
| 212 | + with R.dataflow(): |
| 213 | + lv0 = R.call_tir(add, (x, p0), ...) |
| 214 | + lv1 = R.call_tir(exp, (lv0,), ...) |
| 215 | + gv = R.call_tir(squeeze, (lv1,), ...) |
| 216 | + R.output(gv) |
| 217 | + return gv |
| 218 | +
|
| 219 | + @R.function |
| 220 | + def main(x: R.Tensor((10, 20), "float32")): |
| 221 | + with R.dataflow(): |
| 222 | + gv = fused_add_exp_squeeze(x, const_1) |
| 223 | + R.output(gv) |
| 224 | + return gv |
| 225 | +
|
| 226 | +
|
| 227 | +FuseTIR: Merging TIR Functions |
| 228 | +------------------------------ |
| 229 | + |
| 230 | +``FuseTIR`` (``src/relax/transform/fuse_tir.cc``) takes the grouped Relax functions produced by |
| 231 | +``FuseOps`` and merges their internal TIR ``PrimFunc``\ s into a single TIR function. |
| 232 | + |
| 233 | +Before ``FuseTIR``, a fused group still contains multiple ``R.call_tir`` calls to separate |
| 234 | +TIR functions. ``FuseTIR`` inlines and merges them: |
| 235 | + |
| 236 | +.. code-block:: text |
| 237 | +
|
| 238 | + Before FuseTIR: |
| 239 | + fused_add_exp_squeeze: |
| 240 | + call_tir(add, ...) → separate TIR PrimFunc |
| 241 | + call_tir(exp, ...) → separate TIR PrimFunc |
| 242 | + call_tir(squeeze, ...) → separate TIR PrimFunc |
| 243 | +
|
| 244 | + After FuseTIR: |
| 245 | + fused_add_exp_squeeze: → single merged TIR PrimFunc |
| 246 | +
|
| 247 | +The merged function eliminates intermediate buffers — the output of ``add`` is directly consumed |
| 248 | +by ``exp`` without writing to and reading from global memory. This is the core performance benefit |
| 249 | +of fusion. |
| 250 | + |
| 251 | +Internally, ``FuseTIR`` uses a ``SymbolicMatcher`` to align symbolic shape variables across the |
| 252 | +TIR functions being merged, ensuring that dimensions are correctly mapped when combining buffer |
| 253 | +accesses. |
| 254 | + |
| 255 | + |
| 256 | +FuseOpsByPattern: Pattern-Based Fusion |
| 257 | +-------------------------------------- |
| 258 | + |
| 259 | +While ``FuseOps`` makes fusion decisions automatically based on operator patterns, |
| 260 | +``FuseOpsByPattern`` lets you specify exactly which operator combinations to fuse using |
| 261 | +the Relax :ref:`Dataflow Pattern Language (DPL) <relax-dpl>`. |
| 262 | + |
| 263 | +This is primarily used for **backend-specific dispatch**: identifying operator subgraphs that |
| 264 | +should be offloaded to external libraries like cuBLAS, CUTLASS, cuDNN, or DNNL. |
| 265 | + |
| 266 | +FusionPattern |
| 267 | +~~~~~~~~~~~~~ |
| 268 | + |
| 269 | +A ``FusionPattern`` (``python/tvm/relax/transform/transform.py``) defines what to match: |
| 270 | + |
| 271 | +.. code-block:: python |
| 272 | +
|
| 273 | + from tvm.relax.dpl import wildcard, is_op |
| 274 | + from tvm.relax.transform import FusionPattern |
| 275 | +
|
| 276 | + # Match: matmul(x, w) + bias |
| 277 | + x = wildcard() |
| 278 | + w = wildcard() |
| 279 | + bias = wildcard() |
| 280 | + matmul = is_op("relax.matmul")(x, w) |
| 281 | + out = is_op("relax.add")(matmul, bias) |
| 282 | +
|
| 283 | + pattern = FusionPattern( |
| 284 | + name="cutlass.matmul_bias", |
| 285 | + pattern=out, |
| 286 | + annotation_patterns={"matmul": matmul, "bias": bias}, |
| 287 | + check=my_check_function, # optional validation |
| 288 | + ) |
| 289 | +
|
| 290 | +Fields: |
| 291 | + |
| 292 | +- ``name``: pattern identifier, typically prefixed with the backend name (e.g., |
| 293 | + ``"cutlass.matmul_bias"``). |
| 294 | +- ``pattern``: a DFPattern describing the subgraph to match. See the |
| 295 | + :ref:`DPL deep dive <relax-dpl>` for the full pattern language. |
| 296 | +- ``annotation_patterns``: a mapping of names to sub-patterns within the main pattern. These |
| 297 | + are extracted during matching and made available to the ``check`` function and |
| 298 | + ``attrs_getter``. |
| 299 | +- ``check``: an optional ``Callable[[PatternCheckContext], bool]`` that validates whether |
| 300 | + a match should be accepted. Receives the matched expression, annotated sub-expressions, |
| 301 | + variable usages, and binding information. |
| 302 | +- ``attrs_getter``: an optional function that extracts attributes (e.g., transpose flags, |
| 303 | + data types) from the matched expressions to annotate the grouped function. |
| 304 | + |
| 305 | +Applying patterns |
| 306 | +~~~~~~~~~~~~~~~~~ |
| 307 | + |
| 308 | +.. code-block:: python |
| 309 | +
|
| 310 | + from tvm.relax.transform import FuseOpsByPattern |
| 311 | +
|
| 312 | + mod = FuseOpsByPattern( |
| 313 | + patterns=[pattern1, pattern2, ...], # ordered by priority |
| 314 | + bind_constants=True, |
| 315 | + annotate_codegen=False, |
| 316 | + )(mod) |
| 317 | +
|
| 318 | +Key parameters: |
| 319 | + |
| 320 | +- ``patterns``: a list of ``FusionPattern`` objects, ordered by priority. Higher-priority |
| 321 | + patterns come first — if a subgraph matches multiple patterns, the first match wins. |
| 322 | +- ``bind_constants``: if ``True``, constants used by the matched subgraph are captured inside |
| 323 | + the grouped function. |
| 324 | +- ``annotate_codegen``: if ``True``, wraps each composite function with an outer function |
| 325 | + annotated with ``"Codegen"`` and ``"global_symbol"`` attributes for external backend dispatch. |
| 326 | + The ``"Codegen"`` value is derived from the pattern name prefix (e.g., ``"dnnl"`` from |
| 327 | + ``"dnnl.conv2d_relu"``). |
| 328 | + |
| 329 | +PatternCheckContext |
| 330 | +~~~~~~~~~~~~~~~~~~~ |
| 331 | + |
| 332 | +The ``check`` function receives a ``PatternCheckContext`` with: |
| 333 | + |
| 334 | +- ``matched_expr``: the root expression matched by the pattern. |
| 335 | +- ``annotated_expr``: a mapping from annotation pattern names to their matched expressions. |
| 336 | +- ``matched_bindings``: variable-to-value bindings within the matched subgraph. |
| 337 | +- ``var_usages``: a mapping from variable definitions to all their uses in the function. |
| 338 | +- ``value_to_bound_var``: reverse mapping from values to the variables they are bound to. |
| 339 | + |
| 340 | +This context enables sophisticated validation logic, such as checking that an intermediate |
| 341 | +result is not used outside the fused group, or verifying data type compatibility. |
| 342 | + |
| 343 | + |
| 344 | +How Backends Use Fusion |
| 345 | +----------------------- |
| 346 | + |
| 347 | +The default backend pipelines (CUDA, ROCm, CPU, etc.) all include ``FuseOps`` + ``FuseTIR`` |
| 348 | +in their ``legalize_passes`` phase for automatic fusion. For example, the CUDA pipeline |
| 349 | +(``python/tvm/relax/backend/cuda/pipeline.py``) runs:: |
| 350 | + |
| 351 | + LegalizeOps → AnnotateTIROpPattern → FoldConstant → FuseOps → FuseTIR → DLight |
| 352 | + |
| 353 | +For external library dispatch (cuBLAS, CUTLASS, cuDNN, DNNL), ``FuseOpsByPattern`` is used |
| 354 | +separately. These are **not** included in the default pipeline — users add them explicitly |
| 355 | +when building a custom compilation flow. The typical sequence is: |
| 356 | + |
| 357 | +1. **Pattern-based dispatch** (``FuseOpsByPattern``): identify subgraphs that should be |
| 358 | + offloaded to external libraries. For example, CUTLASS patterns match |
| 359 | + matmul+bias+activation combinations (``python/tvm/relax/backend/cuda/cutlass.py``). |
| 360 | + Functions marked by patterns are annotated with ``Composite`` and optionally ``Codegen`` |
| 361 | + attributes. |
| 362 | + |
| 363 | +2. **Automatic fusion** (``FuseOps`` + ``FuseTIR``): remaining operators that were not |
| 364 | + matched by backend patterns are fused automatically based on their pattern kinds. |
| 365 | + |
| 366 | + |
| 367 | +Source Code Map |
| 368 | +--------------- |
| 369 | + |
| 370 | +.. list-table:: |
| 371 | + :header-rows: 1 |
| 372 | + :widths: 50 50 |
| 373 | + |
| 374 | + * - Path |
| 375 | + - Contents |
| 376 | + * - ``src/relax/transform/fuse_ops.cc`` |
| 377 | + - FuseOps and FuseOpsByPattern implementation |
| 378 | + * - ``src/relax/analysis/graph_partitioner.h`` |
| 379 | + - IndexedForwardGraph, DominatorTree, GraphPartitioner (Union-Find) |
| 380 | + * - ``src/relax/transform/fuse_tir.cc`` |
| 381 | + - FuseTIR implementation, SymbolicMatcher |
| 382 | + * - ``include/tvm/relax/op_attr_types.h`` |
| 383 | + - ``OpPatternKind`` enum definition |
| 384 | + * - ``python/tvm/relax/transform/transform.py`` |
| 385 | + - Python API: FuseOps, FuseTIR, FuseOpsByPattern, FusionPattern |
| 386 | + * - ``python/tvm/relax/dpl/`` |
| 387 | + - Dataflow Pattern Language (DFPattern, is_op, wildcard, etc.) |
| 388 | + * - ``python/tvm/relax/backend/cuda/cutlass.py`` |
| 389 | + - Example: CUTLASS fusion patterns |
| 390 | + * - ``python/tvm/relax/backend/cuda/cublas.py`` |
| 391 | + - Example: cuBLAS fusion patterns |
0 commit comments