Source code for text_machina.src.generators.mixcase

import json
import re
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Type

from datasets import Dataset, concatenate_datasets

from ..common import color_log, get_logger
from ..common.exceptions import DatasetGenerationError
from ..config import Config
from ..extractors import (
    Extractor,
    SentenceGap,
    SentenceMasking,
    WordGap,
    WordMasking,
)
from ..models.types import GENERATION_ERROR
from ..types import DetectionLabels, LabeledSpan, Placeholders
from .base import DatasetGenerator

_logger = get_logger(__name__)


[docs]class MixCaseDatasetGenerator(DatasetGenerator): """ Dataset generator for the mixcase task type. """ def __init__(self, config: Config) -> None: super().__init__(config=config) def _pack(self, generations: List[str], **kwargs) -> Dataset: if isinstance(self.prompter.extractor, (SentenceGap, WordGap)): packer: Type[MixCasePacker] = MixCaseGapPacker elif isinstance( self.prompter.extractor, (SentenceMasking, WordMasking) ): packer = MixCaseMaskPacker else: packer = MixCaseRewritingPacker return packer(self.config, self.prompter.extractor)._pack( generations, **kwargs )
[docs]class MixCasePacker(ABC): """ Base class for mixcase packers. """ def __init__(self, config: Config, extractor: Extractor) -> None: self.config = config self.extractor = extractor @abstractmethod def _build_samples( self, generations: List[str], ) -> Tuple[List[str], List[List[Dict]]]: """ Builds the samples to be packed in a dataset by preparing the texts and the labels according to the extractor (gap-based or mask-based). Args: generations (List[str]): list of generations. kwargs (Dict): additional keyword arguments. Returns: Tuple[List[str], List[List[Dict]]]: texts and labels to be added to the dataset. """ ... @abstractmethod def _pack(self, generations: List[str], **kwargs) -> Dataset: """ Combines and labels the generated and human texts according to the extractor. Args: generations (List[str]): list of generated texts. kwargs: additional keyword arguments. Returns: Dataset: a dataset including all the texts. """ ...
[docs]class MixCaseRewritingPacker(MixCasePacker): """ Packer for mixcase task type when using rewriting-based extractors. """ def __init__(self, config: Config, extractor: Extractor) -> None: super().__init__(config=config, extractor=extractor) def _build_samples( self, generations: List[str], ) -> Tuple[List[str], List[List[Dict]]]: """ Interleaves generated and human spans to build mixcase samples. The `start` and `end` of the labels follow the Python [`start`, `end`) convention, i.e., including the `start` element and excluding `end`. For instance, given the text: "I like Apolo. I don't like Athenea", being the first sentence human-written and the second one generated, the labels will be: [ {"start": 0, "end": 14, "label": "human"}, {"start": 14, "end": 34, "label": "generated"}, ] Args: generations (List[str]): list of generations Returns: Tuple[List[str], List[List[Dict]]]: the interleaved texts and the list of labels of each text. """ texts, labels = [], [] prev_sample = 0 for sentences, positions in zip( self.extractor.workspace["human_spans"], self.extractor.workspace["positions"], ): rewrittens = generations[prev_sample : prev_sample + len(positions)] positions = sorted(positions) sample_labels: List[Dict] = [] # Replace rewritten sentences sentences_with_rewrites = sentences[:] for idx, position in enumerate(positions): sentences_with_rewrites[position] = rewrittens[idx] sample_text = "".join(sentences_with_rewrites) # Mark generated spans generated_labels = [] for idx, position in enumerate(positions): start_pos = len("".join(sentences_with_rewrites[:position])) end_pos = len("".join(sentences_with_rewrites[: position + 1])) generated_labels.append( LabeledSpan( start=start_pos, end=end_pos, label=DetectionLabels.GENERATED.value, ).model_dump() ) # Mark human spans human_labels = [] prev_generated_position = 0 for idx in range(len(generated_labels)): generated_label = generated_labels[idx] if prev_generated_position < generated_label["start"]: human_labels.append( LabeledSpan( start=prev_generated_position, end=generated_label["start"], label=DetectionLabels.HUMAN.value, ).model_dump() ) prev_generated_position = generated_label["end"] # Add human label if the end of the text has not been # reached by the last generated span last_generated_position = generated_labels[-1]["end"] if last_generated_position < len(sample_text): human_labels.append( LabeledSpan( start=last_generated_position, end=len(sample_text), label=DetectionLabels.HUMAN.value, ).model_dump() ) # Join generated labels with human labels and merge overlappings. merged_labels = sorted( human_labels + generated_labels, key=lambda span: span["start"] ) sample_labels = [] while len(merged_labels): current_span = merged_labels.pop(0) current_label = current_span["label"] spans_with_same_label = [] while ( len(merged_labels) and merged_labels[0]["label"] == current_label ): spans_with_same_label.append(merged_labels.pop(0)) if not spans_with_same_label: sample_labels.append(current_span) else: sample_labels.append( LabeledSpan( start=current_span["start"], end=spans_with_same_label[-1]["end"], label=current_label, ).model_dump() ) prev_sample += len(positions) texts.append("".join(sentences_with_rewrites)) labels.append(sample_labels) return texts, labels def _pack(self, generations: List[str], **kwargs) -> Dataset: """ Combines and labels the generated and human texts when using rewriting-based extractors (`sentence_rewriting`) Args: generations (List[str]): list of generated texts. kwargs: additional keyword arguments. Returns: Dataset: a dataset including all the texts. """ prompted_dataset = kwargs.get("prompted_dataset", None) if prompted_dataset is None: raise DatasetGenerationError(f"prompted_dataset not found: {self}") model_name = self.config.model.model_name domain = self.config.input.domain extractor_name = self.config.input.extractor texts, labels = self._build_samples(generations) prev_sample = 0 mixed_samples = [] for idx, (text, sample_labels) in enumerate(zip(texts, labels)): n_positions = len(self.extractor.workspace["positions"][idx]) prompt = prompted_dataset.prompted_texts[ prev_sample : prev_sample + n_positions ] or [Placeholders.NO_PROMPT.value] mixed_samples.append( { "prompt": prompt, "text": text, "label": sample_labels, "model": model_name, "domain": domain, "extractor": extractor_name, } ) prev_sample += n_positions mixed_dataset = Dataset.from_list(mixed_samples) human_dataset = Dataset.from_list( [ { "prompt": [Placeholders.NO_PROMPT.value], "text": text, "label": [ LabeledSpan( start=0, end=len(text), label=DetectionLabels.HUMAN.value, ).model_dump() ], "model": DetectionLabels.HUMAN.value, "domain": domain, "extractor": Placeholders.NO_EXTRACTOR.value, } for text in prompted_dataset.human_texts ] ) dataset = concatenate_datasets([human_dataset, mixed_dataset]) dataset = dataset.shuffle() return dataset
[docs]class MixCaseMaskPacker(MixCasePacker): """ Packer for mixcase task type when using mask-based extractors. """ def __init__(self, config: Config, extractor: Extractor) -> None: super().__init__(config=config, extractor=extractor) self.mask_regex = re.compile( rf"({self.extractor.args['mask_token']}-\d+)" ) def _build_error_sample(self) -> Tuple[str, List[Dict]]: """ Helper to build an error sample when required, e.g., invalid JSON, missing masks, etc; that logs the cause of the resulting error (LLM uncapable always). Returns: Tuple[str, List[Dict]]: the text and labels of an error sample. """ _logger.info( color_log( "The completion was not a valid JSON." " This error is related to the LLM capabilities," " please, use another LLM if there are many" " errors of this type." " The text will be considered a generation error" f"`{GENERATION_ERROR}`", "bold_yellow", ) ) return GENERATION_ERROR, [ LabeledSpan( start=0, end=len(GENERATION_ERROR), label=DetectionLabels.GENERATED.value, ).model_dump() ] def _build_sample( self, masked_text: str, parsed_generation: Dict[str, str] ) -> Tuple[str, List[Dict]]: """ Builds a sample by reconstructing the masks in a text. Args: masked_text (str): a text with masks to be replaced. parsed_generation: (Dict[str, str]): dictionary mapping masks to texts, e.g. {"MASK-0": <text>, ...} Returns: Tuple[str, List[Dict]]: the text and labels of the sample. """ # Replace masks in text and compute the generated labels # The following algorithm relies on masks sorted by their index. mask_completions = sorted( parsed_generation.items(), key=lambda x: int(x[0].split("-")[1]) ) sample_text = masked_text generated_labels = [] for mask_token, completion in mask_completions: mask_match = re.search(rf"{mask_token}\b", sample_text) if mask_match is None: return self._build_error_sample() mask_position = mask_match.span()[0] # Ensure the completion does not append another mask token completion = self.mask_regex.sub("", completion) sample_text = re.sub(rf"({mask_token})\b", completion, sample_text) gen_end = mask_position + len(completion) generated_labels.append( LabeledSpan( start=mask_position, end=gen_end + 1, label=DetectionLabels.GENERATED.value, ).model_dump() ) # Compute human labels human_labels = [] prev_generated_position = 0 for idx in range(len(generated_labels)): generated_label = generated_labels[idx] if prev_generated_position < generated_label["start"]: human_labels.append( LabeledSpan( start=prev_generated_position, end=generated_label["start"], label=DetectionLabels.HUMAN.value, ).model_dump() ) prev_generated_position = generated_label["end"] # Add human label if the end of the text has not been # reached by the last generated span last_generated_position = generated_labels[-1]["end"] if last_generated_position < len(sample_text): human_labels.append( LabeledSpan( start=last_generated_position, end=len(sample_text), label=DetectionLabels.HUMAN.value, ).model_dump() ) # Join generated labels with human labels and merge overlappings. merged_labels = sorted( human_labels + generated_labels, key=lambda span: span["start"] ) sample_labels = [] while len(merged_labels): current_span = merged_labels.pop(0) current_label = current_span["label"] spans_with_same_label = [] while ( len(merged_labels) and merged_labels[0]["label"] == current_label ): spans_with_same_label.append(merged_labels.pop(0)) if not spans_with_same_label: sample_labels.append(current_span) else: sample_labels.append( LabeledSpan( start=current_span["start"], end=spans_with_same_label[-1]["end"], label=current_label, ).model_dump() ) return sample_text, sample_labels def _build_samples( self, generations: List[str], ) -> Tuple[List[str], List[List[Dict]]]: """ Reconstructs masked texts using completions to build mixcase samples. The `start` and `end` of the labels follow the Python [`start`, `end`) convention, i.e., including the `start` element and excluding `end`. For instance, given the text: "I like Apolo. I don't like Athenea", being the first sentence human-written and the second one generated, the labels will be: [ {"start": 0, "end": 14, "label": "human"}, {"start": 14, "end": 34, "label": "generated"}, ] Args: generations (List[str]): list of generations Returns: Tuple[List[str], List[List[Dict]]]: the interleaved texts and the list of labels of each text. """ texts, labels = [], [] for generation, masked_text in zip( generations, self.extractor.workspace["masked_texts"] ): try: parsed_generation = json.loads( generation[ generation.find("{") : generation.rfind("}") + 1 ].strip() ) except json.decoder.JSONDecodeError: # If the LLM didn't generate a valid JSON output, # the sample will be considered as an error. sample_text, sample_labels = self._build_error_sample() else: # If the number of masks in the json do not match with # the number of masks in the text, the sample will be # considered as an error. masks_in_text = set(self.mask_regex.findall(masked_text)) masks_in_completion = set(parsed_generation.keys()) if len(masks_in_text.intersection(masks_in_completion)) != len( masks_in_text ): sample_text, sample_labels = self._build_error_sample() # Otherwise, the sample can be built. else: sample_text, sample_labels = self._build_sample( masked_text, parsed_generation ) finally: texts.append(sample_text) labels.append(sample_labels) return texts, labels def _pack(self, generations: List[str], **kwargs) -> Dataset: """ Combines and labels the generated and human texts when using mask-based extractors (`sentence_masking`, `word_masking`, etc.) Args: generations (List[str]): list of generated texts. kwargs: additional keyword arguments. Returns: Dataset: a dataset including all the texts. """ prompted_dataset = kwargs.get("prompted_dataset", None) if prompted_dataset is None: raise DatasetGenerationError(f"prompted_dataset not found: {self}") model_name = self.config.model.model_name domain = self.config.input.domain extractor_name = self.config.input.extractor texts, labels = self._build_samples(generations) generated_dataset = Dataset.from_list( [ { "prompt": prompt, "text": text, "label": label, "model": model_name, "domain": domain, "extractor": extractor_name, } for prompt, text, label in zip( prompted_dataset.prompted_texts, texts, labels ) ] ) human_dataset = Dataset.from_list( [ { "prompt": Placeholders.NO_PROMPT.value, "text": text, "label": [ LabeledSpan( start=0, end=len(text), label=DetectionLabels.HUMAN.value, ).model_dump() ], "model": DetectionLabels.HUMAN.value, "domain": domain, "extractor": Placeholders.NO_EXTRACTOR.value, } for text in prompted_dataset.human_texts ] ) dataset = concatenate_datasets([human_dataset, generated_dataset]) dataset = dataset.shuffle() return dataset
[docs]class MixCaseGapPacker(MixCasePacker): """ Packer for mixcase task type when using gap-based extractors. """ def __init__(self, config: Config, extractor: Extractor) -> None: super().__init__(config=config, extractor=extractor) def _build_samples( self, generations: List[str], ) -> Tuple[List[str], List[List[Dict]]]: """ Interleaves generated and human spans to build mixcase samples. The `start` and `end` of the labels follow the Python [`start`, `end`) convention, i.e., including the `start` element and excluding `end`. For instance, given the text: "I like Apolo. I don't like Athenea", being the first sentence human-written and the second one generated, the labels will be: [ {"start": 0, "end": 14, "label": "human"}, {"start": 14, "end": 34, "label": "generated"}, ] Args: generations (List[str]): list of generations Returns: Tuple[List[str], List[List[Dict]]]: the interleaved texts and the list of labels of each text. """ texts, labels = [], [] prev_sample = 0 for idx, sample_boundaries in enumerate( self.extractor.workspace["num_boundaries"] ): # Text w/o sampled boundaries if sample_boundaries == 0: text = "".join(self.extractor.workspace["human_spans"][idx]) texts.append(text) sample_labels = [ LabeledSpan( start=0, end=len(text), label=DetectionLabels.HUMAN.value, ).model_dump() ] # Text w/ sampled boundaries else: sample_generations = generations[ prev_sample : prev_sample + sample_boundaries ] # Add a whitespace after generations # to be concatenated with the suffix sample_generations = [ f"{generation} " for generation in sample_generations ] sample_spans = self.extractor.workspace["human_spans"][idx] sample_positions = self.extractor.workspace["positions"][idx] sample_labels = [] added = 0 prev_label_pos = 0 # Interleave the generations in the positions determined # by the extractor, and computes the labeled spans. for i, position in enumerate(sample_positions): # Interleave generation. sample_spans.insert( position + 1 + added, sample_generations[i] ) # The generated span starts just after # the prefix until the position `position`. gen_start = len( "".join(sample_spans[: position + 1 + added]) ) gen_end = gen_start + len(sample_generations[i]) # The human span starts from the previous generated span # if exists (prev_label_pos != -1). human_span = LabeledSpan( start=prev_label_pos, end=gen_start, label=DetectionLabels.HUMAN.value, ).model_dump() gen_span = LabeledSpan( start=gen_start, end=gen_end, label=DetectionLabels.GENERATED.value, ).model_dump() sample_labels.append(human_span) sample_labels.append(gen_span) added += 1 prev_label_pos = gen_end text = "".join(sample_spans) # Fill the labels if them do not cover all the text yet. # The last span is always human by construction. if int(sample_labels[-1]["end"]) < len(text): sample_labels.append( LabeledSpan( start=sample_labels[-1]["end"], end=len(text), label=DetectionLabels.HUMAN.value, ).model_dump() ) texts.append(text) prev_sample += sample_boundaries labels.append(sample_labels) return texts, labels def _pack(self, generations: List[str], **kwargs) -> Dataset: """ Combines and labels the generated and human texts when using gap-based extractors (`sentence_gap`, `word_gap`, etc.) Args: generations (List[str]): list of generated texts. kwargs: additional keyword arguments. Returns: Dataset: a dataset including all the texts. """ prompted_dataset = kwargs.get("prompted_dataset", None) if prompted_dataset is None: raise DatasetGenerationError(f"prompted_dataset not found: {self}") model_name = self.config.model.model_name domain = self.config.input.domain extractor_name = self.config.input.extractor texts, labels = self._build_samples(generations) prev_sample = 0 mixed_samples = [] for idx, (text, sample_labels) in enumerate(zip(texts, labels)): sample_boundaries = self.extractor.workspace["num_boundaries"][idx] prompt = prompted_dataset.prompted_texts[ prev_sample : prev_sample + sample_boundaries ] or [Placeholders.NO_PROMPT.value] mixed_samples.append( { "prompt": prompt, "text": text, "label": sample_labels, "model": model_name, "domain": domain, "extractor": extractor_name, } ) prev_sample += self.extractor.workspace["num_boundaries"][idx] mixed_dataset = Dataset.from_list(mixed_samples) human_dataset = Dataset.from_list( [ { "prompt": [Placeholders.NO_PROMPT.value], "text": text, "label": [ LabeledSpan( start=0, end=len(text), label=DetectionLabels.HUMAN.value, ).model_dump() ], "model": DetectionLabels.HUMAN.value, "domain": domain, "extractor": Placeholders.NO_EXTRACTOR.value, } for text in prompted_dataset.human_texts ] ) dataset = concatenate_datasets([human_dataset, mixed_dataset]) dataset = dataset.shuffle() return dataset