-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Rita Kurban
authored and
Rita Kurban
committed
Dec 27, 2024
1 parent
71a2e18
commit cf878b9
Showing
13 changed files
with
397 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""eva language API.""" | ||
|
||
try: | ||
from eva.language.data import datasets | ||
except ImportError as e: | ||
msg = ( | ||
"eva language requirements are not installed.\n\n" | ||
"Please pip install as follows:\n" | ||
' python -m pip install "kaiko-eva[language]" --upgrade' | ||
) | ||
raise ImportError(str(e) + "\n\n" + msg) from e | ||
|
||
__all__ = ["datasets"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Language data API.""" | ||
|
||
from eva.language.data import datasets | ||
|
||
__all__ = ["datasets"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Language Datasets API.""" | ||
|
||
from eva.language.data.datasets.classification import PubMedQA | ||
from eva.language.data.datasets.language import LanguageDataset | ||
|
||
__all__ = [ | ||
"PubMedQA", | ||
"LanguageDataset", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Text classification datasets API.""" | ||
|
||
from eva.language.data.datasets.classification.pubmedqa import PubMedQA | ||
|
||
__all__ = [ | ||
"PubMedQA", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
"""Base for text classification datasets.""" | ||
|
||
import abc | ||
from typing import Any, Callable, Dict, List, Tuple | ||
|
||
import torch | ||
from typing_extensions import override | ||
|
||
from eva.language.data.datasets.language import LanguageDataset | ||
|
||
class LanguageClassification(LanguageDataset[Tuple[str, torch.Tensor]], abc.ABC): | ||
"""Text classification abstract dataset.""" | ||
|
||
def __init__(self) -> None: | ||
"""Initializes the text classification dataset.""" | ||
super().__init__() | ||
|
||
@property | ||
def classes(self) -> List[str] | None: | ||
"""Returns list of class names.""" | ||
return None | ||
|
||
@property | ||
def class_to_idx(self) -> Dict[str, int] | None: | ||
"""Returns class name to index mapping.""" | ||
return None | ||
|
||
def load_metadata(self, index: int) -> Dict[str, Any] | None: | ||
"""Returns the dataset metadata. | ||
Args: | ||
index: The index of the data sample. | ||
Returns: | ||
The sample metadata. | ||
""" | ||
return None | ||
|
||
@abc.abstractmethod | ||
def load_text(self, index: int) -> str: | ||
"""Returns the text content. | ||
Args: | ||
index: The index of the data sample. | ||
Returns: | ||
The text content. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def load_target(self, index: int) -> torch.Tensor: | ||
"""Returns the target label. | ||
Args: | ||
index: The index of the data sample. | ||
Returns: | ||
The target label. | ||
""" | ||
raise NotImplementedError | ||
|
||
@override | ||
def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]: | ||
return ( | ||
self.load_text(index), | ||
self.load_target(index), | ||
self.load_metadata(index) or {} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
"""PubMedQA dataset class.""" | ||
|
||
from typing import Dict, List, Any | ||
|
||
from typing_extensions import override | ||
from datasets import load_dataset | ||
|
||
from eva.language.data.datasets.classification import base | ||
|
||
|
||
class PubMedQA(base.LanguageClassification): | ||
"""Dataset class for PubMedQA question answering task.""" | ||
|
||
_license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)" | ||
"""Dataset license.""" | ||
|
||
def __init__( | ||
self, | ||
split: str | None = "train+test+validation", | ||
) -> None: | ||
"""Initialize the PubMedQA dataset. | ||
Args: | ||
split: Dataset split to use. If default, entire dataset of 1000 samples is used. | ||
""" | ||
super().__init__() | ||
self._split = split | ||
self.dataset = load_dataset( | ||
"bigbio/pubmed_qa", | ||
name="pubmed_qa_labeled_fold0_source", | ||
split=split | ||
) | ||
|
||
@property | ||
@override | ||
def classes(self) -> List[str]: | ||
return ["no", "yes", "maybe"] | ||
|
||
@property | ||
@override | ||
def class_to_idx(self) -> Dict[str, int]: | ||
return {"no": 0, "yes": 1, "maybe": 2} | ||
|
||
@override | ||
def load_text(self, index: int) -> str: | ||
sample = self.dataset[index] | ||
return f"Question: {sample['QUESTION']}\nContext: {sample['CONTEXTS']}" | ||
|
||
@override | ||
def load_target(self, index: int) -> int: | ||
return self.class_to_idx[self.dataset[index]["final_decision"]] | ||
|
||
@override | ||
def load_metadata(self, index: int) -> Dict[str, Any]: | ||
sample = self.dataset[index] | ||
return { | ||
"year": sample["YEAR"], | ||
"labels": sample["LABELS"], | ||
"meshes": sample["MESHES"], | ||
"long_answer": sample["LONG_ANSWER"], | ||
"reasoning_required": sample["reasoning_required_pred"], | ||
"reasoning_free": sample["reasoning_free_pred"], | ||
} | ||
|
||
@override | ||
def __len__(self) -> int: | ||
return len(self.dataset) | ||
|
||
def _print_license(self) -> None: | ||
"""Prints the dataset license.""" | ||
print(f"Dataset license: {self._license}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Vision Dataset base class.""" | ||
|
||
import abc | ||
from typing import Generic, TypeVar | ||
|
||
from eva.core.data.datasets import base | ||
|
||
DataSample = TypeVar("DataSample") | ||
"""The data sample type.""" | ||
|
||
class LanguageDataset(base.MapDataset, abc.ABC, Generic[DataSample]): | ||
"""Base dataset class for text tasks.""" | ||
pass |
Oops, something went wrong.