Source code for text_machina.src.extractors.sentence_masking

from math import ceil
from random import choice, sample, uniform
from typing import Any, Dict, List

from datasets import Dataset

from ..common import color_log, get_logger
from ..common.exceptions import ExtractorInvalidArgs
from ..config import InputConfig
from ..types import TaskType
from .base import Extractor
from .utils import spacy_pipeline

_logger = get_logger(__name__)


[docs]class SentenceMasking(Extractor): """ Extractor that fills the prompt template with a text with masked sentences and the LLM has to generate all the masked sentences. This extractor needs two template placeholders: - {masked_text}: will be filled with a text with masked sentences. This extractor allows to pass the following arguments in the `extractor_args` field from the config: - mask_token (str): mask token, e.g., "MASK". Several masks in a text will be appended with the index, e.g. "MASK-0" - percentage_range (List[float]): range delimiting the percentage of sentences to be masked. At least one sentence will be always masked. """ def __init__(self, input_config: InputConfig, task_type: TaskType): args: Dict[str, Any] = input_config.extractor_args.get( "sentence_masking", {} ) workspace: Dict[str, Any] = {"masked_texts": []} super().__init__(input_config, task_type, workspace, args) _logger.warn( color_log( f"You are using the `{self.__class__.__name__}` extractor." " Consider that few models like GPT-4 can work properly with this" " type of generation. Models must be:\n" "1) Capable of generating proper JSON.\n" "2) Capable enough to generate all the masks.", "bold_yellow", ) )
[docs] def check_valid_args(self): mandatory_args = ["mask_token", "percentage_range"] for mandatory_arg in mandatory_args: if mandatory_arg not in self.args: raise ExtractorInvalidArgs( self.__class__.__name__, mandatory_args )
[docs] def prepare_human(self, human_texts: List[str]) -> List[str]: return human_texts
def _format_mask_token( self, idx: int, add_period: bool = False, add_whitespace: bool = True ) -> str: period = "." if add_period else "" whitespace = " " if add_whitespace else "" return f"{self.args['mask_token']}-{idx}{period}{whitespace}" def _extract(self, dataset: Dataset) -> Dict[str, List[str]]: text_column = self.input_config.dataset_text_column texts = spacy_pipeline( texts=dataset[text_column], language=self.input_config.language ) masked_texts = [] for text in texts: sentences = [sent.text_with_ws for sent in text.sents] # In case of 1 sentences, avoid masking it, # just add a mask sentence to right or left randomly. if len(sentences) == 1: mask_position = choice([0, 1]) fmt_masked_token = self._format_mask_token(0) sentences.insert(mask_position, fmt_masked_token) masked_texts.append("".join(sentences).strip()) else: percentage_masks = uniform(*self.args["percentage_range"]) # Min to avoid masking all the sentences. sents_to_mask = min( len(sentences) - 1, ceil(percentage_masks * len(sentences)) ) positions = sorted(sample(range(len(sentences)), sents_to_mask)) for mask_idx, sent_idx in enumerate(positions): sentences[sent_idx] = self._format_mask_token(mask_idx) masked_texts.append("".join(sentences).strip()) self.workspace["masked_texts"] = masked_texts return {"masked_text": masked_texts}