diff --git a/medcat-test-models/.gitignore b/medcat-test-models/.gitignore new file mode 100644 index 000000000..6f7d1ffa4 --- /dev/null +++ b/medcat-test-models/.gitignore @@ -0,0 +1,2 @@ +mct_v1_model_pack/ +mct2_model_pack/ \ No newline at end of file diff --git a/medcat-trainer/webapp/api/api/tests/test_model_pack_addons.py b/medcat-trainer/webapp/api/api/tests/test_model_pack_addons.py new file mode 100644 index 000000000..9b2797e63 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_model_pack_addons.py @@ -0,0 +1,144 @@ +"""Tests for registering model packs that include addons. + +The Trainer loads all addons via ``CAT.load_addons`` and registers only +MetaCAT ones (as ``MetaCATModel`` rows). Other addon types (e.g. RelCAT) +must load without error but are skipped during registration. These tests +mock ``CAT.load_addons`` so they exercise that behaviour without +downloading or loading real models. + +Scenarios covered: + +- MetaCAT only — one ``MetaCATModel`` row is created. +- RelCAT only — registration succeeds; no ``MetaCATModel`` rows. +- MetaCAT and RelCAT — both addons load; only MetaCAT is registered. +- No addons — CDB and vocab load; no ``MetaCATModel`` rows. +- Multiple MetaCAT addons — each MetaCAT is registered separately. +""" + +import os +import tempfile +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +from django.core.files.base import ContentFile +from django.test import TestCase, override_settings + +from medcat.components.addons.meta_cat.meta_cat import MetaCATAddon +from medcat.components.addons.relation_extraction.rel_cat import RelCATAddon + +from ..models import ModelPack + + +def _make_meta_cat_addon(category_name="Status", model_name="bert"): + addon = MagicMock(spec=MetaCATAddon) + meta_cat = MagicMock() + meta_cat.config.general.category_name = category_name + meta_cat.config.model.model_name = model_name + meta_cat.config.general.category_value2id = {"True": 0, "False": 1} + addon.mc = meta_cat + return addon + + +def _make_rel_cat_addon(): + return MagicMock(spec=RelCATAddon) + + +@override_settings(MEDIA_ROOT=tempfile.mkdtemp()) +class ModelPackAddonRegistrationTests(TestCase): + def _prepare_model_pack(self, name="addon-pack-test"): + """Create a ModelPack with a fake unpacked dir (cdb dir + vocab file).""" + model_pack = ModelPack(name=name) + model_pack.model_pack.save(f"{name}.zip", ContentFile(b"fake"), save=False) + unpacked = model_pack.model_pack.path[: -len(".zip")] + os.makedirs(os.path.join(unpacked, "cdb"), exist_ok=True) + with open(os.path.join(unpacked, "vocab"), "w", encoding="utf-8") as fh: + fh.write("") + return model_pack, unpacked + + @contextmanager + def _register_model_pack(self, model_pack, addons): + with patch("api.models.CAT.attempt_unpack"), \ + patch("api.models.CDB.load"), \ + patch("api.models.Vocab.load"), \ + patch("api.models.CAT.load_addons", return_value=addons) as load_addons: + model_pack.save() + yield load_addons + + def test_register_model_pack_with_meta_cat_only(self): + model_pack, unpacked = self._prepare_model_pack(name="meta-cat-pack") + comps = os.path.join(unpacked, "saved_components") + meta_cat_path = os.path.join(comps, "addon_meta_cat.Status") + addons = [(meta_cat_path, _make_meta_cat_addon())] + + with self._register_model_pack(model_pack, addons) as load_addons: + load_addons.assert_called_once_with(unpacked) + + self.assertIsNotNone(model_pack.concept_db) + self.assertIsNotNone(model_pack.vocab) + self.assertEqual(model_pack.meta_cats.count(), 1) + meta_cat_model = model_pack.meta_cats.get() + self.assertEqual(meta_cat_model.name, "Status - bert") + self.assertTrue(meta_cat_model.meta_cat_dir.endswith("addon_meta_cat.Status")) + + def test_register_model_pack_with_rel_cat_only(self): + model_pack, unpacked = self._prepare_model_pack(name="rel-cat-pack") + comps = os.path.join(unpacked, "saved_components") + rel_cat_path = os.path.join(comps, "addon_rel_cat.rel_cat") + addons = [(rel_cat_path, _make_rel_cat_addon())] + + with self._register_model_pack(model_pack, addons) as load_addons: + load_addons.assert_called_once_with(unpacked) + + self.assertIsNotNone(model_pack.concept_db) + self.assertIsNotNone(model_pack.vocab) + self.assertEqual(model_pack.meta_cats.count(), 0) + + def test_register_model_pack_registers_multiple_meta_cats(self): + model_pack, unpacked = self._prepare_model_pack(name="multi-meta-cat-pack") + comps = os.path.join(unpacked, "saved_components") + addons = [ + (os.path.join(comps, "addon_meta_cat.Status"), + _make_meta_cat_addon(category_name="Status", model_name="bert")), + (os.path.join(comps, "addon_meta_cat.Experiencer"), + _make_meta_cat_addon(category_name="Experiencer", model_name="roberta")), + ] + + with self._register_model_pack(model_pack, addons): + pass + + self.assertEqual(model_pack.meta_cats.count(), 2) + self.assertEqual( + set(model_pack.meta_cats.values_list("name", flat=True)), + {"Status - bert", "Experiencer - roberta"}, + ) + + def test_register_model_pack_with_meta_cat_and_rel_cat(self): + model_pack, unpacked = self._prepare_model_pack(name="mixed-addon-pack") + comps = os.path.join(unpacked, "saved_components") + meta_cat_path = os.path.join(comps, "addon_meta_cat.Status") + rel_cat_path = os.path.join(comps, "addon_rel_cat.rel_cat") + addons = [ + (meta_cat_path, _make_meta_cat_addon()), + (rel_cat_path, _make_rel_cat_addon()), + ] + + with self._register_model_pack(model_pack, addons) as load_addons: + load_addons.assert_called_once_with(unpacked) + + self.assertIsNotNone(model_pack.concept_db) + self.assertIsNotNone(model_pack.vocab) + # All addons load; only MetaCAT rows are registered. + self.assertEqual(model_pack.meta_cats.count(), 1) + meta_cat_model = model_pack.meta_cats.get() + self.assertEqual(meta_cat_model.name, "Status - bert") + self.assertTrue(meta_cat_model.meta_cat_dir.endswith("addon_meta_cat.Status")) + + def test_register_model_pack_without_addons(self): + model_pack, unpacked = self._prepare_model_pack(name="no-addon-pack") + + with self._register_model_pack(model_pack, []) as load_addons: + load_addons.assert_called_once_with(unpacked) + + self.assertIsNotNone(model_pack.concept_db) + self.assertIsNotNone(model_pack.vocab) + self.assertEqual(model_pack.meta_cats.count(), 0) diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index 05fb4a185..475eb8291 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -223,7 +223,7 @@ def deserialise_from(cls, folder_path: str, **init_kwargs # load legacy config (assuming it exists) config_path += ".dat" logger.info( - "Was not provide a config when loading a meta cat from '%s'. " + "Was not provided a config when loading a meta cat from '%s'. " "Inferring config from file at '%s'", folder_path, config_path) cnf = ConfigMetaCAT.load(config_path) diff --git a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py index 9c3d178c9..a960f3aee 100644 --- a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py +++ b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py @@ -33,12 +33,16 @@ from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer from medcat.tokenizing.tokens import MutableDocument from medcat.utils.defaults import COMPONENTS_FOLDER +from medcat.utils.defaults import ( + avoid_legacy_conversion, doing_legacy_conversion_message, + LegacyConversionDisabledError) logger = logging.getLogger(__name__) class RelCATAddon(AddonComponent): + DEFAULT_TOKENIZER = 'spacy' addon_type = 'rel_cat' output_key = 'relations' config: ConfigRelCAT @@ -72,7 +76,7 @@ def create_new_component( @classmethod def load_existing(cls, cnf: ConfigRelCAT, base_tokenizer: BaseTokenizer, - cdb: CDB, + cdb: Optional[CDB], load_path: str) -> 'RelCATAddon': """Factory method to load an existing RelCAT addon from disk.""" rc = RelCAT.load(load_path) @@ -91,16 +95,78 @@ def name(self) -> str: # for ManualSerialisable: + @classmethod + def _create_throwaway_tokenizer(cls) -> BaseTokenizer: + """ + Mirrors `MetaCATAddon._create_throwaway_tokenizer` + """ + logger.warning( + "A base tokenizer was not provided during the loading of a " + "RelCAT. The tokenizer is used to register the required data " + "paths for RelCAT to function. Using the default of '%s'. If " + "this it not the tokenizer you will end up using, RelCAT may " + "be unable to recover unless a) the paths are registered " + "explicitly, or b) there are other RelCATs created with the " + "correct tokenizer. Do note that this will also create " + "another instance of the tokenizer, though it should be " + "garbage collected soon.", cls.DEFAULT_TOKENIZER + ) + # NOTE: the use of a (mostly) default config here probably won't + # affect anything since the tokenizer itself won't be used + gcnf = Config() + gcnf.general.nlp.provider = 'spacy' + return create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf) + @classmethod def deserialise_from(cls, folder_path: str, **init_kwargs ) -> 'RelCATAddon': - # NOTE: model load path sent by kwargs + """Deserialise a RelCAT addon from disk. + + Mirrors `MetaCATAddon.deserialise_from`: when called via the + pipeline, `tokenizer`/`cnf` are supplied; when called standalone + (e.g. `CAT.load_addons`), they are inferred from disk so that + deserialisation works without full pipeline context. + """ + if "config.json" in os.listdir(folder_path): + if not avoid_legacy_conversion(): + doing_legacy_conversion_message( + logger, cls.__name__, folder_path) + from medcat.utils.legacy.convert_rel_cat import ( + get_rel_cat_from_old) + if 'cdb' in init_kwargs: + cdb = init_kwargs['cdb'] + else: + cdb_path = os.path.join(folder_path, "cdb.dat") + if os.path.exists(cdb_path): + cdb = cast(CDB, deserialise(cdb_path)) + else: + cdb = CDB(config=Config()) + if 'tokenizer' in init_kwargs: + tokenizer = init_kwargs['tokenizer'] + else: + tokenizer = cls._create_throwaway_tokenizer() + return get_rel_cat_from_old(cdb, folder_path, tokenizer) + raise LegacyConversionDisabledError(cls.__name__,) + if 'cnf' in init_kwargs: + cnf = init_kwargs['cnf'] + else: + config_path = os.path.join(folder_path, "config") + logger.info( + "Was not provided a config when loading a rel cat from '%s'. " + "Inferring config from file at '%s'", folder_path, + config_path) + cnf = ConfigRelCAT.load(load_path=folder_path) + if 'model_config' in init_kwargs: + cnf.merge_config(init_kwargs['model_config']) + if 'tokenizer' in init_kwargs: + tokenizer = init_kwargs['tokenizer'] + else: + tokenizer = cls._create_throwaway_tokenizer() return cls.load_existing( load_path=folder_path, - base_tokenizer=init_kwargs['tokenizer'], - cnf=init_kwargs['cnf'], - cdb=init_kwargs['cdb'], - ) + cnf=cnf, + base_tokenizer=tokenizer, + cdb=None) def get_strategy(self) -> SerialisingStrategy: return SerialisingStrategy.MANUAL diff --git a/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py b/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py index 89cb68c01..a89be6e14 100644 --- a/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py +++ b/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py @@ -84,3 +84,16 @@ def test_can_load_model_pack(self): cat = CAT.load_model_pack(self.model_pack_path) self.assertIsInstance(cat, CAT) self.assert_has_rel_cat(cat) + + def test_can_load_rel_cat_via_load_addons(self): + addons = CAT.load_addons(self.model_pack_path) + self.assertEqual(len(addons), 1) + _, addon = addons[0] + self.assertIsInstance(addon, rel_cat.RelCATAddon) + + def test_can_load_rel_cat_with_addon_cnf(self): + addon = CAT.load_addons( + self.model_pack_path, + addon_config_dict={"rel_cat.rel_cat": {"general": {"device": "cpu"}}}, + )[0][1] + self.assertEqual(addon.config.general.device, "cpu")