Source code for text_machina.src.models.inference_server

from typing import Any, Dict

import requests
from requests.adapters import HTTPAdapter, Retry

from ..common.exceptions import InvalidInferenceServer
from ..common.logging import get_logger
from ..config import ModelConfig
from .base import TextGenerationModel
from .types import GENERATION_ERROR

_logger = get_logger(__name__)


ALLOWED_INFERENCE_SERVERS = ["vllm", "trt"]


[docs]class InferenceServerModel(TextGenerationModel): """ Generates completions using models deployed on inference servers like TRT or VLLM. This model assumes the default APIs are being used (e.g., no OpenAI-compatible API in VLLM) """ def __init__(self, model_config: ModelConfig): super().__init__(model_config) if self.model_config.inference_server not in ALLOWED_INFERENCE_SERVERS: raise InvalidInferenceServer( self.model_config.inference_server, ALLOWED_INFERENCE_SERVERS ) self.base_url = self.model_config.base_url self.client = requests.Session() retry_adapter = HTTPAdapter( max_retries=Retry( total=getattr(self.model_config, "max_retries", 10), backoff_factor=getattr(self.model_config, "backoff_factor", 2), status_forcelist=[ code for code in requests.status_codes._codes if code != 200 ], ) ) self.client.mount("http://", retry_adapter) self.client.mount("https://", retry_adapter)
[docs] def prepare_data( self, prompt: str, generation_config: Dict[str, Any] ) -> Dict[str, Any]: """ Prepares the json data of a request to the inference server. Args: prompt (str): a prompt. generation_config (Dict[str, Any]): a generation config. Returns: Dict[str, Any]: json data for a request. """ if self.model_config.inference_server == "vllm": return {"text_input": prompt, "parameters": generation_config} else: return {**{"text_input": prompt}, **generation_config}
[docs] def parse_response(self, response: Dict, prompt_len: int) -> str: """ Get the completion from a response, removing the prompt from it if required. Args: response (Dict): response from the inference server. prompt_len (int): len of the prompt. Returns: str: completion. """ if self.model_config.inference_server == "vllm": return response["text_output"][prompt_len:] else: return response["text_output"][prompt_len:]
[docs] def generate_completion( self, prompt: str, generation_config: Dict, ) -> str: json_data = self.prepare_data(prompt, generation_config) try: response = self.client.post(self.base_url, json=json_data).json() completion = self.parse_response(response, len(prompt)) except Exception as e: _logger.info(f"Unrecoverable exception during the request: {e}") return GENERATION_ERROR return completion