Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion devito/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,34 @@ def T(self):
"""
return self.transpose()

def _mpi_advanced_1d_target(self, glb_idx, axis):
"""
Return a raw local view for MPI advanced-indexing communication.

The returned view is indexed by ``glb_idx`` on all dimensions except
``axis``, which is replaced by ``slice(None)`` so the helper code in
``devito.data.utils`` can pack or unpack the requested global integer
entries itself. ``target_axis`` is the position of ``axis`` in the
returned view after scalar-indexed dimensions have been dropped.

This hook is kept on ``Data`` because only the subclass can bypass its
own ``__getitem__`` and obtain a plain ndarray view.
"""
target_idx = list(glb_idx)
target_idx[axis] = slice(None)
loc_idx = self._index_glb_to_loc(tuple(target_idx))
target_axis = sum(not is_integer(i) for i in glb_idx[:axis])
return super().__getitem__(loc_idx).view(np.ndarray), target_axis

@_check_idx
def __getitem__(self, glb_idx, comm_type, gather_rank=None):
advanced = mpi_advanced_1d_index(self, glb_idx)
if advanced is not None:
# Global integer indices may refer to data owned by any rank.
return mpi_advanced_1d_get(
self, *advanced, target_getter=self._mpi_advanced_1d_target
)

loc_idx = self._index_glb_to_loc(glb_idx)
is_gather = isinstance(gather_rank, int)
if is_gather and comm_type is gather:
Expand Down Expand Up @@ -383,6 +409,17 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):

@_check_idx
def __setitem__(self, glb_idx, val, comm_type):
advanced = mpi_advanced_1d_index(self, glb_idx)
if advanced is not None:
# ``val`` is rank-local and ordered by the caller's global integer
# indices; route entries to their owner ranks.
glb_idx, axis, indices, decomposition = advanced
mpi_advanced_1d_set(
self, glb_idx, val, axis, indices, decomposition,
target_getter=self._mpi_advanced_1d_target
)
return

loc_idx = self._index_glb_to_loc(glb_idx)

if loc_idx is NONLOCAL:
Expand Down Expand Up @@ -461,7 +498,9 @@ def __setitem__(self, glb_idx, val, comm_type):
raise ValueError(f"Cannot insert obj of type `{type(val)}` into a Data")

def _normalize_index(self, idx):
if isinstance(idx, np.ndarray):
if isinstance(idx, np.ndarray) or (
isinstance(idx, list) and index_contains_integer_sequence(idx, self.ndim)
):
# Advanced indexing mode
return (idx,)
else:
Expand Down
Loading
Loading