Source code for switchai.main_client

import glob
import importlib
import os
from typing import List, Optional, Union, Generator, Type

from PIL.Image import Image
from pydantic import BaseModel

from .base_client import BaseClient
from .types import ChatResponse, TranscriptionResponse, ImageGenerationResponse, EmbeddingResponse
from .utils import Task, contains_image


[docs] class SwitchAI(BaseClient): """ The SwitchAI client class. Args: provider: The name of the provider to use. model_name: The name of the model to use. api_key: The API key to use, if not set it will be read from the environment variable. Defaults to None. """ def __init__(self, provider: str, model_name: str, api_key: Optional[str] = None): self.provider = provider.lower() self.model_name = model_name self.client, self.supported_tasks = self._get_provider_client(api_key) def _get_provider_client(self, api_key: Optional[str]) -> tuple[BaseClient, str]: # Get all provider files matching the pattern _*.py provider_files = glob.glob(os.path.join(os.path.dirname(__file__), "providers", "_*.py")) provider_modules = [os.path.basename(f)[1:-3] for f in provider_files] provider_modules.remove("_init__") # Check if the specified provider is supported if self.provider not in provider_modules: supported_providers = ", ".join(provider_modules) raise ValueError( f"Provider '{self.provider}' is not supported. Supported providers are: {supported_providers}." ) # Import the provider module provider_module = importlib.import_module(f"switchai.providers._{self.provider}") model_supported = False supported_tasks = None # Check if the model is supported by the specified provider and get the supported tasks if self.model_name in provider_module.SUPPORTED_MODELS: model_supported = True supported_tasks = provider_module.SUPPORTED_MODELS[self.model_name] if not model_supported: # Find alternative providers that support the model alternative_providers = [ provider for provider in provider_modules if self.model_name in importlib.import_module(f"switchai.providers._{provider}").SUPPORTED_MODELS ] if alternative_providers: alternatives = ", ".join(alternative_providers) raise ValueError( f"Model '{self.model_name}' is not supported by provider '{self.provider}'. " f"However, it is supported by: {alternatives}." ) else: raise ValueError(f"Model '{self.model_name}' is not supported by any provider.") # Retrieve the API key from the environment if not provided if self.provider != "ollama": if api_key is None: api_key = os.environ.get(provider_module.API_KEY_NAMING) if api_key is None: raise ValueError( f"The api_key client option must be set either by passing api_key to the client or by setting the {provider_module.API_KEY_NAMING} environment variable." ) # Construct the client class name and get the class from the provider module class_name = f"{self.provider.capitalize()}ClientAdapter" client_class = getattr(provider_module, class_name) # Return an instance of the client class and the model category return client_class(self.model_name, api_key), supported_tasks
[docs] def chat( self, messages: List[str | dict | ChatResponse], temperature: Optional[float] = 1.0, max_tokens: Optional[int] = None, tools: Optional[List] = None, response_format: Optional[Type[BaseModel]] = None, stream: Optional[bool] = False, ) -> Union[ChatResponse, Generator[ChatResponse, None, None]]: if Task.TEXT_GENERATION not in self.supported_tasks and Task.IMAGE_TEXT_TO_TEXT not in self.supported_tasks: raise ValueError(f"Model '{self.model_name}' is not a chat model.") if contains_image(messages): if Task.IMAGE_TEXT_TO_TEXT not in self.supported_tasks: raise ValueError( f"Your request contains an image, but model '{self.model_name}' does not support have that 'vision' capability." ) return self.client.chat(messages, temperature, max_tokens, tools, response_format, stream)
[docs] def embed(self, inputs: Union[str, Image, List[Union[str, Image]]]) -> EmbeddingResponse: if ( Task.TEXT_TO_EMBEDDING not in self.supported_tasks and Task.IMAGE_TEXT_TO_EMBEDDING not in self.supported_tasks ): raise ValueError(f"Model '{self.model_name}' is not an embedding model.") if contains_image(inputs): if Task.IMAGE_TEXT_TO_EMBEDDING not in self.supported_tasks: raise ValueError(f"Model {self.model_name} does not support image embeddings.") return self.client.embed(inputs)
[docs] def transcribe(self, audio_path: str, language: Optional[str] = None) -> TranscriptionResponse: if Task.AUDIO_TO_TEXT not in self.supported_tasks: raise ValueError(f"Model '{self.model_name}' is not a speech-to-text model.") return self.client.transcribe(audio_path, language)
[docs] def generate_image(self, prompt: str, n: int = 1) -> ImageGenerationResponse: if Task.TEXT_TO_IMAGE not in self.supported_tasks: raise ValueError(f"Model '{self.model_name}' is not an image generation model.") return self.client.generate_image(prompt, n)