@@ -933,6 +933,49 @@ def get_pde_as_diff_op(self):
933933 return laplacian (w )
934934
935935
936+ class GaussTransformKernel (ExpressionKernel ):
937+ init_arg_names = ("dim" ,)
938+
939+ def __init__ (self , dim , gauss_delta_name = "delta" ):
940+ r = pymbolic_real_norm_2 (make_sym_vector ("d" , dim ))
941+ delta = SpatialConstant (gauss_delta_name )
942+ expr = var ("exp" )(- r ** 2 / delta )
943+ scaling = 1
944+
945+ super ().__init__ (
946+ dim ,
947+ expression = expr ,
948+ global_scaling_const = scaling ,
949+ is_complex_valued = False )
950+
951+ self .gauss_delta_name = gauss_delta_name
952+
953+ def __getinitargs__ (self ):
954+ return (self .dim , self .gauss_delta_name )
955+
956+ def update_persistent_hash (self , key_hash , key_builder ):
957+ key_hash .update (type (self ).__name__ .encode ("utf8" ))
958+ key_builder .rec (key_hash , (self .dim , self .gauss_delta_name ))
959+
960+ def __repr__ (self ):
961+ return f"GaussKnl{ self .dim } D"
962+
963+ def get_args (self ):
964+ return [
965+ KernelArgument (
966+ loopy_arg = lp .ValueArg (self .gauss_delta_name , np .float64 ),
967+ )]
968+
969+ mapper_method = "map_gauss_transform_kernel"
970+
971+ def get_derivative_taker (self , dvec , rscale , sac ):
972+ """Return a :class:`sumpy.derivative_taker.ExprDerivativeTaker` instance
973+ that supports taking derivatives of the base kernel with respect to dvec.
974+ """
975+ from sumpy .derivative_taker import RadialDerivativeTaker
976+ return RadialDerivativeTaker (self .get_expression (dvec ), dvec , rscale ,
977+ sac )
978+
936979# }}}
937980
938981
@@ -1350,6 +1393,7 @@ def map_expression_kernel(self, kernel):
13501393 map_elasticity_kernel = map_expression_kernel
13511394 map_line_of_compression_kernel = map_expression_kernel
13521395 map_stresslet_kernel = map_expression_kernel
1396+ map_gauss_transform_kernel = map_expression_kernel
13531397
13541398 def map_axis_target_derivative (self , kernel ):
13551399 return type (kernel )(kernel .axis , self .rec (kernel .inner_kernel ))
@@ -1406,6 +1450,7 @@ def map_expression_kernel(self, kernel):
14061450 map_yukawa_kernel = map_expression_kernel
14071451 map_line_of_compression_kernel = map_expression_kernel
14081452 map_stresslet_kernel = map_expression_kernel
1453+ map_gauss_transform_kernel = map_expression_kernel
14091454
14101455 def map_axis_target_derivative (self , kernel ):
14111456 return 1 + self .rec (kernel .inner_kernel )
0 commit comments