Source code for text_machina.src.models.base
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Dict, List
from tqdm import tqdm
from ..config import ModelConfig
[docs]class TextGenerationModel(ABC):
"""
Base class for LLMs.
"""
def __init__(self, model_config: ModelConfig):
self.model_config = deepcopy(model_config)
[docs] @abstractmethod
def generate_completion(
self,
prompt: str,
generation_config: Dict,
) -> str:
"""
Generates a completion for a `prompt` by decoding a model
parameterized by `generation_config`. This method has to be
overwritten to implement the completion code.
Args:
prompts (str): prompt to generate completions for.
generation_config (Dict): Dictionary containing the generation parameters.
Returns:
str: Generated completion or `.types.GENERATION_ERROR` if there was some error.
"""
...
[docs] def generate_completions(
self,
prompts: List[str],
generation_config: Dict,
) -> List[str]:
"""Generates a completion for each prompt in a list of `prompts`.
Args:
prompts (List[str]): List of prompts to generate completions for.
generation_config (Dict): Dictionary containing the generation parameters.
Returns:
List[str]: List of generated completions.
"""
completions, responses = [], []
with ThreadPoolExecutor(
max_workers=min(self.model_config.threads, len(prompts))
) as thread_pool:
for prompt in prompts:
responses.append(
thread_pool.submit(
self.generate_completion, prompt, generation_config
)
)
# Wait completions
completions = [response.result() for response in tqdm(responses)]
return completions