@@ -37,6 +37,8 @@ class _DpjitPassBuilder(object):
3737 execution.
3838 """
3939
40+ _use_mlir = False
41+
4042 @staticmethod
4143 def define_typed_pipeline (state , name = "dpex_dpjit_typed" ):
4244 """Returns the typed part of the nopython pipeline"""
@@ -55,19 +57,31 @@ def define_typed_pipeline(state, name="dpex_dpjit_typed"):
5557 pm .add_pass (NopythonRewrites , "nopython rewrites" )
5658 pm .add_pass (ParforPass , "convert to parfors" )
5759 pm .add_pass (
58- ParforLegalizeCFDPass , "Legalize parfors for compute follows data"
60+ ParforLegalizeCFDPass ,
61+ "Legalize parfors for compute follows data" ,
5962 )
6063 pm .add_pass (ParforFusionPass , "fuse parfors" )
6164 pm .add_pass (ParforPreLoweringPass , "parfor prelowering" )
6265
6366 pm .finalize ()
6467 return pm
6568
66- @staticmethod
67- def define_nopython_lowering_pipeline (state , name = "dpex_dpjit_lowering" ):
69+ @classmethod
70+ def define_nopython_lowering_pipeline (
71+ cls , state , name = "dpex_dpjit_lowering"
72+ ):
6873 """Returns an nopython mode pipeline based PassManager"""
6974 pm = PassManager (name )
7075
76+ flags = state .flags
77+ if cls ._use_mlir or hasattr (flags , "use_mlir" ) and flags .use_mlir :
78+ from numba_mlir .mlir .passes import MlirReplaceParfors
79+
80+ pm .add_pass (
81+ MlirReplaceParfors ,
82+ "Lower parfor using MLIR pipeline" ,
83+ )
84+
7185 # legalize
7286 pm .add_pass (
7387 NoPythonSupportedFeatureValidation ,
@@ -85,11 +99,11 @@ def define_nopython_lowering_pipeline(state, name="dpex_dpjit_lowering"):
8599 pm .finalize ()
86100 return pm
87101
88- @staticmethod
89- def define_nopython_pipeline (state , name = "dpex_dpjit_nopython" ):
102+ @classmethod
103+ def define_nopython_pipeline (cls , state , name = "dpex_dpjit_nopython" ):
90104 """Returns an nopython mode pipeline based PassManager"""
91105 # compose pipeline from untyped, typed and lowering parts
92- dpb = _DpjitPassBuilder
106+ dpb = cls
93107 pm = PassManager (name )
94108 untyped_passes = DefaultPassBuilder .define_untyped_pipeline (state )
95109 pm .passes .extend (untyped_passes .passes )
@@ -104,17 +118,31 @@ def define_nopython_pipeline(state, name="dpex_dpjit_nopython"):
104118 return pm
105119
106120
121+ class _DpjitPassBuilderMlir (_DpjitPassBuilder ):
122+ _use_mlir = True
123+
124+
107125class DpjitCompiler (CompilerBase ):
108126 """Dpex's compiler pipeline to offload parfor nodes into SYCL kernels."""
109127
128+ _pass_builder = _DpjitPassBuilder
129+
110130 def define_pipelines (self ):
111131 pms = []
112132 self .state .parfor_diagnostics = ExtendedParforDiagnostics ()
113133 self .state .metadata [
114134 "parfor_diagnostics"
115135 ] = self .state .parfor_diagnostics
116136 if not self .state .flags .force_pyobject :
117- pms .append (_DpjitPassBuilder .define_nopython_pipeline (self .state ))
137+ pms .append (self . _pass_builder .define_nopython_pipeline (self .state ))
118138 if self .state .status .can_fallback or self .state .flags .force_pyobject :
119139 raise UnsupportedCompilationModeError ()
120140 return pms
141+
142+
143+ class DpjitCompilerMlir (DpjitCompiler ):
144+ _pass_builder = _DpjitPassBuilderMlir
145+
146+
147+ def get_compiler (use_mlir ):
148+ return DpjitCompilerMlir if use_mlir else DpjitCompiler
0 commit comments