Source code for text_machina.src.models.bedrock

import json
import os
from typing import Any, Dict

import boto3
from botocore import exceptions as boto_exceptions
from botocore.config import Config as BotoConfig

from ..common.logging import get_logger
from ..common.utils import get_instantiation_args
from ..config import ModelConfig
from .base import TextGenerationModel
from .types import GENERATION_ERROR

_logger = get_logger(__name__)


[docs]class BedrockModel(TextGenerationModel): """ Generates completions using AWS Bedrock models. Requires the definition of the `AWS_ACCESS_KEY_ID=<key>` and `AWS_SECRET_ACCESS_KEY=<key>` environment variables. """ def __init__(self, model_config: ModelConfig): super().__init__(model_config) client_config = BotoConfig( **get_instantiation_args( BotoConfig.__init__, self.model_config.model_dump(), accepted_params=list(BotoConfig.OPTION_DEFAULTS.keys()), ) ) self.client = boto3.client( service_name="bedrock-runtime", aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], config=client_config, )
[docs] def generate_completion( self, prompt: str, generation_config: Dict[str, Any], ) -> str: request_body = self.get_request_body(prompt, generation_config) try: response = self.client.invoke_model( body=request_body, modelId=self.model_config.model_name, accept="application/json", contentType="application/json", ) response_body = json.loads(response["body"].read()) completion = self.get_completion_from_response_body(response_body) except boto_exceptions.ClientError as ce: error_msg = ce.response["Error"]["Code"] _logger.info( f"Unrecoverable exception during the request: {error_msg}" ) return GENERATION_ERROR return completion
[docs] def get_request_body( self, prompt: str, generation_config: Dict[str, Any] ) -> str: """ Prepares the request body for a request to a bedrock model. Considers the different parameters that each model provider accepts. Args: prompt (str): the prompt to use for generating text. generationc_config (Dict[str, Any]): the generation config. Returns: Dict: a serializable provider-specific request body. """ bedrock_provider = self.model_config.model_name.split(".")[0] assert bedrock_provider in { "ai21", "amazon", "anthropic", "cohere", "meta", "mistral", } if bedrock_provider == "amazon": request_body = { "inputText": prompt, "textGenerationConfig": generation_config, } elif bedrock_provider == "ai21": if "maxTokenCount" in generation_config: generation_config["maxTokens"] = generation_config.pop( "maxTokenCount" ) request_body = {"prompt": prompt, **generation_config} elif bedrock_provider == "anthropic": if "maxTokenCount" in generation_config: generation_config["max_tokens_to_sample"] = ( generation_config.pop("maxTokenCount") ) request_body = {"prompt": prompt, **generation_config} elif bedrock_provider == "cohere": # length constrainers work directly on providers themselves # so instead we overwrite the key name here if "maxTokenCount" in generation_config: generation_config["max_tokens"] = generation_config.pop( "maxTokenCount" ) request_body = {"prompt": prompt, **generation_config} elif bedrock_provider == "meta": if "maxTokenCount" in generation_config: generation_config["max_gen_len"] = generation_config.pop( "maxTokenCount" ) request_body = {"prompt": prompt, **generation_config} elif bedrock_provider == "mistral": if "maxTokenCount" in generation_config: generation_config["max_tokens"] = generation_config.pop( "maxTokenCount" ) request_body = {"prompt": prompt, **generation_config} return json.dumps(request_body)
[docs] def get_completion_from_response_body(self, response_body: Dict) -> str: """ Obtains the completions from a response body returned by a bedrock model. Considers the different API schemas that each model provider uses. Args: response_body (Dict): the body returned by models in bedrock. Returns: str: the completion of the model extracted from the body. """ bedrock_provider = self.model_config.model_name.split(".")[0] assert bedrock_provider in { "ai21", "amazon", "anthropic", "cohere", "meta", "mistral", } if bedrock_provider == "amazon": completion = response_body["results"][0]["outputText"] elif bedrock_provider == "ai21": completion = response_body["completions"][0]["data"]["text"] elif bedrock_provider == "anthropic": completion = response_body["completion"] elif bedrock_provider == "cohere": completion = response_body["generations"][0]["text"] elif bedrock_provider == "meta": completion = response_body["generation"] elif bedrock_provider == "mistral": completion = response_body["outputs"][0]["text"] return completion