diff --git a/src/pyfmi/fmi1.pyx b/src/pyfmi/fmi1.pyx index 63744646..6e179ab9 100644 --- a/src/pyfmi/fmi1.pyx +++ b/src/pyfmi/fmi1.pyx @@ -418,6 +418,22 @@ cdef class FMUModelBase(FMI_BASE.ModelBase): self._log = [] + def __reduce__(self): + _log_file_name = self._get_log_file_name() + fmu_path = self._fmu_full_path + if isinstance(fmu_path, bytes): + fmu_path = pyfmi_util.decode(fmu_path) + return ( + self.__class__, + (fmu_path, _log_file_name, self._loaded_with_log_level, + None, bool(self._allocated_dll), bool(self._allow_unzipped_fmu)), + {'cache': self.cache} if self.cache else {} + ) + + def __setstate__(self, state): + if 'cache' in state: + self.cache = state['cache'] + def _setup_log_state(self, log_level): if log_level >= FMIL.jm_log_level_nothing and log_level <= FMIL.jm_log_level_all: self._enable_logging = log_level != FMIL.jm_log_level_nothing diff --git a/src/pyfmi/fmi2.pyx b/src/pyfmi/fmi2.pyx index 4014b734..71eec7de 100644 --- a/src/pyfmi/fmi2.pyx +++ b/src/pyfmi/fmi2.pyx @@ -681,6 +681,22 @@ cdef class FMUModelBase2(FMI_BASE.ModelBase): self._log = [] + def __reduce__(self): + _log_file_name = self._get_log_file_name() + fmu_path = self._fmu_full_path + if isinstance(fmu_path, bytes): + fmu_path = pyfmi_util.decode(fmu_path) + return ( + self.__class__, + (fmu_path, _log_file_name, self._loaded_with_log_level, + None, bool(self._allocated_dll), bool(self._allow_unzipped_fmu)), + {'cache': self.cache} if self.cache else {} + ) + + def __setstate__(self, state): + if 'cache' in state: + self.cache = state['cache'] + def _setup_log_state(self, log_level): if log_level >= FMIL.jm_log_level_nothing and log_level <= FMIL.jm_log_level_all: self._enable_logging = log_level != FMIL.jm_log_level_nothing diff --git a/src/pyfmi/fmi3.pyx b/src/pyfmi/fmi3.pyx index db9c71ce..f6415626 100644 --- a/src/pyfmi/fmi3.pyx +++ b/src/pyfmi/fmi3.pyx @@ -477,7 +477,23 @@ cdef class FMUModelBase3(FMI_BASE.ModelBase): self._event_info_nominals_of_continuous_states_changed = FMIL3.fmi3_false self._event_info_values_of_continuous_states_changed = FMIL3.fmi3_true self._event_info_next_event_time_defined = FMIL3.fmi3_false - self._event_info_next_event_time = 0.0 + self._event_info_next_event_time = 0.0 + + def __reduce__(self): + _log_file_name = self._get_log_file_name() + fmu_path = self._fmu_full_path + if isinstance(fmu_path, bytes): + fmu_path = pyfmi_util.decode(fmu_path) + return ( + self.__class__, + (fmu_path, _log_file_name, self._loaded_with_log_level, + None, bool(self._allocated_dll), bool(self._allow_unzipped_fmu)), + {'cache': self.cache} if self.cache else {} + ) + + def __setstate__(self, state): + if 'cache' in state: + self.cache = state['cache'] def _setup_log_state(self, log_level): if isinstance(log_level, int) and (log_level >= FMIL.jm_log_level_nothing and log_level <= FMIL.jm_log_level_all): diff --git a/src/pyfmi/fmi_base.pyx b/src/pyfmi/fmi_base.pyx index b254062d..27072990 100644 --- a/src/pyfmi/fmi_base.pyx +++ b/src/pyfmi/fmi_base.pyx @@ -84,6 +84,11 @@ cdef class ModelBase: self._max_log_size_msg_sent = False self._log_handler = LogHandlerDefault(self._max_log_size) + def _get_log_file_name(self): + if self._fmu_log_name != NULL: + return pyfmi_util.decode(self._fmu_log_name) + return None + def _set_log_stream(self, stream): """ Function that sets the class property 'log_stream' and does error handling. """ if not hasattr(stream, 'write'): diff --git a/tests/test_fmi2.py b/tests/test_fmi2.py index 1264bba2..4b25605d 100644 --- a/tests/test_fmi2.py +++ b/tests/test_fmi2.py @@ -105,6 +105,22 @@ def test_erroneous_ncp(self): with pytest.raises(FMUException): model.simulate(options=opts) + def test_pickle(self): + import pickle + + fmu = FMUModelCS2( + os.path.join(file_path, "files", "FMUs", "XML", "CS2.0", "CoupledClutches.fmu"), + _connect_dll=False + ) + log_name = fmu.get_log_filename() + fmu.cache['test_key'] = 'test_value' + + data = pickle.dumps(fmu) + fmu2 = pickle.loads(data) + + assert fmu2.get_log_filename() == log_name + assert fmu2.cache['test_key'] == 'test_value' + class Test_Downsample: """Tests for the 'result_downsampling_factor' option for CS FMUs.""" def _verify_downsample_result(self, ref_traj, test_traj, ncp, factor): @@ -1115,6 +1131,22 @@ def test_get_variable_description(self): model = FMUModelME2(FMU_PATHS.ME2.coupled_clutches, _connect_dll=False) assert model.get_variable_description("J1.phi") == "Absolute rotation angle of component" + def test_pickle(self): + import pickle + + fmu = FMUModelME2( + os.path.join(file_path, "files", "FMUs", "XML", "ME2.0", "CoupledClutches.fmu"), + _connect_dll=False + ) + log_name = fmu.get_log_filename() + fmu.cache['test_key'] = 'test_value' + + data = pickle.dumps(fmu) + fmu2 = pickle.loads(data) + + assert fmu2.get_log_filename() == log_name + assert fmu2.cache['test_key'] == 'test_value' + @uses_test_fmus @pytest.mark.parametrize("fmu_path", [