import os
from typing import Dict, Any, Tuple, Optional, List
from dataclasses import dataclass
from mistralai import Mistral
import logging
from codestral_ros2_gen import logger_main
logger = logging.getLogger(f"{logger_main}.{__name__.split('.')[-1]}")
[docs]
@dataclass
class ModelUsage:
"""
Dataclass to represent token usage statistics.
Attributes:
prompt_tokens (int): Number of tokens used in the prompt.
completion_tokens (int): Number of tokens used in the completion.
total_tokens (int): Total number of tokens used.
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
def __str__(self):
return (
f"Model usage (tokens):\n"
f"prompt:\t\t{self.prompt_tokens}\n"
f"completion:\t{self.completion_tokens}\n"
f"total:\t\t{self.total_tokens}"
)
[docs]
class MistralClient:
"""
Client for interacting with Mistral AI API.
Attributes:
DEFAULT_CONFIG (Dict[str, Any]): Default configuration for the Mistral client.
api_key (str): API key for authenticating with the Mistral AI API.
client (Mistral): Mistral client instance.
config (Dict[str, Any]): Configuration for the Mistral client.
"""
DEFAULT_CONFIG = {
"model": {"type": "codestral-latest", "parameters": {"temperature": 0.2}}
}
[docs]
def __init__(
self, api_key: Optional[str] = None, config: Optional[Dict[str, Any]] = None
):
"""
Initialize the MistralClient.
Args:
api_key (Optional[str]): API key for authenticating with the Mistral AI API.
config (Optional[Dict[str, Any]]): Configuration for the Mistral client.
"""
model_config = config.get("model", {}) if config else {}
self.api_key = self._get_api_key(api_key, model_config)
self.client = Mistral(api_key=self.api_key)
self.config = model_config
[docs]
def _get_api_key(self, api_key: Optional[str], config: Dict[str, Any]) -> str:
"""
Get API key from provided sources in order of precedence.
Args:
api_key (Optional[str]): API key provided directly.
config (Dict[str, Any]): Configuration dictionary.
Returns:
str: API key.
Raises:
RuntimeError: If API key is not found in any of the provided sources.
"""
msg = "Create Mistral client with provided API key"
if api_key:
logger.info(msg)
return api_key
env_key = os.getenv("MISTRAL_API_KEY")
if env_key:
logger.info(f"{msg} from the environment variable")
return env_key
config_key = config.get("api_key")
if config_key and config_key != "YOUR_API_KEY_HERE":
logger.info(f"{msg} from config")
return config_key
raise RuntimeError(
"Mistral API key not found. Please provide it through one of:\n"
"1. Direct api_key parameter in MistralClient initialization\n"
"2. MISTRAL_API_KEY environment variable\n"
"3. 'api_key' field in the model config\n"
"Current config template value 'YOUR_API_KEY_HERE' is not valid."
)
[docs]
def complete(
self,
prompt: str,
system_prompt: Optional[str] = None,
model_type: Optional[str] = None,
temperature: Optional[float] = None,
) -> Tuple[str, ModelUsage]:
"""
Get completion from the model.
Args:
prompt (str): Main prompt text.
system_prompt (Optional[str]): Optional override for system prompt.
model_type (Optional[str]): Optional override for model type.
temperature (Optional[float]): Optional override for temperature.
Returns:
Tuple[str, ModelUsage]: Tuple of (generated_text, usage_stats).
Raises:
ValueError: If the prompt is empty.
ConnectionError: If there is a connection error with the Mistral API.
RuntimeError: If there is an API error.
"""
logger.info("Start generating completion from Mistral AI")
logger.debug(f"Prompt:\n{'<'*3}\n{prompt.strip()}\n{'<'*3}")
if not prompt or not prompt.strip():
raise ValueError("Empty prompt provided")
try:
messages = self._prepare_messages(
prompt, system_prompt or self.config.get("system_prompt")
)
logger.debug(f"Prepared messages:\n{messages}\n")
response = self.client.chat.complete(
model=model_type
or self.config.get("type", self.DEFAULT_CONFIG["model"]["type"]),
messages=messages,
temperature=temperature
or self.config.get("parameters", {}).get(
"temperature",
self.DEFAULT_CONFIG["model"]["parameters"]["temperature"],
),
)
if not hasattr(response, "choices") or not response.choices:
raise RuntimeError("Invalid response format from API")
completion = response.choices[0].message.content
if not completion:
raise RuntimeError("Empty completion response from API")
logger.info("Completion generated successfully")
logger.debug(f"Response:\n{'>'*3}\n{completion}\n{'>'*3}")
# Update usage statistics
prompt_tokens = max(0, getattr(response.usage, "prompt_tokens", 0))
completion_tokens = max(0, getattr(response.usage, "completion_tokens", 0))
total_tokens = prompt_tokens + completion_tokens
usage = ModelUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
logger.debug(str(usage))
return completion, usage
except ConnectionError as e:
raise ConnectionError(f"Failed to connect to Mistral API: {str(e)}")
except Exception as e:
error_msg = str(e)
raise RuntimeError(f"API error: {error_msg}")
[docs]
def _prepare_messages(
self, prompt: str, system_prompt: Optional[str] = None
) -> List[Dict[str, str]]:
"""
Prepare messages list for the API call.
Args:
prompt (str): Main prompt text.
system_prompt (Optional[str]): Optional system prompt.
Returns:
List[Dict[str, str]]: List of messages for the API call.
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
return messages