Source code for text_machina.src.models.vertex
import os
from typing import Dict
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.preview.language_models import ChatModel as VertexChatModel
from vertexai.preview.language_models import (
TextGenerationModel as VertexTextGenerationModel,
)
from ..common.logging import get_logger
from ..common.utils import get_instantiation_args
from ..config import ModelConfig
from .base import TextGenerationModel
from .types import GENERATION_ERROR, CompletionType
_logger = get_logger(__name__)
[docs]class VertexModel(TextGenerationModel):
"""
Generates completions using VertexAI models.
Requires the definition of the `VERTEX_AI_CREDENTIALS_FILE=<path>` environment variable.
"""
def __init__(self, model_config: ModelConfig):
super().__init__(model_config)
credentials = service_account.Credentials.from_service_account_file(
os.environ["VERTEX_AI_CREDENTIALS_FILE"]
)
aiplatform.init(
credentials=credentials,
**get_instantiation_args(
aiplatform.init, self.model_config.model_dump()
),
)
self.model = self._get_model()
[docs] def generate_completion(
self,
prompt: str,
generation_config: Dict,
) -> str:
completion_fn = (
self._chat_request
if self.model_config.api_type == CompletionType.CHAT
else self._completion_request
)
try:
completion = completion_fn(prompt, generation_config)
except Exception as e:
_logger.info(f"Unrecoverable exception during the request: {e}")
return GENERATION_ERROR
return completion
def _get_model(self):
model_class = (
VertexChatModel
if self.model_config.api_type == CompletionType.CHAT
else VertexTextGenerationModel
)
return model_class.from_pretrained(self.model_config.model_name)
def _chat_request(self, prompt: str, generation_config: Dict) -> str:
return (
self.model.start_chat()
.send_message(prompt, **generation_config)
.text
)
def _completion_request(self, prompt: str, generation_config: Dict) -> str:
return self.model.predict(prompt, **generation_config).text