Skip to content

Commit 481aa6e

Browse files
authored
[Docs] Add operator fusion architecture documentation (#19394)
as per title
1 parent e22be4e commit 481aa6e

2 files changed

Lines changed: 400 additions & 3 deletions

File tree

docs/arch/fusion.rst

Lines changed: 391 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
.. _fusion-arch:
19+
20+
Operator Fusion
21+
===============
22+
23+
Operator fusion is one of the most impactful optimizations in TVM. Instead of launching one kernel
24+
per operator (e.g., conv2d, bias_add, relu), fusion merges multiple operators into a single kernel,
25+
eliminating intermediate memory allocations and kernel launch overhead.
26+
27+
TVM provides two complementary fusion mechanisms:
28+
29+
- **Automatic fusion** (``FuseOps`` + ``FuseTIR``): groups operators based on their computational
30+
patterns using a post-dominator analysis algorithm.
31+
- **Pattern-based fusion** (``FuseOpsByPattern``): groups operators that match user-defined
32+
dataflow patterns, typically for offloading to external backends (cuBLAS, CUTLASS, DNNL, etc.).
33+
34+
Both produce the same output: Relax functions marked with ``Primitive=True`` that are later
35+
lowered to fused TIR kernels or dispatched to external libraries.
36+
37+
Overview
38+
--------
39+
40+
Fusion involves three passes:
41+
42+
.. code-block:: text
43+
44+
IRModule (after LegalizeOps)
45+
46+
▼ AnnotateTIROpPattern ← label each op (elementwise, reduce, etc.)
47+
IRModule (annotated)
48+
49+
▼ FuseOps ← group ops into fused Relax functions
50+
IRModule (with fused functions marked Primitive=True)
51+
52+
▼ FuseTIR ← merge TIR PrimFuncs inside each group
53+
IRModule (fused TIR kernels)
54+
55+
In the compilation pipeline, these passes appear in the backend-specific ``legalize_passes``
56+
phase. For example, the CUDA pipeline (``python/tvm/relax/backend/cuda/pipeline.py``) runs:
57+
58+
.. code-block:: python
59+
60+
LegalizeOps() # lower Relax ops to call_tir
61+
AnnotateTIROpPattern() # annotate pattern kinds
62+
FoldConstant()
63+
FuseOps() # group ops
64+
FuseTIR() # merge TIR functions
65+
66+
67+
Operator Pattern Classification
68+
-------------------------------
69+
70+
Before fusion, ``AnnotateTIROpPattern`` analyzes each TIR function in the module and assigns
71+
an ``OpPatternKind``. The fusion algorithm uses these pattern kinds to decide which operators
72+
can be fused together.
73+
74+
.. list-table::
75+
:header-rows: 1
76+
:widths: 20 10 70
77+
78+
* - Pattern Kind
79+
- Value
80+
- Description
81+
* - ``kElemWise``
82+
- 0
83+
- Elementwise: one-to-one input/output mapping (e.g., ``add``, ``relu``, ``exp``).
84+
* - ``kBroadcast``
85+
- 1
86+
- Broadcasting: output axes map to input axes in order, but some input axes may be
87+
broadcast (e.g., ``bias_add``). Note: ``transpose`` is **not** broadcast because axes
88+
are reordered.
89+
* - ``kInjective``
90+
- 2
91+
- Injective: each output element depends on a single input element, but the mapping may
92+
be non-trivial (e.g., ``reshape``, ``concatenate``, ``transpose``).
93+
* - ``kCommReduce``
94+
- 3
95+
- Communicative reduction: output elements aggregate over input elements
96+
(e.g., ``sum``, ``max``, ``mean``).
97+
* - ``kOutEWiseFusable``
98+
- 4
99+
- Complex operation whose output can accept elementwise followers, but cannot chain
100+
with another complex op (e.g., ``conv2d``, ``matmul``, ``dense``).
101+
* - ``kTuple``
102+
- 7
103+
- Tuple node. Can fuse into subsequent injective ops but is treated specially.
104+
* - ``kOpaque``
105+
- 8
106+
- Opaque: cannot be fused (e.g., external function calls, operations with side effects).
107+
108+
These kinds form an ordering: lower values are "simpler" and more fusable. The fusion algorithm
109+
uses ``CombinePattern(lhs, rhs) = max(lhs, rhs)`` when merging patterns along a path.
110+
111+
112+
FuseOps: Automatic Fusion
113+
-------------------------
114+
115+
``FuseOps`` (``src/relax/transform/fuse_ops.cc``) groups bindings in a dataflow block into
116+
new Relax functions. It operates only within ``DataflowBlock``\ s — if your module doesn't have
117+
any, run ``ConvertToDataflow`` first.
118+
119+
Algorithm
120+
~~~~~~~~~
121+
122+
The fusion algorithm addresses diamond-shaped dataflow branches, where a single producer
123+
(e.g., conv2d) has multiple consumers that eventually reconverge:
124+
125+
.. code-block:: text
126+
127+
conv2d
128+
/ | \
129+
/ | \
130+
op op op
131+
\ | /
132+
\ | /
133+
elemwise add
134+
135+
At the point of ``conv2d``, we don't know if all future paths will merge. The algorithm uses
136+
**post-dominator analysis** to resolve this:
137+
138+
1. **Build forward graph**: construct an ``IndexedForwardGraph`` from the dataflow block.
139+
Each node has an ``OpPatternKind`` and a list of forward edges.
140+
141+
2. **Build post-dominator tree**: compute the immediate post-dominator of each node using
142+
Least Common Ancestor (LCA) on the DAG. The post-dominator of a node is the closest
143+
downstream node where **all** future paths converge.
144+
145+
3. **Fuse groups**: for each node in topological order, check if it can be fused with its
146+
immediate post-dominator:
147+
148+
- **CheckPath**: verify that all paths from the node to its post-dominator satisfy the
149+
fusion conditions (pattern compatibility, depth limits, argument limits).
150+
- **CommitFuse**: mark all intermediate nodes as belonging to the same group using a
151+
Union-Find data structure.
152+
153+
4. **Create grouped functions**: extract each group into a new ``relax.Function`` with the
154+
attribute ``Primitive=True``. Replace the original bindings with a call to the grouped
155+
function.
156+
157+
Fusion rules
158+
~~~~~~~~~~~~
159+
160+
The key fusion decisions depend on the ``OpPatternKind`` of the source, the path, and the
161+
post-dominator. The algorithm runs in three phases (via ``GraphPartitioner::RunFuse``) so that
162+
higher-complexity ops get a chance to fuse first:
163+
164+
- **Phase 0**: ``kOutEWiseFusable`` ops (e.g., ``conv2d``) can fuse with their elementwise
165+
post-dominator if all intermediate ops are broadcast or simpler. This enables patterns like
166+
conv2d + bias_add + relu. Two ``kOutEWiseFusable`` ops cannot fuse together.
167+
- **Phase 1**: ``kInjective`` and ``kTuple`` ops can fuse only when all paths to the
168+
post-dominator are injective or simpler. This is deferred to phase 1 so that
169+
``kOutEWiseFusable`` groups are finalized first.
170+
- **Phase 2**: fuse injective ops into intermediate tuple nodes that have already been absorbed
171+
by subsequent injective groups.
172+
173+
``kElemWise`` / ``kBroadcast`` ops are processed in **every** phase (not restricted to one):
174+
they can fuse into a post-dominator that is injective or reduction. The sink (final node) may
175+
also be a ``kOutEWiseFusable`` group that was formed in phase 0 — this is how elementwise
176+
producers merge into an existing conv2d fusion group.
177+
178+
Additional constraints:
179+
180+
- **Reduction** (``kCommReduce``) ops never initiate fusion — they act as sinks only. Elementwise
181+
and broadcast producers can fuse *into* a reduction, but a reduction cannot fuse forward.
182+
- **Opaque** ops are fusion barriers.
183+
- A group cannot exceed ``kMaxFusedOps`` (256) nodes or the maximum function argument count.
184+
185+
Example
186+
~~~~~~~
187+
188+
Given two elementwise ops (``add``, ``exp``) and one injective op (``squeeze``).
189+
The examples below are simplified pseudocode — real TVMScript would reference TIR functions
190+
via ``cls.func_name``:
191+
192+
.. code-block:: python
193+
194+
# Before FuseOps (simplified)
195+
@R.function
196+
def main(x: R.Tensor((10, 20), "float32")):
197+
with R.dataflow():
198+
lv0 = R.call_tir(add, (x, const_1), out_sinfo=R.Tensor((10, 20), "float32"))
199+
lv1 = R.call_tir(exp, (lv0,), out_sinfo=R.Tensor((10, 20), "float32"))
200+
gv = R.call_tir(squeeze, (lv1,), out_sinfo=R.Tensor((10, 20), "float32"))
201+
R.output(gv)
202+
return gv
203+
204+
After ``FuseOps``, all three are grouped into a single function:
205+
206+
.. code-block:: python
207+
208+
# After FuseOps
209+
@R.function(private=True)
210+
def fused_add_exp_squeeze(x, p0):
211+
R.func_attr({"Primitive": True})
212+
with R.dataflow():
213+
lv0 = R.call_tir(add, (x, p0), ...)
214+
lv1 = R.call_tir(exp, (lv0,), ...)
215+
gv = R.call_tir(squeeze, (lv1,), ...)
216+
R.output(gv)
217+
return gv
218+
219+
@R.function
220+
def main(x: R.Tensor((10, 20), "float32")):
221+
with R.dataflow():
222+
gv = fused_add_exp_squeeze(x, const_1)
223+
R.output(gv)
224+
return gv
225+
226+
227+
FuseTIR: Merging TIR Functions
228+
------------------------------
229+
230+
``FuseTIR`` (``src/relax/transform/fuse_tir.cc``) takes the grouped Relax functions produced by
231+
``FuseOps`` and merges their internal TIR ``PrimFunc``\ s into a single TIR function.
232+
233+
Before ``FuseTIR``, a fused group still contains multiple ``R.call_tir`` calls to separate
234+
TIR functions. ``FuseTIR`` inlines and merges them:
235+
236+
.. code-block:: text
237+
238+
Before FuseTIR:
239+
fused_add_exp_squeeze:
240+
call_tir(add, ...) → separate TIR PrimFunc
241+
call_tir(exp, ...) → separate TIR PrimFunc
242+
call_tir(squeeze, ...) → separate TIR PrimFunc
243+
244+
After FuseTIR:
245+
fused_add_exp_squeeze: → single merged TIR PrimFunc
246+
247+
The merged function eliminates intermediate buffers — the output of ``add`` is directly consumed
248+
by ``exp`` without writing to and reading from global memory. This is the core performance benefit
249+
of fusion.
250+
251+
Internally, ``FuseTIR`` uses a ``SymbolicMatcher`` to align symbolic shape variables across the
252+
TIR functions being merged, ensuring that dimensions are correctly mapped when combining buffer
253+
accesses.
254+
255+
256+
FuseOpsByPattern: Pattern-Based Fusion
257+
--------------------------------------
258+
259+
While ``FuseOps`` makes fusion decisions automatically based on operator patterns,
260+
``FuseOpsByPattern`` lets you specify exactly which operator combinations to fuse using
261+
the Relax :ref:`Dataflow Pattern Language (DPL) <relax-dpl>`.
262+
263+
This is primarily used for **backend-specific dispatch**: identifying operator subgraphs that
264+
should be offloaded to external libraries like cuBLAS, CUTLASS, cuDNN, or DNNL.
265+
266+
FusionPattern
267+
~~~~~~~~~~~~~
268+
269+
A ``FusionPattern`` (``python/tvm/relax/transform/transform.py``) defines what to match:
270+
271+
.. code-block:: python
272+
273+
from tvm.relax.dpl import wildcard, is_op
274+
from tvm.relax.transform import FusionPattern
275+
276+
# Match: matmul(x, w) + bias
277+
x = wildcard()
278+
w = wildcard()
279+
bias = wildcard()
280+
matmul = is_op("relax.matmul")(x, w)
281+
out = is_op("relax.add")(matmul, bias)
282+
283+
pattern = FusionPattern(
284+
name="cutlass.matmul_bias",
285+
pattern=out,
286+
annotation_patterns={"matmul": matmul, "bias": bias},
287+
check=my_check_function, # optional validation
288+
)
289+
290+
Fields:
291+
292+
- ``name``: pattern identifier, typically prefixed with the backend name (e.g.,
293+
``"cutlass.matmul_bias"``).
294+
- ``pattern``: a DFPattern describing the subgraph to match. See the
295+
:ref:`DPL deep dive <relax-dpl>` for the full pattern language.
296+
- ``annotation_patterns``: a mapping of names to sub-patterns within the main pattern. These
297+
are extracted during matching and made available to the ``check`` function and
298+
``attrs_getter``.
299+
- ``check``: an optional ``Callable[[PatternCheckContext], bool]`` that validates whether
300+
a match should be accepted. Receives the matched expression, annotated sub-expressions,
301+
variable usages, and binding information.
302+
- ``attrs_getter``: an optional function that extracts attributes (e.g., transpose flags,
303+
data types) from the matched expressions to annotate the grouped function.
304+
305+
Applying patterns
306+
~~~~~~~~~~~~~~~~~
307+
308+
.. code-block:: python
309+
310+
from tvm.relax.transform import FuseOpsByPattern
311+
312+
mod = FuseOpsByPattern(
313+
patterns=[pattern1, pattern2, ...], # ordered by priority
314+
bind_constants=True,
315+
annotate_codegen=False,
316+
)(mod)
317+
318+
Key parameters:
319+
320+
- ``patterns``: a list of ``FusionPattern`` objects, ordered by priority. Higher-priority
321+
patterns come first — if a subgraph matches multiple patterns, the first match wins.
322+
- ``bind_constants``: if ``True``, constants used by the matched subgraph are captured inside
323+
the grouped function.
324+
- ``annotate_codegen``: if ``True``, wraps each composite function with an outer function
325+
annotated with ``"Codegen"`` and ``"global_symbol"`` attributes for external backend dispatch.
326+
The ``"Codegen"`` value is derived from the pattern name prefix (e.g., ``"dnnl"`` from
327+
``"dnnl.conv2d_relu"``).
328+
329+
PatternCheckContext
330+
~~~~~~~~~~~~~~~~~~~
331+
332+
The ``check`` function receives a ``PatternCheckContext`` with:
333+
334+
- ``matched_expr``: the root expression matched by the pattern.
335+
- ``annotated_expr``: a mapping from annotation pattern names to their matched expressions.
336+
- ``matched_bindings``: variable-to-value bindings within the matched subgraph.
337+
- ``var_usages``: a mapping from variable definitions to all their uses in the function.
338+
- ``value_to_bound_var``: reverse mapping from values to the variables they are bound to.
339+
340+
This context enables sophisticated validation logic, such as checking that an intermediate
341+
result is not used outside the fused group, or verifying data type compatibility.
342+
343+
344+
How Backends Use Fusion
345+
-----------------------
346+
347+
The default backend pipelines (CUDA, ROCm, CPU, etc.) all include ``FuseOps`` + ``FuseTIR``
348+
in their ``legalize_passes`` phase for automatic fusion. For example, the CUDA pipeline
349+
(``python/tvm/relax/backend/cuda/pipeline.py``) runs::
350+
351+
LegalizeOps → AnnotateTIROpPattern → FoldConstant → FuseOps → FuseTIR → DLight
352+
353+
For external library dispatch (cuBLAS, CUTLASS, cuDNN, DNNL), ``FuseOpsByPattern`` is used
354+
separately. These are **not** included in the default pipeline — users add them explicitly
355+
when building a custom compilation flow. The typical sequence is:
356+
357+
1. **Pattern-based dispatch** (``FuseOpsByPattern``): identify subgraphs that should be
358+
offloaded to external libraries. For example, CUTLASS patterns match
359+
matmul+bias+activation combinations (``python/tvm/relax/backend/cuda/cutlass.py``).
360+
Functions marked by patterns are annotated with ``Composite`` and optionally ``Codegen``
361+
attributes.
362+
363+
2. **Automatic fusion** (``FuseOps`` + ``FuseTIR``): remaining operators that were not
364+
matched by backend patterns are fused automatically based on their pattern kinds.
365+
366+
367+
Source Code Map
368+
---------------
369+
370+
.. list-table::
371+
:header-rows: 1
372+
:widths: 50 50
373+
374+
* - Path
375+
- Contents
376+
* - ``src/relax/transform/fuse_ops.cc``
377+
- FuseOps and FuseOpsByPattern implementation
378+
* - ``src/relax/analysis/graph_partitioner.h``
379+
- IndexedForwardGraph, DominatorTree, GraphPartitioner (Union-Find)
380+
* - ``src/relax/transform/fuse_tir.cc``
381+
- FuseTIR implementation, SymbolicMatcher
382+
* - ``include/tvm/relax/op_attr_types.h``
383+
- ``OpPatternKind`` enum definition
384+
* - ``python/tvm/relax/transform/transform.py``
385+
- Python API: FuseOps, FuseTIR, FuseOpsByPattern, FusionPattern
386+
* - ``python/tvm/relax/dpl/``
387+
- Dataflow Pattern Language (DFPattern, is_op, wildcard, etc.)
388+
* - ``python/tvm/relax/backend/cuda/cutlass.py``
389+
- Example: CUTLASS fusion patterns
390+
* - ``python/tvm/relax/backend/cuda/cublas.py``
391+
- Example: cuBLAS fusion patterns

0 commit comments

Comments
 (0)