Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
]
requires-python = '>=3.11'
dependencies = [
'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian',
'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian_mp',
# 'easyscience',
'scipp',
'refnx',
Expand Down Expand Up @@ -69,10 +69,10 @@ dev = [
'mkdocstrings-python', # MkDocs: Python docstring support
'pyyaml', # YAML parser
'spdx-headers', # SPDX license header validation
'corner', # Bayesian analysis and plotting
'arviz', # Bayesian analysis and plotting
]

bayesian = ["corner>=2.2", "arviz>=0.18"]

[project.urls]
Documentation = 'https://easyscience.github.io/reflectometry-lib'
'Release Notes' = 'https://github.com/easyscience/reflectometry-lib/releases'
Expand Down
29 changes: 29 additions & 0 deletions src/easyreflectometry/calculators/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,35 @@
"""Init function."""
super().__init__(interface_list=CalculatorBase._calculators)

def __reduce__(self):

Check warning on line 17 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L17

Added line #L17 was not covered by tests
"""Serialize the active calculator state for worker processes."""
wrapper = getattr(self(), '_wrapper', None)
if wrapper is None and self.current_interface_name is not None:
raise RuntimeError(

Check warning on line 21 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L19-L21

Added lines #L19 - L21 were not covered by tests
f'Cannot pickle CalculatorFactory: active interface '
f"{self.current_interface_name!r} exposes no '_wrapper' attribute. "
'The InterfaceFactoryTemplate API may have changed.'
)
return (

Check warning on line 26 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L26

Added line #L26 was not covered by tests
self._state_restore,
(
self.__class__,
self.current_interface_name,
wrapper.__getstate__() if wrapper is not None else None,
),
)

@staticmethod
def _state_restore(cls, interface_str, wrapper_state):

Check warning on line 36 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L35-L36

Added lines #L35 - L36 were not covered by tests
"""Restore a calculator factory with its active wrapper state."""
obj = cls()
if interface_str is not None and interface_str in obj.available_interfaces:
obj.switch(interface_str)
wrapper = getattr(obj(), '_wrapper', None)
if wrapper is not None and wrapper_state is not None:
wrapper.__setstate__(wrapper_state)
return obj

Check warning on line 44 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L38-L44

Added lines #L38 - L44 were not covered by tests

def reset_storage(self) -> None:
"""Reset storage."""
return self().reset_storage()
Expand Down
12 changes: 12 additions & 0 deletions src/easyreflectometry/calculators/wrapper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,18 @@
item = getattr(item, key)
return getattr(item, 'value')

def __getstate__(self) -> dict:
return {

Check warning on line 297 in src/easyreflectometry/calculators/wrapper_base.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/wrapper_base.py#L296-L297

Added lines #L296 - L297 were not covered by tests
'storage': self.storage,
'resolution_function': self._resolution_function,
'magnetism': self._magnetism,
}

def __setstate__(self, state: dict) -> None:
self.storage = state['storage']
self._resolution_function = state['resolution_function']
self._magnetism = state['magnetism']

Check warning on line 306 in src/easyreflectometry/calculators/wrapper_base.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/wrapper_base.py#L303-L306

Added lines #L303 - L306 were not covered by tests

def set_resolution_function(self, resolution_function: ResolutionFunction) -> None:
"""Set the resolution function for the calculator.

Expand Down
42 changes: 28 additions & 14 deletions src/easyreflectometry/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@
seed: int | None = None,
objective: str | None = None,
initializer: str | None = None,
n_workers: int | None = None,
progress_callback=None,
abort_test=None,
) -> dict:
Expand All @@ -383,12 +384,22 @@
:param initializer: DREAM population initializer. One of ``'eps'``,
``'cov'``, ``'lhs'``, or ``'random'``. By default, None (BUMPS
uses ``'eps'``).
:param n_workers: Number of worker processes for parallel DREAM
population evaluation. ``None`` (default) and ``1`` use
sequential evaluation. Values greater than ``1`` enable
multiprocessing; the effective pool size is capped at
``min(n_workers, population)``.
:param progress_callback: Optional callback for progress updates during
sampling. Forwarded to the core MultiFitter.
:param abort_test: Optional callback that returns ``True`` to signal
that sampling should be aborted.
:return: Dictionary with keys ``'draws'``, ``'param_names'``, ``'state'``,
and ``'logp'``.
:raises RuntimeError: If the current minimizer is not a BUMPS instance.
:raises ValueError: If ``n_workers`` is not None and less than 1.
"""
if n_workers is not None and n_workers < 1:
raise ValueError(f'n_workers must be a positive integer or None, got {n_workers}')

Check warning on line 402 in src/easyreflectometry/fitting.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/fitting.py#L401-L402

Added lines #L401 - L402 were not covered by tests
obj = _validate_objective(objective) if objective is not None else self._objective

refl_nums = [k[3:] for k in data['coords'].keys() if 'Qz' == k[:2]]
Expand Down Expand Up @@ -417,20 +428,23 @@
sampler_kwargs = {}
if initializer is not None:
sampler_kwargs['init'] = initializer
return self.easy_science_multi_fitter.sample(
x=x,
y=y,
weights=dy,
samples=samples,
burn=burn,
thin=thin,
chains=chains,
population=population,
seed=seed,
sampler_kwargs=sampler_kwargs or None,
progress_callback=progress_callback,
abort_test=abort_test,
)
core_sample_kwargs = {

Check warning on line 431 in src/easyreflectometry/fitting.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/fitting.py#L431

Added line #L431 was not covered by tests
'x': x,
'y': y,
'weights': dy,
'samples': samples,
'burn': burn,
'thin': thin,
'chains': chains,
'population': population,
'seed': seed,
'sampler_kwargs': sampler_kwargs or None,
'progress_callback': progress_callback,
'abort_test': abort_test,
}
if n_workers is not None:
core_sample_kwargs['n_workers'] = n_workers
return self.easy_science_multi_fitter.sample(**core_sample_kwargs)

Check warning on line 447 in src/easyreflectometry/fitting.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/fitting.py#L445-L447

Added lines #L445 - L447 were not covered by tests

@property
def chi2(self) -> float | None:
Expand Down
48 changes: 48 additions & 0 deletions tests/calculators/test_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors <https://github.com/easyscience>
# SPDX-License-Identifier: BSD-3-Clause

"""Tests for CalculatorFactory serialization."""

import pickle # noqa: S403

import numpy as np
from numpy.testing import assert_allclose

from easyreflectometry.calculators import CalculatorFactory
from easyreflectometry.model import Model
from easyreflectometry.model import PercentageFwhm
from easyreflectometry.sample import Layer
from easyreflectometry.sample import Material
from easyreflectometry.sample import Multilayer
from easyreflectometry.sample import Sample


def test_calculator_factory_pickle_preserves_active_wrapper_storage():
"""Pickled calculator factories retain model storage for worker processes."""
si = Material(sld=2.07, isld=0.0, name='Si')
film = Material(sld=2.0, isld=0.0, name='Film')
d2o = Material(sld=6.36, isld=0.0, name='D2O')

sample = Sample(
Multilayer(Layer(material=si, thickness=0.0, roughness=3.0, name='Si')),
Multilayer(Layer(material=film, thickness=250.0, roughness=3.0, name='Film')),
Multilayer(Layer(material=d2o, thickness=0.0, roughness=3.0, name='D2O')),
)
model = Model(
sample=sample,
scale=1.0,
background=1e-6,
resolution_function=PercentageFwhm(0.02),
)
interface = CalculatorFactory()
interface.switch('refnx')
model.interface = interface

restored = pickle.loads(pickle.dumps(interface)) # noqa: S301

assert model.unique_name in restored()._wrapper.storage['model']
q = np.linspace(0.01, 0.3, 10)
assert_allclose(
restored.fit_func(q, model.unique_name),
interface.fit_func(q, model.unique_name),
)
84 changes: 84 additions & 0 deletions tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,3 +1034,87 @@ def _fake_sample(*, x, y, weights, **kwargs):
fitter.sample(data, samples=100, burn=20, thin=2, objective='hybrid')

assert len(captured['x'][0]) == 10 # all points kept (Mighell-substituted)


_SENTINEL = object()


class TestSampleWorkers:
"""n_workers parameter forwarding in sample()."""

@pytest.fixture
def sample_fitter(self):
model = Model()
model.interface = CalculatorFactory()
fitter = MultiFitter(model)
data = sc.DataGroup({
'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))},
'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)},
})
return fitter, data

def _mock_sample(self, fitter, captured):
"""Wire a fake sample() that records n_workers into captured."""

def _fake(*, n_workers=_SENTINEL, **kwargs):
captured['n_workers'] = n_workers
return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None}

