2929_KernelModule = namedtuple ("_KernelModule" , ["kernel_name" , "kernel_bitcode" ])
3030
3131_KernelCompileResult = namedtuple (
32- "_KernelCompileResult" ,
33- ["status" , "cres_or_error" , "entry_point" ],
32+ "_KernelCompileResult" , CompileResult ._fields + ("kernel_device_ir_module" ,)
3433)
3534
3635
@@ -96,11 +95,11 @@ def _compile_to_spirv(
9695 )
9796
9897 def compile (self , args , return_type ):
99- kcres = self ._compile_cached (args , return_type )
100- if kcres . status :
98+ status , kcres = self ._compile_cached (args , return_type )
99+ if status :
101100 return kcres
102101
103- raise kcres . cres_or_error
102+ raise kcres
104103
105104 def _compile_cached (
106105 self , args , return_type : types .Type
@@ -137,34 +136,45 @@ def _compile_cached(
137136 """
138137 key = tuple (args ), return_type
139138 try :
140- return _KernelCompileResult ( False , self ._failed_cache [key ], None )
139+ return False , self ._failed_cache [key ]
141140 except KeyError :
142141 pass
143142
144143 try :
145- kernel_cres : CompileResult = self ._compile_core (args , return_type )
144+ cres : CompileResult = self ._compile_core (args , return_type )
146145
147- kernel_library = kernel_cres .library
148- kernel_fndesc = kernel_cres .fndesc
149- kernel_targetctx = kernel_cres .target_context
150-
151- kernel_module = self ._compile_to_spirv (
152- kernel_library , kernel_fndesc , kernel_targetctx
146+ kernel_device_ir_module = self ._compile_to_spirv (
147+ cres .library , cres .fndesc , cres .target_context
153148 )
154149
150+ kcres_attrs = []
151+
152+ for cres_field in cres ._fields :
153+ cres_attr = getattr (cres , cres_field )
154+ if cres_field == "entry_point" :
155+ if cres_attr is not None :
156+ raise AssertionError (
157+ "Compiled kernel and device_func should be "
158+ "compiled with compile_cfunc option turned off"
159+ )
160+ cres_attr = cres .fndesc .qualname
161+ kcres_attrs .append (cres_attr )
162+
163+ kcres_attrs .append (kernel_device_ir_module )
164+
155165 if config .DUMP_KERNEL_LLVM :
156166 with open (
157- kernel_cres .fndesc .llvm_func_name + ".ll" ,
167+ cres .fndesc .llvm_func_name + ".ll" ,
158168 "w" ,
159169 encoding = "UTF-8" ,
160170 ) as f :
161- f .write (kernel_cres .library .final_module )
171+ f .write (cres .library .final_module )
162172
163173 except errors .TypingError as e :
164174 self ._failed_cache [key ] = e
165- return _KernelCompileResult ( False , e , None )
175+ return False , e
166176
167- return _KernelCompileResult ( True , kernel_cres , kernel_module )
177+ return True , _KernelCompileResult ( * kcres_attrs )
168178
169179
170180class KernelDispatcher (Dispatcher ):
@@ -234,7 +244,14 @@ def typeof_pyval(self, val):
234244
235245 def add_overload (self , cres ):
236246 args = tuple (cres .signature .args )
237- self .overloads [args ] = cres .entry_point
247+ self .overloads [args ] = cres
248+
249+ def get_overload_device_ir (self , sig ):
250+ """
251+ Return the compiled device bitcode for the given signature.
252+ """
253+ args , _ = sigutils .normalize_signature (sig )
254+ return self .overloads [tuple (args )].kernel_device_ir_module
238255
239256 def compile (self , sig ) -> _KernelCompileResult :
240257 disp = self ._get_dispatcher_for_current_target ()
@@ -274,7 +291,7 @@ def cb_llvm(dur):
274291 # Don't recompile if signature already exists
275292 existing = self .overloads .get (tuple (args ))
276293 if existing is not None :
277- return existing
294+ return existing . entry_point
278295
279296 # TODO: Enable caching
280297 # Add code to enable on disk caching of a binary spirv kernel.
@@ -298,7 +315,11 @@ def folded(args, kws):
298315 )[1 ]
299316
300317 raise e .bind_fold_arguments (folded )
301- self .add_overload (kcres .cres_or_error )
318+ self .add_overload (kcres )
319+
320+ kcres .target_context .insert_user_function (
321+ kcres .entry_point , kcres .fndesc , [kcres .library ]
322+ )
302323
303324 # TODO: enable caching of kernel_module
304325 # https://github.com/IntelPython/numba-dpex/issues/1197
0 commit comments