[docs]classImageRetriever:""" A superclient to retrieve images similar to a query image or text. Args: client: A SwitchAI client that supports text and image embedding. images_folder_path: The path to the folder containing the images. embeddings_cache_path: The path to the embeddings cache file, else a file named 'embeddings_cache.json' will be created in the images folder. batch_size: The batch size to use when embedding images. """def__init__(self,client:SwitchAI,images_folder_path:str,embeddings_cache_path:Optional[str]=None,batch_size:Optional[int]=32,):ifTask.IMAGE_TEXT_TO_EMBEDDINGnotinclient.supported_tasks:raiseValueError("ImageRetriever requires a text and image embedding model.")self.client=clientifembeddings_cache_pathisNone:embeddings_cache_path=f"{images_folder_path}/embeddings_cache_{client.provider}_{client.model_name}.json"embeddings_cache_path=Path(embeddings_cache_path)# Load embeddings from cache fileself.embeddings={}ifembeddings_cache_path.exists():self.embeddings=json.loads(embeddings_cache_path.read_text())# Determine which images need to be embeddedimages_to_embed=[]forimage_pathinPath(images_folder_path).glob("*.[pjPJ][pnNP][gG]"):ifimage_path.namenotinself.embeddings:images_to_embed.append(image_path)# Embed the imagesforiinrange(0,len(images_to_embed),batch_size):batch=images_to_embed[i:i+batch_size]pil_images=[Image.open(image_path)forimage_pathinbatch]batch_embeddings=self.client.embed(pil_images).embeddingsforembeddinginbatch_embeddings:image_path=batch[embedding.index]self.embeddings[image_path.name]=embedding.data# Save the embeddings to the cache fileembeddings_cache_path.write_text(json.dumps(self.embeddings))
[docs]defretrieve_images(self,query:Union[str,Image.Image],similarity_metric:str="cosine",threshold:float=0.5)->Dict[str,float]:""" Retrieve images similar to the query image or text. Args: query: The query image or text. similarity_metric: The similarity metric to use. Must be 'cosine' or 'euclidean'. threshold: The similarity threshold. Returns: A sorted dictionary containing the image filenames as keys and the similarity scores as values. """ifsimilarity_metric=="cosine":similarity_method=self._cosine_similarityelifsimilarity_metric=="euclidean":similarity_method=self._euclidean_distanceelse:raiseValueError("Similarity metric must be 'cosine' or 'euclidean'.")query_embedding=self.client.embed(query).embeddings[0].dataresults={}forimage_path,image_embeddinginself.embeddings.items():similarity=similarity_method(query_embedding,image_embedding)ifsimilarity>=threshold:results[image_path]=float(similarity)results=dict(sorted(results.items(),key=lambdaitem:item[1],reverse=True))returnresults