Skip to content

Commit fb1618e

Browse files
authored
Disable im2row transform by default, add a flag to enable it (pytorch#18693)
Differential Revision: D97356070 Pull Request resolved: pytorch#18693
1 parent 6020c29 commit fb1618e

4 files changed

Lines changed: 49 additions & 5 deletions

File tree

backends/cadence/aot/compiler.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from torch.export.exported_program import ExportedProgram
4444
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e
4545

46+
from .pass_utils import EdgePassesConfig
4647
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4748
from .utils import print_ops_info
4849

@@ -355,12 +356,15 @@ def _lower_ep_to_cadence(
355356
program: ExportedProgram,
356357
dump_graphs: bool = False,
357358
opt_level: int = 1,
359+
edge_passes_config: Optional[EdgePassesConfig] = None,
358360
) -> EdgeProgramManager:
359361
"""
360362
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
361363
"""
362364
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
363-
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
365+
cadence_prog_manager = apply_exir_ops_passes(
366+
opt_level, edge_prog_manager, edge_passes_config
367+
)
364368
return cadence_prog_manager
365369

366370

@@ -369,9 +373,12 @@ def export_to_cadence(
369373
inputs: tuple[object, ...],
370374
dump_graphs: bool = False,
371375
opt_level: int = 1,
376+
edge_passes_config: Optional[EdgePassesConfig] = None,
372377
) -> EdgeProgramManager:
373378
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
374-
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
379+
cadence_prog_manager = apply_exir_ops_passes(
380+
opt_level, edge_prog_manager, edge_passes_config
381+
)
375382
return cadence_prog_manager
376383

377384

@@ -380,6 +387,7 @@ def quantize_and_export_to_cadence(
380387
inputs: tuple[object, ...],
381388
dump_graphs: bool = False,
382389
opt_level: int = 1,
390+
edge_passes_config: Optional[EdgePassesConfig] = None,
383391
) -> EdgeProgramManager:
384392
"""
385393
Trace, quantize, lower a model/inputs pair to edge IR and apply frontend
@@ -391,6 +399,7 @@ def quantize_and_export_to_cadence(
391399
quantized_model,
392400
opt_level=opt_level,
393401
dump_graphs=dump_graphs,
402+
edge_passes_config=edge_passes_config,
394403
)
395404

396405

backends/cadence/aot/pass_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import dataclasses
910
from abc import abstractmethod
1011
from dataclasses import dataclass
1112
from typing import Callable, List, Optional, override, Set, Type, TypeVar, Union
@@ -28,11 +29,20 @@ def allow_lifetime_and_storage_overlap(opt_level: int) -> bool:
2829
return opt_level >= 2
2930

3031

32+
# A dataclass that bundles feature flags for edge passes.
33+
# When adding a new flag, add a matching bool field to both this class and
34+
# CadencePassAttribute; the pass filter will pick it up automatically.
35+
@dataclass(frozen=True)
36+
class EdgePassesConfig:
37+
use_im2row_transform: bool = False
38+
39+
3140
# A dataclass that stores the attributes of an ExportPass.
3241
@dataclass(frozen=True)
3342
class CadencePassAttribute:
3443
opt_level: Optional[int] = None
3544
debug_pass: bool = False
45+
use_im2row_transform: bool = False
3646

3747

3848
# A dictionary that maps an ExportPass to its attributes.
@@ -58,17 +68,38 @@ def get_all_available_cadence_passes() -> Set[Type[PassBase]]:
5868
return set(ALL_CADENCE_PASSES.keys())
5969

6070

71+
def _check_feature_flags(
72+
pass_attribute: CadencePassAttribute,
73+
config: EdgePassesConfig,
74+
) -> bool:
75+
"""Check all feature flags: a pass is included only if every feature it
76+
requires is enabled in the config. Iterates over EdgePassesConfig fields
77+
so new flags are handled automatically."""
78+
for field in dataclasses.fields(EdgePassesConfig):
79+
if getattr(pass_attribute, field.name, False) and not getattr(
80+
config, field.name
81+
):
82+
return False
83+
return True
84+
85+
6186
# Create a new filter to filter out relevant passes from all passes.
6287
def create_cadence_pass_filter(
63-
opt_level: int, debug: bool = False
88+
opt_level: int,
89+
debug: bool = False,
90+
edge_passes_config: Optional[EdgePassesConfig] = None,
6491
) -> Callable[[Type[PassBase]], bool]:
92+
if edge_passes_config is None:
93+
edge_passes_config = EdgePassesConfig()
94+
6595
def _filter(p: Type[PassBase]) -> bool:
6696
pass_attribute = get_cadence_pass_attribute(p)
6797
return (
6898
pass_attribute is not None
6999
and pass_attribute.opt_level is not None
70100
and pass_attribute.opt_level <= opt_level
71101
and (not pass_attribute.debug_pass or debug)
102+
and _check_feature_flags(pass_attribute, edge_passes_config)
72103
)
73104

74105
return _filter

backends/cadence/aot/passes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.backends.cadence.aot.pass_utils import (
2020
CadencePassAttribute,
2121
create_cadence_pass_filter,
22+
EdgePassesConfig,
2223
register_cadence_pass,
2324
)
2425

@@ -101,9 +102,12 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
101102
def apply_exir_ops_passes(
102103
opt_level: int,
103104
edge_prog_manager: EdgeProgramManager,
105+
edge_passes_config: Optional[EdgePassesConfig] = None,
104106
) -> EdgeProgramManager:
105107
passes = get_passes_in_default_order()
106-
pass_filter = create_cadence_pass_filter(opt_level)
108+
pass_filter = create_cadence_pass_filter(
109+
opt_level, edge_passes_config=edge_passes_config
110+
)
107111
cadence_passes = [
108112
(
109113
lambda graph_module, filtered_pass=filtered_pass: filtered_pass()(

backends/cadence/aot/replace_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1327,7 +1327,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
13271327
return True
13281328

13291329

1330-
@register_cadence_pass(CadencePassAttribute(opt_level=2))
1330+
@register_cadence_pass(CadencePassAttribute(opt_level=2, use_im2row_transform=True))
13311331
class ReplaceConvWithIm2RowAndLinear(RemoveOrReplacePassInterface):
13321332
"""
13331333
Replace convolution where groups=1 with im2row followed by a linear op.

0 commit comments

Comments
 (0)