Source code for text_machina.src.models.cohere
import os
from typing import Dict
from cohere import Client
from ..common.logging import get_logger
from ..common.utils import get_instantiation_args
from ..config import ModelConfig
from .base import TextGenerationModel
from .types import GENERATION_ERROR
_logger = get_logger(__name__)
[docs]class CohereModel(TextGenerationModel):
"""
Generates completions using Cohere models.
Requires the definition of the `COHERE_API_KEY=<key>` environment variable.
"""
def __init__(self, model_config: ModelConfig):
super().__init__(model_config)
self.client = Client(
api_key=os.environ["COHERE_API_KEY"],
**get_instantiation_args(
Client.__init__, self.model_config.model_dump()
),
)
[docs] def generate_completion(
self,
prompt: str,
generation_config: Dict,
) -> str:
try:
completion = (
self.client.generate(
model=self.model_config.model_name,
prompt=prompt,
**generation_config,
)
.generations[0]
.text
)
except Exception as e:
_logger.info(f"Unrecoverable exception during the request: {e}")
return GENERATION_ERROR
return completion