[MLAS] Fix NHWC conv support gating#29127
Conversation
Signed-off-by: Jonathan Clohessy <Jonathan.Clohessy@arm.com>
There was a problem hiding this comment.
Pull request overview
This PR refines the CPU NHWC Conv fast-path gating for Arm KleidiAI by splitting the previous “channels-last supported” predicate into distinct checks for (1) KleidiAI NHWC imatmul Conv and (2) KleidiAI NHWC depthwise/grouped Conv, with depthwise/grouped intentionally disabled for now to avoid routing regressions. This integrates into ORT’s NHWC transformer and CPU Conv kernel selection logic, ensuring only shapes with a known fast KleidiAI kernel are transformed/executed in NHWC.
Changes:
- Introduces explicit MLAS support predicates for KleidiAI NHWC imatmul Conv and KleidiAI NHWC depthwise Conv (currently stubbed to
false). - Updates NHWC transformer + CPU Conv fast-path gating to use the new predicates, preventing depthwise/grouped Conv from being routed to the KleidiAI NHWC path.
- Updates related unit tests and MLAS benchmarks to use the new predicate split.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/optimizer/nhwc_transformer_test.cc | Updates NHWC capability helper to use new KleidiAI imatmul/depthwise predicates. |
| onnxruntime/test/mlas/bench/bench_sconv.cpp | Updates KleidiAI NHWC benchmark gating to accept imatmul or depthwise predicate. |
| onnxruntime/test/contrib_ops/fused_conv_test.cc | Updates NHWC fused-conv support helper to use new predicates. |
| onnxruntime/core/providers/cpu/nn/conv.cc | Refines NHWC fast-path selection to require explicit KleidiAI imatmul/depthwise support. |
| onnxruntime/core/optimizer/nhwc_transformer.cc | Updates NHWC transformer Conv filter to use split KleidiAI support predicates. |
| onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp | Switches KleidiAI override capability check to the imatmul-specific predicate. |
| onnxruntime/core/mlas/lib/convolve.cpp | Adds new KleidiAI imatmul predicate, introduces depthwise predicate stub (disabled), and rewraps legacy symmetric predicate. |
| onnxruntime/core/mlas/inc/mlas.h | Declares the new KleidiAI NHWC support predicate APIs. |
Review of PR #29127 —
|
Signed-off-by: Jonathan Clohessy <Jonathan.Clohessy@arm.com>
Signed-off-by: Jonathan Clohessy <Jonathan.Clohessy@arm.com>
Re-review of PR #29127 —
|
Re-review of PR #29127 —
|
Description
This PR fixes a performance regression introduced by overly broad Arm® KleidiAI™ NHWC Conv support checks.
The NHWC transformer was routing depthwise/grouped convolutions into the KleidiAI NHWC
NhwcFusedConvpath whenever the generic channels-last support predicate returned true. However, many depthwise shapes do not have a dedicated fast KleidiAI kernel and were falling back to the generic NHWC/im2col-style path, which is significantly slower than the existing NCHW execution path.This change splits the support checks into separate predicates for:
The depthwise predicate is intentionally disabled for now, so unsupported depthwise/grouped convolutions remain on the existing NCHW path. A follow-up PR can enable the depthwise predicate for shapes covered by a dedicated KleidiAI depthwise kernel.
Summary
Performance impact
Measured with:
Representative improvements vs the regressed KleidiAI NHWC baseline:
de_efficientnetlitev3_f32.onnxde_efficientnetlitev3_f16.onnxdeeplabv3_mobilenetv2_f32.onnxdeeplabv3_mobilenetv2_f16.onnxmobilenet_v1_f32.onnxmobilenetv1_ssd_f32.onnxretinaface_f32.onnxThe fix also restores performance to at or better than the pre-NHWC-routing reference for several affected models.