Source code for text_machina.src.extractors.word_gap

from math import ceil
from random import randint
from typing import Any, Dict, List, Tuple

from datasets import Dataset

from ..common.exceptions import ExtractorInvalidArgs
from ..config import InputConfig
from ..types import TaskType
from .base import Extractor
from .utils import spacy_pipeline


[docs]class WordGap(Extractor): """ Extractor that fills the prompt template with a boundary of two word spans (left-side and right-side of a sampled word), and with the number of words the LLM has to generate in between the boundary word spans. This extractor needs two template placeholders: - {n}: will be filled with the number of words to generate between the boundary words. - {boundaries}: will be filled with the boundary words separated by the gap token and newlines. E.g., "words1 ____ words2" This extractor allows to pass the following arguments in the `extractor_args` field from the config: - gap_token (str): gap token, e.g., "____" - max_percentage_boundaries (float): max percentage of boundaries to sample from a text. In a text of N words, there will be N-1 possible boundaries of two word spans. - max_word_span (int): max number of words to be generated between the boundary words. - range_boundary_size (List[float, float]): range where to sample the length of the word spans in the boundaries. """ def __init__(self, input_config: InputConfig, task_type: TaskType): args: Dict[str, Any] = input_config.extractor_args.get("word_gap", {}) workspace: Dict[str, Any] = { "positions": [], "human_spans": [], "num_boundaries": [], } super().__init__(input_config, task_type, workspace, args)
[docs] def check_valid_args(self): mandatory_args = [ "gap_token", "max_percentage_boundaries", "max_word_span", "range_boundary_size", ] for mandatory_arg in mandatory_args: if mandatory_arg not in self.args: raise ExtractorInvalidArgs( self.__class__.__name__, mandatory_args )
[docs] def prepare_human(self, human_texts: List[str]) -> List[str]: return [ "".join(doc_words) for doc_words in self.workspace["human_spans"] ]
def _format_boundary(self, pair: Tuple[str, str]) -> str: return f"{pair[0]} {self.args['gap_token']} {pair[1]}" def _extract(self, dataset: Dataset) -> Dict[str, List[str]]: text_column = self.input_config.dataset_text_column texts = spacy_pipeline( texts=dataset[text_column], language=self.input_config.language ) boundaries = [] n_generated_words = [] for text in texts: words = [word.text_with_ws for word in text] self.workspace["human_spans"].append(words) self.workspace["positions"].append([]) # If the text has more than one word, it can be used # to interleave generations w/ human words. max_boundaries_to_select = ceil( (len(words) - 1) * self.args["max_percentage_boundaries"] ) accum = 0 if len(words) > 1: idx = 0 while idx < len(words) - 1: is_boundary = randint(0, 1) boundary_size = 0 if is_boundary and accum < max_boundaries_to_select: gen_sents = randint(1, self.args["max_word_span"]) boundary_size = randint( *self.args["range_boundary_size"] ) boundary = ( "".join( words[max(0, idx - boundary_size) : idx + 1] ), "".join(words[idx + 1 : idx + 1 + boundary_size]), ) boundaries.append(self._format_boundary(boundary)) self.workspace["positions"][-1].append(idx) n_generated_words.append(gen_sents) accum += 1 # Move far away to avoid coherence conflicts # between boundary generations idx += 1 + boundary_size # At this point: # (1) No boundaries have been selected: # -> consider all the text as human # (2) >=1 boundary have been selected: # -> the text could be interleaved (gen, human) # (3) The text has only one word: # -> consider all the text as human self.workspace["num_boundaries"].append(accum) return { "n": list(map(str, n_generated_words)), "boundaries": boundaries, }