From 7057c1f193df6a6f7e2cbbec432046a872401996 Mon Sep 17 00:00:00 2001 From: alhendrickson Date: Mon, 15 Jun 2026 16:05:44 +0000 Subject: [PATCH 1/9] fix(medcat): Relcat addon fix deserialise_from to match metacat with defaults --- medcat-trainer/webapp/api/api/models.py | 34 ++++++++++++++++--- .../addons/relation_extraction/rel_cat.py | 28 +++++++++++---- .../test_rel_cat_in_model_pack.py | 13 +++++++ 3 files changed, 64 insertions(+), 11 deletions(-) diff --git a/medcat-trainer/webapp/api/api/models.py b/medcat-trainer/webapp/api/api/models.py index ec1c673da..27e49cfb2 100644 --- a/medcat-trainer/webapp/api/api/models.py +++ b/medcat-trainer/webapp/api/api/models.py @@ -34,6 +34,35 @@ logger = logging.getLogger(__name__) +def _load_meta_cat_addons(model_pack_path: str) -> list[tuple[str, MetaCATAddon]]: + """Load MetaCAT addons from a model pack, skipping other addon types. + + RelCAT addons require tokenizer/cdb/cnf init kwargs during deserialisation. + They are loaded at inference time via CAT.load_model_pack and do not need + trainer-side registration as MetaCATModel rows. + """ + from medcat.storage.serialisers import deserialise + from medcat.utils.defaults import COMPONENTS_FOLDER + + components_folder = os.path.join(model_pack_path, COMPONENTS_FOLDER) + if not os.path.exists(components_folder): + return [] + + meta_cat_folder_prefix = MetaCATAddon.get_folder_name_for_addon_and_name( + MetaCATAddon.addon_type, '') + addons: list[tuple[str, MetaCATAddon]] = [] + for folder_name in os.listdir(components_folder): + if not folder_name.startswith(meta_cat_folder_prefix): + continue + addon_path = os.path.join(components_folder, folder_name) + if not os.path.isdir(addon_path): + continue + addon = deserialise(addon_path) + if isinstance(addon, MetaCATAddon): + addons.append((addon.full_name, addon)) + return addons + + class ModelPack(models.Model): name = models.TextField(help_text='', unique=True) model_pack = models.FileField(help_text='Model pack zip') @@ -116,10 +145,7 @@ def save(self, *args, skip_load=False, **kwargs): try: metaCATmodels = [] # should raise an error if there already is a MetaCAT model with this definition - addons = CAT.load_addons(unpacked_model_pack_path) - meta_cat_addons = [ - (addon_path, addon) for addon_path, addon in addons - if isinstance(addon, MetaCATAddon)] + meta_cat_addons = _load_meta_cat_addons(unpacked_model_pack_path) for meta_cat_dir, meta_cat_addon in meta_cat_addons: meta_cat = meta_cat_addon.mc mc_model = MetaCATModel() diff --git a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py index 9c3d178c9..c57ac1e52 100644 --- a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py +++ b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py @@ -94,13 +94,27 @@ def name(self) -> str: @classmethod def deserialise_from(cls, folder_path: str, **init_kwargs ) -> 'RelCATAddon': - # NOTE: model load path sent by kwargs - return cls.load_existing( - load_path=folder_path, - base_tokenizer=init_kwargs['tokenizer'], - cnf=init_kwargs['cnf'], - cdb=init_kwargs['cdb'], - ) + """Deserialise a RelCAT addon from disk. + + Mirrors `MetaCATAddon.deserialise_from`: when called via the + pipeline, `tokenizer`/`cnf` are supplied; when called standalone + (e.g. `CAT.load_addons`), they are inferred from disk so that + deserialisation works without full pipeline context. + """ + rc = RelCAT.load(folder_path) + if 'cnf' in init_kwargs: + cnf = init_kwargs['cnf'] + else: + logger.info( + "Was not provided a config when loading a rel cat from '%s'. " + "Inferring config from the loaded model.", folder_path) + cnf = rc.component.relcat_config + if 'model_config' in init_kwargs: + cnf.merge_config(init_kwargs['model_config']) + if 'tokenizer' in init_kwargs: + rc.base_tokenizer = init_kwargs['tokenizer'] + rc._init_data_paths() + return cls(cnf, rc) def get_strategy(self) -> SerialisingStrategy: return SerialisingStrategy.MANUAL diff --git a/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py b/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py index 89cb68c01..a89be6e14 100644 --- a/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py +++ b/medcat-v2/tests/components/addons/relation_extraction/test_rel_cat_in_model_pack.py @@ -84,3 +84,16 @@ def test_can_load_model_pack(self): cat = CAT.load_model_pack(self.model_pack_path) self.assertIsInstance(cat, CAT) self.assert_has_rel_cat(cat) + + def test_can_load_rel_cat_via_load_addons(self): + addons = CAT.load_addons(self.model_pack_path) + self.assertEqual(len(addons), 1) + _, addon = addons[0] + self.assertIsInstance(addon, rel_cat.RelCATAddon) + + def test_can_load_rel_cat_with_addon_cnf(self): + addon = CAT.load_addons( + self.model_pack_path, + addon_config_dict={"rel_cat.rel_cat": {"general": {"device": "cpu"}}}, + )[0][1] + self.assertEqual(addon.config.general.device, "cpu") From 4c9c657144d1b3a6ad04655b997c9d56fb01a7a0 Mon Sep 17 00:00:00 2001 From: alhendrickson Date: Mon, 15 Jun 2026 16:21:02 +0000 Subject: [PATCH 2/9] fix(medcat): Relcat addon fix deserialise_from to match metacat with defaults - cleaup --- medcat-test-models/.gitignore | 2 + medcat-trainer/webapp/api/api/models.py | 34 +------- .../api/api/tests/test_model_pack_relcat.py | 78 +++++++++++++++++++ medcat-v2/uv.lock | 12 +-- 4 files changed, 88 insertions(+), 38 deletions(-) create mode 100644 medcat-test-models/.gitignore create mode 100644 medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py diff --git a/medcat-test-models/.gitignore b/medcat-test-models/.gitignore new file mode 100644 index 000000000..6f7d1ffa4 --- /dev/null +++ b/medcat-test-models/.gitignore @@ -0,0 +1,2 @@ +mct_v1_model_pack/ +mct2_model_pack/ \ No newline at end of file diff --git a/medcat-trainer/webapp/api/api/models.py b/medcat-trainer/webapp/api/api/models.py index 27e49cfb2..ec1c673da 100644 --- a/medcat-trainer/webapp/api/api/models.py +++ b/medcat-trainer/webapp/api/api/models.py @@ -34,35 +34,6 @@ logger = logging.getLogger(__name__) -def _load_meta_cat_addons(model_pack_path: str) -> list[tuple[str, MetaCATAddon]]: - """Load MetaCAT addons from a model pack, skipping other addon types. - - RelCAT addons require tokenizer/cdb/cnf init kwargs during deserialisation. - They are loaded at inference time via CAT.load_model_pack and do not need - trainer-side registration as MetaCATModel rows. - """ - from medcat.storage.serialisers import deserialise - from medcat.utils.defaults import COMPONENTS_FOLDER - - components_folder = os.path.join(model_pack_path, COMPONENTS_FOLDER) - if not os.path.exists(components_folder): - return [] - - meta_cat_folder_prefix = MetaCATAddon.get_folder_name_for_addon_and_name( - MetaCATAddon.addon_type, '') - addons: list[tuple[str, MetaCATAddon]] = [] - for folder_name in os.listdir(components_folder): - if not folder_name.startswith(meta_cat_folder_prefix): - continue - addon_path = os.path.join(components_folder, folder_name) - if not os.path.isdir(addon_path): - continue - addon = deserialise(addon_path) - if isinstance(addon, MetaCATAddon): - addons.append((addon.full_name, addon)) - return addons - - class ModelPack(models.Model): name = models.TextField(help_text='', unique=True) model_pack = models.FileField(help_text='Model pack zip') @@ -145,7 +116,10 @@ def save(self, *args, skip_load=False, **kwargs): try: metaCATmodels = [] # should raise an error if there already is a MetaCAT model with this definition - meta_cat_addons = _load_meta_cat_addons(unpacked_model_pack_path) + addons = CAT.load_addons(unpacked_model_pack_path) + meta_cat_addons = [ + (addon_path, addon) for addon_path, addon in addons + if isinstance(addon, MetaCATAddon)] for meta_cat_dir, meta_cat_addon in meta_cat_addons: meta_cat = meta_cat_addon.mc mc_model = MetaCATModel() diff --git a/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py b/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py new file mode 100644 index 000000000..f7f12e454 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py @@ -0,0 +1,78 @@ +"""Tests for registering model packs that include RelCAT addons.""" + +import os +import shutil +import tempfile +import zipfile +from urllib.request import urlretrieve + +from django.core.files.base import ContentFile +from django.test import TestCase, override_settings + +from medcat.storage.serialisers import MANUAL_SERIALISED_TAG, SER_TYPE_FILE + +from ..models import ModelPack + +MODEL_PACK_ZIP_URL = ( + "https://raw.githubusercontent.com/CogStack/cogstack-nlp/" + "051edf6cbd94fa83436fab807aff49d78dd68e59/" + "medcat-service/models/examples/example-medcat-v2-model-pack.zip" +) +REL_CAT_ADDON_CLS = ( + "medcat.components.addons.relation_extraction.rel_cat.RelCATAddon" +) + + +def _add_rel_cat_addon_stub(model_pack_dir: str, addon_name: str = "rel_cat") -> None: + """Add a minimal RelCAT addon folder that triggers manual deserialisation.""" + components_dir = os.path.join(model_pack_dir, "saved_components") + os.makedirs(components_dir, exist_ok=True) + addon_dir = os.path.join(components_dir, f"addon_rel_cat.{addon_name}") + os.makedirs(addon_dir, exist_ok=True) + with open(os.path.join(addon_dir, SER_TYPE_FILE), "w", encoding="utf-8") as f: + f.write(MANUAL_SERIALISED_TAG + REL_CAT_ADDON_CLS) + + +def _build_model_pack_zip_with_relcat(cache_dir: str) -> str: + zip_path = os.path.join(cache_dir, "cached_model_pack.zip") + if not os.path.exists(zip_path): + urlretrieve(MODEL_PACK_ZIP_URL, zip_path) + unpacked = os.path.join(cache_dir, "model_pack") + if os.path.exists(unpacked): + shutil.rmtree(unpacked) + shutil.unpack_archive(zip_path, unpacked) + _add_rel_cat_addon_stub(unpacked) + out_zip = os.path.join(cache_dir, "model_pack_with_relcat.zip") + if os.path.exists(out_zip): + os.remove(out_zip) + with zipfile.ZipFile(out_zip, "w", zipfile.ZIP_DEFLATED) as zf: + for root, _, files in os.walk(unpacked): + for file_name in files: + file_path = os.path.join(root, file_name) + arcname = os.path.relpath(file_path, unpacked) + zf.write(file_path, arcname) + return out_zip + + +@override_settings(MEDIA_ROOT=tempfile.mkdtemp()) +class ModelPackRelCATRegistrationTests(TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._cache_dir = tempfile.mkdtemp() + cls.model_pack_zip = _build_model_pack_zip_with_relcat(cls._cache_dir) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls._cache_dir, ignore_errors=True) + super().tearDownClass() + + def test_register_model_pack_with_relcat_addon_succeeds(self): + with open(self.model_pack_zip, "rb") as fh: + pack_bytes = fh.read() + model_pack = ModelPack(name="relcat-pack-test") + model_pack.model_pack = ContentFile(pack_bytes, name="relcat-pack-test.zip") + model_pack.save() + + self.assertIsNotNone(model_pack.concept_db) + self.assertIsNotNone(model_pack.vocab) diff --git a/medcat-v2/uv.lock b/medcat-v2/uv.lock index 47159573c..7de59fb11 100644 --- a/medcat-v2/uv.lock +++ b/medcat-v2/uv.lock @@ -849,6 +849,7 @@ wheels = [ name = "medcat" source = { editable = "." } dependencies = [ + { name = "click" }, { name = "dill" }, { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, @@ -858,6 +859,7 @@ dependencies = [ { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, + { name = "typer" }, { name = "typing-extensions" }, { name = "xxhash" }, ] @@ -899,13 +901,9 @@ spacy = [ { name = "spacy" }, ] -[package.dev-dependencies] -dev = [ - { name = "pooch" }, -] - [package.metadata] requires-dist = [ + { name = "click" }, { name = "datasets", marker = "extra == 'deid'", specifier = ">=2.2.2,<3.0.0" }, { name = "dill" }, { name = "mypy", marker = "extra == 'dev'" }, @@ -935,6 +933,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'deid'", specifier = ">=4.41.0,<5.0" }, { name = "transformers", marker = "extra == 'meta-cat'", specifier = ">=4.41.0,<5.0" }, { name = "transformers", marker = "extra == 'rel-cat'", specifier = ">=4.41.0,<5.0" }, + { name = "typer", specifier = "!=0.26.*" }, { name = "types-pyyaml", marker = "extra == 'dev'" }, { name = "types-setuptools", marker = "extra == 'dev'" }, { name = "types-tqdm", marker = "extra == 'dev'" }, @@ -943,9 +942,6 @@ requires-dist = [ ] provides-extras = ["dev", "spacy", "meta-cat", "dict-ner", "deid", "rel-cat", "test"] -[package.metadata.requires-dev] -dev = [{ name = "pooch", specifier = ">=1.9.0" }] - [[package]] name = "mpmath" version = "1.3.0" From c4abb832514fb5d9e0c1880f35c5f52b8de509fd Mon Sep 17 00:00:00 2001 From: alhendrickson Date: Mon, 15 Jun 2026 16:24:41 +0000 Subject: [PATCH 3/9] fix(medcat): Relcat addon fix deserialise_from to match metacat with defaults - cleaup --- .../api/api/tests/test_model_pack_relcat.py | 124 +++++++++--------- medcat-v2/uv.lock | 12 +- 2 files changed, 71 insertions(+), 65 deletions(-) diff --git a/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py b/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py index f7f12e454..1c3edf807 100644 --- a/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py +++ b/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py @@ -1,78 +1,80 @@ -"""Tests for registering model packs that include RelCAT addons.""" +"""Tests for registering model packs that include RelCAT addons. + +The Trainer's responsibility when registering a model pack is to load its +addons and register only the MetaCAT ones (as ``MetaCATModel`` rows). RelCAT +addons must be tolerated (loaded by ``CAT.load_addons``) but skipped during +registration. These tests mock ``CAT.load_addons`` so they exercise that +filtering logic without downloading or loading real models. +""" import os -import shutil import tempfile -import zipfile -from urllib.request import urlretrieve +from unittest.mock import MagicMock, patch from django.core.files.base import ContentFile from django.test import TestCase, override_settings -from medcat.storage.serialisers import MANUAL_SERIALISED_TAG, SER_TYPE_FILE +from medcat.components.addons.meta_cat.meta_cat import MetaCATAddon +from medcat.components.addons.relation_extraction.rel_cat import RelCATAddon from ..models import ModelPack -MODEL_PACK_ZIP_URL = ( - "https://raw.githubusercontent.com/CogStack/cogstack-nlp/" - "051edf6cbd94fa83436fab807aff49d78dd68e59/" - "medcat-service/models/examples/example-medcat-v2-model-pack.zip" -) -REL_CAT_ADDON_CLS = ( - "medcat.components.addons.relation_extraction.rel_cat.RelCATAddon" -) - - -def _add_rel_cat_addon_stub(model_pack_dir: str, addon_name: str = "rel_cat") -> None: - """Add a minimal RelCAT addon folder that triggers manual deserialisation.""" - components_dir = os.path.join(model_pack_dir, "saved_components") - os.makedirs(components_dir, exist_ok=True) - addon_dir = os.path.join(components_dir, f"addon_rel_cat.{addon_name}") - os.makedirs(addon_dir, exist_ok=True) - with open(os.path.join(addon_dir, SER_TYPE_FILE), "w", encoding="utf-8") as f: - f.write(MANUAL_SERIALISED_TAG + REL_CAT_ADDON_CLS) - - -def _build_model_pack_zip_with_relcat(cache_dir: str) -> str: - zip_path = os.path.join(cache_dir, "cached_model_pack.zip") - if not os.path.exists(zip_path): - urlretrieve(MODEL_PACK_ZIP_URL, zip_path) - unpacked = os.path.join(cache_dir, "model_pack") - if os.path.exists(unpacked): - shutil.rmtree(unpacked) - shutil.unpack_archive(zip_path, unpacked) - _add_rel_cat_addon_stub(unpacked) - out_zip = os.path.join(cache_dir, "model_pack_with_relcat.zip") - if os.path.exists(out_zip): - os.remove(out_zip) - with zipfile.ZipFile(out_zip, "w", zipfile.ZIP_DEFLATED) as zf: - for root, _, files in os.walk(unpacked): - for file_name in files: - file_path = os.path.join(root, file_name) - arcname = os.path.relpath(file_path, unpacked) - zf.write(file_path, arcname) - return out_zip + +def _make_meta_cat_addon(category_name="Status", model_name="bert"): + addon = MagicMock(spec=MetaCATAddon) + meta_cat = MagicMock() + meta_cat.config.general.category_name = category_name + meta_cat.config.model.model_name = model_name + meta_cat.config.general.category_value2id = {"True": 0, "False": 1} + addon.mc = meta_cat + return addon + + +def _make_rel_cat_addon(): + return MagicMock(spec=RelCATAddon) @override_settings(MEDIA_ROOT=tempfile.mkdtemp()) class ModelPackRelCATRegistrationTests(TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._cache_dir = tempfile.mkdtemp() - cls.model_pack_zip = _build_model_pack_zip_with_relcat(cls._cache_dir) - - @classmethod - def tearDownClass(cls): - shutil.rmtree(cls._cache_dir, ignore_errors=True) - super().tearDownClass() - - def test_register_model_pack_with_relcat_addon_succeeds(self): - with open(self.model_pack_zip, "rb") as fh: - pack_bytes = fh.read() - model_pack = ModelPack(name="relcat-pack-test") - model_pack.model_pack = ContentFile(pack_bytes, name="relcat-pack-test.zip") - model_pack.save() + def _prepare_model_pack(self, name="relcat-pack-test"): + """Create a ModelPack with a fake unpacked dir (cdb dir + vocab file).""" + model_pack = ModelPack(name=name) + model_pack.model_pack.save(f"{name}.zip", ContentFile(b"fake"), save=False) + unpacked = model_pack.model_pack.path[: -len(".zip")] + os.makedirs(os.path.join(unpacked, "cdb"), exist_ok=True) + with open(os.path.join(unpacked, "vocab"), "w", encoding="utf-8") as fh: + fh.write("") + return model_pack, unpacked + + def test_register_model_pack_with_relcat_addon_skips_relcat(self): + model_pack, unpacked = self._prepare_model_pack() + comps = os.path.join(unpacked, "saved_components") + + addons = [ + (os.path.join(comps, "addon_meta_cat.Status"), _make_meta_cat_addon()), + (os.path.join(comps, "addon_rel_cat.rel_cat"), _make_rel_cat_addon()), + ] + + with patch("api.models.CAT.attempt_unpack"), \ + patch("api.models.CDB.load"), \ + patch("api.models.Vocab.load"), \ + patch("api.models.CAT.load_addons", return_value=addons): + model_pack.save() self.assertIsNotNone(model_pack.concept_db) self.assertIsNotNone(model_pack.vocab) + # RelCAT addon must be filtered out; only the MetaCAT is registered. + self.assertEqual(model_pack.meta_cats.count(), 1) + self.assertEqual(model_pack.meta_cats.first().name, "Status - bert") + + def test_register_model_pack_without_addons(self): + model_pack, unpacked = self._prepare_model_pack(name="no-addon-pack") + + with patch("api.models.CAT.attempt_unpack"), \ + patch("api.models.CDB.load"), \ + patch("api.models.Vocab.load"), \ + patch("api.models.CAT.load_addons", return_value=[]): + model_pack.save() + + self.assertIsNotNone(model_pack.concept_db) + self.assertEqual(model_pack.meta_cats.count(), 0) diff --git a/medcat-v2/uv.lock b/medcat-v2/uv.lock index 7de59fb11..47159573c 100644 --- a/medcat-v2/uv.lock +++ b/medcat-v2/uv.lock @@ -849,7 +849,6 @@ wheels = [ name = "medcat" source = { editable = "." } dependencies = [ - { name = "click" }, { name = "dill" }, { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, @@ -859,7 +858,6 @@ dependencies = [ { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, - { name = "typer" }, { name = "typing-extensions" }, { name = "xxhash" }, ] @@ -901,9 +899,13 @@ spacy = [ { name = "spacy" }, ] +[package.dev-dependencies] +dev = [ + { name = "pooch" }, +] + [package.metadata] requires-dist = [ - { name = "click" }, { name = "datasets", marker = "extra == 'deid'", specifier = ">=2.2.2,<3.0.0" }, { name = "dill" }, { name = "mypy", marker = "extra == 'dev'" }, @@ -933,7 +935,6 @@ requires-dist = [ { name = "transformers", marker = "extra == 'deid'", specifier = ">=4.41.0,<5.0" }, { name = "transformers", marker = "extra == 'meta-cat'", specifier = ">=4.41.0,<5.0" }, { name = "transformers", marker = "extra == 'rel-cat'", specifier = ">=4.41.0,<5.0" }, - { name = "typer", specifier = "!=0.26.*" }, { name = "types-pyyaml", marker = "extra == 'dev'" }, { name = "types-setuptools", marker = "extra == 'dev'" }, { name = "types-tqdm", marker = "extra == 'dev'" }, @@ -942,6 +943,9 @@ requires-dist = [ ] provides-extras = ["dev", "spacy", "meta-cat", "dict-ner", "deid", "rel-cat", "test"] +[package.metadata.requires-dev] +dev = [{ name = "pooch", specifier = ">=1.9.0" }] + [[package]] name = "mpmath" version = "1.3.0" From b6e163fc1edba4e8c97a70c5d8807ad29787f225 Mon Sep 17 00:00:00 2001 From: alhendrickson Date: Mon, 15 Jun 2026 16:29:58 +0000 Subject: [PATCH 4/9] fix(medcat): Relcat addon fix deserialise_from to match metacat with defaults - cleaup --- .../api/api/tests/test_model_pack_addons.py | 144 ++++++++++++++++++ .../api/api/tests/test_model_pack_relcat.py | 80 ---------- 2 files changed, 144 insertions(+), 80 deletions(-) create mode 100644 medcat-trainer/webapp/api/api/tests/test_model_pack_addons.py delete mode 100644 medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py diff --git a/medcat-trainer/webapp/api/api/tests/test_model_pack_addons.py b/medcat-trainer/webapp/api/api/tests/test_model_pack_addons.py new file mode 100644 index 000000000..9b2797e63 --- /dev/null +++ b/medcat-trainer/webapp/api/api/tests/test_model_pack_addons.py @@ -0,0 +1,144 @@ +"""Tests for registering model packs that include addons. + +The Trainer loads all addons via ``CAT.load_addons`` and registers only +MetaCAT ones (as ``MetaCATModel`` rows). Other addon types (e.g. RelCAT) +must load without error but are skipped during registration. These tests +mock ``CAT.load_addons`` so they exercise that behaviour without +downloading or loading real models. + +Scenarios covered: + +- MetaCAT only — one ``MetaCATModel`` row is created. +- RelCAT only — registration succeeds; no ``MetaCATModel`` rows. +- MetaCAT and RelCAT — both addons load; only MetaCAT is registered. +- No addons — CDB and vocab load; no ``MetaCATModel`` rows. +- Multiple MetaCAT addons — each MetaCAT is registered separately. +""" + +import os +import tempfile +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +from django.core.files.base import ContentFile +from django.test import TestCase, override_settings + +from medcat.components.addons.meta_cat.meta_cat import MetaCATAddon +from medcat.components.addons.relation_extraction.rel_cat import RelCATAddon + +from ..models import ModelPack + + +def _make_meta_cat_addon(category_name="Status", model_name="bert"): + addon = MagicMock(spec=MetaCATAddon) + meta_cat = MagicMock() + meta_cat.config.general.category_name = category_name + meta_cat.config.model.model_name = model_name + meta_cat.config.general.category_value2id = {"True": 0, "False": 1} + addon.mc = meta_cat + return addon + + +def _make_rel_cat_addon(): + return MagicMock(spec=RelCATAddon) + + +@override_settings(MEDIA_ROOT=tempfile.mkdtemp()) +class ModelPackAddonRegistrationTests(TestCase): + def _prepare_model_pack(self, name="addon-pack-test"): + """Create a ModelPack with a fake unpacked dir (cdb dir + vocab file).""" + model_pack = ModelPack(name=name) + model_pack.model_pack.save(f"{name}.zip", ContentFile(b"fake"), save=False) + unpacked = model_pack.model_pack.path[: -len(".zip")] + os.makedirs(os.path.join(unpacked, "cdb"), exist_ok=True) + with open(os.path.join(unpacked, "vocab"), "w", encoding="utf-8") as fh: + fh.write("") + return model_pack, unpacked + + @contextmanager + def _register_model_pack(self, model_pack, addons): + with patch("api.models.CAT.attempt_unpack"), \ + patch("api.models.CDB.load"), \ + patch("api.models.Vocab.load"), \ + patch("api.models.CAT.load_addons", return_value=addons) as load_addons: + model_pack.save() + yield load_addons + + def test_register_model_pack_with_meta_cat_only(self): + model_pack, unpacked = self._prepare_model_pack(name="meta-cat-pack") + comps = os.path.join(unpacked, "saved_components") + meta_cat_path = os.path.join(comps, "addon_meta_cat.Status") + addons = [(meta_cat_path, _make_meta_cat_addon())] + + with self._register_model_pack(model_pack, addons) as load_addons: + load_addons.assert_called_once_with(unpacked) + + self.assertIsNotNone(model_pack.concept_db) + self.assertIsNotNone(model_pack.vocab) + self.assertEqual(model_pack.meta_cats.count(), 1) + meta_cat_model = model_pack.meta_cats.get() + self.assertEqual(meta_cat_model.name, "Status - bert") + self.assertTrue(meta_cat_model.meta_cat_dir.endswith("addon_meta_cat.Status")) + + def test_register_model_pack_with_rel_cat_only(self): + model_pack, unpacked = self._prepare_model_pack(name="rel-cat-pack") + comps = os.path.join(unpacked, "saved_components") + rel_cat_path = os.path.join(comps, "addon_rel_cat.rel_cat") + addons = [(rel_cat_path, _make_rel_cat_addon())] + + with self._register_model_pack(model_pack, addons) as load_addons: + load_addons.assert_called_once_with(unpacked) + + self.assertIsNotNone(model_pack.concept_db) + self.assertIsNotNone(model_pack.vocab) + self.assertEqual(model_pack.meta_cats.count(), 0) + + def test_register_model_pack_registers_multiple_meta_cats(self): + model_pack, unpacked = self._prepare_model_pack(name="multi-meta-cat-pack") + comps = os.path.join(unpacked, "saved_components") + addons = [ + (os.path.join(comps, "addon_meta_cat.Status"), + _make_meta_cat_addon(category_name="Status", model_name="bert")), + (os.path.join(comps, "addon_meta_cat.Experiencer"), + _make_meta_cat_addon(category_name="Experiencer", model_name="roberta")), + ] + + with self._register_model_pack(model_pack, addons): + pass + + self.assertEqual(model_pack.meta_cats.count(), 2) + self.assertEqual( + set(model_pack.meta_cats.values_list("name", flat=True)), + {"Status - bert", "Experiencer - roberta"}, + ) + + def test_register_model_pack_with_meta_cat_and_rel_cat(self): + model_pack, unpacked = self._prepare_model_pack(name="mixed-addon-pack") + comps = os.path.join(unpacked, "saved_components") + meta_cat_path = os.path.join(comps, "addon_meta_cat.Status") + rel_cat_path = os.path.join(comps, "addon_rel_cat.rel_cat") + addons = [ + (meta_cat_path, _make_meta_cat_addon()), + (rel_cat_path, _make_rel_cat_addon()), + ] + + with self._register_model_pack(model_pack, addons) as load_addons: + load_addons.assert_called_once_with(unpacked) + + self.assertIsNotNone(model_pack.concept_db) + self.assertIsNotNone(model_pack.vocab) + # All addons load; only MetaCAT rows are registered. + self.assertEqual(model_pack.meta_cats.count(), 1) + meta_cat_model = model_pack.meta_cats.get() + self.assertEqual(meta_cat_model.name, "Status - bert") + self.assertTrue(meta_cat_model.meta_cat_dir.endswith("addon_meta_cat.Status")) + + def test_register_model_pack_without_addons(self): + model_pack, unpacked = self._prepare_model_pack(name="no-addon-pack") + + with self._register_model_pack(model_pack, []) as load_addons: + load_addons.assert_called_once_with(unpacked) + + self.assertIsNotNone(model_pack.concept_db) + self.assertIsNotNone(model_pack.vocab) + self.assertEqual(model_pack.meta_cats.count(), 0) diff --git a/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py b/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py deleted file mode 100644 index 1c3edf807..000000000 --- a/medcat-trainer/webapp/api/api/tests/test_model_pack_relcat.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Tests for registering model packs that include RelCAT addons. - -The Trainer's responsibility when registering a model pack is to load its -addons and register only the MetaCAT ones (as ``MetaCATModel`` rows). RelCAT -addons must be tolerated (loaded by ``CAT.load_addons``) but skipped during -registration. These tests mock ``CAT.load_addons`` so they exercise that -filtering logic without downloading or loading real models. -""" - -import os -import tempfile -from unittest.mock import MagicMock, patch - -from django.core.files.base import ContentFile -from django.test import TestCase, override_settings - -from medcat.components.addons.meta_cat.meta_cat import MetaCATAddon -from medcat.components.addons.relation_extraction.rel_cat import RelCATAddon - -from ..models import ModelPack - - -def _make_meta_cat_addon(category_name="Status", model_name="bert"): - addon = MagicMock(spec=MetaCATAddon) - meta_cat = MagicMock() - meta_cat.config.general.category_name = category_name - meta_cat.config.model.model_name = model_name - meta_cat.config.general.category_value2id = {"True": 0, "False": 1} - addon.mc = meta_cat - return addon - - -def _make_rel_cat_addon(): - return MagicMock(spec=RelCATAddon) - - -@override_settings(MEDIA_ROOT=tempfile.mkdtemp()) -class ModelPackRelCATRegistrationTests(TestCase): - def _prepare_model_pack(self, name="relcat-pack-test"): - """Create a ModelPack with a fake unpacked dir (cdb dir + vocab file).""" - model_pack = ModelPack(name=name) - model_pack.model_pack.save(f"{name}.zip", ContentFile(b"fake"), save=False) - unpacked = model_pack.model_pack.path[: -len(".zip")] - os.makedirs(os.path.join(unpacked, "cdb"), exist_ok=True) - with open(os.path.join(unpacked, "vocab"), "w", encoding="utf-8") as fh: - fh.write("") - return model_pack, unpacked - - def test_register_model_pack_with_relcat_addon_skips_relcat(self): - model_pack, unpacked = self._prepare_model_pack() - comps = os.path.join(unpacked, "saved_components") - - addons = [ - (os.path.join(comps, "addon_meta_cat.Status"), _make_meta_cat_addon()), - (os.path.join(comps, "addon_rel_cat.rel_cat"), _make_rel_cat_addon()), - ] - - with patch("api.models.CAT.attempt_unpack"), \ - patch("api.models.CDB.load"), \ - patch("api.models.Vocab.load"), \ - patch("api.models.CAT.load_addons", return_value=addons): - model_pack.save() - - self.assertIsNotNone(model_pack.concept_db) - self.assertIsNotNone(model_pack.vocab) - # RelCAT addon must be filtered out; only the MetaCAT is registered. - self.assertEqual(model_pack.meta_cats.count(), 1) - self.assertEqual(model_pack.meta_cats.first().name, "Status - bert") - - def test_register_model_pack_without_addons(self): - model_pack, unpacked = self._prepare_model_pack(name="no-addon-pack") - - with patch("api.models.CAT.attempt_unpack"), \ - patch("api.models.CDB.load"), \ - patch("api.models.Vocab.load"), \ - patch("api.models.CAT.load_addons", return_value=[]): - model_pack.save() - - self.assertIsNotNone(model_pack.concept_db) - self.assertEqual(model_pack.meta_cats.count(), 0) From fe4ea575ebcd0205fcafd5554aeac495c4f087b3 Mon Sep 17 00:00:00 2001 From: alhendrickson Date: Mon, 15 Jun 2026 16:46:12 +0000 Subject: [PATCH 5/9] fix(medcat): Relcat addon fix deserialise_from to match metacat with defaults --- .../addons/relation_extraction/rel_cat.py | 43 ++++++++++++++++--- medcat-v2/uv.lock | 12 ++---- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py index c57ac1e52..0e6d099a6 100644 --- a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py +++ b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py @@ -39,6 +39,7 @@ class RelCATAddon(AddonComponent): + DEFAULT_TOKENIZER = 'spacy' addon_type = 'rel_cat' output_key = 'relations' config: ConfigRelCAT @@ -91,30 +92,58 @@ def name(self) -> str: # for ManualSerialisable: + @classmethod + def _create_throwaway_tokenizer(cls) -> BaseTokenizer: + """ + Mirrors `MetaCATAddon._create_throwaway_tokenizer` + """ + logger.warning( + "A base tokenizer was not provided during the loading of a " + "RelCAT. The tokenizer is used to register the required data " + "paths for RelCAT to function. Using the default of '%s'.", + cls.DEFAULT_TOKENIZER, + ) + gcnf = Config() + gcnf.general.nlp.provider = 'spacy' + return create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf) + @classmethod def deserialise_from(cls, folder_path: str, **init_kwargs ) -> 'RelCATAddon': """Deserialise a RelCAT addon from disk. Mirrors `MetaCATAddon.deserialise_from`: when called via the - pipeline, `tokenizer`/`cnf` are supplied; when called standalone + pipeline, `tokenizer`/`cnf`/`cdb` are supplied; when called standalone (e.g. `CAT.load_addons`), they are inferred from disk so that deserialisation works without full pipeline context. """ - rc = RelCAT.load(folder_path) if 'cnf' in init_kwargs: cnf = init_kwargs['cnf'] else: logger.info( "Was not provided a config when loading a rel cat from '%s'. " - "Inferring config from the loaded model.", folder_path) - cnf = rc.component.relcat_config + "Inferring config from file at '%s'", folder_path, + folder_path) + cnf = ConfigRelCAT.load(load_path=folder_path) if 'model_config' in init_kwargs: cnf.merge_config(init_kwargs['model_config']) if 'tokenizer' in init_kwargs: - rc.base_tokenizer = init_kwargs['tokenizer'] - rc._init_data_paths() - return cls(cnf, rc) + tokenizer = init_kwargs['tokenizer'] + else: + tokenizer = cls._create_throwaway_tokenizer() + if 'cdb' in init_kwargs: + cdb = init_kwargs['cdb'] + else: + cdb_path = os.path.join(folder_path, "cdb.dat") + if os.path.exists(cdb_path): + cdb = cast(CDB, deserialise(cdb_path)) + else: + cdb = CDB(config=Config()) + return cls.load_existing( + load_path=folder_path, + cnf=cnf, + base_tokenizer=tokenizer, + cdb=cdb) def get_strategy(self) -> SerialisingStrategy: return SerialisingStrategy.MANUAL diff --git a/medcat-v2/uv.lock b/medcat-v2/uv.lock index 47159573c..7de59fb11 100644 --- a/medcat-v2/uv.lock +++ b/medcat-v2/uv.lock @@ -849,6 +849,7 @@ wheels = [ name = "medcat" source = { editable = "." } dependencies = [ + { name = "click" }, { name = "dill" }, { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, @@ -858,6 +859,7 @@ dependencies = [ { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, + { name = "typer" }, { name = "typing-extensions" }, { name = "xxhash" }, ] @@ -899,13 +901,9 @@ spacy = [ { name = "spacy" }, ] -[package.dev-dependencies] -dev = [ - { name = "pooch" }, -] - [package.metadata] requires-dist = [ + { name = "click" }, { name = "datasets", marker = "extra == 'deid'", specifier = ">=2.2.2,<3.0.0" }, { name = "dill" }, { name = "mypy", marker = "extra == 'dev'" }, @@ -935,6 +933,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'deid'", specifier = ">=4.41.0,<5.0" }, { name = "transformers", marker = "extra == 'meta-cat'", specifier = ">=4.41.0,<5.0" }, { name = "transformers", marker = "extra == 'rel-cat'", specifier = ">=4.41.0,<5.0" }, + { name = "typer", specifier = "!=0.26.*" }, { name = "types-pyyaml", marker = "extra == 'dev'" }, { name = "types-setuptools", marker = "extra == 'dev'" }, { name = "types-tqdm", marker = "extra == 'dev'" }, @@ -943,9 +942,6 @@ requires-dist = [ ] provides-extras = ["dev", "spacy", "meta-cat", "dict-ner", "deid", "rel-cat", "test"] -[package.metadata.requires-dev] -dev = [{ name = "pooch", specifier = ">=1.9.0" }] - [[package]] name = "mpmath" version = "1.3.0" From 63de4b1096aeb2d19f8b90fe334aac637ec58e81 Mon Sep 17 00:00:00 2001 From: alhendrickson Date: Mon, 15 Jun 2026 16:47:14 +0000 Subject: [PATCH 6/9] fix(medcat): Relcat addon fix deserialise_from to match metacat with defaults --- medcat-v2/uv.lock | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/medcat-v2/uv.lock b/medcat-v2/uv.lock index 7de59fb11..47159573c 100644 --- a/medcat-v2/uv.lock +++ b/medcat-v2/uv.lock @@ -849,7 +849,6 @@ wheels = [ name = "medcat" source = { editable = "." } dependencies = [ - { name = "click" }, { name = "dill" }, { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, @@ -859,7 +858,6 @@ dependencies = [ { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, - { name = "typer" }, { name = "typing-extensions" }, { name = "xxhash" }, ] @@ -901,9 +899,13 @@ spacy = [ { name = "spacy" }, ] +[package.dev-dependencies] +dev = [ + { name = "pooch" }, +] + [package.metadata] requires-dist = [ - { name = "click" }, { name = "datasets", marker = "extra == 'deid'", specifier = ">=2.2.2,<3.0.0" }, { name = "dill" }, { name = "mypy", marker = "extra == 'dev'" }, @@ -933,7 +935,6 @@ requires-dist = [ { name = "transformers", marker = "extra == 'deid'", specifier = ">=4.41.0,<5.0" }, { name = "transformers", marker = "extra == 'meta-cat'", specifier = ">=4.41.0,<5.0" }, { name = "transformers", marker = "extra == 'rel-cat'", specifier = ">=4.41.0,<5.0" }, - { name = "typer", specifier = "!=0.26.*" }, { name = "types-pyyaml", marker = "extra == 'dev'" }, { name = "types-setuptools", marker = "extra == 'dev'" }, { name = "types-tqdm", marker = "extra == 'dev'" }, @@ -942,6 +943,9 @@ requires-dist = [ ] provides-extras = ["dev", "spacy", "meta-cat", "dict-ner", "deid", "rel-cat", "test"] +[package.metadata.requires-dev] +dev = [{ name = "pooch", specifier = ">=1.9.0" }] + [[package]] name = "mpmath" version = "1.3.0" From f9326d8b387e9af2dd7303fbded645eb5b3aa7d0 Mon Sep 17 00:00:00 2001 From: alhendrickson Date: Tue, 16 Jun 2026 10:07:47 +0000 Subject: [PATCH 7/9] fix(medcat): Relcat addon fix deserialise - make the same as metacataddon --- .../components/addons/meta_cat/meta_cat.py | 2 +- .../addons/relation_extraction/rel_cat.py | 1927 +++++++++-------- 2 files changed, 976 insertions(+), 953 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index 05fb4a185..475eb8291 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -223,7 +223,7 @@ def deserialise_from(cls, folder_path: str, **init_kwargs # load legacy config (assuming it exists) config_path += ".dat" logger.info( - "Was not provide a config when loading a meta cat from '%s'. " + "Was not provided a config when loading a meta cat from '%s'. " "Inferring config from file at '%s'", folder_path, config_path) cnf = ConfigMetaCAT.load(config_path) diff --git a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py index 0e6d099a6..585b8518e 100644 --- a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py +++ b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py @@ -1,952 +1,975 @@ -import json -import logging -import os -import random -from typing import Optional - -from sklearn.utils import compute_class_weight -import torch -import torch.nn as nn - -from tqdm import tqdm -from datetime import date, datetime -from typing import Iterable, Iterator, cast -from torch.utils.data import DataLoader, Sampler -from torch.optim import AdamW -from torch.optim.lr_scheduler import MultiStepLR -import numpy - -from medcat.cdb import CDB -from medcat.vocab import Vocab -from medcat.config.config import Config, ComponentConfig -from medcat.config.config_rel_cat import ConfigRelCAT -from medcat.storage.serialisers import deserialise -from medcat.storage.serialisables import SerialisingStrategy -from medcat.components.addons.addons import AddonComponent -from medcat.components.addons.relation_extraction.base_component import ( - RelExtrBaseComponent) -from medcat.components.addons.meta_cat.ml_utils import set_all_seeds -from medcat.components.addons.relation_extraction.ml_utils import ( - load_results, load_state, save_results, save_state, - split_list_train_test_by_class) -from medcat.components.addons.relation_extraction.rel_dataset import RelData -from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer -from medcat.tokenizing.tokens import MutableDocument -from medcat.utils.defaults import COMPONENTS_FOLDER - - -logger = logging.getLogger(__name__) - - -class RelCATAddon(AddonComponent): - DEFAULT_TOKENIZER = 'spacy' - addon_type = 'rel_cat' - output_key = 'relations' - config: ConfigRelCAT - - def __init__(self, config: ConfigRelCAT, - rel_cat: "RelCAT"): - self.config = config - self._rel_cat = rel_cat - - @classmethod - def create_new(cls, config: ConfigRelCAT, base_tokenizer: BaseTokenizer, - cdb: CDB) -> 'RelCATAddon': - """Factory method to create a new MetaCATAddon instance.""" - return cls(config, - RelCAT(base_tokenizer, cdb, config=config, init_model=True)) - - @classmethod - def create_new_component( - cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, - cdb: CDB, vocab: Vocab, model_load_path: Optional[str] - ) -> 'RelCATAddon': - if not isinstance(cnf, ConfigRelCAT): - raise ValueError(f"Incompatible config: {cnf}") - config = cnf - if model_load_path is not None: - load_path = os.path.join(model_load_path, COMPONENTS_FOLDER, - cls.NAME_PREFIX + cls.addon_type) - return cls.load_existing(config, tokenizer, cdb, load_path) - return cls.create_new(config, tokenizer, cdb) - - @classmethod - def load_existing(cls, cnf: ConfigRelCAT, - base_tokenizer: BaseTokenizer, - cdb: CDB, - load_path: str) -> 'RelCATAddon': - """Factory method to load an existing RelCAT addon from disk.""" - rc = RelCAT.load(load_path) - # set the correct base tokenizer and redo data paths - rc.base_tokenizer = base_tokenizer - rc._init_data_paths() - return cls(cnf, rc) - - def serialise_to(self, folder_path: str) -> None: - os.mkdir(folder_path) - self._rel_cat.save(folder_path) - - @property - def name(self) -> str: - return str(self.addon_type) - - # for ManualSerialisable: - - @classmethod - def _create_throwaway_tokenizer(cls) -> BaseTokenizer: - """ - Mirrors `MetaCATAddon._create_throwaway_tokenizer` - """ - logger.warning( - "A base tokenizer was not provided during the loading of a " - "RelCAT. The tokenizer is used to register the required data " - "paths for RelCAT to function. Using the default of '%s'.", - cls.DEFAULT_TOKENIZER, - ) - gcnf = Config() - gcnf.general.nlp.provider = 'spacy' - return create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf) - - @classmethod - def deserialise_from(cls, folder_path: str, **init_kwargs - ) -> 'RelCATAddon': - """Deserialise a RelCAT addon from disk. - - Mirrors `MetaCATAddon.deserialise_from`: when called via the - pipeline, `tokenizer`/`cnf`/`cdb` are supplied; when called standalone - (e.g. `CAT.load_addons`), they are inferred from disk so that - deserialisation works without full pipeline context. - """ - if 'cnf' in init_kwargs: - cnf = init_kwargs['cnf'] - else: - logger.info( - "Was not provided a config when loading a rel cat from '%s'. " - "Inferring config from file at '%s'", folder_path, - folder_path) - cnf = ConfigRelCAT.load(load_path=folder_path) - if 'model_config' in init_kwargs: - cnf.merge_config(init_kwargs['model_config']) - if 'tokenizer' in init_kwargs: - tokenizer = init_kwargs['tokenizer'] - else: - tokenizer = cls._create_throwaway_tokenizer() - if 'cdb' in init_kwargs: - cdb = init_kwargs['cdb'] - else: - cdb_path = os.path.join(folder_path, "cdb.dat") - if os.path.exists(cdb_path): - cdb = cast(CDB, deserialise(cdb_path)) - else: - cdb = CDB(config=Config()) - return cls.load_existing( - load_path=folder_path, - cnf=cnf, - base_tokenizer=tokenizer, - cdb=cdb) - - def get_strategy(self) -> SerialisingStrategy: - return SerialisingStrategy.MANUAL - - @classmethod - def get_init_attrs(cls) -> list[str]: - return [] - - @classmethod - def ignore_attrs(cls) -> list[str]: - return [] - - @classmethod - def include_properties(cls) -> list[str]: - return [] - - def __call__(self, doc: MutableDocument): - return self._rel_cat(doc) - - -class BalancedBatchSampler(Sampler): - - def __init__(self, dataset, classes, - batch_size, max_samples, max_minority): - self.dataset = dataset - self.classes = classes - self.batch_size = batch_size - self.num_classes = len(classes) - self.indices = list(range(len(dataset))) - - self.max_minority = max_minority - - self.max_samples_per_class = max_samples - - def __len__(self): - return (len(self.indices) + self.batch_size - 1) // self.batch_size - - def __iter__(self): - batch_counter = 0 - indices = self.indices.copy() - while batch_counter != self.__len__(): - batch = [] - - class_counts = {c: 0 for c in self.classes} - while len(batch) < self.batch_size: - - index = random.choice(indices) - # Assuming label is at index 1 - label = self.dataset[index][2].numpy().tolist()[0] - if class_counts[label] < self.max_samples_per_class[label]: - batch.append(index) - class_counts[label] += 1 - if self.max_samples_per_class[label] > self.max_minority: - indices.remove(index) - - yield batch - batch_counter += 1 - - -class RelCAT: - """The RelCAT class used for training 'Relation-Annotation' models, i.e., - annotation of relations between clinical concepts. - - Args: - cdb (CDB): cdb, this is used when creating relation datasets. - - tokenizer (TokenizerWrapperBERT): - The Huggingface tokenizer instance. This can be a pre-trained - tokenzier instance from a BERT-style model. For now, only - BERT models are supported. - - config (ConfigRelCAT): - the configuration for RelCAT. Param descriptions available in - ConfigRelCAT class docs. - - task (str, optional): What task is this model supposed to handle. - Defaults to "train" - init_model (bool, optional): loads default model. Defaults to False. - - """ - addon_type = 'rel_cat' - output_key = 'rel_' - - def __init__(self, base_tokenizer: BaseTokenizer, - cdb: CDB, config: ConfigRelCAT = ConfigRelCAT(), - task: str = "train", init_model: bool = False): - self.base_tokenizer = base_tokenizer - self.component = RelExtrBaseComponent() - self.task: str = task - self.checkpoint_path: str = "./" - - set_all_seeds(config.general.seed) - - if init_model: - self.component = RelExtrBaseComponent( - config=config, task=task, init_model=True) - - self.cdb = cdb - logging.basicConfig( - level=self.component.relcat_config.general.log_level) - logger.setLevel(self.component.relcat_config.general.log_level) - - self.is_cuda_available = torch.cuda.is_available() - self.device = torch.device( - "cuda" if self.is_cuda_available and - self.component.relcat_config.general.device != "cpu" else "cpu") - self._init_data_paths() - - def _init_data_paths(self): - doc_cls = self.base_tokenizer.get_doc_class() - doc_cls.register_addon_path('relations', def_val=[], force=True) - entity_cls = self.base_tokenizer.get_entity_class() - entity_cls.register_addon_path('start', def_val=None, force=True) - entity_cls.register_addon_path('end', def_val=None, force=True) - - def save(self, save_path: str = "./") -> None: - self.component.save(save_path=save_path) - - @classmethod - def load(cls, load_path: str = "./") -> "RelCAT": - - if os.path.exists(os.path.join(load_path, "cdb.dat")): - cdb = cast(CDB, deserialise(os.path.join(load_path, "cdb.dat"))) - else: - cdb = CDB(config=Config()) - logger.info( - "The default CDB file name 'cdb.dat' doesn't exist in the " - "specified path, you will need to load & set " - "a CDB manually via rel_cat.cdb = CDB.load('path') ") - - component = RelExtrBaseComponent.load( - pretrained_model_name_or_path=load_path) - - device = torch.device( - "cuda" if torch.cuda.is_available() and - component.relcat_config.general.device != "cpu" else "cpu") - - rel_cat = RelCAT( - # NOTE: this is a throaway tokenizer just for registrations - create_tokenizer(cdb.config.general.nlp.provider, cdb.config), - cdb=cdb, config=component.relcat_config, task=component.task) - rel_cat.device = device - rel_cat.component = component - - return rel_cat - - def __call__(self, doc: MutableDocument) -> MutableDocument: - doc = next(self.pipe(iter([doc]))) - return doc - - def _create_test_train_datasets(self, data: dict, - split_sets: bool = False): - train_data: dict = {} - test_data: dict = {} - - if split_sets: - rc_cnf = self.component.relcat_config - (train_data["output_relations"], - test_data["output_relations"]) = split_list_train_test_by_class( - data["output_relations"], - test_size=rc_cnf.train.test_size, - shuffle=rc_cnf.train.shuffle_data, - sample_limit=rc_cnf.general.limit_samples_per_class) - - test_data_label_names = [ - rec[4] for rec in test_data["output_relations"]] - - (test_data["nclasses"], test_data["labels2idx"], - test_data["idx2label"]) = RelData.get_labels( - test_data_label_names, self.component.relcat_config) - - for idx in range(len(test_data["output_relations"])): - test_data["output_relations" - ][idx][5] = test_data["labels2idx"][ - test_data["output_relations"][idx][4]] - else: - train_data["output_relations"] = data["output_relations"] - - for k, v in data.items(): - if k != "output_relations": - train_data[k] = [] - test_data[k] = [] - - train_data_label_names = [rec[4] - for rec in train_data["output_relations"]] - - (train_data["nclasses"], train_data["labels2idx"], - train_data["idx2label"]) = RelData.get_labels( - train_data_label_names, self.component.relcat_config) - - for idx in range(len(train_data["output_relations"])): - train_data["output_relations" - ][idx][5] = train_data["labels2idx"][ - train_data["output_relations"][idx][4]] - - return train_data, test_data - - def train(self, export_data_path: str = "", train_csv_path: str = "", - test_csv_path: str = "", checkpoint_path: str = "./"): - - if self.is_cuda_available: - logger.info("Training on device: %s%s", - str(torch.cuda.get_device_name(0)), str(self.device)) - - self.component.model = self.component.model.to(self.device) - - rc_cnf = self.component.relcat_config - - # resize vocab just in case more tokens have been added - self.component.model_config.vocab_size = ( - self.component.tokenizer.get_size()) - - train_rel_data = RelData( - cdb=self.cdb, config=rc_cnf, - tokenizer=self.component.tokenizer) - test_rel_data = RelData( - cdb=self.cdb, config=rc_cnf, - tokenizer=self.component.tokenizer) - - if train_csv_path != "": - if test_csv_path != "": - train_rel_data.dataset, _ = self._create_test_train_datasets( - train_rel_data.create_base_relations_from_csv( - train_csv_path), split_sets=False) - test_rel_data.dataset, _ = self._create_test_train_datasets( - train_rel_data.create_base_relations_from_csv( - test_csv_path), split_sets=False) - else: - (train_rel_data.dataset, - test_rel_data.dataset) = self._create_test_train_datasets( - train_rel_data.create_base_relations_from_csv( - train_csv_path), split_sets=True) - - elif export_data_path != "": - export_data = {} - with open(export_data_path) as f: - export_data = json.load(f) - (train_rel_data.dataset, - test_rel_data.dataset) = self._create_test_train_datasets( - train_rel_data.create_relations_from_export(export_data), - split_sets=True) - else: - raise ValueError( - "NO DATA HAS BEEN PROVIDED (MedCAT Trainer export " - "JSON/CSV/spacy_DOCS)") - - train_dataset_size = len(train_rel_data) - batch_size = ( - train_dataset_size if train_dataset_size < rc_cnf.train.batch_size - else rc_cnf.train.batch_size) - - # to use stratified batching - if rc_cnf.train.stratified_batching: - sampler = BalancedBatchSampler( - train_rel_data, [ - i for i in - range(rc_cnf.train.nclasses)], - batch_size, - rc_cnf.train.batching_samples_per_class, - rc_cnf.train.batching_minority_limit) - - train_dataloader = DataLoader( - train_rel_data, num_workers=0, - collate_fn=self.component.padding_seq, - batch_sampler=sampler, - pin_memory=rc_cnf.general.pin_memory) - else: - train_dataloader = DataLoader( - train_rel_data, batch_size=batch_size, - shuffle=rc_cnf.train.shuffle_data, - num_workers=0, - collate_fn=self.component.padding_seq, - pin_memory=rc_cnf.general.pin_memory) - - test_dataset_size = len(test_rel_data) - test_batch_size = ( - test_dataset_size if - test_dataset_size < rc_cnf.train.batch_size - else rc_cnf.train.batch_size) - test_dataloader = DataLoader( - test_rel_data, - batch_size=test_batch_size, - shuffle=rc_cnf.train.shuffle_data, - num_workers=0, - collate_fn=self.component.padding_seq, - pin_memory=rc_cnf.general.pin_memory) - - if (rc_cnf.train.class_weights is not None and - rc_cnf.train.enable_class_weights): - criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( - numpy.asarray(rc_cnf.train.class_weights) - ).to(self.device)) - elif rc_cnf.train.enable_class_weights: - all_class_lbl_ids = [ - rec[5] for rec in train_rel_data.dataset["output_relations"]] - rc_cnf.train.class_weights = ( - compute_class_weight(class_weight="balanced", - classes=numpy.unique(all_class_lbl_ids), - y=all_class_lbl_ids).tolist()) - criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( - rc_cnf.train.class_weights).to( - self.device)) - else: - criterion = nn.CrossEntropyLoss() - - if self.component.optimizer is None: - parameters = filter(lambda p: p.requires_grad, - self.component.model.parameters()) - self.component.optimizer = AdamW( - parameters, lr=self.component.relcat_config.train.lr, - weight_decay=rc_cnf.train.adam_weight_decay, - betas=rc_cnf.train.adam_betas, eps=rc_cnf.train.adam_epsilon) - - if self.component.scheduler is None: - self.component.scheduler = MultiStepLR( - self.component.optimizer, - milestones=rc_cnf.train.multistep_milestones, - gamma=rc_cnf.train.multistep_lr_gamma) - - self.epoch, self.best_f1 = load_state( - self.component.model, self.component.optimizer, - self.component.scheduler, load_best=False, path=checkpoint_path, - relcat_config=rc_cnf) - - logger.info("Starting training process...") - - losses_per_epoch, accuracy_per_epoch, f1_per_epoch = load_results( - path=checkpoint_path) - - if train_rel_data.dataset["nclasses"] > rc_cnf.train.nclasses: - rc_cnf.train.nclasses = train_rel_data.dataset["nclasses"] - self.component.model.relcat_config.train.nclasses = ( - rc_cnf.train.nclasses) - - rc_cnf.general.labels2idx.update(train_rel_data.dataset["labels2idx"]) - rc_cnf.general.idx2labels = { - int(v): k for k, v in rc_cnf.general.labels2idx.items()} - - gradient_acc_steps = ( - rc_cnf.train.gradient_acc_steps) - max_grad_norm = rc_cnf.train.max_grad_norm - - _epochs = self.epoch + rc_cnf.train.nepochs - - for epoch in range(0, _epochs): - epoch_losses, epoch_precision, epoch_f1 = self._train_epoch( - epoch, gradient_acc_steps, max_grad_norm, train_dataset_size, - train_dataloader, test_dataloader, criterion, _epochs, - checkpoint_path) - losses_per_epoch.extend(epoch_losses) - accuracy_per_epoch.extend(epoch_precision) - f1_per_epoch.extend(epoch_f1) - - def _train_epoch(self, epoch: int, - gradient_acc_steps: int, - max_grad_norm: float, - train_dataset_size: int, - train_dataloader: DataLoader, - test_dataloader: DataLoader, - criterion: nn.CrossEntropyLoss, - _epochs: int, - checkpoint_path: str) -> tuple[list, list, list]: - rc_cnf = self.component.relcat_config - start_time = datetime.now().time() - total_loss = 0.0 - - loss_per_batch = [] - accuracy_per_batch = [] - - logger.info( - "Total epochs on this model: %d | currently training " - "epoch %d", _epochs, epoch) - - pbar = tqdm(total=train_dataset_size) - - for i, data in enumerate(train_dataloader, 0): - self.component.model.train() - self.component.model.zero_grad() - - current_batch_size = len(data[0]) - token_ids, e1_e2_start, labels, _, _ = data - - attention_mask = ( - token_ids != self.component.pad_id).float().to(self.device) - - token_type_ids = torch.zeros( - (token_ids.shape[0], token_ids.shape[1])).long().to( - self.device) - - labels = labels.to(self.device) - - model_output, classification_logits = self.component.model( - input_ids=token_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - e1_e2_start=e1_e2_start - ) - - batch_loss = criterion( - classification_logits.view( - -1, rc_cnf.train.nclasses).to(self.device), - labels.squeeze(1)) - - batch_loss.backward() - batch_loss = batch_loss / gradient_acc_steps - - total_loss += batch_loss.item() / current_batch_size - - (batch_acc, _, batch_precision, batch_f1, - _, _, batch_stats_per_label) = self.evaluate_( - classification_logits, labels, ignore_idx=-1) - - loss_per_batch.append(batch_loss / current_batch_size) - accuracy_per_batch.append(batch_acc) - - torch.nn.utils.clip_grad_norm_( - self.component.model.parameters(), max_grad_norm) - - if (i % gradient_acc_steps) == 0: - self.component.optimizer.step() - self.component.scheduler.step() - if ((i + 1) % current_batch_size == 0): - logger.debug( - "[Epoch: %d, loss per batch, accuracy per batch: %.3f," - " %.3f, average total loss %.3f , total loss %.3f]", - epoch, loss_per_batch[-1], accuracy_per_batch[-1], - total_loss / (i + 1), total_loss) - - pbar.update(current_batch_size) - - pbar.close() - - losses_per_epoch = [] - accuracy_per_epoch = [] - f1_per_epoch = [] - if len(loss_per_batch) > 0: - losses_per_epoch.append( - sum(loss_per_batch) / len(loss_per_batch)) - logger.info("Losses at Epoch %d: %.5f" % - (epoch, losses_per_epoch[-1])) - - if len(accuracy_per_batch) > 0: - accuracy_per_epoch.append( - sum(accuracy_per_batch) / len(accuracy_per_batch)) - logger.info("Train accuracy at Epoch %d: %.5f" % - (epoch, accuracy_per_epoch[-1])) - - total_loss = total_loss / (i + 1) - - end_time = datetime.now().time() - - logger.info( - "========================" - " TRAIN SET TEST RESULTS " - "========================") - _ = self.evaluate_results(train_dataloader, self.component.pad_id) - - logger.info( - "========================" - " TEST SET TEST RESULTS " - "========================") - results = self.evaluate_results( - test_dataloader, self.component.pad_id) - - f1_per_epoch.append(results['f1']) - - logger.info("Epoch finished, took %s seconds", - str(datetime.combine(date.today(), end_time) - - datetime.combine(date.today(), start_time))) - - self.epoch += 1 - - if len(f1_per_epoch) > 0 and f1_per_epoch[-1] > self.best_f1: - self.best_f1 = f1_per_epoch[-1] - save_state( - self.component.model, self.component.optimizer, - self.component.scheduler, self.epoch, self.best_f1, - checkpoint_path, model_name=rc_cnf.general.model_name, - task=self.task, is_checkpoint=False) - - if (epoch % 1) == 0: - save_results( - { - "losses_per_epoch": losses_per_epoch, - "accuracy_per_epoch": accuracy_per_epoch, - "f1_per_epoch": f1_per_epoch, - "epoch": epoch - }, file_prefix="train", path=checkpoint_path) - save_state(self.component.model, self.component.optimizer, - self.component.scheduler, self.epoch, self.best_f1, - checkpoint_path, - model_name=rc_cnf.general.model_name, - task=self.task, is_checkpoint=True) - return losses_per_epoch, accuracy_per_epoch, f1_per_epoch - - def evaluate_(self, output_logits, labels, ignore_idx): - # ignore index (padding) when calculating accuracy - idxs = (labels != ignore_idx).squeeze() - labels_ = labels.squeeze()[idxs].to(self.device) - pred_labels = torch.softmax(output_logits, dim=1).max(1)[1] - pred_labels = pred_labels[idxs].to(self.device) - - true_labels = labels_.cpu().numpy().tolist( - ) if labels_.is_cuda else labels_.numpy().tolist() - pred_labels = pred_labels.cpu().numpy().tolist( - ) if pred_labels.is_cuda else pred_labels.numpy().tolist() - - unique_labels = set(true_labels) - - batch_size = len(true_labels) - - stat_per_label = dict() - - total_tp, total_fp, total_tn, total_fn = 0, 0, 0, 0 - acc, micro_recall, micro_precision, micro_f1 = 0, 0, 0, 0 - - for label in unique_labels: - stat_per_label[label] = { - "tp": 0, "fp": 0, "tn": 0, "fn": 0, - "f1": 0.0, "acc": 0.0, "prec": 0.0, "recall": 0.0} - - for true_label_idx in range(len(true_labels)): - if true_labels[true_label_idx] == label: - if pred_labels[true_label_idx] == label: - stat_per_label[label]["tp"] += 1 - total_tp += 1 - if pred_labels[true_label_idx] != label: - stat_per_label[label]["fp"] += 1 - total_fp += 1 - elif (true_labels[true_label_idx] != label and - label == pred_labels[true_label_idx]): - stat_per_label[label]["fn"] += 1 - total_fn += 1 - else: - stat_per_label[label]["tn"] += 1 - total_tn += 1 - - lbl_tp_tn = stat_per_label[label]["tn"] + \ - stat_per_label[label]["tp"] - - lbl_tp_fn = stat_per_label[label]["fn"] + \ - stat_per_label[label]["tp"] - lbl_tp_fn = lbl_tp_fn if lbl_tp_fn > 0.0 else 1.0 - - lbl_tp_fp = stat_per_label[label]["tp"] + \ - stat_per_label[label]["fp"] - lbl_tp_fp = lbl_tp_fp if lbl_tp_fp > 0.0 else 1.0 - - stat_per_label[label]["acc"] = lbl_tp_tn / batch_size - stat_per_label[label]["prec"] = (stat_per_label[label]["tp"] / - lbl_tp_fp) - stat_per_label[label]["recall"] = (stat_per_label[label]["tp"] / - lbl_tp_fn) - - lbl_re_pr = stat_per_label[label]["recall"] + \ - stat_per_label[label]["prec"] - lbl_re_pr = lbl_re_pr if lbl_re_pr > 0.0 else 1.0 - - stat_per_label[label]["f1"] = ( - 2 * (stat_per_label[label]["recall"] * - stat_per_label[label]["prec"])) / lbl_re_pr - - tp_fn = total_fn + total_tp - tp_fn = tp_fn if tp_fn > 0.0 else 1.0 - - tp_fp = total_fp + total_tp - tp_fp = tp_fp if tp_fp > 0.0 else 1.0 - - micro_recall = total_tp / tp_fn - micro_precision = total_tp / tp_fp - - re_pr = micro_recall + micro_precision - re_pr = re_pr if re_pr > 0.0 else 1.0 - micro_f1 = (2 * (micro_recall * micro_precision)) / re_pr - - acc = total_tp / batch_size - - return (acc, micro_recall, micro_precision, micro_f1, - pred_labels, true_labels, stat_per_label) - - def evaluate_results(self, data_loader, pad_id): - logger.info("Evaluating test samples...") - rc_cnf = self.component.relcat_config - if (rc_cnf.train.class_weights is not None and - rc_cnf.train.enable_class_weights): - criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( - rc_cnf.train.class_weights).to(self.device)) - else: - criterion = nn.CrossEntropyLoss() - - total_loss, total_acc, total_f1, total_recall, total_precision = ( - 0.0, 0.0, 0.0, 0.0, 0.0) - all_batch_stats_per_label = [] - - self.component.model.eval() - - for i, data in enumerate(data_loader): - with torch.no_grad(): - token_ids, e1_e2_start, labels, _, _ = data - attention_mask = (token_ids != pad_id).float().to(self.device) - token_type_ids = torch.zeros( - (*token_ids.shape[:2],)).long().to(self.device) - - labels = labels.to(self.device) - - model_output, pred_classification_logits = ( - self.component.model(token_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - Q=None, - e1_e2_start=e1_e2_start)) - - batch_loss = criterion(pred_classification_logits.view( - -1, rc_cnf.train.nclasses).to(self.device), - labels.squeeze(1)) - total_loss += batch_loss.item() - - (batch_accuracy, batch_recall, batch_precision, batch_f1, - pred_labels, true_labels, batch_stats_per_label) = ( - self.evaluate_(pred_classification_logits, - labels, ignore_idx=-1)) - - all_batch_stats_per_label.append(batch_stats_per_label) - - total_acc += batch_accuracy - total_recall += batch_recall - total_precision += batch_precision - total_f1 += batch_f1 - - final_stats_per_label = {} - - for batch_label_stats in all_batch_stats_per_label: - for label_id, stat_dict in batch_label_stats.items(): - - if label_id not in final_stats_per_label.keys(): - final_stats_per_label[label_id] = stat_dict - else: - for stat, score in stat_dict.items(): - final_stats_per_label[label_id][stat] += score - - for label_id, stat_dict in final_stats_per_label.items(): - for stat_name, value in stat_dict.items(): - final_stats_per_label[label_id][stat_name] = value / (i + 1) - - total_loss = total_loss / (i + 1) - total_acc = total_acc / (i + 1) - total_precision = total_precision / (i + 1) - total_f1 = total_f1 / (i + 1) - total_recall = total_recall / (i + 1) - - results = { - "loss": total_loss, - "accuracy": total_acc, - "precision": total_precision, - "recall": total_recall, - "f1": total_f1 - } - - logger.info("=" * 20 + " Evaluation Results " + "=" * 20) - logger.info(" no. of batches:" + str(i + 1)) - for key in sorted(results.keys()): - logger.info(" %s = %0.3f" % (key, results[key])) - logger.info("-" * 23 + " class stats " + "-" * 23) - for label_id, stat_dict in final_stats_per_label.items(): - logger.info( - "label: %s | f1: %0.3f | prec : %0.3f | acc: %0.3f | " - "recall: %0.3f ", - rc_cnf.general.idx2labels[label_id], - stat_dict["f1"], - stat_dict["prec"], - stat_dict["acc"], - stat_dict["recall"] - ) - logger.info("-" * 59) - logger.info("=" * 59) - - return results - - def pipe(self, stream: Iterable[MutableDocument], *args, **kwargs - ) -> Iterator[MutableDocument]: - rc_cnf = self.component.relcat_config - - predict_rel_dataset = RelData( - cdb=self.cdb, config=rc_cnf, - tokenizer=self.component.tokenizer) - - self.component.model = self.component.model.to(self.device) - - for doc_id, doc in enumerate(stream, 0): - predict_rel_dataset.dataset, _ = self._create_test_train_datasets( - data=predict_rel_dataset.create_base_relations_from_doc( - doc, doc_id=str(doc_id)), - split_sets=False) - - predict_dataloader = DataLoader( - dataset=predict_rel_dataset, shuffle=False, - batch_size=rc_cnf.train.batch_size, - num_workers=0, collate_fn=self.component.padding_seq, - pin_memory=rc_cnf.general.pin_memory) - - total_rel_found = len( - predict_rel_dataset.dataset["output_relations"]) - rel_idx = -1 - - logger.info("total relations for doc: " + str(total_rel_found)) - logger.info("processing...") - - pbar = tqdm(total=total_rel_found) - - for i, data in enumerate(predict_dataloader): - with torch.no_grad(): - token_ids, e1_e2_start, labels, _, _ = data - - attention_mask = ( - token_ids != self.component.pad_id - ).float().to(self.device) - token_type_ids = torch.zeros( - *token_ids.shape[:2]).long().to(self.device) - - (model_output, - pred_classification_logits) = self.component.model( - token_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, - e1_e2_start=e1_e2_start) - - for i, pred_rel_logits in enumerate( - pred_classification_logits): - rel_idx += 1 - - confidence = torch.softmax( - pred_rel_logits, dim=0).max(0) - predicted_label_id = int(confidence[1].item()) - - relations: list = doc.get_addon_data( # type: ignore - "relations") - out_rels = predict_rel_dataset.dataset[ - "output_relations"][rel_idx] - relations.append( - { - "relation": rc_cnf.general.idx2labels[ - predicted_label_id], - "label_id": predicted_label_id, - "ent1_text": out_rels[2], - "ent2_text": out_rels[3], - "confidence": float("{:.3f}".format( - confidence[0])), - "start_ent1_char_pos": out_rels[18], - "end_ent1_char_pos": out_rels[19], - "start_ent2_char_pos": out_rels[20], - "end_ent2_char_pos": out_rels[21], - "start_entity_id": out_rels[8], - "end_entity_id": out_rels[9], - }) - pbar.update(len(token_ids)) - pbar.close() - - yield doc - - def predict_text_with_anns(self, text: str, annotations: list[dict] - ) -> MutableDocument: - """ Creates spacy doc from text and annotation input. - Predicts using self.__call__ - - Args: - text (str): text - annotations (dict): dict containing the entities from NER - (of your choosing), the format must be the following format: - [ - { - "cui": "202099003", -this is optional - "value": "discoid lateral meniscus", - "start": 294, - "end": 318 - }, - { - "cui": "202099003", - "value": "Discoid lateral meniscus", - "start": 1905, - "end": 1929, - } - ] - - Returns: - Doc: spacy doc with the relations. - """ - # NOTE: This runs not an empty language, but the specified one - base_tokenizer = create_tokenizer( - self.cdb.config.general.nlp.provider, self.cdb.config) - doc = base_tokenizer(text) - - for ann in annotations: - tkn_idx = [] - for ind, word in enumerate(doc): - end_char = word.base.char_index + len(word.base.text) - if end_char <= ann['end'] and end_char > ann['start']: - tkn_idx.append(ind) - entity = base_tokenizer.create_entity( - doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"]) - entity.cui = ann["cui"] - entity.set_addon_data('start', ann['start']) - entity.set_addon_data('end', ann['end']) - doc.ner_ents.append(entity) - - doc = self(doc) - - return doc +import json +import logging +import os +import random +from typing import Optional + +from sklearn.utils import compute_class_weight +import torch +import torch.nn as nn + +from tqdm import tqdm +from datetime import date, datetime +from typing import Iterable, Iterator, cast +from torch.utils.data import DataLoader, Sampler +from torch.optim import AdamW +from torch.optim.lr_scheduler import MultiStepLR +import numpy + +from medcat.cdb import CDB +from medcat.vocab import Vocab +from medcat.config.config import Config, ComponentConfig +from medcat.config.config_rel_cat import ConfigRelCAT +from medcat.storage.serialisers import deserialise +from medcat.storage.serialisables import SerialisingStrategy +from medcat.components.addons.addons import AddonComponent +from medcat.components.addons.relation_extraction.base_component import ( + RelExtrBaseComponent) +from medcat.components.addons.meta_cat.ml_utils import set_all_seeds +from medcat.components.addons.relation_extraction.ml_utils import ( + load_results, load_state, save_results, save_state, + split_list_train_test_by_class) +from medcat.components.addons.relation_extraction.rel_dataset import RelData +from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer +from medcat.tokenizing.tokens import MutableDocument +from medcat.utils.defaults import COMPONENTS_FOLDER +from medcat.utils.defaults import ( + avoid_legacy_conversion, doing_legacy_conversion_message, + LegacyConversionDisabledError) + + +logger = logging.getLogger(__name__) + + +class RelCATAddon(AddonComponent): + DEFAULT_TOKENIZER = 'spacy' + addon_type = 'rel_cat' + output_key = 'relations' + config: ConfigRelCAT + + def __init__(self, config: ConfigRelCAT, + rel_cat: "RelCAT"): + self.config = config + self._rel_cat = rel_cat + + @classmethod + def create_new(cls, config: ConfigRelCAT, base_tokenizer: BaseTokenizer, + cdb: CDB) -> 'RelCATAddon': + """Factory method to create a new MetaCATAddon instance.""" + return cls(config, + RelCAT(base_tokenizer, cdb, config=config, init_model=True)) + + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'RelCATAddon': + if not isinstance(cnf, ConfigRelCAT): + raise ValueError(f"Incompatible config: {cnf}") + config = cnf + if model_load_path is not None: + load_path = os.path.join(model_load_path, COMPONENTS_FOLDER, + cls.NAME_PREFIX + cls.addon_type) + return cls.load_existing(config, tokenizer, cdb, load_path) + return cls.create_new(config, tokenizer, cdb) + + @classmethod + def load_existing(cls, cnf: ConfigRelCAT, + base_tokenizer: BaseTokenizer, + cdb: Optional[CDB], + load_path: str) -> 'RelCATAddon': + """Factory method to load an existing RelCAT addon from disk.""" + rc = RelCAT.load(load_path) + # set the correct base tokenizer and redo data paths + rc.base_tokenizer = base_tokenizer + rc._init_data_paths() + return cls(cnf, rc) + + def serialise_to(self, folder_path: str) -> None: + os.mkdir(folder_path) + self._rel_cat.save(folder_path) + + @property + def name(self) -> str: + return str(self.addon_type) + + # for ManualSerialisable: + + @classmethod + def _create_throwaway_tokenizer(cls) -> BaseTokenizer: + """ + Mirrors `MetaCATAddon._create_throwaway_tokenizer` + """ + logger.warning( + "A base tokenizer was not provided during the loading of a " + "RelCAT. The tokenizer is used to register the required data " + "paths for RelCAT to function. Using the default of '%s'. If " + "this it not the tokenizer you will end up using, RelCAT may " + "be unable to recover unless a) the paths are registered " + "explicitly, or b) there are other RelCATs created with the " + "correct tokenizer. Do note that this will also create " + "another instance of the tokenizer, though it should be " + "garbage collected soon.", cls.DEFAULT_TOKENIZER + ) + # NOTE: the use of a (mostly) default config here probably won't + # affect anything since the tokenizer itself won't be used + gcnf = Config() + gcnf.general.nlp.provider = 'spacy' + return create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf) + + @classmethod + def deserialise_from(cls, folder_path: str, **init_kwargs + ) -> 'RelCATAddon': + """Deserialise a RelCAT addon from disk. + + Mirrors `MetaCATAddon.deserialise_from`: when called via the + pipeline, `tokenizer`/`cnf` are supplied; when called standalone + (e.g. `CAT.load_addons`), they are inferred from disk so that + deserialisation works without full pipeline context. + """ + if "config.json" in os.listdir(folder_path): + if not avoid_legacy_conversion(): + doing_legacy_conversion_message( + logger, cls.__name__, folder_path) + from medcat.utils.legacy.convert_rel_cat import ( + get_rel_cat_from_old) + if 'cdb' in init_kwargs: + cdb = init_kwargs['cdb'] + else: + cdb_path = os.path.join(folder_path, "cdb.dat") + if os.path.exists(cdb_path): + cdb = cast(CDB, deserialise(cdb_path)) + else: + cdb = CDB(config=Config()) + if 'tokenizer' in init_kwargs: + tokenizer = init_kwargs['tokenizer'] + else: + tokenizer = cls._create_throwaway_tokenizer() + return get_rel_cat_from_old(cdb, folder_path, tokenizer) + raise LegacyConversionDisabledError(cls.__name__,) + if 'cnf' in init_kwargs: + cnf = init_kwargs['cnf'] + else: + config_path = os.path.join(folder_path, "config") + logger.info( + "Was not provided a config when loading a rel cat from '%s'. " + "Inferring config from file at '%s'", folder_path, + config_path) + cnf = ConfigRelCAT.load(load_path=folder_path) + if 'model_config' in init_kwargs: + cnf.merge_config(init_kwargs['model_config']) + if 'tokenizer' in init_kwargs: + tokenizer = init_kwargs['tokenizer'] + else: + tokenizer = cls._create_throwaway_tokenizer() + return cls.load_existing( + load_path=folder_path, + cnf=cnf, + base_tokenizer=tokenizer, + cdb=None) + + def get_strategy(self) -> SerialisingStrategy: + return SerialisingStrategy.MANUAL + + @classmethod + def get_init_attrs(cls) -> list[str]: + return [] + + @classmethod + def ignore_attrs(cls) -> list[str]: + return [] + + @classmethod + def include_properties(cls) -> list[str]: + return [] + + def __call__(self, doc: MutableDocument): + return self._rel_cat(doc) + + +class BalancedBatchSampler(Sampler): + + def __init__(self, dataset, classes, + batch_size, max_samples, max_minority): + self.dataset = dataset + self.classes = classes + self.batch_size = batch_size + self.num_classes = len(classes) + self.indices = list(range(len(dataset))) + + self.max_minority = max_minority + + self.max_samples_per_class = max_samples + + def __len__(self): + return (len(self.indices) + self.batch_size - 1) // self.batch_size + + def __iter__(self): + batch_counter = 0 + indices = self.indices.copy() + while batch_counter != self.__len__(): + batch = [] + + class_counts = {c: 0 for c in self.classes} + while len(batch) < self.batch_size: + + index = random.choice(indices) + # Assuming label is at index 1 + label = self.dataset[index][2].numpy().tolist()[0] + if class_counts[label] < self.max_samples_per_class[label]: + batch.append(index) + class_counts[label] += 1 + if self.max_samples_per_class[label] > self.max_minority: + indices.remove(index) + + yield batch + batch_counter += 1 + + +class RelCAT: + """The RelCAT class used for training 'Relation-Annotation' models, i.e., + annotation of relations between clinical concepts. + + Args: + cdb (CDB): cdb, this is used when creating relation datasets. + + tokenizer (TokenizerWrapperBERT): + The Huggingface tokenizer instance. This can be a pre-trained + tokenzier instance from a BERT-style model. For now, only + BERT models are supported. + + config (ConfigRelCAT): + the configuration for RelCAT. Param descriptions available in + ConfigRelCAT class docs. + + task (str, optional): What task is this model supposed to handle. + Defaults to "train" + init_model (bool, optional): loads default model. Defaults to False. + + """ + addon_type = 'rel_cat' + output_key = 'rel_' + + def __init__(self, base_tokenizer: BaseTokenizer, + cdb: CDB, config: ConfigRelCAT = ConfigRelCAT(), + task: str = "train", init_model: bool = False): + self.base_tokenizer = base_tokenizer + self.component = RelExtrBaseComponent() + self.task: str = task + self.checkpoint_path: str = "./" + + set_all_seeds(config.general.seed) + + if init_model: + self.component = RelExtrBaseComponent( + config=config, task=task, init_model=True) + + self.cdb = cdb + logging.basicConfig( + level=self.component.relcat_config.general.log_level) + logger.setLevel(self.component.relcat_config.general.log_level) + + self.is_cuda_available = torch.cuda.is_available() + self.device = torch.device( + "cuda" if self.is_cuda_available and + self.component.relcat_config.general.device != "cpu" else "cpu") + self._init_data_paths() + + def _init_data_paths(self): + doc_cls = self.base_tokenizer.get_doc_class() + doc_cls.register_addon_path('relations', def_val=[], force=True) + entity_cls = self.base_tokenizer.get_entity_class() + entity_cls.register_addon_path('start', def_val=None, force=True) + entity_cls.register_addon_path('end', def_val=None, force=True) + + def save(self, save_path: str = "./") -> None: + self.component.save(save_path=save_path) + + @classmethod + def load(cls, load_path: str = "./") -> "RelCAT": + + if os.path.exists(os.path.join(load_path, "cdb.dat")): + cdb = cast(CDB, deserialise(os.path.join(load_path, "cdb.dat"))) + else: + cdb = CDB(config=Config()) + logger.info( + "The default CDB file name 'cdb.dat' doesn't exist in the " + "specified path, you will need to load & set " + "a CDB manually via rel_cat.cdb = CDB.load('path') ") + + component = RelExtrBaseComponent.load( + pretrained_model_name_or_path=load_path) + + device = torch.device( + "cuda" if torch.cuda.is_available() and + component.relcat_config.general.device != "cpu" else "cpu") + + rel_cat = RelCAT( + # NOTE: this is a throaway tokenizer just for registrations + create_tokenizer(cdb.config.general.nlp.provider, cdb.config), + cdb=cdb, config=component.relcat_config, task=component.task) + rel_cat.device = device + rel_cat.component = component + + return rel_cat + + def __call__(self, doc: MutableDocument) -> MutableDocument: + doc = next(self.pipe(iter([doc]))) + return doc + + def _create_test_train_datasets(self, data: dict, + split_sets: bool = False): + train_data: dict = {} + test_data: dict = {} + + if split_sets: + rc_cnf = self.component.relcat_config + (train_data["output_relations"], + test_data["output_relations"]) = split_list_train_test_by_class( + data["output_relations"], + test_size=rc_cnf.train.test_size, + shuffle=rc_cnf.train.shuffle_data, + sample_limit=rc_cnf.general.limit_samples_per_class) + + test_data_label_names = [ + rec[4] for rec in test_data["output_relations"]] + + (test_data["nclasses"], test_data["labels2idx"], + test_data["idx2label"]) = RelData.get_labels( + test_data_label_names, self.component.relcat_config) + + for idx in range(len(test_data["output_relations"])): + test_data["output_relations" + ][idx][5] = test_data["labels2idx"][ + test_data["output_relations"][idx][4]] + else: + train_data["output_relations"] = data["output_relations"] + + for k, v in data.items(): + if k != "output_relations": + train_data[k] = [] + test_data[k] = [] + + train_data_label_names = [rec[4] + for rec in train_data["output_relations"]] + + (train_data["nclasses"], train_data["labels2idx"], + train_data["idx2label"]) = RelData.get_labels( + train_data_label_names, self.component.relcat_config) + + for idx in range(len(train_data["output_relations"])): + train_data["output_relations" + ][idx][5] = train_data["labels2idx"][ + train_data["output_relations"][idx][4]] + + return train_data, test_data + + def train(self, export_data_path: str = "", train_csv_path: str = "", + test_csv_path: str = "", checkpoint_path: str = "./"): + + if self.is_cuda_available: + logger.info("Training on device: %s%s", + str(torch.cuda.get_device_name(0)), str(self.device)) + + self.component.model = self.component.model.to(self.device) + + rc_cnf = self.component.relcat_config + + # resize vocab just in case more tokens have been added + self.component.model_config.vocab_size = ( + self.component.tokenizer.get_size()) + + train_rel_data = RelData( + cdb=self.cdb, config=rc_cnf, + tokenizer=self.component.tokenizer) + test_rel_data = RelData( + cdb=self.cdb, config=rc_cnf, + tokenizer=self.component.tokenizer) + + if train_csv_path != "": + if test_csv_path != "": + train_rel_data.dataset, _ = self._create_test_train_datasets( + train_rel_data.create_base_relations_from_csv( + train_csv_path), split_sets=False) + test_rel_data.dataset, _ = self._create_test_train_datasets( + train_rel_data.create_base_relations_from_csv( + test_csv_path), split_sets=False) + else: + (train_rel_data.dataset, + test_rel_data.dataset) = self._create_test_train_datasets( + train_rel_data.create_base_relations_from_csv( + train_csv_path), split_sets=True) + + elif export_data_path != "": + export_data = {} + with open(export_data_path) as f: + export_data = json.load(f) + (train_rel_data.dataset, + test_rel_data.dataset) = self._create_test_train_datasets( + train_rel_data.create_relations_from_export(export_data), + split_sets=True) + else: + raise ValueError( + "NO DATA HAS BEEN PROVIDED (MedCAT Trainer export " + "JSON/CSV/spacy_DOCS)") + + train_dataset_size = len(train_rel_data) + batch_size = ( + train_dataset_size if train_dataset_size < rc_cnf.train.batch_size + else rc_cnf.train.batch_size) + + # to use stratified batching + if rc_cnf.train.stratified_batching: + sampler = BalancedBatchSampler( + train_rel_data, [ + i for i in + range(rc_cnf.train.nclasses)], + batch_size, + rc_cnf.train.batching_samples_per_class, + rc_cnf.train.batching_minority_limit) + + train_dataloader = DataLoader( + train_rel_data, num_workers=0, + collate_fn=self.component.padding_seq, + batch_sampler=sampler, + pin_memory=rc_cnf.general.pin_memory) + else: + train_dataloader = DataLoader( + train_rel_data, batch_size=batch_size, + shuffle=rc_cnf.train.shuffle_data, + num_workers=0, + collate_fn=self.component.padding_seq, + pin_memory=rc_cnf.general.pin_memory) + + test_dataset_size = len(test_rel_data) + test_batch_size = ( + test_dataset_size if + test_dataset_size < rc_cnf.train.batch_size + else rc_cnf.train.batch_size) + test_dataloader = DataLoader( + test_rel_data, + batch_size=test_batch_size, + shuffle=rc_cnf.train.shuffle_data, + num_workers=0, + collate_fn=self.component.padding_seq, + pin_memory=rc_cnf.general.pin_memory) + + if (rc_cnf.train.class_weights is not None and + rc_cnf.train.enable_class_weights): + criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( + numpy.asarray(rc_cnf.train.class_weights) + ).to(self.device)) + elif rc_cnf.train.enable_class_weights: + all_class_lbl_ids = [ + rec[5] for rec in train_rel_data.dataset["output_relations"]] + rc_cnf.train.class_weights = ( + compute_class_weight(class_weight="balanced", + classes=numpy.unique(all_class_lbl_ids), + y=all_class_lbl_ids).tolist()) + criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( + rc_cnf.train.class_weights).to( + self.device)) + else: + criterion = nn.CrossEntropyLoss() + + if self.component.optimizer is None: + parameters = filter(lambda p: p.requires_grad, + self.component.model.parameters()) + self.component.optimizer = AdamW( + parameters, lr=self.component.relcat_config.train.lr, + weight_decay=rc_cnf.train.adam_weight_decay, + betas=rc_cnf.train.adam_betas, eps=rc_cnf.train.adam_epsilon) + + if self.component.scheduler is None: + self.component.scheduler = MultiStepLR( + self.component.optimizer, + milestones=rc_cnf.train.multistep_milestones, + gamma=rc_cnf.train.multistep_lr_gamma) + + self.epoch, self.best_f1 = load_state( + self.component.model, self.component.optimizer, + self.component.scheduler, load_best=False, path=checkpoint_path, + relcat_config=rc_cnf) + + logger.info("Starting training process...") + + losses_per_epoch, accuracy_per_epoch, f1_per_epoch = load_results( + path=checkpoint_path) + + if train_rel_data.dataset["nclasses"] > rc_cnf.train.nclasses: + rc_cnf.train.nclasses = train_rel_data.dataset["nclasses"] + self.component.model.relcat_config.train.nclasses = ( + rc_cnf.train.nclasses) + + rc_cnf.general.labels2idx.update(train_rel_data.dataset["labels2idx"]) + rc_cnf.general.idx2labels = { + int(v): k for k, v in rc_cnf.general.labels2idx.items()} + + gradient_acc_steps = ( + rc_cnf.train.gradient_acc_steps) + max_grad_norm = rc_cnf.train.max_grad_norm + + _epochs = self.epoch + rc_cnf.train.nepochs + + for epoch in range(0, _epochs): + epoch_losses, epoch_precision, epoch_f1 = self._train_epoch( + epoch, gradient_acc_steps, max_grad_norm, train_dataset_size, + train_dataloader, test_dataloader, criterion, _epochs, + checkpoint_path) + losses_per_epoch.extend(epoch_losses) + accuracy_per_epoch.extend(epoch_precision) + f1_per_epoch.extend(epoch_f1) + + def _train_epoch(self, epoch: int, + gradient_acc_steps: int, + max_grad_norm: float, + train_dataset_size: int, + train_dataloader: DataLoader, + test_dataloader: DataLoader, + criterion: nn.CrossEntropyLoss, + _epochs: int, + checkpoint_path: str) -> tuple[list, list, list]: + rc_cnf = self.component.relcat_config + start_time = datetime.now().time() + total_loss = 0.0 + + loss_per_batch = [] + accuracy_per_batch = [] + + logger.info( + "Total epochs on this model: %d | currently training " + "epoch %d", _epochs, epoch) + + pbar = tqdm(total=train_dataset_size) + + for i, data in enumerate(train_dataloader, 0): + self.component.model.train() + self.component.model.zero_grad() + + current_batch_size = len(data[0]) + token_ids, e1_e2_start, labels, _, _ = data + + attention_mask = ( + token_ids != self.component.pad_id).float().to(self.device) + + token_type_ids = torch.zeros( + (token_ids.shape[0], token_ids.shape[1])).long().to( + self.device) + + labels = labels.to(self.device) + + model_output, classification_logits = self.component.model( + input_ids=token_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + e1_e2_start=e1_e2_start + ) + + batch_loss = criterion( + classification_logits.view( + -1, rc_cnf.train.nclasses).to(self.device), + labels.squeeze(1)) + + batch_loss.backward() + batch_loss = batch_loss / gradient_acc_steps + + total_loss += batch_loss.item() / current_batch_size + + (batch_acc, _, batch_precision, batch_f1, + _, _, batch_stats_per_label) = self.evaluate_( + classification_logits, labels, ignore_idx=-1) + + loss_per_batch.append(batch_loss / current_batch_size) + accuracy_per_batch.append(batch_acc) + + torch.nn.utils.clip_grad_norm_( + self.component.model.parameters(), max_grad_norm) + + if (i % gradient_acc_steps) == 0: + self.component.optimizer.step() + self.component.scheduler.step() + if ((i + 1) % current_batch_size == 0): + logger.debug( + "[Epoch: %d, loss per batch, accuracy per batch: %.3f," + " %.3f, average total loss %.3f , total loss %.3f]", + epoch, loss_per_batch[-1], accuracy_per_batch[-1], + total_loss / (i + 1), total_loss) + + pbar.update(current_batch_size) + + pbar.close() + + losses_per_epoch = [] + accuracy_per_epoch = [] + f1_per_epoch = [] + if len(loss_per_batch) > 0: + losses_per_epoch.append( + sum(loss_per_batch) / len(loss_per_batch)) + logger.info("Losses at Epoch %d: %.5f" % + (epoch, losses_per_epoch[-1])) + + if len(accuracy_per_batch) > 0: + accuracy_per_epoch.append( + sum(accuracy_per_batch) / len(accuracy_per_batch)) + logger.info("Train accuracy at Epoch %d: %.5f" % + (epoch, accuracy_per_epoch[-1])) + + total_loss = total_loss / (i + 1) + + end_time = datetime.now().time() + + logger.info( + "========================" + " TRAIN SET TEST RESULTS " + "========================") + _ = self.evaluate_results(train_dataloader, self.component.pad_id) + + logger.info( + "========================" + " TEST SET TEST RESULTS " + "========================") + results = self.evaluate_results( + test_dataloader, self.component.pad_id) + + f1_per_epoch.append(results['f1']) + + logger.info("Epoch finished, took %s seconds", + str(datetime.combine(date.today(), end_time) + - datetime.combine(date.today(), start_time))) + + self.epoch += 1 + + if len(f1_per_epoch) > 0 and f1_per_epoch[-1] > self.best_f1: + self.best_f1 = f1_per_epoch[-1] + save_state( + self.component.model, self.component.optimizer, + self.component.scheduler, self.epoch, self.best_f1, + checkpoint_path, model_name=rc_cnf.general.model_name, + task=self.task, is_checkpoint=False) + + if (epoch % 1) == 0: + save_results( + { + "losses_per_epoch": losses_per_epoch, + "accuracy_per_epoch": accuracy_per_epoch, + "f1_per_epoch": f1_per_epoch, + "epoch": epoch + }, file_prefix="train", path=checkpoint_path) + save_state(self.component.model, self.component.optimizer, + self.component.scheduler, self.epoch, self.best_f1, + checkpoint_path, + model_name=rc_cnf.general.model_name, + task=self.task, is_checkpoint=True) + return losses_per_epoch, accuracy_per_epoch, f1_per_epoch + + def evaluate_(self, output_logits, labels, ignore_idx): + # ignore index (padding) when calculating accuracy + idxs = (labels != ignore_idx).squeeze() + labels_ = labels.squeeze()[idxs].to(self.device) + pred_labels = torch.softmax(output_logits, dim=1).max(1)[1] + pred_labels = pred_labels[idxs].to(self.device) + + true_labels = labels_.cpu().numpy().tolist( + ) if labels_.is_cuda else labels_.numpy().tolist() + pred_labels = pred_labels.cpu().numpy().tolist( + ) if pred_labels.is_cuda else pred_labels.numpy().tolist() + + unique_labels = set(true_labels) + + batch_size = len(true_labels) + + stat_per_label = dict() + + total_tp, total_fp, total_tn, total_fn = 0, 0, 0, 0 + acc, micro_recall, micro_precision, micro_f1 = 0, 0, 0, 0 + + for label in unique_labels: + stat_per_label[label] = { + "tp": 0, "fp": 0, "tn": 0, "fn": 0, + "f1": 0.0, "acc": 0.0, "prec": 0.0, "recall": 0.0} + + for true_label_idx in range(len(true_labels)): + if true_labels[true_label_idx] == label: + if pred_labels[true_label_idx] == label: + stat_per_label[label]["tp"] += 1 + total_tp += 1 + if pred_labels[true_label_idx] != label: + stat_per_label[label]["fp"] += 1 + total_fp += 1 + elif (true_labels[true_label_idx] != label and + label == pred_labels[true_label_idx]): + stat_per_label[label]["fn"] += 1 + total_fn += 1 + else: + stat_per_label[label]["tn"] += 1 + total_tn += 1 + + lbl_tp_tn = stat_per_label[label]["tn"] + \ + stat_per_label[label]["tp"] + + lbl_tp_fn = stat_per_label[label]["fn"] + \ + stat_per_label[label]["tp"] + lbl_tp_fn = lbl_tp_fn if lbl_tp_fn > 0.0 else 1.0 + + lbl_tp_fp = stat_per_label[label]["tp"] + \ + stat_per_label[label]["fp"] + lbl_tp_fp = lbl_tp_fp if lbl_tp_fp > 0.0 else 1.0 + + stat_per_label[label]["acc"] = lbl_tp_tn / batch_size + stat_per_label[label]["prec"] = (stat_per_label[label]["tp"] / + lbl_tp_fp) + stat_per_label[label]["recall"] = (stat_per_label[label]["tp"] / + lbl_tp_fn) + + lbl_re_pr = stat_per_label[label]["recall"] + \ + stat_per_label[label]["prec"] + lbl_re_pr = lbl_re_pr if lbl_re_pr > 0.0 else 1.0 + + stat_per_label[label]["f1"] = ( + 2 * (stat_per_label[label]["recall"] * + stat_per_label[label]["prec"])) / lbl_re_pr + + tp_fn = total_fn + total_tp + tp_fn = tp_fn if tp_fn > 0.0 else 1.0 + + tp_fp = total_fp + total_tp + tp_fp = tp_fp if tp_fp > 0.0 else 1.0 + + micro_recall = total_tp / tp_fn + micro_precision = total_tp / tp_fp + + re_pr = micro_recall + micro_precision + re_pr = re_pr if re_pr > 0.0 else 1.0 + micro_f1 = (2 * (micro_recall * micro_precision)) / re_pr + + acc = total_tp / batch_size + + return (acc, micro_recall, micro_precision, micro_f1, + pred_labels, true_labels, stat_per_label) + + def evaluate_results(self, data_loader, pad_id): + logger.info("Evaluating test samples...") + rc_cnf = self.component.relcat_config + if (rc_cnf.train.class_weights is not None and + rc_cnf.train.enable_class_weights): + criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( + rc_cnf.train.class_weights).to(self.device)) + else: + criterion = nn.CrossEntropyLoss() + + total_loss, total_acc, total_f1, total_recall, total_precision = ( + 0.0, 0.0, 0.0, 0.0, 0.0) + all_batch_stats_per_label = [] + + self.component.model.eval() + + for i, data in enumerate(data_loader): + with torch.no_grad(): + token_ids, e1_e2_start, labels, _, _ = data + attention_mask = (token_ids != pad_id).float().to(self.device) + token_type_ids = torch.zeros( + (*token_ids.shape[:2],)).long().to(self.device) + + labels = labels.to(self.device) + + model_output, pred_classification_logits = ( + self.component.model(token_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + Q=None, + e1_e2_start=e1_e2_start)) + + batch_loss = criterion(pred_classification_logits.view( + -1, rc_cnf.train.nclasses).to(self.device), + labels.squeeze(1)) + total_loss += batch_loss.item() + + (batch_accuracy, batch_recall, batch_precision, batch_f1, + pred_labels, true_labels, batch_stats_per_label) = ( + self.evaluate_(pred_classification_logits, + labels, ignore_idx=-1)) + + all_batch_stats_per_label.append(batch_stats_per_label) + + total_acc += batch_accuracy + total_recall += batch_recall + total_precision += batch_precision + total_f1 += batch_f1 + + final_stats_per_label = {} + + for batch_label_stats in all_batch_stats_per_label: + for label_id, stat_dict in batch_label_stats.items(): + + if label_id not in final_stats_per_label.keys(): + final_stats_per_label[label_id] = stat_dict + else: + for stat, score in stat_dict.items(): + final_stats_per_label[label_id][stat] += score + + for label_id, stat_dict in final_stats_per_label.items(): + for stat_name, value in stat_dict.items(): + final_stats_per_label[label_id][stat_name] = value / (i + 1) + + total_loss = total_loss / (i + 1) + total_acc = total_acc / (i + 1) + total_precision = total_precision / (i + 1) + total_f1 = total_f1 / (i + 1) + total_recall = total_recall / (i + 1) + + results = { + "loss": total_loss, + "accuracy": total_acc, + "precision": total_precision, + "recall": total_recall, + "f1": total_f1 + } + + logger.info("=" * 20 + " Evaluation Results " + "=" * 20) + logger.info(" no. of batches:" + str(i + 1)) + for key in sorted(results.keys()): + logger.info(" %s = %0.3f" % (key, results[key])) + logger.info("-" * 23 + " class stats " + "-" * 23) + for label_id, stat_dict in final_stats_per_label.items(): + logger.info( + "label: %s | f1: %0.3f | prec : %0.3f | acc: %0.3f | " + "recall: %0.3f ", + rc_cnf.general.idx2labels[label_id], + stat_dict["f1"], + stat_dict["prec"], + stat_dict["acc"], + stat_dict["recall"] + ) + logger.info("-" * 59) + logger.info("=" * 59) + + return results + + def pipe(self, stream: Iterable[MutableDocument], *args, **kwargs + ) -> Iterator[MutableDocument]: + rc_cnf = self.component.relcat_config + + predict_rel_dataset = RelData( + cdb=self.cdb, config=rc_cnf, + tokenizer=self.component.tokenizer) + + self.component.model = self.component.model.to(self.device) + + for doc_id, doc in enumerate(stream, 0): + predict_rel_dataset.dataset, _ = self._create_test_train_datasets( + data=predict_rel_dataset.create_base_relations_from_doc( + doc, doc_id=str(doc_id)), + split_sets=False) + + predict_dataloader = DataLoader( + dataset=predict_rel_dataset, shuffle=False, + batch_size=rc_cnf.train.batch_size, + num_workers=0, collate_fn=self.component.padding_seq, + pin_memory=rc_cnf.general.pin_memory) + + total_rel_found = len( + predict_rel_dataset.dataset["output_relations"]) + rel_idx = -1 + + logger.info("total relations for doc: " + str(total_rel_found)) + logger.info("processing...") + + pbar = tqdm(total=total_rel_found) + + for i, data in enumerate(predict_dataloader): + with torch.no_grad(): + token_ids, e1_e2_start, labels, _, _ = data + + attention_mask = ( + token_ids != self.component.pad_id + ).float().to(self.device) + token_type_ids = torch.zeros( + *token_ids.shape[:2]).long().to(self.device) + + (model_output, + pred_classification_logits) = self.component.model( + token_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, + e1_e2_start=e1_e2_start) + + for i, pred_rel_logits in enumerate( + pred_classification_logits): + rel_idx += 1 + + confidence = torch.softmax( + pred_rel_logits, dim=0).max(0) + predicted_label_id = int(confidence[1].item()) + + relations: list = doc.get_addon_data( # type: ignore + "relations") + out_rels = predict_rel_dataset.dataset[ + "output_relations"][rel_idx] + relations.append( + { + "relation": rc_cnf.general.idx2labels[ + predicted_label_id], + "label_id": predicted_label_id, + "ent1_text": out_rels[2], + "ent2_text": out_rels[3], + "confidence": float("{:.3f}".format( + confidence[0])), + "start_ent1_char_pos": out_rels[18], + "end_ent1_char_pos": out_rels[19], + "start_ent2_char_pos": out_rels[20], + "end_ent2_char_pos": out_rels[21], + "start_entity_id": out_rels[8], + "end_entity_id": out_rels[9], + }) + pbar.update(len(token_ids)) + pbar.close() + + yield doc + + def predict_text_with_anns(self, text: str, annotations: list[dict] + ) -> MutableDocument: + """ Creates spacy doc from text and annotation input. + Predicts using self.__call__ + + Args: + text (str): text + annotations (dict): dict containing the entities from NER + (of your choosing), the format must be the following format: + [ + { + "cui": "202099003", -this is optional + "value": "discoid lateral meniscus", + "start": 294, + "end": 318 + }, + { + "cui": "202099003", + "value": "Discoid lateral meniscus", + "start": 1905, + "end": 1929, + } + ] + + Returns: + Doc: spacy doc with the relations. + """ + # NOTE: This runs not an empty language, but the specified one + base_tokenizer = create_tokenizer( + self.cdb.config.general.nlp.provider, self.cdb.config) + doc = base_tokenizer(text) + + for ann in annotations: + tkn_idx = [] + for ind, word in enumerate(doc): + end_char = word.base.char_index + len(word.base.text) + if end_char <= ann['end'] and end_char > ann['start']: + tkn_idx.append(ind) + entity = base_tokenizer.create_entity( + doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"]) + entity.cui = ann["cui"] + entity.set_addon_data('start', ann['start']) + entity.set_addon_data('end', ann['end']) + doc.ner_ents.append(entity) + + doc = self(doc) + + return doc From 340ac31afd637b67480c95915191a2dedfd65b9b Mon Sep 17 00:00:00 2001 From: alhendrickson Date: Tue, 16 Jun 2026 10:15:34 +0000 Subject: [PATCH 8/9] Revert "fix(medcat): Relcat addon fix deserialise - make the same as metacataddon" This reverts commit f9326d8b387e9af2dd7303fbded645eb5b3aa7d0. --- .../components/addons/meta_cat/meta_cat.py | 2 +- .../addons/relation_extraction/rel_cat.py | 1927 ++++++++--------- 2 files changed, 953 insertions(+), 976 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index 475eb8291..05fb4a185 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -223,7 +223,7 @@ def deserialise_from(cls, folder_path: str, **init_kwargs # load legacy config (assuming it exists) config_path += ".dat" logger.info( - "Was not provided a config when loading a meta cat from '%s'. " + "Was not provide a config when loading a meta cat from '%s'. " "Inferring config from file at '%s'", folder_path, config_path) cnf = ConfigMetaCAT.load(config_path) diff --git a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py index 585b8518e..0e6d099a6 100644 --- a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py +++ b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py @@ -1,975 +1,952 @@ -import json -import logging -import os -import random -from typing import Optional - -from sklearn.utils import compute_class_weight -import torch -import torch.nn as nn - -from tqdm import tqdm -from datetime import date, datetime -from typing import Iterable, Iterator, cast -from torch.utils.data import DataLoader, Sampler -from torch.optim import AdamW -from torch.optim.lr_scheduler import MultiStepLR -import numpy - -from medcat.cdb import CDB -from medcat.vocab import Vocab -from medcat.config.config import Config, ComponentConfig -from medcat.config.config_rel_cat import ConfigRelCAT -from medcat.storage.serialisers import deserialise -from medcat.storage.serialisables import SerialisingStrategy -from medcat.components.addons.addons import AddonComponent -from medcat.components.addons.relation_extraction.base_component import ( - RelExtrBaseComponent) -from medcat.components.addons.meta_cat.ml_utils import set_all_seeds -from medcat.components.addons.relation_extraction.ml_utils import ( - load_results, load_state, save_results, save_state, - split_list_train_test_by_class) -from medcat.components.addons.relation_extraction.rel_dataset import RelData -from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer -from medcat.tokenizing.tokens import MutableDocument -from medcat.utils.defaults import COMPONENTS_FOLDER -from medcat.utils.defaults import ( - avoid_legacy_conversion, doing_legacy_conversion_message, - LegacyConversionDisabledError) - - -logger = logging.getLogger(__name__) - - -class RelCATAddon(AddonComponent): - DEFAULT_TOKENIZER = 'spacy' - addon_type = 'rel_cat' - output_key = 'relations' - config: ConfigRelCAT - - def __init__(self, config: ConfigRelCAT, - rel_cat: "RelCAT"): - self.config = config - self._rel_cat = rel_cat - - @classmethod - def create_new(cls, config: ConfigRelCAT, base_tokenizer: BaseTokenizer, - cdb: CDB) -> 'RelCATAddon': - """Factory method to create a new MetaCATAddon instance.""" - return cls(config, - RelCAT(base_tokenizer, cdb, config=config, init_model=True)) - - @classmethod - def create_new_component( - cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, - cdb: CDB, vocab: Vocab, model_load_path: Optional[str] - ) -> 'RelCATAddon': - if not isinstance(cnf, ConfigRelCAT): - raise ValueError(f"Incompatible config: {cnf}") - config = cnf - if model_load_path is not None: - load_path = os.path.join(model_load_path, COMPONENTS_FOLDER, - cls.NAME_PREFIX + cls.addon_type) - return cls.load_existing(config, tokenizer, cdb, load_path) - return cls.create_new(config, tokenizer, cdb) - - @classmethod - def load_existing(cls, cnf: ConfigRelCAT, - base_tokenizer: BaseTokenizer, - cdb: Optional[CDB], - load_path: str) -> 'RelCATAddon': - """Factory method to load an existing RelCAT addon from disk.""" - rc = RelCAT.load(load_path) - # set the correct base tokenizer and redo data paths - rc.base_tokenizer = base_tokenizer - rc._init_data_paths() - return cls(cnf, rc) - - def serialise_to(self, folder_path: str) -> None: - os.mkdir(folder_path) - self._rel_cat.save(folder_path) - - @property - def name(self) -> str: - return str(self.addon_type) - - # for ManualSerialisable: - - @classmethod - def _create_throwaway_tokenizer(cls) -> BaseTokenizer: - """ - Mirrors `MetaCATAddon._create_throwaway_tokenizer` - """ - logger.warning( - "A base tokenizer was not provided during the loading of a " - "RelCAT. The tokenizer is used to register the required data " - "paths for RelCAT to function. Using the default of '%s'. If " - "this it not the tokenizer you will end up using, RelCAT may " - "be unable to recover unless a) the paths are registered " - "explicitly, or b) there are other RelCATs created with the " - "correct tokenizer. Do note that this will also create " - "another instance of the tokenizer, though it should be " - "garbage collected soon.", cls.DEFAULT_TOKENIZER - ) - # NOTE: the use of a (mostly) default config here probably won't - # affect anything since the tokenizer itself won't be used - gcnf = Config() - gcnf.general.nlp.provider = 'spacy' - return create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf) - - @classmethod - def deserialise_from(cls, folder_path: str, **init_kwargs - ) -> 'RelCATAddon': - """Deserialise a RelCAT addon from disk. - - Mirrors `MetaCATAddon.deserialise_from`: when called via the - pipeline, `tokenizer`/`cnf` are supplied; when called standalone - (e.g. `CAT.load_addons`), they are inferred from disk so that - deserialisation works without full pipeline context. - """ - if "config.json" in os.listdir(folder_path): - if not avoid_legacy_conversion(): - doing_legacy_conversion_message( - logger, cls.__name__, folder_path) - from medcat.utils.legacy.convert_rel_cat import ( - get_rel_cat_from_old) - if 'cdb' in init_kwargs: - cdb = init_kwargs['cdb'] - else: - cdb_path = os.path.join(folder_path, "cdb.dat") - if os.path.exists(cdb_path): - cdb = cast(CDB, deserialise(cdb_path)) - else: - cdb = CDB(config=Config()) - if 'tokenizer' in init_kwargs: - tokenizer = init_kwargs['tokenizer'] - else: - tokenizer = cls._create_throwaway_tokenizer() - return get_rel_cat_from_old(cdb, folder_path, tokenizer) - raise LegacyConversionDisabledError(cls.__name__,) - if 'cnf' in init_kwargs: - cnf = init_kwargs['cnf'] - else: - config_path = os.path.join(folder_path, "config") - logger.info( - "Was not provided a config when loading a rel cat from '%s'. " - "Inferring config from file at '%s'", folder_path, - config_path) - cnf = ConfigRelCAT.load(load_path=folder_path) - if 'model_config' in init_kwargs: - cnf.merge_config(init_kwargs['model_config']) - if 'tokenizer' in init_kwargs: - tokenizer = init_kwargs['tokenizer'] - else: - tokenizer = cls._create_throwaway_tokenizer() - return cls.load_existing( - load_path=folder_path, - cnf=cnf, - base_tokenizer=tokenizer, - cdb=None) - - def get_strategy(self) -> SerialisingStrategy: - return SerialisingStrategy.MANUAL - - @classmethod - def get_init_attrs(cls) -> list[str]: - return [] - - @classmethod - def ignore_attrs(cls) -> list[str]: - return [] - - @classmethod - def include_properties(cls) -> list[str]: - return [] - - def __call__(self, doc: MutableDocument): - return self._rel_cat(doc) - - -class BalancedBatchSampler(Sampler): - - def __init__(self, dataset, classes, - batch_size, max_samples, max_minority): - self.dataset = dataset - self.classes = classes - self.batch_size = batch_size - self.num_classes = len(classes) - self.indices = list(range(len(dataset))) - - self.max_minority = max_minority - - self.max_samples_per_class = max_samples - - def __len__(self): - return (len(self.indices) + self.batch_size - 1) // self.batch_size - - def __iter__(self): - batch_counter = 0 - indices = self.indices.copy() - while batch_counter != self.__len__(): - batch = [] - - class_counts = {c: 0 for c in self.classes} - while len(batch) < self.batch_size: - - index = random.choice(indices) - # Assuming label is at index 1 - label = self.dataset[index][2].numpy().tolist()[0] - if class_counts[label] < self.max_samples_per_class[label]: - batch.append(index) - class_counts[label] += 1 - if self.max_samples_per_class[label] > self.max_minority: - indices.remove(index) - - yield batch - batch_counter += 1 - - -class RelCAT: - """The RelCAT class used for training 'Relation-Annotation' models, i.e., - annotation of relations between clinical concepts. - - Args: - cdb (CDB): cdb, this is used when creating relation datasets. - - tokenizer (TokenizerWrapperBERT): - The Huggingface tokenizer instance. This can be a pre-trained - tokenzier instance from a BERT-style model. For now, only - BERT models are supported. - - config (ConfigRelCAT): - the configuration for RelCAT. Param descriptions available in - ConfigRelCAT class docs. - - task (str, optional): What task is this model supposed to handle. - Defaults to "train" - init_model (bool, optional): loads default model. Defaults to False. - - """ - addon_type = 'rel_cat' - output_key = 'rel_' - - def __init__(self, base_tokenizer: BaseTokenizer, - cdb: CDB, config: ConfigRelCAT = ConfigRelCAT(), - task: str = "train", init_model: bool = False): - self.base_tokenizer = base_tokenizer - self.component = RelExtrBaseComponent() - self.task: str = task - self.checkpoint_path: str = "./" - - set_all_seeds(config.general.seed) - - if init_model: - self.component = RelExtrBaseComponent( - config=config, task=task, init_model=True) - - self.cdb = cdb - logging.basicConfig( - level=self.component.relcat_config.general.log_level) - logger.setLevel(self.component.relcat_config.general.log_level) - - self.is_cuda_available = torch.cuda.is_available() - self.device = torch.device( - "cuda" if self.is_cuda_available and - self.component.relcat_config.general.device != "cpu" else "cpu") - self._init_data_paths() - - def _init_data_paths(self): - doc_cls = self.base_tokenizer.get_doc_class() - doc_cls.register_addon_path('relations', def_val=[], force=True) - entity_cls = self.base_tokenizer.get_entity_class() - entity_cls.register_addon_path('start', def_val=None, force=True) - entity_cls.register_addon_path('end', def_val=None, force=True) - - def save(self, save_path: str = "./") -> None: - self.component.save(save_path=save_path) - - @classmethod - def load(cls, load_path: str = "./") -> "RelCAT": - - if os.path.exists(os.path.join(load_path, "cdb.dat")): - cdb = cast(CDB, deserialise(os.path.join(load_path, "cdb.dat"))) - else: - cdb = CDB(config=Config()) - logger.info( - "The default CDB file name 'cdb.dat' doesn't exist in the " - "specified path, you will need to load & set " - "a CDB manually via rel_cat.cdb = CDB.load('path') ") - - component = RelExtrBaseComponent.load( - pretrained_model_name_or_path=load_path) - - device = torch.device( - "cuda" if torch.cuda.is_available() and - component.relcat_config.general.device != "cpu" else "cpu") - - rel_cat = RelCAT( - # NOTE: this is a throaway tokenizer just for registrations - create_tokenizer(cdb.config.general.nlp.provider, cdb.config), - cdb=cdb, config=component.relcat_config, task=component.task) - rel_cat.device = device - rel_cat.component = component - - return rel_cat - - def __call__(self, doc: MutableDocument) -> MutableDocument: - doc = next(self.pipe(iter([doc]))) - return doc - - def _create_test_train_datasets(self, data: dict, - split_sets: bool = False): - train_data: dict = {} - test_data: dict = {} - - if split_sets: - rc_cnf = self.component.relcat_config - (train_data["output_relations"], - test_data["output_relations"]) = split_list_train_test_by_class( - data["output_relations"], - test_size=rc_cnf.train.test_size, - shuffle=rc_cnf.train.shuffle_data, - sample_limit=rc_cnf.general.limit_samples_per_class) - - test_data_label_names = [ - rec[4] for rec in test_data["output_relations"]] - - (test_data["nclasses"], test_data["labels2idx"], - test_data["idx2label"]) = RelData.get_labels( - test_data_label_names, self.component.relcat_config) - - for idx in range(len(test_data["output_relations"])): - test_data["output_relations" - ][idx][5] = test_data["labels2idx"][ - test_data["output_relations"][idx][4]] - else: - train_data["output_relations"] = data["output_relations"] - - for k, v in data.items(): - if k != "output_relations": - train_data[k] = [] - test_data[k] = [] - - train_data_label_names = [rec[4] - for rec in train_data["output_relations"]] - - (train_data["nclasses"], train_data["labels2idx"], - train_data["idx2label"]) = RelData.get_labels( - train_data_label_names, self.component.relcat_config) - - for idx in range(len(train_data["output_relations"])): - train_data["output_relations" - ][idx][5] = train_data["labels2idx"][ - train_data["output_relations"][idx][4]] - - return train_data, test_data - - def train(self, export_data_path: str = "", train_csv_path: str = "", - test_csv_path: str = "", checkpoint_path: str = "./"): - - if self.is_cuda_available: - logger.info("Training on device: %s%s", - str(torch.cuda.get_device_name(0)), str(self.device)) - - self.component.model = self.component.model.to(self.device) - - rc_cnf = self.component.relcat_config - - # resize vocab just in case more tokens have been added - self.component.model_config.vocab_size = ( - self.component.tokenizer.get_size()) - - train_rel_data = RelData( - cdb=self.cdb, config=rc_cnf, - tokenizer=self.component.tokenizer) - test_rel_data = RelData( - cdb=self.cdb, config=rc_cnf, - tokenizer=self.component.tokenizer) - - if train_csv_path != "": - if test_csv_path != "": - train_rel_data.dataset, _ = self._create_test_train_datasets( - train_rel_data.create_base_relations_from_csv( - train_csv_path), split_sets=False) - test_rel_data.dataset, _ = self._create_test_train_datasets( - train_rel_data.create_base_relations_from_csv( - test_csv_path), split_sets=False) - else: - (train_rel_data.dataset, - test_rel_data.dataset) = self._create_test_train_datasets( - train_rel_data.create_base_relations_from_csv( - train_csv_path), split_sets=True) - - elif export_data_path != "": - export_data = {} - with open(export_data_path) as f: - export_data = json.load(f) - (train_rel_data.dataset, - test_rel_data.dataset) = self._create_test_train_datasets( - train_rel_data.create_relations_from_export(export_data), - split_sets=True) - else: - raise ValueError( - "NO DATA HAS BEEN PROVIDED (MedCAT Trainer export " - "JSON/CSV/spacy_DOCS)") - - train_dataset_size = len(train_rel_data) - batch_size = ( - train_dataset_size if train_dataset_size < rc_cnf.train.batch_size - else rc_cnf.train.batch_size) - - # to use stratified batching - if rc_cnf.train.stratified_batching: - sampler = BalancedBatchSampler( - train_rel_data, [ - i for i in - range(rc_cnf.train.nclasses)], - batch_size, - rc_cnf.train.batching_samples_per_class, - rc_cnf.train.batching_minority_limit) - - train_dataloader = DataLoader( - train_rel_data, num_workers=0, - collate_fn=self.component.padding_seq, - batch_sampler=sampler, - pin_memory=rc_cnf.general.pin_memory) - else: - train_dataloader = DataLoader( - train_rel_data, batch_size=batch_size, - shuffle=rc_cnf.train.shuffle_data, - num_workers=0, - collate_fn=self.component.padding_seq, - pin_memory=rc_cnf.general.pin_memory) - - test_dataset_size = len(test_rel_data) - test_batch_size = ( - test_dataset_size if - test_dataset_size < rc_cnf.train.batch_size - else rc_cnf.train.batch_size) - test_dataloader = DataLoader( - test_rel_data, - batch_size=test_batch_size, - shuffle=rc_cnf.train.shuffle_data, - num_workers=0, - collate_fn=self.component.padding_seq, - pin_memory=rc_cnf.general.pin_memory) - - if (rc_cnf.train.class_weights is not None and - rc_cnf.train.enable_class_weights): - criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( - numpy.asarray(rc_cnf.train.class_weights) - ).to(self.device)) - elif rc_cnf.train.enable_class_weights: - all_class_lbl_ids = [ - rec[5] for rec in train_rel_data.dataset["output_relations"]] - rc_cnf.train.class_weights = ( - compute_class_weight(class_weight="balanced", - classes=numpy.unique(all_class_lbl_ids), - y=all_class_lbl_ids).tolist()) - criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( - rc_cnf.train.class_weights).to( - self.device)) - else: - criterion = nn.CrossEntropyLoss() - - if self.component.optimizer is None: - parameters = filter(lambda p: p.requires_grad, - self.component.model.parameters()) - self.component.optimizer = AdamW( - parameters, lr=self.component.relcat_config.train.lr, - weight_decay=rc_cnf.train.adam_weight_decay, - betas=rc_cnf.train.adam_betas, eps=rc_cnf.train.adam_epsilon) - - if self.component.scheduler is None: - self.component.scheduler = MultiStepLR( - self.component.optimizer, - milestones=rc_cnf.train.multistep_milestones, - gamma=rc_cnf.train.multistep_lr_gamma) - - self.epoch, self.best_f1 = load_state( - self.component.model, self.component.optimizer, - self.component.scheduler, load_best=False, path=checkpoint_path, - relcat_config=rc_cnf) - - logger.info("Starting training process...") - - losses_per_epoch, accuracy_per_epoch, f1_per_epoch = load_results( - path=checkpoint_path) - - if train_rel_data.dataset["nclasses"] > rc_cnf.train.nclasses: - rc_cnf.train.nclasses = train_rel_data.dataset["nclasses"] - self.component.model.relcat_config.train.nclasses = ( - rc_cnf.train.nclasses) - - rc_cnf.general.labels2idx.update(train_rel_data.dataset["labels2idx"]) - rc_cnf.general.idx2labels = { - int(v): k for k, v in rc_cnf.general.labels2idx.items()} - - gradient_acc_steps = ( - rc_cnf.train.gradient_acc_steps) - max_grad_norm = rc_cnf.train.max_grad_norm - - _epochs = self.epoch + rc_cnf.train.nepochs - - for epoch in range(0, _epochs): - epoch_losses, epoch_precision, epoch_f1 = self._train_epoch( - epoch, gradient_acc_steps, max_grad_norm, train_dataset_size, - train_dataloader, test_dataloader, criterion, _epochs, - checkpoint_path) - losses_per_epoch.extend(epoch_losses) - accuracy_per_epoch.extend(epoch_precision) - f1_per_epoch.extend(epoch_f1) - - def _train_epoch(self, epoch: int, - gradient_acc_steps: int, - max_grad_norm: float, - train_dataset_size: int, - train_dataloader: DataLoader, - test_dataloader: DataLoader, - criterion: nn.CrossEntropyLoss, - _epochs: int, - checkpoint_path: str) -> tuple[list, list, list]: - rc_cnf = self.component.relcat_config - start_time = datetime.now().time() - total_loss = 0.0 - - loss_per_batch = [] - accuracy_per_batch = [] - - logger.info( - "Total epochs on this model: %d | currently training " - "epoch %d", _epochs, epoch) - - pbar = tqdm(total=train_dataset_size) - - for i, data in enumerate(train_dataloader, 0): - self.component.model.train() - self.component.model.zero_grad() - - current_batch_size = len(data[0]) - token_ids, e1_e2_start, labels, _, _ = data - - attention_mask = ( - token_ids != self.component.pad_id).float().to(self.device) - - token_type_ids = torch.zeros( - (token_ids.shape[0], token_ids.shape[1])).long().to( - self.device) - - labels = labels.to(self.device) - - model_output, classification_logits = self.component.model( - input_ids=token_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - e1_e2_start=e1_e2_start - ) - - batch_loss = criterion( - classification_logits.view( - -1, rc_cnf.train.nclasses).to(self.device), - labels.squeeze(1)) - - batch_loss.backward() - batch_loss = batch_loss / gradient_acc_steps - - total_loss += batch_loss.item() / current_batch_size - - (batch_acc, _, batch_precision, batch_f1, - _, _, batch_stats_per_label) = self.evaluate_( - classification_logits, labels, ignore_idx=-1) - - loss_per_batch.append(batch_loss / current_batch_size) - accuracy_per_batch.append(batch_acc) - - torch.nn.utils.clip_grad_norm_( - self.component.model.parameters(), max_grad_norm) - - if (i % gradient_acc_steps) == 0: - self.component.optimizer.step() - self.component.scheduler.step() - if ((i + 1) % current_batch_size == 0): - logger.debug( - "[Epoch: %d, loss per batch, accuracy per batch: %.3f," - " %.3f, average total loss %.3f , total loss %.3f]", - epoch, loss_per_batch[-1], accuracy_per_batch[-1], - total_loss / (i + 1), total_loss) - - pbar.update(current_batch_size) - - pbar.close() - - losses_per_epoch = [] - accuracy_per_epoch = [] - f1_per_epoch = [] - if len(loss_per_batch) > 0: - losses_per_epoch.append( - sum(loss_per_batch) / len(loss_per_batch)) - logger.info("Losses at Epoch %d: %.5f" % - (epoch, losses_per_epoch[-1])) - - if len(accuracy_per_batch) > 0: - accuracy_per_epoch.append( - sum(accuracy_per_batch) / len(accuracy_per_batch)) - logger.info("Train accuracy at Epoch %d: %.5f" % - (epoch, accuracy_per_epoch[-1])) - - total_loss = total_loss / (i + 1) - - end_time = datetime.now().time() - - logger.info( - "========================" - " TRAIN SET TEST RESULTS " - "========================") - _ = self.evaluate_results(train_dataloader, self.component.pad_id) - - logger.info( - "========================" - " TEST SET TEST RESULTS " - "========================") - results = self.evaluate_results( - test_dataloader, self.component.pad_id) - - f1_per_epoch.append(results['f1']) - - logger.info("Epoch finished, took %s seconds", - str(datetime.combine(date.today(), end_time) - - datetime.combine(date.today(), start_time))) - - self.epoch += 1 - - if len(f1_per_epoch) > 0 and f1_per_epoch[-1] > self.best_f1: - self.best_f1 = f1_per_epoch[-1] - save_state( - self.component.model, self.component.optimizer, - self.component.scheduler, self.epoch, self.best_f1, - checkpoint_path, model_name=rc_cnf.general.model_name, - task=self.task, is_checkpoint=False) - - if (epoch % 1) == 0: - save_results( - { - "losses_per_epoch": losses_per_epoch, - "accuracy_per_epoch": accuracy_per_epoch, - "f1_per_epoch": f1_per_epoch, - "epoch": epoch - }, file_prefix="train", path=checkpoint_path) - save_state(self.component.model, self.component.optimizer, - self.component.scheduler, self.epoch, self.best_f1, - checkpoint_path, - model_name=rc_cnf.general.model_name, - task=self.task, is_checkpoint=True) - return losses_per_epoch, accuracy_per_epoch, f1_per_epoch - - def evaluate_(self, output_logits, labels, ignore_idx): - # ignore index (padding) when calculating accuracy - idxs = (labels != ignore_idx).squeeze() - labels_ = labels.squeeze()[idxs].to(self.device) - pred_labels = torch.softmax(output_logits, dim=1).max(1)[1] - pred_labels = pred_labels[idxs].to(self.device) - - true_labels = labels_.cpu().numpy().tolist( - ) if labels_.is_cuda else labels_.numpy().tolist() - pred_labels = pred_labels.cpu().numpy().tolist( - ) if pred_labels.is_cuda else pred_labels.numpy().tolist() - - unique_labels = set(true_labels) - - batch_size = len(true_labels) - - stat_per_label = dict() - - total_tp, total_fp, total_tn, total_fn = 0, 0, 0, 0 - acc, micro_recall, micro_precision, micro_f1 = 0, 0, 0, 0 - - for label in unique_labels: - stat_per_label[label] = { - "tp": 0, "fp": 0, "tn": 0, "fn": 0, - "f1": 0.0, "acc": 0.0, "prec": 0.0, "recall": 0.0} - - for true_label_idx in range(len(true_labels)): - if true_labels[true_label_idx] == label: - if pred_labels[true_label_idx] == label: - stat_per_label[label]["tp"] += 1 - total_tp += 1 - if pred_labels[true_label_idx] != label: - stat_per_label[label]["fp"] += 1 - total_fp += 1 - elif (true_labels[true_label_idx] != label and - label == pred_labels[true_label_idx]): - stat_per_label[label]["fn"] += 1 - total_fn += 1 - else: - stat_per_label[label]["tn"] += 1 - total_tn += 1 - - lbl_tp_tn = stat_per_label[label]["tn"] + \ - stat_per_label[label]["tp"] - - lbl_tp_fn = stat_per_label[label]["fn"] + \ - stat_per_label[label]["tp"] - lbl_tp_fn = lbl_tp_fn if lbl_tp_fn > 0.0 else 1.0 - - lbl_tp_fp = stat_per_label[label]["tp"] + \ - stat_per_label[label]["fp"] - lbl_tp_fp = lbl_tp_fp if lbl_tp_fp > 0.0 else 1.0 - - stat_per_label[label]["acc"] = lbl_tp_tn / batch_size - stat_per_label[label]["prec"] = (stat_per_label[label]["tp"] / - lbl_tp_fp) - stat_per_label[label]["recall"] = (stat_per_label[label]["tp"] / - lbl_tp_fn) - - lbl_re_pr = stat_per_label[label]["recall"] + \ - stat_per_label[label]["prec"] - lbl_re_pr = lbl_re_pr if lbl_re_pr > 0.0 else 1.0 - - stat_per_label[label]["f1"] = ( - 2 * (stat_per_label[label]["recall"] * - stat_per_label[label]["prec"])) / lbl_re_pr - - tp_fn = total_fn + total_tp - tp_fn = tp_fn if tp_fn > 0.0 else 1.0 - - tp_fp = total_fp + total_tp - tp_fp = tp_fp if tp_fp > 0.0 else 1.0 - - micro_recall = total_tp / tp_fn - micro_precision = total_tp / tp_fp - - re_pr = micro_recall + micro_precision - re_pr = re_pr if re_pr > 0.0 else 1.0 - micro_f1 = (2 * (micro_recall * micro_precision)) / re_pr - - acc = total_tp / batch_size - - return (acc, micro_recall, micro_precision, micro_f1, - pred_labels, true_labels, stat_per_label) - - def evaluate_results(self, data_loader, pad_id): - logger.info("Evaluating test samples...") - rc_cnf = self.component.relcat_config - if (rc_cnf.train.class_weights is not None and - rc_cnf.train.enable_class_weights): - criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( - rc_cnf.train.class_weights).to(self.device)) - else: - criterion = nn.CrossEntropyLoss() - - total_loss, total_acc, total_f1, total_recall, total_precision = ( - 0.0, 0.0, 0.0, 0.0, 0.0) - all_batch_stats_per_label = [] - - self.component.model.eval() - - for i, data in enumerate(data_loader): - with torch.no_grad(): - token_ids, e1_e2_start, labels, _, _ = data - attention_mask = (token_ids != pad_id).float().to(self.device) - token_type_ids = torch.zeros( - (*token_ids.shape[:2],)).long().to(self.device) - - labels = labels.to(self.device) - - model_output, pred_classification_logits = ( - self.component.model(token_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - Q=None, - e1_e2_start=e1_e2_start)) - - batch_loss = criterion(pred_classification_logits.view( - -1, rc_cnf.train.nclasses).to(self.device), - labels.squeeze(1)) - total_loss += batch_loss.item() - - (batch_accuracy, batch_recall, batch_precision, batch_f1, - pred_labels, true_labels, batch_stats_per_label) = ( - self.evaluate_(pred_classification_logits, - labels, ignore_idx=-1)) - - all_batch_stats_per_label.append(batch_stats_per_label) - - total_acc += batch_accuracy - total_recall += batch_recall - total_precision += batch_precision - total_f1 += batch_f1 - - final_stats_per_label = {} - - for batch_label_stats in all_batch_stats_per_label: - for label_id, stat_dict in batch_label_stats.items(): - - if label_id not in final_stats_per_label.keys(): - final_stats_per_label[label_id] = stat_dict - else: - for stat, score in stat_dict.items(): - final_stats_per_label[label_id][stat] += score - - for label_id, stat_dict in final_stats_per_label.items(): - for stat_name, value in stat_dict.items(): - final_stats_per_label[label_id][stat_name] = value / (i + 1) - - total_loss = total_loss / (i + 1) - total_acc = total_acc / (i + 1) - total_precision = total_precision / (i + 1) - total_f1 = total_f1 / (i + 1) - total_recall = total_recall / (i + 1) - - results = { - "loss": total_loss, - "accuracy": total_acc, - "precision": total_precision, - "recall": total_recall, - "f1": total_f1 - } - - logger.info("=" * 20 + " Evaluation Results " + "=" * 20) - logger.info(" no. of batches:" + str(i + 1)) - for key in sorted(results.keys()): - logger.info(" %s = %0.3f" % (key, results[key])) - logger.info("-" * 23 + " class stats " + "-" * 23) - for label_id, stat_dict in final_stats_per_label.items(): - logger.info( - "label: %s | f1: %0.3f | prec : %0.3f | acc: %0.3f | " - "recall: %0.3f ", - rc_cnf.general.idx2labels[label_id], - stat_dict["f1"], - stat_dict["prec"], - stat_dict["acc"], - stat_dict["recall"] - ) - logger.info("-" * 59) - logger.info("=" * 59) - - return results - - def pipe(self, stream: Iterable[MutableDocument], *args, **kwargs - ) -> Iterator[MutableDocument]: - rc_cnf = self.component.relcat_config - - predict_rel_dataset = RelData( - cdb=self.cdb, config=rc_cnf, - tokenizer=self.component.tokenizer) - - self.component.model = self.component.model.to(self.device) - - for doc_id, doc in enumerate(stream, 0): - predict_rel_dataset.dataset, _ = self._create_test_train_datasets( - data=predict_rel_dataset.create_base_relations_from_doc( - doc, doc_id=str(doc_id)), - split_sets=False) - - predict_dataloader = DataLoader( - dataset=predict_rel_dataset, shuffle=False, - batch_size=rc_cnf.train.batch_size, - num_workers=0, collate_fn=self.component.padding_seq, - pin_memory=rc_cnf.general.pin_memory) - - total_rel_found = len( - predict_rel_dataset.dataset["output_relations"]) - rel_idx = -1 - - logger.info("total relations for doc: " + str(total_rel_found)) - logger.info("processing...") - - pbar = tqdm(total=total_rel_found) - - for i, data in enumerate(predict_dataloader): - with torch.no_grad(): - token_ids, e1_e2_start, labels, _, _ = data - - attention_mask = ( - token_ids != self.component.pad_id - ).float().to(self.device) - token_type_ids = torch.zeros( - *token_ids.shape[:2]).long().to(self.device) - - (model_output, - pred_classification_logits) = self.component.model( - token_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, - e1_e2_start=e1_e2_start) - - for i, pred_rel_logits in enumerate( - pred_classification_logits): - rel_idx += 1 - - confidence = torch.softmax( - pred_rel_logits, dim=0).max(0) - predicted_label_id = int(confidence[1].item()) - - relations: list = doc.get_addon_data( # type: ignore - "relations") - out_rels = predict_rel_dataset.dataset[ - "output_relations"][rel_idx] - relations.append( - { - "relation": rc_cnf.general.idx2labels[ - predicted_label_id], - "label_id": predicted_label_id, - "ent1_text": out_rels[2], - "ent2_text": out_rels[3], - "confidence": float("{:.3f}".format( - confidence[0])), - "start_ent1_char_pos": out_rels[18], - "end_ent1_char_pos": out_rels[19], - "start_ent2_char_pos": out_rels[20], - "end_ent2_char_pos": out_rels[21], - "start_entity_id": out_rels[8], - "end_entity_id": out_rels[9], - }) - pbar.update(len(token_ids)) - pbar.close() - - yield doc - - def predict_text_with_anns(self, text: str, annotations: list[dict] - ) -> MutableDocument: - """ Creates spacy doc from text and annotation input. - Predicts using self.__call__ - - Args: - text (str): text - annotations (dict): dict containing the entities from NER - (of your choosing), the format must be the following format: - [ - { - "cui": "202099003", -this is optional - "value": "discoid lateral meniscus", - "start": 294, - "end": 318 - }, - { - "cui": "202099003", - "value": "Discoid lateral meniscus", - "start": 1905, - "end": 1929, - } - ] - - Returns: - Doc: spacy doc with the relations. - """ - # NOTE: This runs not an empty language, but the specified one - base_tokenizer = create_tokenizer( - self.cdb.config.general.nlp.provider, self.cdb.config) - doc = base_tokenizer(text) - - for ann in annotations: - tkn_idx = [] - for ind, word in enumerate(doc): - end_char = word.base.char_index + len(word.base.text) - if end_char <= ann['end'] and end_char > ann['start']: - tkn_idx.append(ind) - entity = base_tokenizer.create_entity( - doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"]) - entity.cui = ann["cui"] - entity.set_addon_data('start', ann['start']) - entity.set_addon_data('end', ann['end']) - doc.ner_ents.append(entity) - - doc = self(doc) - - return doc +import json +import logging +import os +import random +from typing import Optional + +from sklearn.utils import compute_class_weight +import torch +import torch.nn as nn + +from tqdm import tqdm +from datetime import date, datetime +from typing import Iterable, Iterator, cast +from torch.utils.data import DataLoader, Sampler +from torch.optim import AdamW +from torch.optim.lr_scheduler import MultiStepLR +import numpy + +from medcat.cdb import CDB +from medcat.vocab import Vocab +from medcat.config.config import Config, ComponentConfig +from medcat.config.config_rel_cat import ConfigRelCAT +from medcat.storage.serialisers import deserialise +from medcat.storage.serialisables import SerialisingStrategy +from medcat.components.addons.addons import AddonComponent +from medcat.components.addons.relation_extraction.base_component import ( + RelExtrBaseComponent) +from medcat.components.addons.meta_cat.ml_utils import set_all_seeds +from medcat.components.addons.relation_extraction.ml_utils import ( + load_results, load_state, save_results, save_state, + split_list_train_test_by_class) +from medcat.components.addons.relation_extraction.rel_dataset import RelData +from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer +from medcat.tokenizing.tokens import MutableDocument +from medcat.utils.defaults import COMPONENTS_FOLDER + + +logger = logging.getLogger(__name__) + + +class RelCATAddon(AddonComponent): + DEFAULT_TOKENIZER = 'spacy' + addon_type = 'rel_cat' + output_key = 'relations' + config: ConfigRelCAT + + def __init__(self, config: ConfigRelCAT, + rel_cat: "RelCAT"): + self.config = config + self._rel_cat = rel_cat + + @classmethod + def create_new(cls, config: ConfigRelCAT, base_tokenizer: BaseTokenizer, + cdb: CDB) -> 'RelCATAddon': + """Factory method to create a new MetaCATAddon instance.""" + return cls(config, + RelCAT(base_tokenizer, cdb, config=config, init_model=True)) + + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str] + ) -> 'RelCATAddon': + if not isinstance(cnf, ConfigRelCAT): + raise ValueError(f"Incompatible config: {cnf}") + config = cnf + if model_load_path is not None: + load_path = os.path.join(model_load_path, COMPONENTS_FOLDER, + cls.NAME_PREFIX + cls.addon_type) + return cls.load_existing(config, tokenizer, cdb, load_path) + return cls.create_new(config, tokenizer, cdb) + + @classmethod + def load_existing(cls, cnf: ConfigRelCAT, + base_tokenizer: BaseTokenizer, + cdb: CDB, + load_path: str) -> 'RelCATAddon': + """Factory method to load an existing RelCAT addon from disk.""" + rc = RelCAT.load(load_path) + # set the correct base tokenizer and redo data paths + rc.base_tokenizer = base_tokenizer + rc._init_data_paths() + return cls(cnf, rc) + + def serialise_to(self, folder_path: str) -> None: + os.mkdir(folder_path) + self._rel_cat.save(folder_path) + + @property + def name(self) -> str: + return str(self.addon_type) + + # for ManualSerialisable: + + @classmethod + def _create_throwaway_tokenizer(cls) -> BaseTokenizer: + """ + Mirrors `MetaCATAddon._create_throwaway_tokenizer` + """ + logger.warning( + "A base tokenizer was not provided during the loading of a " + "RelCAT. The tokenizer is used to register the required data " + "paths for RelCAT to function. Using the default of '%s'.", + cls.DEFAULT_TOKENIZER, + ) + gcnf = Config() + gcnf.general.nlp.provider = 'spacy' + return create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf) + + @classmethod + def deserialise_from(cls, folder_path: str, **init_kwargs + ) -> 'RelCATAddon': + """Deserialise a RelCAT addon from disk. + + Mirrors `MetaCATAddon.deserialise_from`: when called via the + pipeline, `tokenizer`/`cnf`/`cdb` are supplied; when called standalone + (e.g. `CAT.load_addons`), they are inferred from disk so that + deserialisation works without full pipeline context. + """ + if 'cnf' in init_kwargs: + cnf = init_kwargs['cnf'] + else: + logger.info( + "Was not provided a config when loading a rel cat from '%s'. " + "Inferring config from file at '%s'", folder_path, + folder_path) + cnf = ConfigRelCAT.load(load_path=folder_path) + if 'model_config' in init_kwargs: + cnf.merge_config(init_kwargs['model_config']) + if 'tokenizer' in init_kwargs: + tokenizer = init_kwargs['tokenizer'] + else: + tokenizer = cls._create_throwaway_tokenizer() + if 'cdb' in init_kwargs: + cdb = init_kwargs['cdb'] + else: + cdb_path = os.path.join(folder_path, "cdb.dat") + if os.path.exists(cdb_path): + cdb = cast(CDB, deserialise(cdb_path)) + else: + cdb = CDB(config=Config()) + return cls.load_existing( + load_path=folder_path, + cnf=cnf, + base_tokenizer=tokenizer, + cdb=cdb) + + def get_strategy(self) -> SerialisingStrategy: + return SerialisingStrategy.MANUAL + + @classmethod + def get_init_attrs(cls) -> list[str]: + return [] + + @classmethod + def ignore_attrs(cls) -> list[str]: + return [] + + @classmethod + def include_properties(cls) -> list[str]: + return [] + + def __call__(self, doc: MutableDocument): + return self._rel_cat(doc) + + +class BalancedBatchSampler(Sampler): + + def __init__(self, dataset, classes, + batch_size, max_samples, max_minority): + self.dataset = dataset + self.classes = classes + self.batch_size = batch_size + self.num_classes = len(classes) + self.indices = list(range(len(dataset))) + + self.max_minority = max_minority + + self.max_samples_per_class = max_samples + + def __len__(self): + return (len(self.indices) + self.batch_size - 1) // self.batch_size + + def __iter__(self): + batch_counter = 0 + indices = self.indices.copy() + while batch_counter != self.__len__(): + batch = [] + + class_counts = {c: 0 for c in self.classes} + while len(batch) < self.batch_size: + + index = random.choice(indices) + # Assuming label is at index 1 + label = self.dataset[index][2].numpy().tolist()[0] + if class_counts[label] < self.max_samples_per_class[label]: + batch.append(index) + class_counts[label] += 1 + if self.max_samples_per_class[label] > self.max_minority: + indices.remove(index) + + yield batch + batch_counter += 1 + + +class RelCAT: + """The RelCAT class used for training 'Relation-Annotation' models, i.e., + annotation of relations between clinical concepts. + + Args: + cdb (CDB): cdb, this is used when creating relation datasets. + + tokenizer (TokenizerWrapperBERT): + The Huggingface tokenizer instance. This can be a pre-trained + tokenzier instance from a BERT-style model. For now, only + BERT models are supported. + + config (ConfigRelCAT): + the configuration for RelCAT. Param descriptions available in + ConfigRelCAT class docs. + + task (str, optional): What task is this model supposed to handle. + Defaults to "train" + init_model (bool, optional): loads default model. Defaults to False. + + """ + addon_type = 'rel_cat' + output_key = 'rel_' + + def __init__(self, base_tokenizer: BaseTokenizer, + cdb: CDB, config: ConfigRelCAT = ConfigRelCAT(), + task: str = "train", init_model: bool = False): + self.base_tokenizer = base_tokenizer + self.component = RelExtrBaseComponent() + self.task: str = task + self.checkpoint_path: str = "./" + + set_all_seeds(config.general.seed) + + if init_model: + self.component = RelExtrBaseComponent( + config=config, task=task, init_model=True) + + self.cdb = cdb + logging.basicConfig( + level=self.component.relcat_config.general.log_level) + logger.setLevel(self.component.relcat_config.general.log_level) + + self.is_cuda_available = torch.cuda.is_available() + self.device = torch.device( + "cuda" if self.is_cuda_available and + self.component.relcat_config.general.device != "cpu" else "cpu") + self._init_data_paths() + + def _init_data_paths(self): + doc_cls = self.base_tokenizer.get_doc_class() + doc_cls.register_addon_path('relations', def_val=[], force=True) + entity_cls = self.base_tokenizer.get_entity_class() + entity_cls.register_addon_path('start', def_val=None, force=True) + entity_cls.register_addon_path('end', def_val=None, force=True) + + def save(self, save_path: str = "./") -> None: + self.component.save(save_path=save_path) + + @classmethod + def load(cls, load_path: str = "./") -> "RelCAT": + + if os.path.exists(os.path.join(load_path, "cdb.dat")): + cdb = cast(CDB, deserialise(os.path.join(load_path, "cdb.dat"))) + else: + cdb = CDB(config=Config()) + logger.info( + "The default CDB file name 'cdb.dat' doesn't exist in the " + "specified path, you will need to load & set " + "a CDB manually via rel_cat.cdb = CDB.load('path') ") + + component = RelExtrBaseComponent.load( + pretrained_model_name_or_path=load_path) + + device = torch.device( + "cuda" if torch.cuda.is_available() and + component.relcat_config.general.device != "cpu" else "cpu") + + rel_cat = RelCAT( + # NOTE: this is a throaway tokenizer just for registrations + create_tokenizer(cdb.config.general.nlp.provider, cdb.config), + cdb=cdb, config=component.relcat_config, task=component.task) + rel_cat.device = device + rel_cat.component = component + + return rel_cat + + def __call__(self, doc: MutableDocument) -> MutableDocument: + doc = next(self.pipe(iter([doc]))) + return doc + + def _create_test_train_datasets(self, data: dict, + split_sets: bool = False): + train_data: dict = {} + test_data: dict = {} + + if split_sets: + rc_cnf = self.component.relcat_config + (train_data["output_relations"], + test_data["output_relations"]) = split_list_train_test_by_class( + data["output_relations"], + test_size=rc_cnf.train.test_size, + shuffle=rc_cnf.train.shuffle_data, + sample_limit=rc_cnf.general.limit_samples_per_class) + + test_data_label_names = [ + rec[4] for rec in test_data["output_relations"]] + + (test_data["nclasses"], test_data["labels2idx"], + test_data["idx2label"]) = RelData.get_labels( + test_data_label_names, self.component.relcat_config) + + for idx in range(len(test_data["output_relations"])): + test_data["output_relations" + ][idx][5] = test_data["labels2idx"][ + test_data["output_relations"][idx][4]] + else: + train_data["output_relations"] = data["output_relations"] + + for k, v in data.items(): + if k != "output_relations": + train_data[k] = [] + test_data[k] = [] + + train_data_label_names = [rec[4] + for rec in train_data["output_relations"]] + + (train_data["nclasses"], train_data["labels2idx"], + train_data["idx2label"]) = RelData.get_labels( + train_data_label_names, self.component.relcat_config) + + for idx in range(len(train_data["output_relations"])): + train_data["output_relations" + ][idx][5] = train_data["labels2idx"][ + train_data["output_relations"][idx][4]] + + return train_data, test_data + + def train(self, export_data_path: str = "", train_csv_path: str = "", + test_csv_path: str = "", checkpoint_path: str = "./"): + + if self.is_cuda_available: + logger.info("Training on device: %s%s", + str(torch.cuda.get_device_name(0)), str(self.device)) + + self.component.model = self.component.model.to(self.device) + + rc_cnf = self.component.relcat_config + + # resize vocab just in case more tokens have been added + self.component.model_config.vocab_size = ( + self.component.tokenizer.get_size()) + + train_rel_data = RelData( + cdb=self.cdb, config=rc_cnf, + tokenizer=self.component.tokenizer) + test_rel_data = RelData( + cdb=self.cdb, config=rc_cnf, + tokenizer=self.component.tokenizer) + + if train_csv_path != "": + if test_csv_path != "": + train_rel_data.dataset, _ = self._create_test_train_datasets( + train_rel_data.create_base_relations_from_csv( + train_csv_path), split_sets=False) + test_rel_data.dataset, _ = self._create_test_train_datasets( + train_rel_data.create_base_relations_from_csv( + test_csv_path), split_sets=False) + else: + (train_rel_data.dataset, + test_rel_data.dataset) = self._create_test_train_datasets( + train_rel_data.create_base_relations_from_csv( + train_csv_path), split_sets=True) + + elif export_data_path != "": + export_data = {} + with open(export_data_path) as f: + export_data = json.load(f) + (train_rel_data.dataset, + test_rel_data.dataset) = self._create_test_train_datasets( + train_rel_data.create_relations_from_export(export_data), + split_sets=True) + else: + raise ValueError( + "NO DATA HAS BEEN PROVIDED (MedCAT Trainer export " + "JSON/CSV/spacy_DOCS)") + + train_dataset_size = len(train_rel_data) + batch_size = ( + train_dataset_size if train_dataset_size < rc_cnf.train.batch_size + else rc_cnf.train.batch_size) + + # to use stratified batching + if rc_cnf.train.stratified_batching: + sampler = BalancedBatchSampler( + train_rel_data, [ + i for i in + range(rc_cnf.train.nclasses)], + batch_size, + rc_cnf.train.batching_samples_per_class, + rc_cnf.train.batching_minority_limit) + + train_dataloader = DataLoader( + train_rel_data, num_workers=0, + collate_fn=self.component.padding_seq, + batch_sampler=sampler, + pin_memory=rc_cnf.general.pin_memory) + else: + train_dataloader = DataLoader( + train_rel_data, batch_size=batch_size, + shuffle=rc_cnf.train.shuffle_data, + num_workers=0, + collate_fn=self.component.padding_seq, + pin_memory=rc_cnf.general.pin_memory) + + test_dataset_size = len(test_rel_data) + test_batch_size = ( + test_dataset_size if + test_dataset_size < rc_cnf.train.batch_size + else rc_cnf.train.batch_size) + test_dataloader = DataLoader( + test_rel_data, + batch_size=test_batch_size, + shuffle=rc_cnf.train.shuffle_data, + num_workers=0, + collate_fn=self.component.padding_seq, + pin_memory=rc_cnf.general.pin_memory) + + if (rc_cnf.train.class_weights is not None and + rc_cnf.train.enable_class_weights): + criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( + numpy.asarray(rc_cnf.train.class_weights) + ).to(self.device)) + elif rc_cnf.train.enable_class_weights: + all_class_lbl_ids = [ + rec[5] for rec in train_rel_data.dataset["output_relations"]] + rc_cnf.train.class_weights = ( + compute_class_weight(class_weight="balanced", + classes=numpy.unique(all_class_lbl_ids), + y=all_class_lbl_ids).tolist()) + criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( + rc_cnf.train.class_weights).to( + self.device)) + else: + criterion = nn.CrossEntropyLoss() + + if self.component.optimizer is None: + parameters = filter(lambda p: p.requires_grad, + self.component.model.parameters()) + self.component.optimizer = AdamW( + parameters, lr=self.component.relcat_config.train.lr, + weight_decay=rc_cnf.train.adam_weight_decay, + betas=rc_cnf.train.adam_betas, eps=rc_cnf.train.adam_epsilon) + + if self.component.scheduler is None: + self.component.scheduler = MultiStepLR( + self.component.optimizer, + milestones=rc_cnf.train.multistep_milestones, + gamma=rc_cnf.train.multistep_lr_gamma) + + self.epoch, self.best_f1 = load_state( + self.component.model, self.component.optimizer, + self.component.scheduler, load_best=False, path=checkpoint_path, + relcat_config=rc_cnf) + + logger.info("Starting training process...") + + losses_per_epoch, accuracy_per_epoch, f1_per_epoch = load_results( + path=checkpoint_path) + + if train_rel_data.dataset["nclasses"] > rc_cnf.train.nclasses: + rc_cnf.train.nclasses = train_rel_data.dataset["nclasses"] + self.component.model.relcat_config.train.nclasses = ( + rc_cnf.train.nclasses) + + rc_cnf.general.labels2idx.update(train_rel_data.dataset["labels2idx"]) + rc_cnf.general.idx2labels = { + int(v): k for k, v in rc_cnf.general.labels2idx.items()} + + gradient_acc_steps = ( + rc_cnf.train.gradient_acc_steps) + max_grad_norm = rc_cnf.train.max_grad_norm + + _epochs = self.epoch + rc_cnf.train.nepochs + + for epoch in range(0, _epochs): + epoch_losses, epoch_precision, epoch_f1 = self._train_epoch( + epoch, gradient_acc_steps, max_grad_norm, train_dataset_size, + train_dataloader, test_dataloader, criterion, _epochs, + checkpoint_path) + losses_per_epoch.extend(epoch_losses) + accuracy_per_epoch.extend(epoch_precision) + f1_per_epoch.extend(epoch_f1) + + def _train_epoch(self, epoch: int, + gradient_acc_steps: int, + max_grad_norm: float, + train_dataset_size: int, + train_dataloader: DataLoader, + test_dataloader: DataLoader, + criterion: nn.CrossEntropyLoss, + _epochs: int, + checkpoint_path: str) -> tuple[list, list, list]: + rc_cnf = self.component.relcat_config + start_time = datetime.now().time() + total_loss = 0.0 + + loss_per_batch = [] + accuracy_per_batch = [] + + logger.info( + "Total epochs on this model: %d | currently training " + "epoch %d", _epochs, epoch) + + pbar = tqdm(total=train_dataset_size) + + for i, data in enumerate(train_dataloader, 0): + self.component.model.train() + self.component.model.zero_grad() + + current_batch_size = len(data[0]) + token_ids, e1_e2_start, labels, _, _ = data + + attention_mask = ( + token_ids != self.component.pad_id).float().to(self.device) + + token_type_ids = torch.zeros( + (token_ids.shape[0], token_ids.shape[1])).long().to( + self.device) + + labels = labels.to(self.device) + + model_output, classification_logits = self.component.model( + input_ids=token_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + e1_e2_start=e1_e2_start + ) + + batch_loss = criterion( + classification_logits.view( + -1, rc_cnf.train.nclasses).to(self.device), + labels.squeeze(1)) + + batch_loss.backward() + batch_loss = batch_loss / gradient_acc_steps + + total_loss += batch_loss.item() / current_batch_size + + (batch_acc, _, batch_precision, batch_f1, + _, _, batch_stats_per_label) = self.evaluate_( + classification_logits, labels, ignore_idx=-1) + + loss_per_batch.append(batch_loss / current_batch_size) + accuracy_per_batch.append(batch_acc) + + torch.nn.utils.clip_grad_norm_( + self.component.model.parameters(), max_grad_norm) + + if (i % gradient_acc_steps) == 0: + self.component.optimizer.step() + self.component.scheduler.step() + if ((i + 1) % current_batch_size == 0): + logger.debug( + "[Epoch: %d, loss per batch, accuracy per batch: %.3f," + " %.3f, average total loss %.3f , total loss %.3f]", + epoch, loss_per_batch[-1], accuracy_per_batch[-1], + total_loss / (i + 1), total_loss) + + pbar.update(current_batch_size) + + pbar.close() + + losses_per_epoch = [] + accuracy_per_epoch = [] + f1_per_epoch = [] + if len(loss_per_batch) > 0: + losses_per_epoch.append( + sum(loss_per_batch) / len(loss_per_batch)) + logger.info("Losses at Epoch %d: %.5f" % + (epoch, losses_per_epoch[-1])) + + if len(accuracy_per_batch) > 0: + accuracy_per_epoch.append( + sum(accuracy_per_batch) / len(accuracy_per_batch)) + logger.info("Train accuracy at Epoch %d: %.5f" % + (epoch, accuracy_per_epoch[-1])) + + total_loss = total_loss / (i + 1) + + end_time = datetime.now().time() + + logger.info( + "========================" + " TRAIN SET TEST RESULTS " + "========================") + _ = self.evaluate_results(train_dataloader, self.component.pad_id) + + logger.info( + "========================" + " TEST SET TEST RESULTS " + "========================") + results = self.evaluate_results( + test_dataloader, self.component.pad_id) + + f1_per_epoch.append(results['f1']) + + logger.info("Epoch finished, took %s seconds", + str(datetime.combine(date.today(), end_time) + - datetime.combine(date.today(), start_time))) + + self.epoch += 1 + + if len(f1_per_epoch) > 0 and f1_per_epoch[-1] > self.best_f1: + self.best_f1 = f1_per_epoch[-1] + save_state( + self.component.model, self.component.optimizer, + self.component.scheduler, self.epoch, self.best_f1, + checkpoint_path, model_name=rc_cnf.general.model_name, + task=self.task, is_checkpoint=False) + + if (epoch % 1) == 0: + save_results( + { + "losses_per_epoch": losses_per_epoch, + "accuracy_per_epoch": accuracy_per_epoch, + "f1_per_epoch": f1_per_epoch, + "epoch": epoch + }, file_prefix="train", path=checkpoint_path) + save_state(self.component.model, self.component.optimizer, + self.component.scheduler, self.epoch, self.best_f1, + checkpoint_path, + model_name=rc_cnf.general.model_name, + task=self.task, is_checkpoint=True) + return losses_per_epoch, accuracy_per_epoch, f1_per_epoch + + def evaluate_(self, output_logits, labels, ignore_idx): + # ignore index (padding) when calculating accuracy + idxs = (labels != ignore_idx).squeeze() + labels_ = labels.squeeze()[idxs].to(self.device) + pred_labels = torch.softmax(output_logits, dim=1).max(1)[1] + pred_labels = pred_labels[idxs].to(self.device) + + true_labels = labels_.cpu().numpy().tolist( + ) if labels_.is_cuda else labels_.numpy().tolist() + pred_labels = pred_labels.cpu().numpy().tolist( + ) if pred_labels.is_cuda else pred_labels.numpy().tolist() + + unique_labels = set(true_labels) + + batch_size = len(true_labels) + + stat_per_label = dict() + + total_tp, total_fp, total_tn, total_fn = 0, 0, 0, 0 + acc, micro_recall, micro_precision, micro_f1 = 0, 0, 0, 0 + + for label in unique_labels: + stat_per_label[label] = { + "tp": 0, "fp": 0, "tn": 0, "fn": 0, + "f1": 0.0, "acc": 0.0, "prec": 0.0, "recall": 0.0} + + for true_label_idx in range(len(true_labels)): + if true_labels[true_label_idx] == label: + if pred_labels[true_label_idx] == label: + stat_per_label[label]["tp"] += 1 + total_tp += 1 + if pred_labels[true_label_idx] != label: + stat_per_label[label]["fp"] += 1 + total_fp += 1 + elif (true_labels[true_label_idx] != label and + label == pred_labels[true_label_idx]): + stat_per_label[label]["fn"] += 1 + total_fn += 1 + else: + stat_per_label[label]["tn"] += 1 + total_tn += 1 + + lbl_tp_tn = stat_per_label[label]["tn"] + \ + stat_per_label[label]["tp"] + + lbl_tp_fn = stat_per_label[label]["fn"] + \ + stat_per_label[label]["tp"] + lbl_tp_fn = lbl_tp_fn if lbl_tp_fn > 0.0 else 1.0 + + lbl_tp_fp = stat_per_label[label]["tp"] + \ + stat_per_label[label]["fp"] + lbl_tp_fp = lbl_tp_fp if lbl_tp_fp > 0.0 else 1.0 + + stat_per_label[label]["acc"] = lbl_tp_tn / batch_size + stat_per_label[label]["prec"] = (stat_per_label[label]["tp"] / + lbl_tp_fp) + stat_per_label[label]["recall"] = (stat_per_label[label]["tp"] / + lbl_tp_fn) + + lbl_re_pr = stat_per_label[label]["recall"] + \ + stat_per_label[label]["prec"] + lbl_re_pr = lbl_re_pr if lbl_re_pr > 0.0 else 1.0 + + stat_per_label[label]["f1"] = ( + 2 * (stat_per_label[label]["recall"] * + stat_per_label[label]["prec"])) / lbl_re_pr + + tp_fn = total_fn + total_tp + tp_fn = tp_fn if tp_fn > 0.0 else 1.0 + + tp_fp = total_fp + total_tp + tp_fp = tp_fp if tp_fp > 0.0 else 1.0 + + micro_recall = total_tp / tp_fn + micro_precision = total_tp / tp_fp + + re_pr = micro_recall + micro_precision + re_pr = re_pr if re_pr > 0.0 else 1.0 + micro_f1 = (2 * (micro_recall * micro_precision)) / re_pr + + acc = total_tp / batch_size + + return (acc, micro_recall, micro_precision, micro_f1, + pred_labels, true_labels, stat_per_label) + + def evaluate_results(self, data_loader, pad_id): + logger.info("Evaluating test samples...") + rc_cnf = self.component.relcat_config + if (rc_cnf.train.class_weights is not None and + rc_cnf.train.enable_class_weights): + criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor( + rc_cnf.train.class_weights).to(self.device)) + else: + criterion = nn.CrossEntropyLoss() + + total_loss, total_acc, total_f1, total_recall, total_precision = ( + 0.0, 0.0, 0.0, 0.0, 0.0) + all_batch_stats_per_label = [] + + self.component.model.eval() + + for i, data in enumerate(data_loader): + with torch.no_grad(): + token_ids, e1_e2_start, labels, _, _ = data + attention_mask = (token_ids != pad_id).float().to(self.device) + token_type_ids = torch.zeros( + (*token_ids.shape[:2],)).long().to(self.device) + + labels = labels.to(self.device) + + model_output, pred_classification_logits = ( + self.component.model(token_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + Q=None, + e1_e2_start=e1_e2_start)) + + batch_loss = criterion(pred_classification_logits.view( + -1, rc_cnf.train.nclasses).to(self.device), + labels.squeeze(1)) + total_loss += batch_loss.item() + + (batch_accuracy, batch_recall, batch_precision, batch_f1, + pred_labels, true_labels, batch_stats_per_label) = ( + self.evaluate_(pred_classification_logits, + labels, ignore_idx=-1)) + + all_batch_stats_per_label.append(batch_stats_per_label) + + total_acc += batch_accuracy + total_recall += batch_recall + total_precision += batch_precision + total_f1 += batch_f1 + + final_stats_per_label = {} + + for batch_label_stats in all_batch_stats_per_label: + for label_id, stat_dict in batch_label_stats.items(): + + if label_id not in final_stats_per_label.keys(): + final_stats_per_label[label_id] = stat_dict + else: + for stat, score in stat_dict.items(): + final_stats_per_label[label_id][stat] += score + + for label_id, stat_dict in final_stats_per_label.items(): + for stat_name, value in stat_dict.items(): + final_stats_per_label[label_id][stat_name] = value / (i + 1) + + total_loss = total_loss / (i + 1) + total_acc = total_acc / (i + 1) + total_precision = total_precision / (i + 1) + total_f1 = total_f1 / (i + 1) + total_recall = total_recall / (i + 1) + + results = { + "loss": total_loss, + "accuracy": total_acc, + "precision": total_precision, + "recall": total_recall, + "f1": total_f1 + } + + logger.info("=" * 20 + " Evaluation Results " + "=" * 20) + logger.info(" no. of batches:" + str(i + 1)) + for key in sorted(results.keys()): + logger.info(" %s = %0.3f" % (key, results[key])) + logger.info("-" * 23 + " class stats " + "-" * 23) + for label_id, stat_dict in final_stats_per_label.items(): + logger.info( + "label: %s | f1: %0.3f | prec : %0.3f | acc: %0.3f | " + "recall: %0.3f ", + rc_cnf.general.idx2labels[label_id], + stat_dict["f1"], + stat_dict["prec"], + stat_dict["acc"], + stat_dict["recall"] + ) + logger.info("-" * 59) + logger.info("=" * 59) + + return results + + def pipe(self, stream: Iterable[MutableDocument], *args, **kwargs + ) -> Iterator[MutableDocument]: + rc_cnf = self.component.relcat_config + + predict_rel_dataset = RelData( + cdb=self.cdb, config=rc_cnf, + tokenizer=self.component.tokenizer) + + self.component.model = self.component.model.to(self.device) + + for doc_id, doc in enumerate(stream, 0): + predict_rel_dataset.dataset, _ = self._create_test_train_datasets( + data=predict_rel_dataset.create_base_relations_from_doc( + doc, doc_id=str(doc_id)), + split_sets=False) + + predict_dataloader = DataLoader( + dataset=predict_rel_dataset, shuffle=False, + batch_size=rc_cnf.train.batch_size, + num_workers=0, collate_fn=self.component.padding_seq, + pin_memory=rc_cnf.general.pin_memory) + + total_rel_found = len( + predict_rel_dataset.dataset["output_relations"]) + rel_idx = -1 + + logger.info("total relations for doc: " + str(total_rel_found)) + logger.info("processing...") + + pbar = tqdm(total=total_rel_found) + + for i, data in enumerate(predict_dataloader): + with torch.no_grad(): + token_ids, e1_e2_start, labels, _, _ = data + + attention_mask = ( + token_ids != self.component.pad_id + ).float().to(self.device) + token_type_ids = torch.zeros( + *token_ids.shape[:2]).long().to(self.device) + + (model_output, + pred_classification_logits) = self.component.model( + token_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, + e1_e2_start=e1_e2_start) + + for i, pred_rel_logits in enumerate( + pred_classification_logits): + rel_idx += 1 + + confidence = torch.softmax( + pred_rel_logits, dim=0).max(0) + predicted_label_id = int(confidence[1].item()) + + relations: list = doc.get_addon_data( # type: ignore + "relations") + out_rels = predict_rel_dataset.dataset[ + "output_relations"][rel_idx] + relations.append( + { + "relation": rc_cnf.general.idx2labels[ + predicted_label_id], + "label_id": predicted_label_id, + "ent1_text": out_rels[2], + "ent2_text": out_rels[3], + "confidence": float("{:.3f}".format( + confidence[0])), + "start_ent1_char_pos": out_rels[18], + "end_ent1_char_pos": out_rels[19], + "start_ent2_char_pos": out_rels[20], + "end_ent2_char_pos": out_rels[21], + "start_entity_id": out_rels[8], + "end_entity_id": out_rels[9], + }) + pbar.update(len(token_ids)) + pbar.close() + + yield doc + + def predict_text_with_anns(self, text: str, annotations: list[dict] + ) -> MutableDocument: + """ Creates spacy doc from text and annotation input. + Predicts using self.__call__ + + Args: + text (str): text + annotations (dict): dict containing the entities from NER + (of your choosing), the format must be the following format: + [ + { + "cui": "202099003", -this is optional + "value": "discoid lateral meniscus", + "start": 294, + "end": 318 + }, + { + "cui": "202099003", + "value": "Discoid lateral meniscus", + "start": 1905, + "end": 1929, + } + ] + + Returns: + Doc: spacy doc with the relations. + """ + # NOTE: This runs not an empty language, but the specified one + base_tokenizer = create_tokenizer( + self.cdb.config.general.nlp.provider, self.cdb.config) + doc = base_tokenizer(text) + + for ann in annotations: + tkn_idx = [] + for ind, word in enumerate(doc): + end_char = word.base.char_index + len(word.base.text) + if end_char <= ann['end'] and end_char > ann['start']: + tkn_idx.append(ind) + entity = base_tokenizer.create_entity( + doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"]) + entity.cui = ann["cui"] + entity.set_addon_data('start', ann['start']) + entity.set_addon_data('end', ann['end']) + doc.ner_ents.append(entity) + + doc = self(doc) + + return doc From 1a5c9a98860be17ea91e170731e7f569dd159c42 Mon Sep 17 00:00:00 2001 From: alhendrickson Date: Tue, 16 Jun 2026 10:20:10 +0000 Subject: [PATCH 9/9] fix(medcat): Relcat addon fix deserialise - make the same as metacataddon --- .../components/addons/meta_cat/meta_cat.py | 2 +- .../addons/relation_extraction/rel_cat.py | 51 ++++++++++++++----- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index 05fb4a185..475eb8291 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -223,7 +223,7 @@ def deserialise_from(cls, folder_path: str, **init_kwargs # load legacy config (assuming it exists) config_path += ".dat" logger.info( - "Was not provide a config when loading a meta cat from '%s'. " + "Was not provided a config when loading a meta cat from '%s'. " "Inferring config from file at '%s'", folder_path, config_path) cnf = ConfigMetaCAT.load(config_path) diff --git a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py index 0e6d099a6..a960f3aee 100644 --- a/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py +++ b/medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py @@ -33,6 +33,9 @@ from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer from medcat.tokenizing.tokens import MutableDocument from medcat.utils.defaults import COMPONENTS_FOLDER +from medcat.utils.defaults import ( + avoid_legacy_conversion, doing_legacy_conversion_message, + LegacyConversionDisabledError) logger = logging.getLogger(__name__) @@ -73,7 +76,7 @@ def create_new_component( @classmethod def load_existing(cls, cnf: ConfigRelCAT, base_tokenizer: BaseTokenizer, - cdb: CDB, + cdb: Optional[CDB], load_path: str) -> 'RelCATAddon': """Factory method to load an existing RelCAT addon from disk.""" rc = RelCAT.load(load_path) @@ -100,9 +103,16 @@ def _create_throwaway_tokenizer(cls) -> BaseTokenizer: logger.warning( "A base tokenizer was not provided during the loading of a " "RelCAT. The tokenizer is used to register the required data " - "paths for RelCAT to function. Using the default of '%s'.", - cls.DEFAULT_TOKENIZER, + "paths for RelCAT to function. Using the default of '%s'. If " + "this it not the tokenizer you will end up using, RelCAT may " + "be unable to recover unless a) the paths are registered " + "explicitly, or b) there are other RelCATs created with the " + "correct tokenizer. Do note that this will also create " + "another instance of the tokenizer, though it should be " + "garbage collected soon.", cls.DEFAULT_TOKENIZER ) + # NOTE: the use of a (mostly) default config here probably won't + # affect anything since the tokenizer itself won't be used gcnf = Config() gcnf.general.nlp.provider = 'spacy' return create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf) @@ -113,17 +123,38 @@ def deserialise_from(cls, folder_path: str, **init_kwargs """Deserialise a RelCAT addon from disk. Mirrors `MetaCATAddon.deserialise_from`: when called via the - pipeline, `tokenizer`/`cnf`/`cdb` are supplied; when called standalone + pipeline, `tokenizer`/`cnf` are supplied; when called standalone (e.g. `CAT.load_addons`), they are inferred from disk so that deserialisation works without full pipeline context. """ + if "config.json" in os.listdir(folder_path): + if not avoid_legacy_conversion(): + doing_legacy_conversion_message( + logger, cls.__name__, folder_path) + from medcat.utils.legacy.convert_rel_cat import ( + get_rel_cat_from_old) + if 'cdb' in init_kwargs: + cdb = init_kwargs['cdb'] + else: + cdb_path = os.path.join(folder_path, "cdb.dat") + if os.path.exists(cdb_path): + cdb = cast(CDB, deserialise(cdb_path)) + else: + cdb = CDB(config=Config()) + if 'tokenizer' in init_kwargs: + tokenizer = init_kwargs['tokenizer'] + else: + tokenizer = cls._create_throwaway_tokenizer() + return get_rel_cat_from_old(cdb, folder_path, tokenizer) + raise LegacyConversionDisabledError(cls.__name__,) if 'cnf' in init_kwargs: cnf = init_kwargs['cnf'] else: + config_path = os.path.join(folder_path, "config") logger.info( "Was not provided a config when loading a rel cat from '%s'. " "Inferring config from file at '%s'", folder_path, - folder_path) + config_path) cnf = ConfigRelCAT.load(load_path=folder_path) if 'model_config' in init_kwargs: cnf.merge_config(init_kwargs['model_config']) @@ -131,19 +162,11 @@ def deserialise_from(cls, folder_path: str, **init_kwargs tokenizer = init_kwargs['tokenizer'] else: tokenizer = cls._create_throwaway_tokenizer() - if 'cdb' in init_kwargs: - cdb = init_kwargs['cdb'] - else: - cdb_path = os.path.join(folder_path, "cdb.dat") - if os.path.exists(cdb_path): - cdb = cast(CDB, deserialise(cdb_path)) - else: - cdb = CDB(config=Config()) return cls.load_existing( load_path=folder_path, cnf=cnf, base_tokenizer=tokenizer, - cdb=cdb) + cdb=None) def get_strategy(self) -> SerialisingStrategy: return SerialisingStrategy.MANUAL