Source code for text_machina.src.extractors.word_masking

from math import ceil
from random import choice, randint, uniform
from typing import Any, Dict, List

from datasets import Dataset

from ..common import color_log, get_logger
from ..common.exceptions import ExtractorInvalidArgs
from ..config import InputConfig
from ..types import TaskType
from .base import Extractor
from .utils import spacy_pipeline

_logger = get_logger(__name__)


[docs]class WordMasking(Extractor): """ Extractor that fills the prompt template with a text with masked word spans and the LLM has to generate all the masked word spans. This extractor needs two template placeholders: - {masked_text}: will be filled with a text with masked word spans. This extractor allows to pass the following arguments in the `extractor_args` field from the config: - mask_token (str): mask token, e.g., "MASK". Several masks in a text will be appended with the index, e.g. "MASK-0" - percentage_range (List[float]): range delimiting the percentage of word spans to be masked. At least one word span will be always masked. - span_length_range (List[int]): range where to sample the length of each masked span. """ def __init__(self, input_config: InputConfig, task_type: TaskType): args: Dict[str, Any] = input_config.extractor_args.get( "word_masking", {} ) workspace: Dict[str, Any] = {"masked_texts": []} super().__init__(input_config, task_type, workspace, args) _logger.warn( color_log( f"You are using the `{self.__class__.__name__}` extractor." " Consider that few models like GPT-4 can work properly with this" " type of generation. Models must be:\n" "1) Capable of generating proper JSON.\n" "2) Capable enough to generate all the masks.", "bold_yellow", ) )
[docs] def check_valid_args(self): mandatory_args = ["mask_token", "percentage_range", "span_length_range"] for mandatory_arg in mandatory_args: if mandatory_arg not in self.args: raise ExtractorInvalidArgs( self.__class__.__name__, mandatory_args )
[docs] def prepare_human(self, human_texts: List[str]) -> List[str]: return human_texts
def _format_mask_token( self, idx: int, add_period: bool = False, add_whitespace: bool = True ) -> str: period = "." if add_period else "" whitespace = " " if add_whitespace else "" return f"{self.args['mask_token']}-{idx}{period}{whitespace}" def _extract(self, dataset: Dataset) -> Dict[str, List[str]]: text_column = self.input_config.dataset_text_column texts = spacy_pipeline( texts=dataset[text_column], language=self.input_config.language ) masked_texts = [] for text in texts: words = [word.text_with_ws for word in text] percentage_masks = uniform(*self.args["percentage_range"]) spans_to_mask = min( ceil((len(words) - 1) / max(self.args["span_length_range"])), ceil( percentage_masks * (len(words) / max(self.args["span_length_range"])) ), ) sampled_positions: List[int] = [] positions = list(range(len(words))) # Sample N words with at least max(args["span_length_range"]) # positions of separation while ( len(sampled_positions) < spans_to_mask and len(positions) >= spans_to_mask ): start_position = choice(positions) valid_positions = [ position for position in positions if abs(position - start_position) >= max(self.args["span_length_range"]) ] if valid_positions: chosen_position = choice(valid_positions) sampled_positions.append(chosen_position) positions.remove(chosen_position) # Sample lengths and mask spans of the text sampled_positions = sorted(sampled_positions) span_lengths = [ randint(*self.args["span_length_range"]) for _ in range(len(sampled_positions)) ] masked_words = [] prev_position = 0 mask_idxs = 0 for position, length in zip(sampled_positions, span_lengths): masked_words.extend(words[prev_position:position]) masked_words.append(self._format_mask_token(mask_idxs)) prev_position = position + length mask_idxs += 1 masked_texts.append("".join(masked_words)) self.workspace["masked_texts"] = masked_texts return {"masked_text": masked_texts}