fitter.easy_science_multi_fitter = MagicMock()
fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake)

@pytest.mark.parametrize(
'kwargs,expected',
[
({}, _SENTINEL),
({'n_workers': None}, _SENTINEL),
({'n_workers': 1}, 1),
({'n_workers': 2}, 2),
({'n_workers': 8}, 8),
],
)
def test_n_workers_forwarded(self, sample_fitter, kwargs, expected):
"""n_workers is forwarded to core only when explicitly set to ≥1."""
fitter, data = sample_fitter
captured = {}
self._mock_sample(fitter, captured)
fitter.sample(data, samples=100, burn=20, thin=2, **kwargs)
assert captured['n_workers'] == expected

def test_with_other_params_combined(self, sample_fitter):
"""n_workers can be combined with all other sample() parameters."""
fitter, data = sample_fitter
captured = {}

def _fake(*, samples, burn, thin, population, seed, n_workers, sampler_kwargs, **kwargs):
captured.update(
samples=samples,
burn=burn,
thin=thin,
population=population,
seed=seed,
n_workers=n_workers,
sampler_kwargs=sampler_kwargs,
)
return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None}

fitter.easy_science_multi_fitter = MagicMock()
fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake)

fitter.sample(data, samples=500, burn=100, thin=5, population=8, seed=42, initializer='cov', n_workers=4)
assert captured == {
'samples': 500,
'burn': 100,
'thin': 5,
'population': 8,
'seed': 42,
'n_workers': 4,
'sampler_kwargs': {'init': 'cov'},
}

@pytest.mark.parametrize('bad', [0, -1, -100])
def test_invalid_n_workers_raises(self, sample_fitter, bad):
"""n_workers < 1 raises ValueError before reaching the core."""
fitter, data = sample_fitter
with pytest.raises(ValueError, match='n_workers'):
fitter.sample(data, samples=100, burn=20, thin=2, n_workers=bad)
Loading