@@ -146,36 +146,188 @@ class dpnp_max_c_kernel;
146146template <typename _DataType>
147147void dpnp_max_c (void * array1_in, void * result1, const size_t * shape, size_t ndim, const size_t * axis, size_t naxis)
148148{
149- __attribute__ ((unused)) void * tmp = (void *)(axis + naxis);
149+ if (naxis == 0 )
150+ {
151+ __attribute__ ((unused)) void * tmp = (void *)(axis + naxis);
150152
151- _DataType* array_1 = reinterpret_cast <_DataType*>(array1_in);
152- _DataType* result = reinterpret_cast <_DataType*>(result1);
153+ _DataType* array_1 = reinterpret_cast <_DataType*>(array1_in);
154+ _DataType* result = reinterpret_cast <_DataType*>(result1);
153155
154- size_t size = 1 ;
155- for (size_t i = 0 ; i < ndim; ++i)
156- {
157- size *= shape[i];
158- }
156+ size_t size = 1 ;
157+ for (size_t i = 0 ; i < ndim; ++i)
158+ {
159+ size *= shape[i];
160+ }
159161
160- if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value)
161- {
162- // Required initializing the result before call the function
163- result[0 ] = array_1[0 ];
162+ if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value)
163+ {
164+ // Required initializing the result before call the function
165+ result[0 ] = array_1[0 ];
164166
165- auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1 , size, array_1);
167+ auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1 , size, array_1);
166168
167- cl::sycl::event event = mkl_stats::max (DPNP_QUEUE, dataset, result);
169+ cl::sycl::event event = mkl_stats::max (DPNP_QUEUE, dataset, result);
168170
169- event.wait ();
171+ event.wait ();
172+ }
173+ else
174+ {
175+ auto policy = oneapi::dpl::execution::make_device_policy<class dpnp_max_c_kernel <_DataType>>(DPNP_QUEUE);
176+
177+ _DataType* res = std::max_element (policy, array_1, array_1 + size);
178+ policy.queue ().wait ();
179+
180+ result[0 ] = *res;
181+ }
170182 }
171183 else
172184 {
173- auto policy = oneapi::dpl::execution::make_device_policy<class dpnp_max_c_kernel <_DataType>>(DPNP_QUEUE);
185+ _DataType* array_1 = reinterpret_cast <_DataType*>(array1_in);
186+ _DataType* result = reinterpret_cast <_DataType*>(result1);
187+
188+ size_t res_ndim = ndim - naxis;
189+ size_t res_shape[res_ndim];
190+ int ind = 0 ;
191+ for (size_t i = 0 ; i < ndim; i++)
192+ {
193+ bool found = false ;
194+ for (size_t j = 0 ; j < naxis; j++)
195+ {
196+ if (axis[j] == i)
197+ {
198+ found = true ;
199+ break ;
200+ }
201+ }
202+ if (!found)
203+ {
204+ res_shape[ind] = shape[i];
205+ ind++;
206+ }
207+ }
208+
209+ size_t size_input = 1 ;
210+ for (size_t i = 0 ; i < ndim; ++i)
211+ {
212+ size_input *= shape[i];
213+ }
214+
215+ size_t input_shape_offsets[ndim];
216+ size_t acc = 1 ;
217+ for (size_t i = ndim - 1 ; i > 0 ; --i)
218+ {
219+ input_shape_offsets[i] = acc;
220+ acc *= shape[i];
221+ }
222+ input_shape_offsets[0 ] = acc;
223+
224+ size_t output_shape_offsets[res_ndim];
225+ acc = 1 ;
226+ if (res_ndim > 0 )
227+ {
228+ for (size_t i = res_ndim - 1 ; i > 0 ; --i)
229+ {
230+ output_shape_offsets[i] = acc;
231+ acc *= res_shape[i];
232+ }
233+ }
234+ output_shape_offsets[0 ] = acc;
235+
236+ size_t size_result = 1 ;
237+ for (size_t i = 0 ; i < res_ndim; ++i)
238+ {
239+ size_result *= res_shape[i];
240+ }
174241
175- _DataType* res = std::max_element (policy, array_1, array_1 + size);
176- policy.queue ().wait ();
242+ // init result array
243+ for (size_t result_idx = 0 ; result_idx < size_result; ++result_idx)
244+ {
245+ size_t xyz[res_ndim];
246+ size_t remainder = result_idx;
247+ for (size_t i = 0 ; i < res_ndim; ++i)
248+ {
249+ xyz[i] = remainder / output_shape_offsets[i];
250+ remainder = remainder - xyz[i] * output_shape_offsets[i];
251+ }
252+
253+ size_t source_axis[ndim];
254+ size_t result_axis_idx = 0 ;
255+ for (size_t idx = 0 ; idx < ndim; ++idx)
256+ {
257+ bool found = false ;
258+ for (size_t i = 0 ; i < naxis; ++i)
259+ {
260+ if (axis[i] == idx)
261+ {
262+ found = true ;
263+ break ;
264+ }
265+ }
266+ if (found)
267+ {
268+ source_axis[idx] = 0 ;
269+ }
270+ else
271+ {
272+ source_axis[idx] = xyz[result_axis_idx];
273+ result_axis_idx++;
274+ }
275+ }
276+
277+ size_t source_idx = 0 ;
278+ for (size_t i = 0 ; i < ndim; ++i)
279+ {
280+ source_idx += input_shape_offsets[i] * source_axis[i];
281+ }
282+
283+ result[result_idx] = array_1[source_idx];
284+ }
177285
178- result[0 ] = *res;
286+ for (size_t source_idx = 0 ; source_idx < size_input; ++source_idx)
287+ {
288+ // reconstruct x,y,z from linear source_idx
289+ size_t xyz[ndim];
290+ size_t remainder = source_idx;
291+ for (size_t i = 0 ; i < ndim; ++i)
292+ {
293+ xyz[i] = remainder / input_shape_offsets[i];
294+ remainder = remainder - xyz[i] * input_shape_offsets[i];
295+ }
296+
297+ // extract result axis
298+ size_t result_axis[res_ndim];
299+ size_t result_idx = 0 ;
300+ for (size_t idx = 0 ; idx < ndim; ++idx)
301+ {
302+ // try to find current idx in axis array
303+ bool found = false ;
304+ for (size_t i = 0 ; i < naxis; ++i)
305+ {
306+ if (axis[i] == idx)
307+ {
308+ found = true ;
309+ break ;
310+ }
311+ }
312+ if (!found)
313+ {
314+ result_axis[result_idx] = xyz[idx];
315+ result_idx++;
316+ }
317+ }
318+
319+ // Construct result offset
320+ size_t result_offset = 0 ;
321+ for (size_t i = 0 ; i < res_ndim; ++i)
322+ {
323+ result_offset += output_shape_offsets[i] * result_axis[i];
324+ }
325+
326+ if (result[result_offset] < array_1[source_idx])
327+ {
328+ result[result_offset] = array_1[source_idx];
329+ }
330+ }
179331 }
180332
181333 return ;
0 commit comments