diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index b98f3755d1..82c4f493b2 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,7 +1,7 @@ import gc import os import shutil - +from multiprocessing import cpu_count from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim from ....utils import Zero from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix @@ -53,55 +53,63 @@ def receiver_derivs(survey, mesh, fields, blocks): def compute_rows( - simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address, Jmatrix + simulation, + Ainv_deriv_u, + count, + blocks_receiver_derivs, + deriv_m, + fields, + addresses, + Jmatrix, + indices, ): """ Evaluate the sensitivities for the block or data """ + for ind in indices: + if Ainv_deriv_u.ndim == 1: + deriv_columns = Ainv_deriv_u[:, np.newaxis] + else: + deriv_indices = np.arange( + count, count + blocks_receiver_derivs[ind].shape[1] + ) + deriv_columns = Ainv_deriv_u[:, deriv_indices] - if Ainv_deriv_u.ndim == 1: - deriv_columns = Ainv_deriv_u[:, np.newaxis] - else: - deriv_columns = Ainv_deriv_u[:, deriv_indices] - - n_receivers = address[1][2] - source = simulation.survey.source_list[address[0][0]] + source = simulation.survey.source_list[addresses[ind][0][0]] - if isinstance(source, PlanewaveXYPrimary): - source_fields = fields - n_cols = 2 - else: - source_fields = fields[:, address[0][0]] - n_cols = 1 - - n_cols *= n_receivers + if isinstance(source, PlanewaveXYPrimary): + source_fields = fields + else: + source_fields = fields[:, addresses[ind][0][0]] - dA_dmT = simulation.getADeriv( - source.frequency, - source_fields, - deriv_columns, - adjoint=True, - ) + dA_dmT = simulation.getADeriv( + source.frequency, + source_fields, + deriv_columns, + adjoint=True, + ) - dRHS_dmT = simulation.getRHSDeriv( - source.frequency, - source, - deriv_columns, - adjoint=True, - ) + dRHS_dmT = simulation.getRHSDeriv( + source.frequency, + source, + deriv_columns, + adjoint=True, + ) - du_dmT = -dA_dmT - if not isinstance(dRHS_dmT, Zero): - du_dmT += dRHS_dmT - if not isinstance(deriv_m, Zero): - du_dmT += deriv_m + du_dmT = -dA_dmT + if not isinstance(dRHS_dmT, Zero): + du_dmT += dRHS_dmT + if not isinstance(deriv_m, Zero): + du_dmT += deriv_m - values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T - if isinstance(Jmatrix, zarr.Array): - Jmatrix.set_orthogonal_selection((address[1][1], slice(None)), values) - else: - Jmatrix[address[1][1], :] = values + if isinstance(Jmatrix, zarr.Array): + Jmatrix.set_orthogonal_selection( + (addresses[ind][1][1], slice(None)), values + ) + else: + Jmatrix[addresses[ind][1][1], :] = values return None @@ -182,7 +190,9 @@ def fields(self, m=None, return_Ainv=False): A = self.getA(freq) rhs = self.getRHS(freq) Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) + u = Ainv_solve * rhs + sources = self.survey.get_sources_by_frequency(freq) f[sources, self._solutionType] = u Ainv[freq] = Ainv_solve @@ -208,11 +218,16 @@ def compute_J(self, m, f=None): client, worker = self._get_client_worker() + if client: + n_threads = self.n_threads(client=client, worker=worker) + else: + n_threads = cpu_count() + A_i = list(Ainv.values())[0] m_size = m.size compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) blocks = get_parallel_blocks( - self.survey.source_list, compute_row_size, optimize=True + self.survey.source_list, compute_row_size, optimize=False ) if self.store_sensitivities == "disk": @@ -291,9 +306,9 @@ def compute_J(self, m, f=None): fields_array, addresses_chunks, client, - worker, + n_threads, + worker=worker, ) - for A in Ainv.values(): A.clean() @@ -318,6 +333,7 @@ def parallel_block_compute( fields_array, addresses, client, + n_threads, worker=None, ): m_size = m.size @@ -332,42 +348,45 @@ def parallel_block_compute( count = 0 block_delayed = [] - for address, dfduT in zip(addresses, blocks_receiver_derivs): - n_cols = dfduT.shape[1] - n_rows = address[1][2] - + block_indices = np.array_split(np.arange(len(addresses)), n_threads) + for indices in block_indices: if client: block_delayed.append( client.submit( compute_rows, simulation, ATinvdf_duT, - np.arange(count, count + n_cols), + count, + blocks_receiver_derivs, Zero(), fields_array, - address, + addresses, Jmatrix, + indices, workers=worker, ) ) else: + n_rows = np.sum([addresses[ind][1][2] for ind in indices]) delayed_eval = delayed(compute_rows) block_delayed.append( array.from_delayed( delayed_eval( simulation, ATinvdf_duT, - np.arange(count, count + n_cols), + count, + blocks_receiver_derivs, Zero(), fields_array, - address, + addresses, Jmatrix, + indices, ), dtype=np.float32, shape=(n_rows, m_size), ) ) - count += n_cols + count += np.sum([blocks_receiver_derivs[ind].shape[1] for ind in indices]) if client: return client.gather(block_delayed) diff --git a/simpeg/dask/utils.py b/simpeg/dask/utils.py index d7100b604a..8281179ea2 100644 --- a/simpeg/dask/utils.py +++ b/simpeg/dask/utils.py @@ -1,5 +1,4 @@ import numpy as np -from multiprocessing import cpu_count def compute_chunk_sizes(M, N, target_chunk_size): @@ -81,7 +80,7 @@ def get_parallel_blocks( for block in blocks: flatten_blocks += block - chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + chunks = np.array_split(np.arange(len(flatten_blocks)), thread_count) return [ [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 ]