Source code for text_machina.src.models.openai

import os
from typing import Dict

from openai import OpenAI

from ..common.logging import get_logger
from ..common.utils import get_instantiation_args
from ..config import ModelConfig
from .base import TextGenerationModel
from .types import GENERATION_ERROR, CompletionType

_logger = get_logger(__name__)


[docs]class OpenAIModel(TextGenerationModel): """ Generates completions using OpenAI models. Requires the definition of the `OPENAI_API_KEY=<key>` environment variable. """ def __init__(self, model_config: ModelConfig): super().__init__(model_config) self.client = OpenAI( api_key=os.environ["OPENAI_API_KEY"], **get_instantiation_args( OpenAI.__init__, self.model_config.model_dump() ), )
[docs] def generate_completion( self, prompt: str, generation_config: Dict, ) -> str: completion_fn = ( self._chat_request if self.model_config.api_type == CompletionType.CHAT else self._completion_request ) try: completion = completion_fn(prompt, generation_config) except Exception as e: _logger.info(f"Unrecoverable exception during the request: {e}") return GENERATION_ERROR return completion
def _chat_request(self, prompt: str, generation_config: Dict) -> str: return ( self.client.chat.completions.create( messages=[ { "role": "user", "content": prompt, } ], model=self.model_config.model_name, **generation_config, ) .choices[0] .message.content ) def _completion_request(self, prompt: str, generation_config: Dict) -> str: return ( self.client.completions.create( prompt=prompt, model=self.model_config.model_name, **generation_config, ) .choices[0] .text )