@@ -49,11 +49,48 @@ std::string _default_device_fp_type(const sycl::device &d)
4949 }
5050}
5151
52+ int get_numpy_major_version ()
53+ {
54+ namespace py = pybind11;
55+
56+ py::module_ numpy = py::module_::import (" numpy" );
57+ py::str version_string = numpy.attr (" __version__" );
58+ py::module_ numpy_lib = py::module_::import (" numpy.lib" );
59+
60+ py::object numpy_version = numpy_lib.attr (" NumpyVersion" )(version_string);
61+ int major_version = numpy_version.attr (" major" ).cast <int >();
62+
63+ return major_version;
64+ }
65+
5266std::string _default_device_int_type (const sycl::device &)
5367{
54- return " l" ; // code for numpy.dtype('long') to be consistent
55- // with NumPy's default integer type across
56- // platforms.
68+ const int np_ver = get_numpy_major_version ();
69+
70+ if (np_ver >= 2 ) {
71+ return " i8" ;
72+ }
73+ else {
74+ // code for numpy.dtype('long') to be consistent
75+ // with NumPy's default integer type across
76+ // platforms.
77+ return " l" ;
78+ }
79+ }
80+
81+ std::string _default_device_uint_type (const sycl::device &)
82+ {
83+ const int np_ver = get_numpy_major_version ();
84+
85+ if (np_ver >= 2 ) {
86+ return " u8" ;
87+ }
88+ else {
89+ // code for numpy.dtype('long') to be consistent
90+ // with NumPy's default integer type across
91+ // platforms.
92+ return " L" ;
93+ }
5794}
5895
5996std::string _default_device_complex_type (const sycl::device &d)
@@ -108,6 +145,12 @@ std::string default_device_int_type(const py::object &arg)
108145 return _default_device_int_type (d);
109146}
110147
148+ std::string default_device_uint_type (const py::object &arg)
149+ {
150+ const sycl::device &d = _extract_device (arg);
151+ return _default_device_uint_type (d);
152+ }
153+
111154std::string default_device_bool_type (const py::object &arg)
112155{
113156 const sycl::device &d = _extract_device (arg);
0 commit comments