diff --git a/Dockerfile b/Dockerfile index 5ea240ed..beaba037 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ FROM python:3.10-slim-bookworm LABEL org.opencontainers.image.authors="grp-natlibfi-annif@helsinki.fi" SHELL ["/bin/bash", "-c"] -ARG optional_dependencies="voikko fasttext nn omikuji yake spacy stwfsa" +ARG optional_dependencies="voikko fasttext nn omikuji yake spacy stwfsa pecos" ARG POETRY_VIRTUALENVS_CREATE=false # Install system dependencies needed at runtime: @@ -37,6 +37,10 @@ RUN if [[ $optional_dependencies =~ "spacy" ]]; then \ python -m spacy download $model; \ done; \ fi +RUN if [[ $optional_dependencies =~ "pecos" ]]; then \ + mkdir /.cache -m a=rwx; \ + fi + # Second round of installation with the actual code: COPY annif /Annif/annif diff --git a/annif/backend/__init__.py b/annif/backend/__init__.py index 7be1264b..80056c15 100644 --- a/annif/backend/__init__.py +++ b/annif/backend/__init__.py @@ -89,6 +89,17 @@ def _tfidf() -> Type[AnnifBackend]: return tfidf.TFIDFBackend +def _xtransformer() -> Type[AnnifBackend]: + try: + from . import xtransformer + + return xtransformer.XTransformerBackend + except ImportError: + raise ValueError( + "XTransformer not available, not enabling XTransformer backend" + ) + + def _yake() -> Type[AnnifBackend]: try: from . import yake @@ -111,6 +122,7 @@ def _yake() -> Type[AnnifBackend]: "stwfsa": _stwfsa, "svc": _svc, "tfidf": _tfidf, + "xtransformer": _xtransformer, "yake": _yake, } diff --git a/annif/backend/fasttext.py b/annif/backend/fasttext.py index e102b02b..e98437aa 100644 --- a/annif/backend/fasttext.py +++ b/annif/backend/fasttext.py @@ -124,11 +124,7 @@ def _create_model(self, params: dict[str, Any], jobs: int) -> None: self.info("creating fastText model") trainpath = os.path.join(self.datadir, self.TRAIN_FILE) modelpath = os.path.join(self.datadir, self.MODEL_FILE) - params = { - param: self.FASTTEXT_PARAMS[param](val) - for param, val in params.items() - if param in self.FASTTEXT_PARAMS - } + params = annif.util.apply_param_parse_config(self.FASTTEXT_PARAMS, params) if jobs != 0: # jobs set by user to non-default value params["thread"] = jobs self.debug("Model parameters: {}".format(params)) diff --git a/annif/backend/omikuji.py b/annif/backend/omikuji.py index 89d8f0ea..6e77994b 100644 --- a/annif/backend/omikuji.py +++ b/annif/backend/omikuji.py @@ -103,9 +103,7 @@ def _create_model(self, params: dict[str, Any], jobs: int) -> None: hyper_param.collapse_every_n_layers = int(params["collapse_every_n_layers"]) self._model = omikuji.Model.train_on_data(train_path, hyper_param, jobs or None) - if os.path.exists(model_path): - shutil.rmtree(model_path) - self._model.save(os.path.join(self.datadir, self.MODEL_FILE)) + annif.util.atomic_save_folder(self._model, model_path) def _train( self, diff --git a/annif/backend/stwfsa.py b/annif/backend/stwfsa.py index fdc962b1..ec9b756c 100644 --- a/annif/backend/stwfsa.py +++ b/annif/backend/stwfsa.py @@ -7,7 +7,7 @@ from annif.exception import NotInitializedException, NotSupportedException from annif.suggestion import SubjectSuggestion -from annif.util import atomic_save, boolean +from annif.util import apply_param_parse_config, atomic_save, boolean from . import backend @@ -106,11 +106,7 @@ def _train( jobs: int = 0, ) -> None: X, y = self._load_data(corpus) - new_params = { - key: self.STWFSA_PARAMETERS[key](val) - for key, val in params.items() - if key in self.STWFSA_PARAMETERS - } + new_params = apply_param_parse_config(self.STWFSA_PARAMETERS, params) p = StwfsapyPredictor( graph=self.project.vocab.as_graph(), langs=frozenset([params["language"]]), diff --git a/annif/backend/xtransformer.py b/annif/backend/xtransformer.py new file mode 100644 index 00000000..0ec5a42a --- /dev/null +++ b/annif/backend/xtransformer.py @@ -0,0 +1,252 @@ +"""Annif backend using the transformer variant of pecos.""" + +import logging +import os.path as osp +import sys +from typing import Any + +import numpy as np +import scipy.sparse as sp +from pecos.utils.featurization.text.preprocess import Preprocessor +from pecos.xmc.xtransformer import matcher, model +from pecos.xmc.xtransformer.model import XTransformer +from pecos.xmc.xtransformer.module import MLProblemWithText + +from annif.corpus.document import DocumentCorpus +from annif.exception import NotInitializedException, NotSupportedException +from annif.suggestion import SubjectSuggestion, SuggestionBatch +from annif.util import ( + apply_param_parse_config, + atomic_save, + atomic_save_folder, + boolean, +) + +from . import backend, mixins + + +class XTransformerBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend): + """XTransformer based backend for Annif""" + + name = "xtransformer" + needs_subject_index = True + + _model = None + + train_X_file = "xtransformer-train-X.npz" + train_y_file = "xtransformer-train-y.npz" + train_txt_file = "xtransformer-train-raw.txt" + model_folder = "xtransformer-model" + + PARAM_CONFIG = { + "min_df": int, + "ngram": int, + "fix_clustering": boolean, + "nr_splits": int, + "min_codes": int, + "max_leaf_size": int, + "imbalanced_ratio": float, + "imbalanced_depth": int, + "max_match_clusters": int, + "do_fine_tune": boolean, + "model_shortcut": str, + "beam_size": int, + "limit": int, + "post_processor": str, + "negative_sampling": str, + "ensemble_method": str, + "threshold": float, + "loss_function": str, + "truncate_length": int, + "hidden_droput_prob": float, + "batch_size": int, + "gradient_accumulation_steps": int, + "learning_rate": float, + "weight_decay": float, + "adam_epsilon": float, + "num_train_epochs": int, + "max_steps": int, + "lr_schedule": str, + "warmup_steps": int, + "logging_steps": int, + "save_steps": int, + "max_active_matching_labels": int, + "max_num_labels_in_gpu": int, + "use_gpu": boolean, + "bootstrap_model": str, + } + + DEFAULT_PARAMETERS = { + "min_df": 1, + "ngram": 1, + "fix_clustering": False, + "nr_splits": 16, + "min_codes": None, + "max_leaf_size": 100, + "imbalanced_ratio": 0.0, + "imbalanced_depth": 100, + "max_match_clusters": 32768, + "do_fine_tune": True, + "model_shortcut": "distilbert-base-multilingual-uncased", + "beam_size": 20, + "limit": 100, + "post_processor": "sigmoid", + "negative_sampling": "tfn", + "ensemble_method": "transformer-only", + "threshold": 0.1, + "loss_function": "squared-hinge", + "truncate_length": 128, + "hidden_droput_prob": 0.1, + "batch_size": 32, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-4, + "weight_decay": 0.0, + "adam_epsilon": 1e-8, + "num_train_epochs": 1, + "max_steps": 0, + "lr_schedule": "linear", + "warmup_steps": 0, + "logging_steps": 100, + "save_steps": 1000, + "max_active_matching_labels": None, + "max_num_labels_in_gpu": 65536, + "use_gpu": True, + "bootstrap_model": "linear", + } + + def _initialize_model(self): + if self._model is None: + path = osp.join(self.datadir, self.model_folder) + self.debug("loading model from {}".format(path)) + if osp.exists(path): + self._model = XTransformer.load(path) + else: + raise NotInitializedException( + "model {} not found".format(path), backend_id=self.backend_id + ) + + def initialize(self, parallel: bool = False) -> None: + self.initialize_vectorizer() + self._initialize_model() + + def default_params(self): + params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy() + params.update(self.DEFAULT_PARAMETERS) + return params + + def _create_train_files(self, veccorpus, corpus): + self.info("creating train file") + Xs = [] + ys = [] + txt_pth = osp.join(self.datadir, self.train_txt_file) + with open(txt_pth, "w", encoding="utf-8") as txt_file: + for doc, vector in zip(corpus.documents, veccorpus): + subject_set = doc.subject_set + if not (subject_set and doc.text): + continue # noqa + print(" ".join(doc.text.split()), file=txt_file) + Xs.append(sp.csr_matrix(vector, dtype=np.float32).sorted_indices()) + ys.append( + sp.csr_matrix( + ( + np.ones(len(subject_set)), + (np.zeros(len(subject_set)), [s for s in subject_set]), + ), + shape=(1, len(self.project.subjects)), + dtype=np.float32, + ).sorted_indices() + ) + atomic_save( + sp.vstack(Xs, format="csr"), + self.datadir, + self.train_X_file, + method=lambda mtrx, target: sp.save_npz(target, mtrx, compressed=True), + ) + atomic_save( + sp.vstack(ys, format="csr"), + self.datadir, + self.train_y_file, + method=lambda mtrx, target: sp.save_npz(target, mtrx, compressed=True), + ) + + def _create_model(self, params, jobs): + train_txts = Preprocessor.load_data_from_file( + osp.join(self.datadir, self.train_txt_file), + label_text_path=None, + text_pos=0, + )["corpus"] + train_X = sp.load_npz(osp.join(self.datadir, self.train_X_file)) + train_y = sp.load_npz(osp.join(self.datadir, self.train_y_file)) + model_path = osp.join(self.datadir, self.model_folder) + new_params = apply_param_parse_config(self.PARAM_CONFIG, self.params) + new_params["only_topk"] = new_params.pop("limit") + train_params = XTransformer.TrainParams.from_dict( + new_params, recursive=True + ).to_dict() + pred_params = XTransformer.PredParams.from_dict( + new_params, recursive=True + ).to_dict() + + self.info("Start training") + # enable progress + matcher.LOGGER.setLevel(logging.DEBUG) + matcher.LOGGER.addHandler(logging.StreamHandler(stream=sys.stdout)) + model.LOGGER.setLevel(logging.DEBUG) + model.LOGGER.addHandler(logging.StreamHandler(stream=sys.stdout)) + self._model = XTransformer.train( + MLProblemWithText(train_txts, train_y, X_feat=train_X), + clustering=None, + val_prob=None, + train_params=train_params, + pred_params=pred_params, + beam_size=int(params["beam_size"]), + steps_scale=None, + label_feat=None, + ) + atomic_save_folder(self._model, model_path) + + def _train( + self, + corpus: DocumentCorpus, + params: dict[str, Any], + jobs: int = 0, + ) -> None: + if corpus == "cached": + self.info("Reusing cached training data from previous run.") + else: + if corpus.is_empty(): + raise NotSupportedException("Cannot t project with no documents") + input = (doc.text for doc in corpus.documents) + vecparams = { + "min_df": int(params["min_df"]), + "tokenizer": self.project.analyzer.tokenize_words, + "ngram_range": (1, int(params["ngram"])), + } + veccorpus = self.create_vectorizer(input, vecparams) + self._create_train_files(veccorpus, corpus) + self._create_model(params, jobs) + + def _suggest_batch( + self, texts: list[str], params: dict[str, Any] + ) -> SuggestionBatch: + vector = self.vectorizer.transform(texts) + if vector.nnz == 0: # All zero vector, empty result + return list() + new_params = apply_param_parse_config(self.PARAM_CONFIG, params) + prediction = self._model.predict( + texts, + X_feat=vector.sorted_indices(), + batch_size=new_params["batch_size"], + use_gpu=True, + only_top_k=new_params["limit"], + post_processor=new_params["post_processor"], + ) + current_batchsize = prediction.get_shape()[0] + batch_result = [] + for i in range(current_batchsize): + results = [] + row = prediction.getrow(i) + for idx, score in zip(row.indices, row.data): + results.append(SubjectSuggestion(subject_id=idx, score=score)) + batch_result.append(results) + return SuggestionBatch.from_sequence(batch_result, self.project.subjects) diff --git a/annif/util.py b/annif/util.py index b03c63ec..b816d24d 100644 --- a/annif/util.py +++ b/annif/util.py @@ -7,6 +7,7 @@ import os import os.path import tempfile +from shutil import rmtree from typing import Any, Callable from annif import logger @@ -33,7 +34,8 @@ def atomic_save( """Save the given object (which must have a .save() method, unless the method parameter is given) into the given directory with the given filename, using a temporary file and renaming the temporary file to the - final name.""" + final name. The .save() mehod or the function provided in the method argument + will be called with the path to the temporary file.""" prefix, suffix = os.path.splitext(filename) prefix = "tmp-" + prefix @@ -50,6 +52,31 @@ def atomic_save( os.rename(fn, newname) +def atomic_save_folder(obj, dirname, method=None): + """Save the given object (which must have a .save() method, unless the + method parameter is given) into the given directory, + using a temporary directory and renaming the temporary directory to the + final name. The .save() method or the function provided in the method argument + will be called with the path to the temporary directory.""" + + tldir = os.path.dirname(dirname.rstrip("/")) + os.makedirs(dirname, exist_ok=tldir) + tempdir = tempfile.TemporaryDirectory(dir=tldir) + temp_dir_name = tempdir.name + target_pth = dirname + logger.debug("saving %s to temporary file %s", str(obj)[:90], temp_dir_name) + if method is not None: + method(obj, temp_dir_name) + else: + obj.save(temp_dir_name) + for fn in glob.glob(temp_dir_name + "*"): + newname = fn.replace(temp_dir_name, target_pth) + logger.debug("renaming temporary file %s to %s", fn, newname) + if os.path.isdir(newname): + rmtree(newname) + os.replace(fn, newname) + + def cleanup_uri(uri: str) -> str: """remove angle brackets from a URI, if any""" if uri.startswith("<") and uri.endswith(">"): @@ -93,6 +120,15 @@ def parse_args(param_string: str) -> tuple[list, dict]: return posargs, kwargs +def apply_param_parse_config(configs, params): + """Applies a parsing configuration to a parameter dict.""" + return { + param: configs[param](val) + for param, val in params.items() + if param in configs and val is not None + } + + def boolean(val: Any) -> bool: """Convert the given value to a boolean True/False value, if it isn't already. True values are '1', 'yes', 'true', and 'on' (case insensitive), everything diff --git a/pyproject.toml b/pyproject.toml index 970fd250..ff31e10d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ omikuji = { version = "0.5.*", optional = true } yake = { version = "0.4.8", optional = true } spacy = { version = "3.7.*", optional = true } stwfsapy = { version = "0.4.*", optional = true, python = "<3.12" } +libpecos = {version = "1.*", optional = true} [tool.poetry.dev-dependencies] py = "*" @@ -78,6 +79,7 @@ omikuji = ["omikuji"] yake = ["yake"] spacy = ["spacy"] stwfsa = ["stwfsapy"] +pecos = ["libpecos"] [tool.poetry.scripts] annif = "annif.cli:cli" diff --git a/tests/test_backend.py b/tests/test_backend.py index b7e583d1..12b31773 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -95,6 +95,16 @@ def test_get_backend_yake_not_installed(): assert "YAKE not available" in str(excinfo.value) +@pytest.mark.skipif( + importlib.util.find_spec("pecos") is not None, + reason="test requires that PECOS is NOT installed", +) +def test_get_backend_xtransformer_not_installed(): + with pytest.raises(ValueError) as excinfo: + annif.backend.get_backend("xtransformer") + assert "XTransformer not available" in str(excinfo.value) + + @pytest.mark.skipif( importlib.util.find_spec("stwfsapy") is not None, reason="test requires that STWFSA is NOT installed", diff --git a/tests/test_backend_xtransformer.py b/tests/test_backend_xtransformer.py new file mode 100644 index 00000000..f3e7af76 --- /dev/null +++ b/tests/test_backend_xtransformer.py @@ -0,0 +1,250 @@ +"""Unit tests for the XTransformer backend in Annif""" + +import os.path as osp +from os import mknod +from unittest.mock import MagicMock, patch + +import pytest +from scipy.sparse import csr_matrix, load_npz + +import annif.backend +import annif.corpus +from annif.exception import NotInitializedException, NotSupportedException + +pytest.importorskip("annif.backend.xtransformer") +XTransformer = annif.backend.xtransformer.XTransformer + + +@pytest.fixture +def mocked_xtransformer(datadir, project): + model_mock = MagicMock() + model_mock.save.side_effect = lambda x: mknod(osp.join(x, "test")) + + return patch.object( + annif.backend.xtransformer.XTransformer, "train", return_value=model_mock + ) + + +def test_xtransformer_default_params(project): + backend_type = annif.backend.get_backend("xtransformer") + xtransformer = backend_type( + backend_id="xtransfomer", config_params={}, project=project + ) + expected = { + "min_df": 1, + "ngram": 1, + "fix_clustering": False, + "nr_splits": 16, + "min_codes": None, + "max_leaf_size": 100, + "imbalanced_ratio": 0.0, + "imbalanced_depth": 100, + "max_match_clusters": 32768, + "do_fine_tune": True, + "model_shortcut": "distilbert-base-multilingual-uncased", + # "model_shortcut": "bert-base-multilingual-uncased", + "beam_size": 20, + "limit": 100, + "post_processor": "sigmoid", + "negative_sampling": "tfn", + "ensemble_method": "transformer-only", + "threshold": 0.1, + "loss_function": "squared-hinge", + "truncate_length": 128, + "hidden_droput_prob": 0.1, + "batch_size": 32, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-4, + "weight_decay": 0.0, + "adam_epsilon": 1e-8, + "num_train_epochs": 1, + "max_steps": 0, + "lr_schedule": "linear", + "warmup_steps": 0, + "logging_steps": 100, + "save_steps": 1000, + "max_active_matching_labels": None, + "max_num_labels_in_gpu": 65536, + "use_gpu": True, + "bootstrap_model": "linear", + } + actual = xtransformer.params + assert len(actual) == len(expected) + for param, val in expected.items(): + assert param in actual and actual[param] == val + + +def test_xtransformer_suggest_no_vectorizer(project): + backend_type = annif.backend.get_backend("xtransformer") + xtransformer = backend_type( + backend_id="xtransfomer", config_params={}, project=project + ) + with pytest.raises(NotInitializedException): + xtransformer.suggest("example text") + + +def test_xtransformer_create_train_files(tmpdir, project, datadir): + tmpfile = tmpdir.join("document.tsv") + tmpfile.write( + "nonexistent\thttp://example.com/nonexistent\n" + + "arkeologia\thttp://www.yso.fi/onto/yso/p1265\n" + + "...\thttp://example.com/none" + ) + corpus = annif.corpus.DocumentFile(str(tmpfile), project.subjects) + backend_type = annif.backend.get_backend("xtransformer") + xtransformer = backend_type( + backend_id="xtransformer", config_params={}, project=project + ) + input = (doc.text for doc in corpus.documents) + veccorpus = xtransformer.create_vectorizer(input, {}) + xtransformer._create_train_files(veccorpus, corpus) + assert datadir.join("xtransformer-train-X.npz").exists() + assert datadir.join("xtransformer-train-y.npz").exists() + assert datadir.join("xtransformer-train-raw.txt").exists() + traindata = datadir.join("xtransformer-train-raw.txt").read().splitlines() + assert len(traindata) == 1 + train_features = load_npz(str(datadir.join("xtransformer-train-X.npz"))) + assert train_features.shape[0] == 1 + train_labels = load_npz(str(datadir.join("xtransformer-train-y.npz"))) + assert train_labels.shape[0] == 1 + + +def test_xtransformer_train(datadir, document_corpus, project, mocked_xtransformer): + backend_type = annif.backend.get_backend("xtransformer") + xtransformer = backend_type( + backend_id="xtransfomer", config_params={}, project=project + ) + + with mocked_xtransformer as train_mock: + xtransformer.train(document_corpus) + + train_mock.assert_called_once() + first_arg = train_mock.call_args.args[0] + kwargs = train_mock.call_args.kwargs + assert len(first_arg.X_text) == 6402 + assert first_arg.X_feat.shape == (6402, 12479) + assert first_arg.Y.shape == (6402, 130) + expected_pred_params = XTransformer.PredParams.from_dict( + { + "beam_size": 20, + "only_topk": 100, + "post_processor": "sigmoid", + "truncate_length": 128, + }, + recursive=True, + ).to_dict() + + expected_train_params = XTransformer.TrainParams.from_dict( + { + "do_fine_tune": True, + "only_encoder": False, + "fix_clustering": False, + "max_match_clusters": 32768, + "nr_splits": 16, + "max_leaf_size": 100, + "imbalanced_ratio": 0.0, + "imbalanced_depth": 100, + # "model_shortcut": "bert-base-multilingual-uncased", + "model_shortcut": "distilbert-base-multilingual-uncased", + "post_processor": "sigmoid", + "negative_sampling": "tfn", + "ensemble_method": "transformer-only", + "threshold": 0.1, + "loss_function": "squared-hinge", + "truncate_length": 128, + "hidden_droput_prob": 0.1, + "batch_size": 32, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-4, + "weight_decay": 0.0, + "adam_epsilon": 1e-8, + "num_train_epochs": 1, + "max_steps": 0, + "lr_schedule": "linear", + "warmup_steps": 0, + "logging_steps": 100, + "save_steps": 1000, + "max_active_matching_labels": None, + "max_num_labels_in_gpu": 65536, + "use_gpu": True, + "bootstrap_model": "linear", + }, + recursive=True, + ).to_dict() + + assert kwargs == { + "clustering": None, + "val_prob": None, + "steps_scale": None, + "label_feat": None, + "beam_size": 20, + "pred_params": expected_pred_params, + "train_params": expected_train_params, + } + xtransformer._model.save.assert_called_once() + assert datadir.join("xtransformer-model").check() + + +def test_xtransformer_train_cached(mocked_xtransformer, datadir, project): + backend_type = annif.backend.get_backend("xtransformer") + xtransformer = backend_type( + backend_id="xtransfomer", config_params={}, project=project + ) + xtransformer._create_train_files = MagicMock() + xtransformer._create_model = MagicMock() + with mocked_xtransformer: + xtransformer.train("cached") + xtransformer._create_train_files.assert_not_called() + xtransformer._create_model.assert_called_once() + + +def test_xtransfomer_train_no_documents(datadir, project, empty_corpus): + backend_type = annif.backend.get_backend("xtransformer") + xtransformer = backend_type( + backend_id="xtransfomer", config_params={}, project=project + ) + with pytest.raises(NotSupportedException): + xtransformer.train(empty_corpus) + + +def test_xtransformer_suggest(project): + backend_type = annif.backend.get_backend("xtransformer") + xtransformer = backend_type( + backend_id="xtransfomer", config_params={}, project=project + ) + xtransformer._model = MagicMock() + xtransformer._model.predict.return_value = csr_matrix([0, 0.2, 0, 0, 0, 0.5, 0]) + results = xtransformer.suggest( + [ + """Arkeologiaa sanotaan joskus myös + muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede + tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä. + Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä, + joita ihmisten toiminta on jättänyt maaperään tai vesistöjen + pohjaan.""" + ] + )[0] + xtransformer._model.predict.assert_called_once() + + ship_finds = project.subjects.by_uri("https://www.yso.fi/onto/yso/p8869") + assert ship_finds in [result.subject_id for result in results] + + +def test_xtransformer_suggest_no_input(project, datadir): + backend_type = annif.backend.get_backend("xtransformer") + xtransformer = backend_type( + backend_id="xtransfomer", config_params={"limit": 5}, project=project + ) + xtransformer._model = MagicMock() + results = xtransformer.suggest(["j"]) + assert len(results) == 0 + + +def test_xtransformer_suggest_no_model(datadir, project): + backend_type = annif.backend.get_backend("xtransformer") + xtransformer = backend_type( + backend_id="xtransfomer", config_params={}, project=project + ) + datadir.remove() + with pytest.raises(NotInitializedException): + xtransformer.suggest("example text") diff --git a/tests/test_util.py b/tests/test_util.py index 18333311..afce3a11 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,5 +1,8 @@ """Unit tests for Annif utility functions""" +import os.path as osp +from unittest.mock import MagicMock + import annif.util @@ -17,3 +20,64 @@ def test_metric_code(): for input, output in zip(inputs, outputs): assert annif.util.metric_code(input) == output + + +def test_apply_parse_param_config(): + fun0 = MagicMock() + fun0.return_value = 23 + fun1 = MagicMock() + fun1.return_value = "ret" + configs = {"a": fun0, "c": fun1} + params = {"a": 0, "b": 23, "c": None} + ret = annif.util.apply_param_parse_config(configs, params) + assert ret == {"a": 23} + fun0.assert_called_once_with(0) + fun1.assert_not_called() + + +def _save(obj, pth): + with open(pth, "w") as f: + print("test file content", file=f) + + +def test_atomic_save_method(tmpdir): + fname = "tst_file_method.txt" + annif.util.atomic_save(None, tmpdir.strpath, fname, method=_save) + f_pth = tmpdir.join(fname) + assert f_pth.exists() + with f_pth.open() as f: + assert f.readlines() == ["test file content\n"] + + +def test_atomic_save(tmpdir): + fname = "tst_file_obj.txt" + to_save = MagicMock() + to_save.save.side_effect = lambda pth: _save(None, pth) + annif.util.atomic_save(to_save, tmpdir.strpath, fname) + f_pth = tmpdir.join(fname) + assert f_pth.exists() + with f_pth.open() as f: + assert f.readlines() == ["test file content\n"] + to_save.save.assert_called_once() + call_args = to_save.save.calls[0].args + assert isinstance(call_args[0], MagicMock) + assert call_args[1] != f_pth.strpath + + +def test_atomic_save_folder(tmpdir): + folder_name = "test_save" + fname_0 = "tst_file_0" + fname_1 = "tst_file_1" + + def save_folder(obj, pth): + _save(None, osp.join(pth, fname_0)) + _save(None, osp.join(pth, fname_1)) + + folder_path = tmpdir.join(folder_name) + annif.util.atomic_save_folder(None, folder_path.strpath, method=save_folder) + assert folder_path.exists() + for f_name in [fname_0, fname_1]: + f_pth = folder_path.join(f_name) + assert f_pth.exists() + with f_pth.open() as f: + assert f.readlines() == ["test file content\n"]