Skip to content

Commit

Permalink
Add PubMedQA dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Rita Kurban authored and Rita Kurban committed Dec 27, 2024
1 parent 71a2e18 commit cf878b9
Show file tree
Hide file tree
Showing 13 changed files with 397 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/eva/language/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""eva language API."""

try:
from eva.language.data import datasets
except ImportError as e:
msg = (
"eva language requirements are not installed.\n\n"
"Please pip install as follows:\n"
' python -m pip install "kaiko-eva[language]" --upgrade'
)
raise ImportError(str(e) + "\n\n" + msg) from e

__all__ = ["datasets"]
5 changes: 5 additions & 0 deletions src/eva/language/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Language data API."""

from eva.language.data import datasets

__all__ = ["datasets"]
9 changes: 9 additions & 0 deletions src/eva/language/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Language Datasets API."""

from eva.language.data.datasets.classification import PubMedQA
from eva.language.data.datasets.language import LanguageDataset

__all__ = [
"PubMedQA",
"LanguageDataset",
]
7 changes: 7 additions & 0 deletions src/eva/language/data/datasets/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Text classification datasets API."""

from eva.language.data.datasets.classification.pubmedqa import PubMedQA

__all__ = [
"PubMedQA",
]
69 changes: 69 additions & 0 deletions src/eva/language/data/datasets/classification/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Base for text classification datasets."""

import abc
from typing import Any, Callable, Dict, List, Tuple

import torch
from typing_extensions import override

from eva.language.data.datasets.language import LanguageDataset

class LanguageClassification(LanguageDataset[Tuple[str, torch.Tensor]], abc.ABC):
"""Text classification abstract dataset."""

def __init__(self) -> None:
"""Initializes the text classification dataset."""
super().__init__()

@property
def classes(self) -> List[str] | None:
"""Returns list of class names."""
return None

@property
def class_to_idx(self) -> Dict[str, int] | None:
"""Returns class name to index mapping."""
return None

def load_metadata(self, index: int) -> Dict[str, Any] | None:
"""Returns the dataset metadata.
Args:
index: The index of the data sample.
Returns:
The sample metadata.
"""
return None

@abc.abstractmethod
def load_text(self, index: int) -> str:
"""Returns the text content.
Args:
index: The index of the data sample.
Returns:
The text content.
"""
raise NotImplementedError

@abc.abstractmethod
def load_target(self, index: int) -> torch.Tensor:
"""Returns the target label.
Args:
index: The index of the data sample.
Returns:
The target label.
"""
raise NotImplementedError

@override
def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]:
return (
self.load_text(index),
self.load_target(index),
self.load_metadata(index) or {}
)
71 changes: 71 additions & 0 deletions src/eva/language/data/datasets/classification/pubmedqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""PubMedQA dataset class."""

from typing import Dict, List, Any

from typing_extensions import override
from datasets import load_dataset

from eva.language.data.datasets.classification import base


class PubMedQA(base.LanguageClassification):
"""Dataset class for PubMedQA question answering task."""

_license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)"
"""Dataset license."""

def __init__(
self,
split: str | None = "train+test+validation",
) -> None:
"""Initialize the PubMedQA dataset.
Args:
split: Dataset split to use. If default, entire dataset of 1000 samples is used.
"""
super().__init__()
self._split = split
self.dataset = load_dataset(
"bigbio/pubmed_qa",
name="pubmed_qa_labeled_fold0_source",
split=split
)

@property
@override
def classes(self) -> List[str]:
return ["no", "yes", "maybe"]

@property
@override
def class_to_idx(self) -> Dict[str, int]:
return {"no": 0, "yes": 1, "maybe": 2}

@override
def load_text(self, index: int) -> str:
sample = self.dataset[index]
return f"Question: {sample['QUESTION']}\nContext: {sample['CONTEXTS']}"

@override
def load_target(self, index: int) -> int:
return self.class_to_idx[self.dataset[index]["final_decision"]]

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
sample = self.dataset[index]
return {
"year": sample["YEAR"],
"labels": sample["LABELS"],
"meshes": sample["MESHES"],
"long_answer": sample["LONG_ANSWER"],
"reasoning_required": sample["reasoning_required_pred"],
"reasoning_free": sample["reasoning_free_pred"],
}

@override
def __len__(self) -> int:
return len(self.dataset)

def _print_license(self) -> None:
"""Prints the dataset license."""
print(f"Dataset license: {self._license}")
13 changes: 13 additions & 0 deletions src/eva/language/data/datasets/language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Vision Dataset base class."""

import abc
from typing import Generic, TypeVar

from eva.core.data.datasets import base

DataSample = TypeVar("DataSample")
"""The data sample type."""

class LanguageDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
"""Base dataset class for text tasks."""
pass
Loading

0 comments on commit cf878b9

Please sign in to comment.