From ba268dc8e94423a44bd210a7a667c0a8d9251182 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 30 Apr 2026 08:38:40 -0700 Subject: [PATCH 1/4] Remove optimize. Add time prints --- .../frequency_domain/simulation.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index b98f3755d1..c85ff29dfe 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -7,7 +7,7 @@ from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix import numpy as np import scipy.sparse as sp - +from time import time from dask import array, compute, delayed from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary @@ -182,7 +182,9 @@ def fields(self, m=None, return_Ainv=False): A = self.getA(freq) rhs = self.getRHS(freq) Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) + u = Ainv_solve * rhs + sources = self.survey.get_sources_by_frequency(freq) f[sources, self._solutionType] = u Ainv[freq] = Ainv_solve @@ -212,7 +214,7 @@ def compute_J(self, m, f=None): m_size = m.size compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) blocks = get_parallel_blocks( - self.survey.source_list, compute_row_size, optimize=True + self.survey.source_list, compute_row_size, optimize=False ) if self.store_sensitivities == "disk": @@ -238,7 +240,7 @@ def compute_J(self, m, f=None): fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - + ct = time() if client: fields_array = client.scatter(f[:, self._solutionType], workers=worker) fields = client.scatter(f, workers=worker) @@ -272,13 +274,13 @@ def compute_J(self, m, f=None): block, ) ) - + print(f"Derivatives time: {time() - ct}") # Dask process for all derivatives if client: blocks_receiver_derivs = client.gather(blocks_receiver_derivs) else: blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] - + ct = time() for block_derivs_chunks, addresses_chunks in zip( blocks_receiver_derivs, blocks, strict=True ): @@ -293,7 +295,7 @@ def compute_J(self, m, f=None): client, worker, ) - + print(f"Solve time: {time() - ct}") for A in Ainv.values(): A.clean() From 5d630c7b7c4c7b507b383346a09a830aed815941 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 30 Apr 2026 11:32:45 -0700 Subject: [PATCH 2/4] Re-block the compute_rows for less small tasks --- .../frequency_domain/simulation.py | 119 ++++++++++-------- simpeg/dask/utils.py | 3 +- 2 files changed, 70 insertions(+), 52 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index c85ff29dfe..07b6f3edab 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,7 +1,7 @@ import gc import os import shutil - +from multiprocessing import cpu_count from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim from ....utils import Zero from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix @@ -53,55 +53,63 @@ def receiver_derivs(survey, mesh, fields, blocks): def compute_rows( - simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address, Jmatrix + simulation, + Ainv_deriv_u, + count, + blocks_receiver_derivs, + deriv_m, + fields, + addresses, + Jmatrix, + indices, ): """ Evaluate the sensitivities for the block or data """ + for ind in indices: + if Ainv_deriv_u.ndim == 1: + deriv_columns = Ainv_deriv_u[:, np.newaxis] + else: + deriv_indices = np.arange( + count, count + blocks_receiver_derivs[ind].shape[1] + ) + deriv_columns = Ainv_deriv_u[:, deriv_indices] - if Ainv_deriv_u.ndim == 1: - deriv_columns = Ainv_deriv_u[:, np.newaxis] - else: - deriv_columns = Ainv_deriv_u[:, deriv_indices] - - n_receivers = address[1][2] - source = simulation.survey.source_list[address[0][0]] + source = simulation.survey.source_list[addresses[ind][0][0]] - if isinstance(source, PlanewaveXYPrimary): - source_fields = fields - n_cols = 2 - else: - source_fields = fields[:, address[0][0]] - n_cols = 1 - - n_cols *= n_receivers + if isinstance(source, PlanewaveXYPrimary): + source_fields = fields + else: + source_fields = fields[:, addresses[ind][0][0]] - dA_dmT = simulation.getADeriv( - source.frequency, - source_fields, - deriv_columns, - adjoint=True, - ) + dA_dmT = simulation.getADeriv( + source.frequency, + source_fields, + deriv_columns, + adjoint=True, + ) - dRHS_dmT = simulation.getRHSDeriv( - source.frequency, - source, - deriv_columns, - adjoint=True, - ) + dRHS_dmT = simulation.getRHSDeriv( + source.frequency, + source, + deriv_columns, + adjoint=True, + ) - du_dmT = -dA_dmT - if not isinstance(dRHS_dmT, Zero): - du_dmT += dRHS_dmT - if not isinstance(deriv_m, Zero): - du_dmT += deriv_m + du_dmT = -dA_dmT + if not isinstance(dRHS_dmT, Zero): + du_dmT += dRHS_dmT + if not isinstance(deriv_m, Zero): + du_dmT += deriv_m - values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T - if isinstance(Jmatrix, zarr.Array): - Jmatrix.set_orthogonal_selection((address[1][1], slice(None)), values) - else: - Jmatrix[address[1][1], :] = values + if isinstance(Jmatrix, zarr.Array): + Jmatrix.set_orthogonal_selection( + (addresses[ind][1][1], slice(None)), values + ) + else: + Jmatrix[addresses[ind][1][1], :] = values return None @@ -210,6 +218,11 @@ def compute_J(self, m, f=None): client, worker = self._get_client_worker() + if client: + n_threads = self.n_threads(client=client, worker=worker) + else: + n_threads = cpu_count() + A_i = list(Ainv.values())[0] m_size = m.size compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6)) @@ -274,12 +287,13 @@ def compute_J(self, m, f=None): block, ) ) - print(f"Derivatives time: {time() - ct}") + # Dask process for all derivatives if client: blocks_receiver_derivs = client.gather(blocks_receiver_derivs) else: blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] + print(f"Derivatives time: {time() - ct}") ct = time() for block_derivs_chunks, addresses_chunks in zip( blocks_receiver_derivs, blocks, strict=True @@ -293,7 +307,8 @@ def compute_J(self, m, f=None): fields_array, addresses_chunks, client, - worker, + n_threads, + worker=worker, ) print(f"Solve time: {time() - ct}") for A in Ainv.values(): @@ -320,6 +335,7 @@ def parallel_block_compute( fields_array, addresses, client, + n_threads, worker=None, ): m_size = m.size @@ -334,42 +350,45 @@ def parallel_block_compute( count = 0 block_delayed = [] - for address, dfduT in zip(addresses, blocks_receiver_derivs): - n_cols = dfduT.shape[1] - n_rows = address[1][2] - + block_indices = np.array_split(np.arange(len(addresses)), n_threads) + for indices in block_indices: if client: block_delayed.append( client.submit( compute_rows, simulation, ATinvdf_duT, - np.arange(count, count + n_cols), + count, + blocks_receiver_derivs, Zero(), fields_array, - address, + addresses, Jmatrix, + indices, workers=worker, ) ) else: + n_rows = np.sum([addresses[ind][1][2] for ind in indices]) delayed_eval = delayed(compute_rows) block_delayed.append( array.from_delayed( delayed_eval( simulation, ATinvdf_duT, - np.arange(count, count + n_cols), + count, + blocks_receiver_derivs, Zero(), fields_array, - address, + addresses, Jmatrix, + indices, ), dtype=np.float32, shape=(n_rows, m_size), ) ) - count += n_cols + count += np.sum([blocks_receiver_derivs[ind].shape[1] for ind in indices]) if client: return client.gather(block_delayed) diff --git a/simpeg/dask/utils.py b/simpeg/dask/utils.py index d7100b604a..8281179ea2 100644 --- a/simpeg/dask/utils.py +++ b/simpeg/dask/utils.py @@ -1,5 +1,4 @@ import numpy as np -from multiprocessing import cpu_count def compute_chunk_sizes(M, N, target_chunk_size): @@ -81,7 +80,7 @@ def get_parallel_blocks( for block in blocks: flatten_blocks += block - chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count()) + chunks = np.array_split(np.arange(len(flatten_blocks)), thread_count) return [ [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0 ] From d2319d4ea23551f53b5cece9ee327aa40227846b Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 30 Apr 2026 15:59:25 -0700 Subject: [PATCH 3/4] Skip inner loop --- .../frequency_domain/simulation.py | 95 ++++++++++--------- 1 file changed, 49 insertions(+), 46 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 07b6f3edab..a2fc63391a 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -7,7 +7,7 @@ from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix import numpy as np import scipy.sparse as sp -from time import time + from dask import array, compute, delayed from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary @@ -56,7 +56,7 @@ def compute_rows( simulation, Ainv_deriv_u, count, - blocks_receiver_derivs, + block_shapes, deriv_m, fields, addresses, @@ -66,50 +66,50 @@ def compute_rows( """ Evaluate the sensitivities for the block or data """ - for ind in indices: - if Ainv_deriv_u.ndim == 1: - deriv_columns = Ainv_deriv_u[:, np.newaxis] - else: - deriv_indices = np.arange( - count, count + blocks_receiver_derivs[ind].shape[1] - ) - deriv_columns = Ainv_deriv_u[:, deriv_indices] + # for ind, shape in zip(indices, block_shapes, strict=True): + inds = np.hstack([addresses[ind][1][0] for ind in indices]) + shape = np.sum(block_shapes) + if Ainv_deriv_u.ndim == 1: + deriv_columns = Ainv_deriv_u[:, np.newaxis] + else: + deriv_indices = np.arange(count, count + shape) + deriv_columns = Ainv_deriv_u[:, deriv_indices] - source = simulation.survey.source_list[addresses[ind][0][0]] + source = simulation.survey.source_list[addresses[indices[0]][0][0]] - if isinstance(source, PlanewaveXYPrimary): - source_fields = fields - else: - source_fields = fields[:, addresses[ind][0][0]] + if isinstance(source, PlanewaveXYPrimary): + source_fields = fields + else: + source_fields = fields[:, inds] - dA_dmT = simulation.getADeriv( - source.frequency, - source_fields, - deriv_columns, - adjoint=True, - ) + dA_dmT = simulation.getADeriv( + source.frequency, + source_fields, + deriv_columns, + adjoint=True, + ) + dRHS_dmT = simulation.getRHSDeriv( + source.frequency, + source, + deriv_columns, + adjoint=True, + ) - dRHS_dmT = simulation.getRHSDeriv( - source.frequency, - source, - deriv_columns, - adjoint=True, - ) + du_dmT = -dA_dmT + if not isinstance(dRHS_dmT, Zero): + du_dmT += dRHS_dmT + if not isinstance(deriv_m, Zero): + du_dmT += deriv_m - du_dmT = -dA_dmT - if not isinstance(dRHS_dmT, Zero): - du_dmT += dRHS_dmT - if not isinstance(deriv_m, Zero): - du_dmT += deriv_m + values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T - values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + inds = np.hstack([addresses[ind][1][1] for ind in indices]) + if isinstance(Jmatrix, zarr.Array): + Jmatrix.set_orthogonal_selection((inds, slice(None)), values) + else: + Jmatrix[inds, :] = values - if isinstance(Jmatrix, zarr.Array): - Jmatrix.set_orthogonal_selection( - (addresses[ind][1][1], slice(None)), values - ) - else: - Jmatrix[addresses[ind][1][1], :] = values + count += shape return None @@ -253,7 +253,7 @@ def compute_J(self, m, f=None): fields_array = f[:, self._solutionType] blocks_receiver_derivs = [] - ct = time() + if client: fields_array = client.scatter(f[:, self._solutionType], workers=worker) fields = client.scatter(f, workers=worker) @@ -293,8 +293,7 @@ def compute_J(self, m, f=None): blocks_receiver_derivs = client.gather(blocks_receiver_derivs) else: blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] - print(f"Derivatives time: {time() - ct}") - ct = time() + for block_derivs_chunks, addresses_chunks in zip( blocks_receiver_derivs, blocks, strict=True ): @@ -310,7 +309,7 @@ def compute_J(self, m, f=None): n_threads, worker=worker, ) - print(f"Solve time: {time() - ct}") + for A in Ainv.values(): A.clean() @@ -352,6 +351,10 @@ def parallel_block_compute( block_delayed = [] block_indices = np.array_split(np.arange(len(addresses)), n_threads) for indices in block_indices: + if len(indices) == 0: + continue + + block_shapes = [blocks_receiver_derivs[ind].shape[1] for ind in indices] if client: block_delayed.append( client.submit( @@ -359,7 +362,7 @@ def parallel_block_compute( simulation, ATinvdf_duT, count, - blocks_receiver_derivs, + block_shapes, Zero(), fields_array, addresses, @@ -377,7 +380,7 @@ def parallel_block_compute( simulation, ATinvdf_duT, count, - blocks_receiver_derivs, + block_shapes, Zero(), fields_array, addresses, @@ -388,7 +391,7 @@ def parallel_block_compute( shape=(n_rows, m_size), ) ) - count += np.sum([blocks_receiver_derivs[ind].shape[1] for ind in indices]) + count += np.sum(block_shapes) if client: return client.gather(block_delayed) From c6f29cb8165354cd4e9d2f1bfe6c9677eed545d0 Mon Sep 17 00:00:00 2001 From: domfournier Date: Thu, 30 Apr 2026 16:27:48 -0700 Subject: [PATCH 4/4] Revert "Skip inner loop" This reverts commit d2319d4ea23551f53b5cece9ee327aa40227846b. --- .../frequency_domain/simulation.py | 87 +++++++++---------- 1 file changed, 41 insertions(+), 46 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index a2fc63391a..82c4f493b2 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -56,7 +56,7 @@ def compute_rows( simulation, Ainv_deriv_u, count, - block_shapes, + blocks_receiver_derivs, deriv_m, fields, addresses, @@ -66,50 +66,50 @@ def compute_rows( """ Evaluate the sensitivities for the block or data """ - # for ind, shape in zip(indices, block_shapes, strict=True): - inds = np.hstack([addresses[ind][1][0] for ind in indices]) - shape = np.sum(block_shapes) - if Ainv_deriv_u.ndim == 1: - deriv_columns = Ainv_deriv_u[:, np.newaxis] - else: - deriv_indices = np.arange(count, count + shape) - deriv_columns = Ainv_deriv_u[:, deriv_indices] + for ind in indices: + if Ainv_deriv_u.ndim == 1: + deriv_columns = Ainv_deriv_u[:, np.newaxis] + else: + deriv_indices = np.arange( + count, count + blocks_receiver_derivs[ind].shape[1] + ) + deriv_columns = Ainv_deriv_u[:, deriv_indices] - source = simulation.survey.source_list[addresses[indices[0]][0][0]] + source = simulation.survey.source_list[addresses[ind][0][0]] - if isinstance(source, PlanewaveXYPrimary): - source_fields = fields - else: - source_fields = fields[:, inds] + if isinstance(source, PlanewaveXYPrimary): + source_fields = fields + else: + source_fields = fields[:, addresses[ind][0][0]] - dA_dmT = simulation.getADeriv( - source.frequency, - source_fields, - deriv_columns, - adjoint=True, - ) - dRHS_dmT = simulation.getRHSDeriv( - source.frequency, - source, - deriv_columns, - adjoint=True, - ) + dA_dmT = simulation.getADeriv( + source.frequency, + source_fields, + deriv_columns, + adjoint=True, + ) - du_dmT = -dA_dmT - if not isinstance(dRHS_dmT, Zero): - du_dmT += dRHS_dmT - if not isinstance(deriv_m, Zero): - du_dmT += deriv_m + dRHS_dmT = simulation.getRHSDeriv( + source.frequency, + source, + deriv_columns, + adjoint=True, + ) - values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T + du_dmT = -dA_dmT + if not isinstance(dRHS_dmT, Zero): + du_dmT += dRHS_dmT + if not isinstance(deriv_m, Zero): + du_dmT += deriv_m - inds = np.hstack([addresses[ind][1][1] for ind in indices]) - if isinstance(Jmatrix, zarr.Array): - Jmatrix.set_orthogonal_selection((inds, slice(None)), values) - else: - Jmatrix[inds, :] = values + values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T - count += shape + if isinstance(Jmatrix, zarr.Array): + Jmatrix.set_orthogonal_selection( + (addresses[ind][1][1], slice(None)), values + ) + else: + Jmatrix[addresses[ind][1][1], :] = values return None @@ -309,7 +309,6 @@ def compute_J(self, m, f=None): n_threads, worker=worker, ) - for A in Ainv.values(): A.clean() @@ -351,10 +350,6 @@ def parallel_block_compute( block_delayed = [] block_indices = np.array_split(np.arange(len(addresses)), n_threads) for indices in block_indices: - if len(indices) == 0: - continue - - block_shapes = [blocks_receiver_derivs[ind].shape[1] for ind in indices] if client: block_delayed.append( client.submit( @@ -362,7 +357,7 @@ def parallel_block_compute( simulation, ATinvdf_duT, count, - block_shapes, + blocks_receiver_derivs, Zero(), fields_array, addresses, @@ -380,7 +375,7 @@ def parallel_block_compute( simulation, ATinvdf_duT, count, - block_shapes, + blocks_receiver_derivs, Zero(), fields_array, addresses, @@ -391,7 +386,7 @@ def parallel_block_compute( shape=(n_rows, m_size), ) ) - count += np.sum(block_shapes) + count += np.sum([blocks_receiver_derivs[ind].shape[1] for ind in indices]) if client: return client.gather(block_delayed)