Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions medcat-test-models/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mct_v1_model_pack/
mct2_model_pack/
144 changes: 144 additions & 0 deletions medcat-trainer/webapp/api/api/tests/test_model_pack_addons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Tests for registering model packs that include addons.

The Trainer loads all addons via ``CAT.load_addons`` and registers only
MetaCAT ones (as ``MetaCATModel`` rows). Other addon types (e.g. RelCAT)
must load without error but are skipped during registration. These tests
mock ``CAT.load_addons`` so they exercise that behaviour without
downloading or loading real models.

Scenarios covered:

- MetaCAT only — one ``MetaCATModel`` row is created.
- RelCAT only — registration succeeds; no ``MetaCATModel`` rows.
- MetaCAT and RelCAT — both addons load; only MetaCAT is registered.
- No addons — CDB and vocab load; no ``MetaCATModel`` rows.
- Multiple MetaCAT addons — each MetaCAT is registered separately.
"""

import os
import tempfile
from contextlib import contextmanager
from unittest.mock import MagicMock, patch

from django.core.files.base import ContentFile
from django.test import TestCase, override_settings

from medcat.components.addons.meta_cat.meta_cat import MetaCATAddon
from medcat.components.addons.relation_extraction.rel_cat import RelCATAddon

from ..models import ModelPack


def _make_meta_cat_addon(category_name="Status", model_name="bert"):
addon = MagicMock(spec=MetaCATAddon)
meta_cat = MagicMock()
meta_cat.config.general.category_name = category_name
meta_cat.config.model.model_name = model_name
meta_cat.config.general.category_value2id = {"True": 0, "False": 1}
addon.mc = meta_cat
return addon


def _make_rel_cat_addon():
return MagicMock(spec=RelCATAddon)


@override_settings(MEDIA_ROOT=tempfile.mkdtemp())
class ModelPackAddonRegistrationTests(TestCase):
def _prepare_model_pack(self, name="addon-pack-test"):
"""Create a ModelPack with a fake unpacked dir (cdb dir + vocab file)."""
model_pack = ModelPack(name=name)
model_pack.model_pack.save(f"{name}.zip", ContentFile(b"fake"), save=False)
unpacked = model_pack.model_pack.path[: -len(".zip")]
os.makedirs(os.path.join(unpacked, "cdb"), exist_ok=True)
with open(os.path.join(unpacked, "vocab"), "w", encoding="utf-8") as fh:
fh.write("")
return model_pack, unpacked

@contextmanager
def _register_model_pack(self, model_pack, addons):
with patch("api.models.CAT.attempt_unpack"), \
patch("api.models.CDB.load"), \
patch("api.models.Vocab.load"), \
patch("api.models.CAT.load_addons", return_value=addons) as load_addons:
model_pack.save()
yield load_addons

def test_register_model_pack_with_meta_cat_only(self):
model_pack, unpacked = self._prepare_model_pack(name="meta-cat-pack")
comps = os.path.join(unpacked, "saved_components")
meta_cat_path = os.path.join(comps, "addon_meta_cat.Status")
addons = [(meta_cat_path, _make_meta_cat_addon())]

with self._register_model_pack(model_pack, addons) as load_addons:
load_addons.assert_called_once_with(unpacked)

self.assertIsNotNone(model_pack.concept_db)
self.assertIsNotNone(model_pack.vocab)
self.assertEqual(model_pack.meta_cats.count(), 1)
meta_cat_model = model_pack.meta_cats.get()
self.assertEqual(meta_cat_model.name, "Status - bert")
self.assertTrue(meta_cat_model.meta_cat_dir.endswith("addon_meta_cat.Status"))

def test_register_model_pack_with_rel_cat_only(self):
model_pack, unpacked = self._prepare_model_pack(name="rel-cat-pack")
comps = os.path.join(unpacked, "saved_components")
rel_cat_path = os.path.join(comps, "addon_rel_cat.rel_cat")
addons = [(rel_cat_path, _make_rel_cat_addon())]

with self._register_model_pack(model_pack, addons) as load_addons:
load_addons.assert_called_once_with(unpacked)

self.assertIsNotNone(model_pack.concept_db)
self.assertIsNotNone(model_pack.vocab)
self.assertEqual(model_pack.meta_cats.count(), 0)

def test_register_model_pack_registers_multiple_meta_cats(self):
model_pack, unpacked = self._prepare_model_pack(name="multi-meta-cat-pack")
comps = os.path.join(unpacked, "saved_components")
addons = [
(os.path.join(comps, "addon_meta_cat.Status"),
_make_meta_cat_addon(category_name="Status", model_name="bert")),
(os.path.join(comps, "addon_meta_cat.Experiencer"),
_make_meta_cat_addon(category_name="Experiencer", model_name="roberta")),
]

with self._register_model_pack(model_pack, addons):
pass

self.assertEqual(model_pack.meta_cats.count(), 2)
self.assertEqual(
set(model_pack.meta_cats.values_list("name", flat=True)),
{"Status - bert", "Experiencer - roberta"},
)

def test_register_model_pack_with_meta_cat_and_rel_cat(self):
model_pack, unpacked = self._prepare_model_pack(name="mixed-addon-pack")
comps = os.path.join(unpacked, "saved_components")
meta_cat_path = os.path.join(comps, "addon_meta_cat.Status")
rel_cat_path = os.path.join(comps, "addon_rel_cat.rel_cat")
addons = [
(meta_cat_path, _make_meta_cat_addon()),
(rel_cat_path, _make_rel_cat_addon()),
]

with self._register_model_pack(model_pack, addons) as load_addons:
load_addons.assert_called_once_with(unpacked)

self.assertIsNotNone(model_pack.concept_db)
self.assertIsNotNone(model_pack.vocab)
# All addons load; only MetaCAT rows are registered.
self.assertEqual(model_pack.meta_cats.count(), 1)
meta_cat_model = model_pack.meta_cats.get()
self.assertEqual(meta_cat_model.name, "Status - bert")
self.assertTrue(meta_cat_model.meta_cat_dir.endswith("addon_meta_cat.Status"))

def test_register_model_pack_without_addons(self):
model_pack, unpacked = self._prepare_model_pack(name="no-addon-pack")

with self._register_model_pack(model_pack, []) as load_addons:
load_addons.assert_called_once_with(unpacked)

self.assertIsNotNone(model_pack.concept_db)
self.assertIsNotNone(model_pack.vocab)
self.assertEqual(model_pack.meta_cats.count(), 0)
53 changes: 48 additions & 5 deletions medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@


class RelCATAddon(AddonComponent):
DEFAULT_TOKENIZER = 'spacy'
addon_type = 'rel_cat'
output_key = 'relations'
config: ConfigRelCAT
Expand Down Expand Up @@ -91,16 +92,58 @@ def name(self) -> str:

# for ManualSerialisable:

@classmethod
def _create_throwaway_tokenizer(cls) -> BaseTokenizer:
"""
Mirrors `MetaCATAddon._create_throwaway_tokenizer`
"""
logger.warning(
"A base tokenizer was not provided during the loading of a "
"RelCAT. The tokenizer is used to register the required data "
"paths for RelCAT to function. Using the default of '%s'.",
cls.DEFAULT_TOKENIZER,
)
gcnf = Config()
gcnf.general.nlp.provider = 'spacy'
return create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf)

@classmethod
def deserialise_from(cls, folder_path: str, **init_kwargs
) -> 'RelCATAddon':
# NOTE: model load path sent by kwargs
"""Deserialise a RelCAT addon from disk.

