@@ -287,10 +287,10 @@ def _concat_diff_input(arr, axis, prepend, append):
287287 )
288288 if not prepend_shape :
289289 prepend_shape = arr_shape [:axis ] + (1 ,) + arr_shape [axis + 1 :]
290- a_prepend = dpt .broadcast_to (a_prepend , arr_shape )
290+ a_prepend = dpt .broadcast_to (a_prepend , prepend_shape )
291291 if not append_shape :
292292 append_shape = arr_shape [:axis ] + (1 ,) + arr_shape [axis + 1 :]
293- a_append = dpt .broadcast_to (a_append , arr_shape )
293+ a_append = dpt .broadcast_to (a_append , append_shape )
294294 return dpt .concat ((a_prepend , arr , a_append ), axis = axis )
295295 elif prepend is not None :
296296 q1 , x_usm_type = arr .sycl_queue , arr .usm_type
@@ -347,7 +347,7 @@ def _concat_diff_input(arr, axis, prepend, append):
347347 )
348348 if not prepend_shape :
349349 prepend_shape = arr_shape [:axis ] + (1 ,) + arr_shape [axis + 1 :]
350- a_prepend = dpt .broadcast_to (a_prepend , arr_shape )
350+ a_prepend = dpt .broadcast_to (a_prepend , prepend_shape )
351351 return dpt .concat ((a_prepend , arr ), axis = axis )
352352 elif append is not None :
353353 q1 , x_usm_type = arr .sycl_queue , arr .usm_type
@@ -402,7 +402,7 @@ def _concat_diff_input(arr, axis, prepend, append):
402402 )
403403 if not append_shape :
404404 append_shape = arr_shape [:axis ] + (1 ,) + arr_shape [axis + 1 :]
405- a_append = dpt .broadcast_to (a_append , arr_shape )
405+ a_append = dpt .broadcast_to (a_append , append_shape )
406406 return dpt .concat ((arr , a_append ), axis = axis )
407407 else :
408408 arr1 = arr
0 commit comments