diff --git a/pyproject.toml b/pyproject.toml index 2713a73f..f1c1e8dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] requires-python = '>=3.11' dependencies = [ - 'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian', + 'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian_mp', # 'easyscience', 'scipp', 'refnx', @@ -69,10 +69,10 @@ dev = [ 'mkdocstrings-python', # MkDocs: Python docstring support 'pyyaml', # YAML parser 'spdx-headers', # SPDX license header validation + 'corner', # Bayesian analysis and plotting + 'arviz', # Bayesian analysis and plotting ] -bayesian = ["corner>=2.2", "arviz>=0.18"] - [project.urls] Documentation = 'https://easyscience.github.io/reflectometry-lib' 'Release Notes' = 'https://github.com/easyscience/reflectometry-lib/releases' diff --git a/src/easyreflectometry/calculators/factory.py b/src/easyreflectometry/calculators/factory.py index c3e1479c..74ce98de 100644 --- a/src/easyreflectometry/calculators/factory.py +++ b/src/easyreflectometry/calculators/factory.py @@ -14,6 +14,35 @@ def __init__(self): """Init function.""" super().__init__(interface_list=CalculatorBase._calculators) + def __reduce__(self): + """Serialize the active calculator state for worker processes.""" + wrapper = getattr(self(), '_wrapper', None) + if wrapper is None and self.current_interface_name is not None: + raise RuntimeError( + f'Cannot pickle CalculatorFactory: active interface ' + f"{self.current_interface_name!r} exposes no '_wrapper' attribute. " + 'The InterfaceFactoryTemplate API may have changed.' + ) + return ( + self._state_restore, + ( + self.__class__, + self.current_interface_name, + wrapper.__getstate__() if wrapper is not None else None, + ), + ) + + @staticmethod + def _state_restore(cls, interface_str, wrapper_state): + """Restore a calculator factory with its active wrapper state.""" + obj = cls() + if interface_str is not None and interface_str in obj.available_interfaces: + obj.switch(interface_str) + wrapper = getattr(obj(), '_wrapper', None) + if wrapper is not None and wrapper_state is not None: + wrapper.__setstate__(wrapper_state) + return obj + def reset_storage(self) -> None: """Reset storage.""" return self().reset_storage() diff --git a/src/easyreflectometry/calculators/wrapper_base.py b/src/easyreflectometry/calculators/wrapper_base.py index dc53ceca..8b383dca 100644 --- a/src/easyreflectometry/calculators/wrapper_base.py +++ b/src/easyreflectometry/calculators/wrapper_base.py @@ -293,6 +293,18 @@ def get_item_value(self, name: str, key: str) -> float: item = getattr(item, key) return getattr(item, 'value') + def __getstate__(self) -> dict: + return { + 'storage': self.storage, + 'resolution_function': self._resolution_function, + 'magnetism': self._magnetism, + } + + def __setstate__(self, state: dict) -> None: + self.storage = state['storage'] + self._resolution_function = state['resolution_function'] + self._magnetism = state['magnetism'] + def set_resolution_function(self, resolution_function: ResolutionFunction) -> None: """Set the resolution function for the calculator. diff --git a/src/easyreflectometry/fitting.py b/src/easyreflectometry/fitting.py index 3d50fe99..db0c81fe 100644 --- a/src/easyreflectometry/fitting.py +++ b/src/easyreflectometry/fitting.py @@ -364,6 +364,7 @@ def sample( seed: int | None = None, objective: str | None = None, initializer: str | None = None, + n_workers: int | None = None, progress_callback=None, abort_test=None, ) -> dict: @@ -383,12 +384,22 @@ def sample( :param initializer: DREAM population initializer. One of ``'eps'``, ``'cov'``, ``'lhs'``, or ``'random'``. By default, None (BUMPS uses ``'eps'``). + :param n_workers: Number of worker processes for parallel DREAM + population evaluation. ``None`` (default) and ``1`` use + sequential evaluation. Values greater than ``1`` enable + multiprocessing; the effective pool size is capped at + ``min(n_workers, population)``. :param progress_callback: Optional callback for progress updates during sampling. Forwarded to the core MultiFitter. + :param abort_test: Optional callback that returns ``True`` to signal + that sampling should be aborted. :return: Dictionary with keys ``'draws'``, ``'param_names'``, ``'state'``, and ``'logp'``. :raises RuntimeError: If the current minimizer is not a BUMPS instance. + :raises ValueError: If ``n_workers`` is not None and less than 1. """ + if n_workers is not None and n_workers < 1: + raise ValueError(f'n_workers must be a positive integer or None, got {n_workers}') obj = _validate_objective(objective) if objective is not None else self._objective refl_nums = [k[3:] for k in data['coords'].keys() if 'Qz' == k[:2]] @@ -417,20 +428,23 @@ def sample( sampler_kwargs = {} if initializer is not None: sampler_kwargs['init'] = initializer - return self.easy_science_multi_fitter.sample( - x=x, - y=y, - weights=dy, - samples=samples, - burn=burn, - thin=thin, - chains=chains, - population=population, - seed=seed, - sampler_kwargs=sampler_kwargs or None, - progress_callback=progress_callback, - abort_test=abort_test, - ) + core_sample_kwargs = { + 'x': x, + 'y': y, + 'weights': dy, + 'samples': samples, + 'burn': burn, + 'thin': thin, + 'chains': chains, + 'population': population, + 'seed': seed, + 'sampler_kwargs': sampler_kwargs or None, + 'progress_callback': progress_callback, + 'abort_test': abort_test, + } + if n_workers is not None: + core_sample_kwargs['n_workers'] = n_workers + return self.easy_science_multi_fitter.sample(**core_sample_kwargs) @property def chi2(self) -> float | None: diff --git a/tests/calculators/test_factory.py b/tests/calculators/test_factory.py new file mode 100644 index 00000000..3af5e062 --- /dev/null +++ b/tests/calculators/test_factory.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for CalculatorFactory serialization.""" + +import pickle # noqa: S403 + +import numpy as np +from numpy.testing import assert_allclose + +from easyreflectometry.calculators import CalculatorFactory +from easyreflectometry.model import Model +from easyreflectometry.model import PercentageFwhm +from easyreflectometry.sample import Layer +from easyreflectometry.sample import Material +from easyreflectometry.sample import Multilayer +from easyreflectometry.sample import Sample + + +def test_calculator_factory_pickle_preserves_active_wrapper_storage(): + """Pickled calculator factories retain model storage for worker processes.""" + si = Material(sld=2.07, isld=0.0, name='Si') + film = Material(sld=2.0, isld=0.0, name='Film') + d2o = Material(sld=6.36, isld=0.0, name='D2O') + + sample = Sample( + Multilayer(Layer(material=si, thickness=0.0, roughness=3.0, name='Si')), + Multilayer(Layer(material=film, thickness=250.0, roughness=3.0, name='Film')), + Multilayer(Layer(material=d2o, thickness=0.0, roughness=3.0, name='D2O')), + ) + model = Model( + sample=sample, + scale=1.0, + background=1e-6, + resolution_function=PercentageFwhm(0.02), + ) + interface = CalculatorFactory() + interface.switch('refnx') + model.interface = interface + + restored = pickle.loads(pickle.dumps(interface)) # noqa: S301 + + assert model.unique_name in restored()._wrapper.storage['model'] + q = np.linspace(0.01, 0.3, 10) + assert_allclose( + restored.fit_func(q, model.unique_name), + interface.fit_func(q, model.unique_name), + ) diff --git a/tests/test_fitting.py b/tests/test_fitting.py index a860fdb4..28ead1ab 100644 --- a/tests/test_fitting.py +++ b/tests/test_fitting.py @@ -1034,3 +1034,87 @@ def _fake_sample(*, x, y, weights, **kwargs): fitter.sample(data, samples=100, burn=20, thin=2, objective='hybrid') assert len(captured['x'][0]) == 10 # all points kept (Mighell-substituted) + + +_SENTINEL = object() + + +class TestSampleWorkers: + """n_workers parameter forwarding in sample().""" + + @pytest.fixture + def sample_fitter(self): + model = Model() + model.interface = CalculatorFactory() + fitter = MultiFitter(model) + data = sc.DataGroup({ + 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, + 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, + }) + return fitter, data + + def _mock_sample(self, fitter, captured): + """Wire a fake sample() that records n_workers into captured.""" + + def _fake(*, n_workers=_SENTINEL, **kwargs): + captured['n_workers'] = n_workers + return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} + + fitter.easy_science_multi_fitter = MagicMock() + fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake) + + @pytest.mark.parametrize( + 'kwargs,expected', + [ + ({}, _SENTINEL), + ({'n_workers': None}, _SENTINEL), + ({'n_workers': 1}, 1), + ({'n_workers': 2}, 2), + ({'n_workers': 8}, 8), + ], + ) + def test_n_workers_forwarded(self, sample_fitter, kwargs, expected): + """n_workers is forwarded to core only when explicitly set to ≥1.""" + fitter, data = sample_fitter + captured = {} + self._mock_sample(fitter, captured) + fitter.sample(data, samples=100, burn=20, thin=2, **kwargs) + assert captured['n_workers'] == expected + + def test_with_other_params_combined(self, sample_fitter): + """n_workers can be combined with all other sample() parameters.""" + fitter, data = sample_fitter + captured = {} + + def _fake(*, samples, burn, thin, population, seed, n_workers, sampler_kwargs, **kwargs): + captured.update( + samples=samples, + burn=burn, + thin=thin, + population=population, + seed=seed, + n_workers=n_workers, + sampler_kwargs=sampler_kwargs, + ) + return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} + + fitter.easy_science_multi_fitter = MagicMock() + fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake) + + fitter.sample(data, samples=500, burn=100, thin=5, population=8, seed=42, initializer='cov', n_workers=4) + assert captured == { + 'samples': 500, + 'burn': 100, + 'thin': 5, + 'population': 8, + 'seed': 42, + 'n_workers': 4, + 'sampler_kwargs': {'init': 'cov'}, + } + + @pytest.mark.parametrize('bad', [0, -1, -100]) + def test_invalid_n_workers_raises(self, sample_fitter, bad): + """n_workers < 1 raises ValueError before reaching the core.""" + fitter, data = sample_fitter + with pytest.raises(ValueError, match='n_workers'): + fitter.sample(data, samples=100, burn=20, thin=2, n_workers=bad)