[docs]classClassifier:""" A superclient that extends a chat SwitchAI client to support classification tasks. It can be used to classify text or images. Args: client: A SwitchAI client initialized with a chat model. classes: The classes to classify the data into. task_description: A description of the classification task. multi_label: Whether the classifier should support multi-label classification or single-label classification. """def__init__(self,client:SwitchAI,classes:List[str],task_description:str=None,multi_label:bool=False):if(Task.TEXT_GENERATIONnotinclient.supported_tasksandTask.IMAGE_TEXT_TO_TEXTnotinclient.supported_tasks):raiseValueError("The classifier client only supports chat models.")self.client=clientself.task_description=task_descriptionClassesType=Enum("ClassesType",{class_:class_forclass_inclasses},type=str)ifmulti_label:self.ClassificationResult=type("ClassificationResult",(BaseModel,),{"__annotations__":{"class_name":List[ClassesType]},"class_name":Field(...),"__module__":__name__,},)else:self.ClassificationResult=type("ClassificationResult",(BaseModel,),{"__annotations__":{"class_name":ClassesType},"class_name":Field(...),"__module__":__name__,},)
[docs]defclassify(self,data:Union[str,Image,List[Union[str,Image]]])->Union[str,List[str]]:""" Classifies the given data. Args: data: The data to classify. Returns: The classification result(s). """ifisinstance(data,list):return[self._classify_single(item)foritemindata]returnself._classify_single(data)
def_classify_single(self,data:Union[str,Image])->str:messages=self._create_messages(data)response=self.client.chat(messages=messages,response_format=self.ClassificationResult)returnself._parse_response(response)def_create_messages(self,data:Union[str,Image])->List[dict]:messages=[]ifself.task_description:messages=[{"role":"system","content":f"Your task is to classify data.\nTask description: {self.task_description}",}]ifisinstance(data,str):messages.append({"role":"user","content":data})returnmessageselifisinstance(data,Image):messages.append({"role":"user","content":[{"type":"image","image":data}]})returnmessageselse:raiseValueError("Unsupported data type for classification")def_parse_response(self,response:dict)->str:try:returnjson.loads(response.message.content)["class_name"]except(KeyError,IndexError,json.JSONDecodeError)ase:raiseValueError("Invalid response format")frome