Skip to content

Commit c8a4a14

Browse files
committed
attempted fix to xgcm's grid_ufunc.py, but triggers error in dask
1 parent db79ddf commit c8a4a14

1 file changed

Lines changed: 22 additions & 3 deletions

File tree

misc/xgcm_fixes/grid_ufunc.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -939,9 +939,25 @@ def _map_func_over_core_dims(
939939

940940
# Need to transpose the numpy axis arguments to leave core dims at end
941941
# else they won't match up inside mapped_func after xr.apply_ufunc does its transposition
942-
transposed_original_args = [
943-
arg.transpose(..., *in_core_dims[i]) for i, arg in enumerate(original_args)
944-
]
942+
if isinstance(original_args[0],dict):
943+
# transposed_original_args = []
944+
# for i, arg in enumerate(original_args):
945+
# transposed_original_args.append(
946+
# {axis: da.transpose(..., *in_core_dims[i]) for axis,da in arg.items()}
947+
# )
948+
transposed_original_args = [
949+
tuple(arg.values())[0].transpose(..., *in_core_dims[i]) for i, arg in enumerate(original_args)
950+
]
951+
952+
# boundary_width_per_numpy_axis = {
953+
# grid.axes[ax_name]._get_axis_dim_num(tuple(transposed_original_args[0].values())[0]): width
954+
# for ax_name, width in boundary_width_real_axes.items()
955+
# }
956+
957+
else:
958+
transposed_original_args = [
959+
arg.transpose(..., *in_core_dims[i]) for i, arg in enumerate(original_args)
960+
]
945961

946962
boundary_width_per_numpy_axis = {
947963
grid.axes[ax_name]._get_axis_dim_num(transposed_original_args[0]): width
@@ -959,6 +975,9 @@ def _dict_to_numbered_axes(
959975
# Our rechunking means dask.map_overlap needs to be explicitly told what chunks output should have
960976
# But in this case output chunks are the same as input chunks
961977
# (as we disallowed axis positions for which this is not the case)
978+
# if isinstance(transposed_original_args[0],dict):
979+
# original_chunksizes = [tuple(arg.values())[0].variable.chunksizes for arg in transposed_original_args]
980+
# else:
962981
original_chunksizes = [arg.variable.chunksizes for arg in transposed_original_args]
963982
# TODO first argument only because map_overlap can't handle multiple return values (I think)
964983
true_chunksizes = original_chunksizes[0]

0 commit comments

Comments
 (0)