Source code for text_machina.src.extractors.base
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from datasets import Dataset
from ..common.exceptions import ExtractorEmptyColumns
from ..config import InputConfig
from ..types import TaskType
from .utils import clean_inputs
[docs]class Extractor(ABC):
"""
Base class for an extractor.
"""
def __init__(
self,
input_config: InputConfig,
task_type: TaskType,
workspace: Dict[str, Any] = {},
args: Dict[str, Any] = {},
):
self.input_config = input_config
self.task_type = task_type
self.workspace: Dict[str, Any] = workspace
self.args: Dict[str, Any] = args
self.check_valid_args()
[docs] def check_valid_args(self) -> None:
"""
Checks if the arguments passed to the extractor are valid.
Raises:
ExtractorInvalidArgs: if the arguments are invalid.
"""
...
@abstractmethod
def _extract(self, dataset: Dataset) -> Dict[str, List[str]]:
"""
Returns the prompt inputs for each sample in a dataset.
This method must be overridden in each new extractor.
Example:
input = Dataset({"text": ["hi Jose", "hi Areg"], "label": [0, 1]})
output = {"entities": ["Jose", "Areg"], "interject": ["hi", "hi"]}
Args:
dataset (Dataset): A dataset to extract inputs from.
Returns:
Dict[str, List[str]]: A dictionary mapping each template
key to a list of prompt inputs
(one input per template key and example).
"""
...
[docs] def prepare_human(self, human_texts: List[str]) -> List[str]:
"""
Prepares the human texts. Some extractors could need to modify
human texts according to the extractions, e.g., remove prefixes
from texts to ensure that generations and human texts are
continuations of the same prefix.
Args:
human_texts (List[str]): list of human texts.
Returns:
List[str]: prepared human texts.
"""
return human_texts
[docs] def extract(self, dataset: Dataset) -> Dict[str, List[str]]:
"""
Calls _extract and cleans the extracted inputs.
Args:
dataset (Dataset): A dataset to extract inputs from.
Returns:
Dict[str, List[str]]: A dictionary mapping each template
key to a list of prompt inputs
(one input per template key and example).
Raises:
ExtractorEmptyColumns: if any field of the prompt_inputs is empty.
"""
prompt_inputs = self._extract(dataset)
prompt_inputs = {
column: clean_inputs(prompt_inputs[column])
for column in prompt_inputs
}
for field in prompt_inputs:
if len(prompt_inputs[field]) == 0:
raise ExtractorEmptyColumns(self.__class__.__name__, field)
return prompt_inputs