diff --git a/src/eva/language/__init__.py b/src/eva/language/__init__.py new file mode 100644 index 00000000..e7f42944 --- /dev/null +++ b/src/eva/language/__init__.py @@ -0,0 +1,13 @@ +"""eva language API.""" + +try: + from eva.language.data import datasets +except ImportError as e: + msg = ( + "eva language requirements are not installed.\n\n" + "Please pip install as follows:\n" + ' python -m pip install "kaiko-eva[language]" --upgrade' + ) + raise ImportError(str(e) + "\n\n" + msg) from e + +__all__ = ["datasets"] \ No newline at end of file diff --git a/src/eva/language/data/__init__.py b/src/eva/language/data/__init__.py new file mode 100644 index 00000000..2b24bc84 --- /dev/null +++ b/src/eva/language/data/__init__.py @@ -0,0 +1,5 @@ +"""Language data API.""" + +from eva.language.data import datasets + +__all__ = ["datasets"] \ No newline at end of file diff --git a/src/eva/language/data/datasets/__init__.py b/src/eva/language/data/datasets/__init__.py new file mode 100644 index 00000000..171b0204 --- /dev/null +++ b/src/eva/language/data/datasets/__init__.py @@ -0,0 +1,9 @@ +"""Language Datasets API.""" + +from eva.language.data.datasets.classification import PubMedQA +from eva.language.data.datasets.language import LanguageDataset + +__all__ = [ + "PubMedQA", + "LanguageDataset", +] diff --git a/src/eva/language/data/datasets/classification/__init__.py b/src/eva/language/data/datasets/classification/__init__.py new file mode 100644 index 00000000..093019dc --- /dev/null +++ b/src/eva/language/data/datasets/classification/__init__.py @@ -0,0 +1,7 @@ +"""Text classification datasets API.""" + +from eva.language.data.datasets.classification.pubmedqa import PubMedQA + +__all__ = [ + "PubMedQA", +] \ No newline at end of file diff --git a/src/eva/language/data/datasets/classification/base.py b/src/eva/language/data/datasets/classification/base.py new file mode 100644 index 00000000..ca081ff8 --- /dev/null +++ b/src/eva/language/data/datasets/classification/base.py @@ -0,0 +1,69 @@ +"""Base for text classification datasets.""" + +import abc +from typing import Any, Callable, Dict, List, Tuple + +import torch +from typing_extensions import override + +from eva.language.data.datasets.language import LanguageDataset + +class LanguageClassification(LanguageDataset[Tuple[str, torch.Tensor]], abc.ABC): + """Text classification abstract dataset.""" + + def __init__(self) -> None: + """Initializes the text classification dataset.""" + super().__init__() + + @property + def classes(self) -> List[str] | None: + """Returns list of class names.""" + return None + + @property + def class_to_idx(self) -> Dict[str, int] | None: + """Returns class name to index mapping.""" + return None + + def load_metadata(self, index: int) -> Dict[str, Any] | None: + """Returns the dataset metadata. + + Args: + index: The index of the data sample. + + Returns: + The sample metadata. + """ + return None + + @abc.abstractmethod + def load_text(self, index: int) -> str: + """Returns the text content. + + Args: + index: The index of the data sample. + + Returns: + The text content. + """ + raise NotImplementedError + + @abc.abstractmethod + def load_target(self, index: int) -> torch.Tensor: + """Returns the target label. + + Args: + index: The index of the data sample. + + Returns: + The target label. + """ + raise NotImplementedError + + @override + def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]: + return ( + self.load_text(index), + self.load_target(index), + self.load_metadata(index) or {} + ) diff --git a/src/eva/language/data/datasets/classification/pubmedqa.py b/src/eva/language/data/datasets/classification/pubmedqa.py new file mode 100644 index 00000000..7f185c30 --- /dev/null +++ b/src/eva/language/data/datasets/classification/pubmedqa.py @@ -0,0 +1,71 @@ +"""PubMedQA dataset class.""" + +from typing import Dict, List, Any + +from typing_extensions import override +from datasets import load_dataset + +from eva.language.data.datasets.classification import base + + +class PubMedQA(base.LanguageClassification): + """Dataset class for PubMedQA question answering task.""" + + _license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)" + """Dataset license.""" + + def __init__( + self, + split: str | None = "train+test+validation", + ) -> None: + """Initialize the PubMedQA dataset. + + Args: + split: Dataset split to use. If default, entire dataset of 1000 samples is used. + """ + super().__init__() + self._split = split + self.dataset = load_dataset( + "bigbio/pubmed_qa", + name="pubmed_qa_labeled_fold0_source", + split=split + ) + + @property + @override + def classes(self) -> List[str]: + return ["no", "yes", "maybe"] + + @property + @override + def class_to_idx(self) -> Dict[str, int]: + return {"no": 0, "yes": 1, "maybe": 2} + + @override + def load_text(self, index: int) -> str: + sample = self.dataset[index] + return f"Question: {sample['QUESTION']}\nContext: {sample['CONTEXTS']}" + + @override + def load_target(self, index: int) -> int: + return self.class_to_idx[self.dataset[index]["final_decision"]] + + @override + def load_metadata(self, index: int) -> Dict[str, Any]: + sample = self.dataset[index] + return { + "year": sample["YEAR"], + "labels": sample["LABELS"], + "meshes": sample["MESHES"], + "long_answer": sample["LONG_ANSWER"], + "reasoning_required": sample["reasoning_required_pred"], + "reasoning_free": sample["reasoning_free_pred"], + } + + @override + def __len__(self) -> int: + return len(self.dataset) + + def _print_license(self) -> None: + """Prints the dataset license.""" + print(f"Dataset license: {self._license}") diff --git a/src/eva/language/data/datasets/language.py b/src/eva/language/data/datasets/language.py new file mode 100644 index 00000000..7f5688b5 --- /dev/null +++ b/src/eva/language/data/datasets/language.py @@ -0,0 +1,13 @@ +"""Vision Dataset base class.""" + +import abc +from typing import Generic, TypeVar + +from eva.core.data.datasets import base + +DataSample = TypeVar("DataSample") +"""The data sample type.""" + +class LanguageDataset(base.MapDataset, abc.ABC, Generic[DataSample]): + """Base dataset class for text tasks.""" + pass \ No newline at end of file diff --git a/src/eva/language/experiment.ipynb b/src/eva/language/experiment.ipynb new file mode 100644 index 00000000..86f3eea5 --- /dev/null +++ b/src/eva/language/experiment.ipynb @@ -0,0 +1,143 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'eva'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[25], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdata\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mclassification\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpubmedqa\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m PubMedQA\n", + "File \u001b[0;32m~/eva/src/eva/language/data/datasets/__init__.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;124;03m\"\"\"Language Datasets API.\"\"\"\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01meva\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlanguage\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mclassification\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m PubMedQA\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01meva\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlanguage\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlanguage\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m LanguageDataset\n\u001b[1;32m 7\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPubMedQA\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 9\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLanguageDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 10\u001b[0m ]\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'eva'" + ] + } + ], + "source": [ + "from datasets import load_dataset\n", + "from data.datasets.classification.pubmedqa import PubMedQA\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_dataset(\n", + " \"bigbio/pubmed_qa\",\n", + " name=\"pubmed_qa_labeled_fold0_source\",\n", + " split=\"train+test+validation\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: kaiko-eva in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (0.1.7)\n", + "Requirement already satisfied: torch>=2.3.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (2.5.1)\n", + "Requirement already satisfied: lightning>=2.2.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (2.4.0)\n", + "Requirement already satisfied: jsonargparse>=4.30.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from jsonargparse[omegaconf]>=4.30.0->kaiko-eva) (4.35.0)\n", + "Requirement already satisfied: tensorboard>=2.16.2 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (2.18.0)\n", + "Requirement already satisfied: loguru>=0.7.2 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (0.7.3)\n", + "Requirement already satisfied: pandas>=2.0.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (2.2.3)\n", + "Requirement already satisfied: transformers>=4.38.2 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (4.47.1)\n", + "Requirement already satisfied: onnxruntime>=1.15.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (1.20.1)\n", + "Requirement already satisfied: onnx>=1.16.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (1.17.0)\n", + "Requirement already satisfied: toolz>=0.12.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (1.0.0)\n", + "Requirement already satisfied: rich>=13.7.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (13.9.4)\n", + "Requirement already satisfied: torchmetrics>=1.6.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from kaiko-eva) (1.6.0)\n", + "Requirement already satisfied: PyYAML>=3.13 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from jsonargparse>=4.30.0->jsonargparse[omegaconf]>=4.30.0->kaiko-eva) (6.0.2)\n", + "Requirement already satisfied: omegaconf>=2.1.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from jsonargparse[omegaconf]>=4.30.0->kaiko-eva) (2.3.0)\n", + "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning>=2.2.0->kaiko-eva) (2024.10.0)\n", + "Requirement already satisfied: lightning-utilities<2.0,>=0.10.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from lightning>=2.2.0->kaiko-eva) (0.11.9)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from lightning>=2.2.0->kaiko-eva) (24.2)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from lightning>=2.2.0->kaiko-eva) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from lightning>=2.2.0->kaiko-eva) (4.12.2)\n", + "Requirement already satisfied: pytorch-lightning in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from lightning>=2.2.0->kaiko-eva) (2.4.0)\n", + "Requirement already satisfied: numpy>=1.20 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from onnx>=1.16.0->kaiko-eva) (1.26.4)\n", + "Requirement already satisfied: protobuf>=3.20.2 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from onnx>=1.16.0->kaiko-eva) (5.29.1)\n", + "Requirement already satisfied: coloredlogs in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from onnxruntime>=1.15.1->kaiko-eva) (15.0.1)\n", + "Requirement already satisfied: flatbuffers in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from onnxruntime>=1.15.1->kaiko-eva) (24.3.25)\n", + "Requirement already satisfied: sympy in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from onnxruntime>=1.15.1->kaiko-eva) (1.13.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from pandas>=2.0.0->kaiko-eva) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from pandas>=2.0.0->kaiko-eva) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from pandas>=2.0.0->kaiko-eva) (2024.2)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from rich>=13.7.1->kaiko-eva) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from rich>=13.7.1->kaiko-eva) (2.18.0)\n", + "Requirement already satisfied: absl-py>=0.4 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from tensorboard>=2.16.2->kaiko-eva) (2.1.0)\n", + "Requirement already satisfied: grpcio>=1.48.2 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from tensorboard>=2.16.2->kaiko-eva) (1.68.1)\n", + "Requirement already satisfied: markdown>=2.6.8 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from tensorboard>=2.16.2->kaiko-eva) (3.7)\n", + "Requirement already satisfied: setuptools>=41.0.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from tensorboard>=2.16.2->kaiko-eva) (65.5.0)\n", + "Requirement already satisfied: six>1.9 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from tensorboard>=2.16.2->kaiko-eva) (1.17.0)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from tensorboard>=2.16.2->kaiko-eva) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from tensorboard>=2.16.2->kaiko-eva) (3.1.3)\n", + "Requirement already satisfied: filelock in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from torch>=2.3.0->kaiko-eva) (3.16.1)\n", + "Requirement already satisfied: networkx in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from torch>=2.3.0->kaiko-eva) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from torch>=2.3.0->kaiko-eva) (3.1.4)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from sympy->onnxruntime>=1.15.1->kaiko-eva) (1.3.0)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from transformers>=4.38.2->kaiko-eva) (0.27.0)\n", + "Requirement already satisfied: regex!=2019.12.17 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from transformers>=4.38.2->kaiko-eva) (2024.11.6)\n", + "Requirement already satisfied: requests in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from transformers>=4.38.2->kaiko-eva) (2.32.3)\n", + "Requirement already satisfied: tokenizers<0.22,>=0.21 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from transformers>=4.38.2->kaiko-eva) (0.21.0)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from transformers>=4.38.2->kaiko-eva) (0.4.5)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning>=2.2.0->kaiko-eva) (3.11.10)\n", + "Requirement already satisfied: mdurl~=0.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=13.7.1->kaiko-eva) (0.1.2)\n", + "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from omegaconf>=2.1.1->jsonargparse[omegaconf]>=4.30.0->kaiko-eva) (4.9.3)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from werkzeug>=1.0.1->tensorboard>=2.16.2->kaiko-eva) (3.0.2)\n", + "Requirement already satisfied: humanfriendly>=9.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from coloredlogs->onnxruntime>=1.15.1->kaiko-eva) (10.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from requests->transformers>=4.38.2->kaiko-eva) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from requests->transformers>=4.38.2->kaiko-eva) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from requests->transformers>=4.38.2->kaiko-eva) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from requests->transformers>=4.38.2->kaiko-eva) (2024.12.14)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.2.0->kaiko-eva) (2.4.4)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.2.0->kaiko-eva) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.2.0->kaiko-eva) (24.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.2.0->kaiko-eva) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.2.0->kaiko-eva) (6.1.0)\n", + "Requirement already satisfied: propcache>=0.2.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.2.0->kaiko-eva) (0.2.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /Users/ritakurban/.pyenv/versions/3.11.10/envs/mainenv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.2.0->kaiko-eva) (1.18.3)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "! pip install kaiko-eva" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "kaiko", + "language": "python", + "name": "my_custom_kernel" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/eva/text/__init__.py b/tests/eva/text/__init__.py new file mode 100644 index 00000000..e17ac3fd --- /dev/null +++ b/tests/eva/text/__init__.py @@ -0,0 +1 @@ +"""EVA language tests.""" \ No newline at end of file diff --git a/tests/eva/text/data/__init__.py b/tests/eva/text/data/__init__.py new file mode 100644 index 00000000..d7969dda --- /dev/null +++ b/tests/eva/text/data/__init__.py @@ -0,0 +1 @@ +"""Language data tests.""" \ No newline at end of file diff --git a/tests/eva/text/data/datasets/__init__.py b/tests/eva/text/data/datasets/__init__.py new file mode 100644 index 00000000..f3604a2f --- /dev/null +++ b/tests/eva/text/data/datasets/__init__.py @@ -0,0 +1 @@ +"""Language datasets tests.""" \ No newline at end of file diff --git a/tests/eva/text/data/datasets/classification/__init__.py b/tests/eva/text/data/datasets/classification/__init__.py new file mode 100644 index 00000000..1f19f351 --- /dev/null +++ b/tests/eva/text/data/datasets/classification/__init__.py @@ -0,0 +1 @@ +"""Tests for the text classification datasets.""" \ No newline at end of file diff --git a/tests/eva/text/data/datasets/classification/test_pubmedqa.py b/tests/eva/text/data/datasets/classification/test_pubmedqa.py new file mode 100644 index 00000000..c76e4efc --- /dev/null +++ b/tests/eva/text/data/datasets/classification/test_pubmedqa.py @@ -0,0 +1,63 @@ +"""PubMedQA dataset tests.""" + +import pytest +from typing import Literal, Dict, Any +from eva.language.data import datasets + + +@pytest.mark.parametrize( + "split, expected_length", + [ + ("train", 450), + ("test", 500), + ("validation", 50), + ("train+test+validation", 1000) + ], +) +def test_length(pubmedqa_dataset: datasets.PubMedQA, expected_length: int) -> None: + """Tests the length of the dataset.""" + assert len(pubmedqa_dataset) == expected_length + + +@pytest.mark.parametrize( + "split, index", + [ + ("train", 0), + ("train", 10), + ("test", 0), + ("validation", 0), + ("train+test+validation", 0), + ], +) +def test_sample(pubmedqa_dataset: datasets.PubMedQA, index: int) -> None: + """Tests the format of a dataset sample.""" + sample = pubmedqa_dataset[index] + assert isinstance(sample, tuple) + assert len(sample) == 3 + + text, target, metadata = sample + assert isinstance(text, str) + assert text.startswith("Question: ") + assert "Context: " in text + + assert isinstance(target, int) + assert target in [0, 1, 2] + + assert isinstance(metadata, dict) + required_keys = {"year", "labels", "meshes", "long_answer", + "reasoning_required", "reasoning_free"} + assert all(key in metadata for key in required_keys) + + +@pytest.mark.parametrize("split", ["train", "test", "validation", "train+test+validation"]) +def test_classes(pubmedqa_dataset: datasets.PubMedQA) -> None: + """Tests the dataset classes.""" + assert pubmedqa_dataset.classes == ["no", "yes", "maybe"] + assert pubmedqa_dataset.class_to_idx == {"no": 0, "yes": 1, "maybe": 2} + + +@pytest.fixture(scope="function") +def pubmedqa_dataset(split: Literal["train", "test", "validation", "train+test+validation"]) -> datasets.PubMedQA: + """PubMedQA dataset fixture.""" + dataset = datasets.PubMedQA(split=split) + return dataset