Mirrors `MetaCATAddon.deserialise_from`: when called via the
pipeline, `tokenizer`/`cnf`/`cdb` are supplied; when called standalone
(e.g. `CAT.load_addons`), they are inferred from disk so that
deserialisation works without full pipeline context.
"""
if 'cnf' in init_kwargs:
cnf = init_kwargs['cnf']
else:
logger.info(
"Was not provided a config when loading a rel cat from '%s'. "
"Inferring config from file at '%s'", folder_path,
folder_path)
cnf = ConfigRelCAT.load(load_path=folder_path)
if 'model_config' in init_kwargs:
cnf.merge_config(init_kwargs['model_config'])
if 'tokenizer' in init_kwargs:
tokenizer = init_kwargs['tokenizer']
else:
tokenizer = cls._create_throwaway_tokenizer()
if 'cdb' in init_kwargs:
cdb = init_kwargs['cdb']
else:
cdb_path = os.path.join(folder_path, "cdb.dat")
if os.path.exists(cdb_path):
cdb = cast(CDB, deserialise(cdb_path))
else:
cdb = CDB(config=Config())
return cls.load_existing(
load_path=folder_path,
base_tokenizer=init_kwargs['tokenizer'],
cnf=init_kwargs['cnf'],
cdb=init_kwargs['cdb'],
)
cnf=cnf,
base_tokenizer=tokenizer,
cdb=cdb)

def get_strategy(self) -> SerialisingStrategy:
return SerialisingStrategy.MANUAL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,16 @@ def test_can_load_model_pack(self):
cat = CAT.load_model_pack(self.model_pack_path)
self.assertIsInstance(cat, CAT)
self.assert_has_rel_cat(cat)

def test_can_load_rel_cat_via_load_addons(self):
addons = CAT.load_addons(self.model_pack_path)
self.assertEqual(len(addons), 1)
_, addon = addons[0]
self.assertIsInstance(addon, rel_cat.RelCATAddon)

def test_can_load_rel_cat_with_addon_cnf(self):
addon = CAT.load_addons(
self.model_pack_path,
addon_config_dict={"rel_cat.rel_cat": {"general": {"device": "cpu"}}},
)[0][1]
self.assertEqual(addon.config.general.device, "cpu")
Loading