Source code for text_machina.src.config
from itertools import chain
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from pydantic import (
BaseModel,
ConfigDict,
Field,
ValidationInfo,
field_validator,
)
from yaml import full_load, safe_load
from .types import TaskType
[docs]class InputConfig(BaseModel):
"""
Wrapper for the input_config field.
"""
quantity: int = Field(gt=0, description="Number of samples to generate.")
domain: str = Field(description="Domain of a dataset.")
dataset: str = Field(description="Name (HF Hub) or path to the dataset.")
dataset_text_column: str = Field(
description="Name of column in the dataset containing the text."
)
dataset_params: Dict[str, Any] = Field(
description="Arguments to load the dataset."
)
template: str = Field(description="Template for the generations.")
extractor: str = Field(description="Extractor name.")
extractors_list: List[str] = Field(
default=[],
description=(
"List of extractors to be used" " with the `combined` extractor."
),
validate_default=True,
)
random_sample_human: bool = Field(
default=False,
desc=(
"Whether to randomly sample human texts or use"
" the same ones used to generate MGT"
),
)
max_input_tokens: int = Field(
default=256,
gt=0,
desc=(
"Maximum token length to be distributed across the"
" prompt inputs extracted with the extractors."
),
)
extractor_args: Dict[str, Dict[str, Any]] = Field(
default={}, desc="Extractors-specific arguments."
)
language: str = Field(
default="en",
desc="Language of the dataset used.",
validate_default=True,
)
[docs] @field_validator("language")
@classmethod
def language_must_be_iso639(cls, language: str) -> str:
import pycountry
allowed_languages = [
lang.alpha_2
for lang in pycountry.languages
if hasattr(lang, "alpha_2")
] + ["multilingual"]
if language not in allowed_languages:
from .common import InvalidLanguage
raise InvalidLanguage()
return language
[docs] @field_validator("extractor")
@classmethod
def extractor_must_exist(cls, extractor: str) -> str:
from .extractors import EXTRACTORS
if extractor not in EXTRACTORS.keys():
from .common import InvalidExtractor
raise InvalidExtractor(extractor)
return extractor
[docs] @field_validator("extractors_list")
@classmethod
def not_empty_list_in_combined(
cls, extractors_list: List[str], info: ValidationInfo
) -> List[str]:
if info.data["extractor"] == "combined" and not extractors_list:
from .common import CombinedEmptyExtractors
raise CombinedEmptyExtractors()
return extractors_list
[docs]class ModelConfig(BaseModel):
"""
Wrapper for the input_config field.
"""
provider: str = Field(description="Provider of text generation models.")
model_name: str = Field(description="Name of a text generation model.")
threads: int = Field(
default=8,
gt=0,
description="Number of threads to use in `generate_completions`",
)
api_type: Literal["CHAT", "COMPLETION"] = Field(
default="COMPLETION",
description=(
"API type for providers that allows chat and completion endpoints."
"This arg must be `CHAT` or `COMPLETION` and must be according to"
"the model used:\n"
"- `CHAT`: for chat completion endpoints.\n"
"- `COMPLETION` for traditional completion endpoints.\n"
"For instance, GPT-4 in OpenAI can only be used with `CHAT`."
),
)
# Allow extra args and avoid protected naming conflicts
model_config = ConfigDict(extra="allow", protected_namespaces=(""))
[docs] @field_validator("provider")
@classmethod
def provider_must_exist(cls, provider: str) -> str:
from .common import InvalidProvider
from .models import MODELS
if provider not in MODELS.keys():
raise InvalidProvider(provider)
return provider
[docs]class Config(BaseModel):
"""
Wrapper for the config.
"""
path: Optional[Path] = None
task_type: TaskType
input: InputConfig
model: ModelConfig
generation: Dict[str, Any]
# Avoid protected naming conflicts
model_config = ConfigDict(protected_namespaces=(""))
[docs] @classmethod
def load_config(
cls: Type["Config"],
path: Path,
task_type: TaskType,
max_generations: Optional[int] = None,
) -> "Config":
config = safe_load(path.open("r"))
if max_generations:
config["input_config"]["quantity"] = min(
max_generations, config["input_config"]["quantity"]
)
input = InputConfig(**config["input_config"])
model = ModelConfig(**config["model_config"])
generation = config["generation_config"]
return cls(
path=path,
task_type=task_type,
input=input,
model=model,
generation=generation,
)
[docs] @classmethod
def load_configs(
cls: Type["Config"],
path: Path,
task_type: TaskType,
max_generations: Optional[int] = None,
) -> List["Config"]:
path_iterator = (
[path]
if path.suffix in {".yml", ".yaml"}
else chain(path.rglob("*.yml"), path.rglob("*.yaml"))
)
if not path_iterator:
raise ValueError(
"The provided path does not contain any YML files", path
)
return [
cls.load_config(
path=p, task_type=task_type, max_generations=max_generations
)
for p in path_iterator
]
[docs] def safe_model_name(self) -> str:
return self.model.model_name.split("/")[-1].replace("_", "-")
[docs] def safe_dataset_name(self) -> str:
return self.input.dataset.split("/")[-1].replace("_", "-")
[docs] def safe_domain_name(self) -> str:
return self.input.domain.replace("/", "-").replace("_", "-")
[docs]def parse_metrics_config(path: Path) -> Tuple[List[str], Dict]:
"""
Parses a metrics config.
Args:
path (Path): the metric config path to parse.
Returns:
Tuple[List[str], Dict]: a tuple of structure (list of metrci names, args).
"""
config = full_load(path.open("r"))
metrics = config["metrics_to_run"]
del config["metrics_to_run"]
return metrics, config