Source code for text_machina.cli_utils

from logging import FileHandler
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import pandas as pd
import petname
from datasets import Dataset

from .src.common.logging import get_logger
from .src.common.utils import get_cache_path
from .src.config import Config, parse_metrics_config
from .src.data import (
    concatenate,
    domain_model_counts,
    errors_per_model,
    get_save_path,
    serialize_dataset,
)
from .src.exploration import get_explorer
from .src.generators import get_generator
from .src.metrics import run_metrics as _run_metrics
from .src.models.types import GENERATION_ERROR
from .src.postprocessing import filter_by_language, postprocess
from .src.types import TaskType

_logger = get_logger(__name__)


[docs]def generate_from_config( config: Config, save_dir: Path, run_name: str, ) -> Path: """ Generates a dataset using `TextGeneration` parameterized by `config` and saves it. Args: config (Config): a configuration. save_dir (Path): root dir where to save the generated dataset. run_name (str): name of this run. Returns: Path: path where the generated dataset was saved. """ generator = get_generator(config) dataset = generator.generate() errors = count_errors(dataset) _logger.info(f"{errors} errors found in the generated dataset.") dataset = filter_by_language(dataset, config.input.language) output_path = serialize_dataset(dataset, config, save_dir, run_name) return output_path
def _generate( config_path: Path, save_dir: Path, run_name: str, task_type: TaskType, ) -> None: """ Runs the generation pipeline in an end-to-end manner. Args: config_path (Path): path containing YAML config files. save_dir (Path): root dir where to save the generated dataset. run_name (str): name of this run. task_type (TaskType): the type of task. """ configs = Config.load_configs(config_path, task_type) _, statistics = generate_dataset(configs, save_dir, run_name) statistics_dir = save_dir / "statistics" statistics_dir.mkdir(parents=True, exist_ok=True) for name, df in statistics.items(): df.to_markdown(statistics_dir / f"{name}.md") df.to_json(statistics_dir / f"{name}.json")
[docs]def generate_dataset( configs: List[Config], save_dir: Path, run_name: str, ) -> Tuple[Dataset, Dict[str, pd.DataFrame]]: """ Generates a dataset given a list of configs. Only generates a dataset for a config if it hasn't been already generated for this `run_name`. Computes statistics for the generated dataset. Args: configs (List[Config]): list of configs to use for generation. save_dir (Path): root dir where to save the generated dataset. run_name (str): name of this run. Returns: Tuple[Dataset, Dict[str, pd.DataFrame]]: a tuple (dataset, statistics dict) """ cache_path = get_cache_path() paths = [] for config in configs: path = get_save_path(config, cache_path, run_name, check_exists=True) if not path: path = generate_from_config(config, cache_path, run_name) paths.append(path) dataset = concatenate(paths, save_dir) statistics = compute_statistics(dataset) errors = count_errors(dataset) dataset = postprocess(dataset, configs[0].task_type) dataset.save_to_disk(save_dir.as_posix()) _logger.info( f"A total of {errors} errors have been found in the generation process." ) _logger.info(f"Your dataset has been generated at {str(save_dir)}") return dataset, statistics
[docs]def compute_statistics(dataset: Dataset) -> Dict[str, pd.DataFrame]: """ Computes a set of statistics of a generated dataset. Args: dataset (Dataset): the dataset of which statistics are computed. Returns: Dict[str, pd.DataFrame]: the statistics. """ domain_model = domain_model_counts(dataset) errors = errors_per_model(dataset) return {"domain_model": domain_model, "errors": errors}
[docs]def generate_run_name() -> str: """ Generates a name for a run. Returns: str: name of the run. """ return petname.generate()
[docs]def log_final_message(run_name: str) -> None: """ Logs the last logging message of TextMachina. Args: run_name (str): name of the run. """ _logger.info( f"This run has been registered with name: '{run_name}'." " If the run terminated due to errors, you can use this" " name to continue from the where the process left off." ) file_handler = [h for h in _logger.handlers if isinstance(h, FileHandler)][ 0 ] _logger.info(f"Logfile saved in {file_handler.baseFilename}")
[docs]def count_errors(dataset: Dataset) -> int: """ Counts the number of generation errors in the dataset. A generation error is identified as a text being `GENERATION_ERROR`. Args: dataset (Dataset): a dataset. Returns: int: number of errors in the dataset. """ error_count = sum( [x in x for x in dataset["text"] if x == GENERATION_ERROR] ) return error_count
def _explore( config_path: Path, metrics_path: Optional[Path], save_dir: Path, run_name: str, task_type: TaskType, interactive: bool, max_generations: int, ) -> None: """ Carries out the exploration steps: - create a small dataset - generate a set of metrics based on the task type - interactive exploration through the dataset Args: config_path (Path): path containing YAML config files for generation. metrics_path (Optional[Path]): path to YAML config file of metrics. save_dir (Path): root dir where to save the generated dataset. run_name (str): name of this run. task_type (TaskType): the type of task to be explored. interactive(bool): whether to step interactively through the dataset. max_generations (int): the maximum number of texts to generate for each config. Ignored if the dataset has already been generated. """ configs = Config.load_configs( config_path, task_type=task_type, max_generations=max_generations ) dataset, _ = generate_dataset(configs, save_dir, run_name) if metrics_path: metrics, metric_args = parse_metrics_config(metrics_path) # Generate metrics if metrics: if metric_args is None: metric_args = {} _run_metrics(dataset, task_type, save_dir, metrics, metric_args) # Show examples in console if interactive: explorer = get_explorer(task_type, dataset) explorer.step()