Source code for text_machina.src.constrainers.base
from abc import ABC, abstractmethod
from typing import Any, Dict
[docs]class Constrainer(ABC):
"""
Base class for constrainers.
A constrainer is any kind of class that infers something from a dataset
and constrains the generation parameters according to that. For instance,
length constrainers, that automatically infer the length and return maximum
or minimum number of tokens accordingly.
"""
def __init__(self):
pass
[docs] @abstractmethod
def get_constraints(self) -> Dict[str, Any]:
"""
Method that return parameters with values to constrain later the
generation parameters.
Example:
output: {"max_tokens": 137, "min_tokens": 32}
Returns:
Dict[str, Any]: values to constrain generation parameters.
"""
...
[docs] @abstractmethod
def estimate(self) -> Any:
"""
Method to estimate values that will be used in `constrain`.
Returns:
Any: any kind of value.
"""
...
[docs] def constrain(self, generation_config: Dict[str, Any]) -> Dict[str, Any]:
"""
Constrains a generation_config.
Args:
generation_config (Dict[str, Any]): a generation config.
Returns:
Dict[str, Any]: constrained generation config.
"""
return {
**self.get_constraints(),
**generation_config,
}