Source code for text_machina.src.generators.base

from abc import ABC, abstractmethod
from typing import Dict, List, Tuple

from datasets import Dataset

from ..common.logging import get_logger
from ..config import Config
from ..constrainers import get_length_constrainer
from ..data import PromptedDatasetBuilder
from ..models import get_model

_logger = get_logger(__name__)


[docs]class DatasetGenerator(ABC): """ Base class for dataset generators. """ def __init__(self, config: Config) -> None: self.config = config self.model = get_model(self.config.model) self.prompter = PromptedDatasetBuilder(self.config)
[docs] def generate(self) -> Dataset: """ Generates a labeled dataset based on the provided config. Returns: Dataset: the dataset """ generations, kwargs = self._generate() dataset = self._pack(generations, **kwargs) dataset = self.add_config_info(dataset) return dataset
def _generate(self) -> Tuple[List[str], Dict]: """ Generates a dataset based on the provided config. Returns: Tuple[List[str], Dict]: a tuple of the generated texts and additional arguments to use for dataset packing. """ # prepare inputs prompted_dataset = self.prompter.build() _logger.info( f"This is how one input looks like: {prompted_dataset.prompted_texts[0]}" ) # instantiate length constrainer length_constrainer = get_length_constrainer( texts=prompted_dataset.human_texts, model_name=self.config.model.model_name, provider=self.config.model.provider, ) # constrain generation config generation_config = length_constrainer.constrain(self.config.generation) _logger.info( f"Generating completions for with args:\n" f"Model: {self.config.model}.\n" f"Args: {generation_config}" ) # run generator generations = self.model.generate_completions( prompts=prompted_dataset.prompted_texts, generation_config=generation_config, ) return generations, {"prompted_dataset": prompted_dataset} @abstractmethod def _pack(self, generations: List[str], **kwargs) -> Dataset: """ Builds a dataset by packing the texts accordingly to the task. Args: generations (List[str]): the generated texts. **kwargs: additional keyword arguments. Returns: Dataset: the final labeled dataset. """ ...
[docs] def add_config_info(self, dataset: Dataset) -> Dataset: """ Adds config information to the dataset. Args: dataset (Dataset): the dataset to add config information to. Returns: Dataset: the dataset with config information added. """ dataset = dataset.add_column( "config_path", [str(self.config.path)] * len(dataset), ) dataset = dataset.add_column( "language", [str(self.config.input.language)] * len(dataset), ) return dataset