| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | """BotClient class for interacting with bot models.""" |
| |
|
| | import os |
| | import argparse |
| | import json |
| | import logging |
| | import traceback |
| |
|
| | import jieba |
| | import requests |
| | from openai import OpenAI |
| |
|
| |
|
| | class BotClient: |
| | """Client for interacting with various AI models.""" |
| |
|
| | def __init__(self, args: argparse.Namespace): |
| | """ |
| | Initializes the BotClient instance by configuring essential parameters from command line arguments |
| | including retry limits, character constraints, model endpoints and API credentials while setting up |
| | default values for missing arguments to ensure robust operation. |
| | |
| | Args: |
| | args (argparse.Namespace): Command line arguments containing configuration parameters. |
| | Uses getattr() to safely retrieve values with fallback defaults. |
| | """ |
| | self.logger = logging.getLogger(__name__) |
| |
|
| | self.max_retry_num = getattr(args, "max_retry_num", 3) |
| | self.max_char = getattr(args, "max_char", 8000) |
| |
|
| | self.model_map = getattr(args, "model_map", {}) |
| | self.api_key = os.environ.get("API_KEY") |
| |
|
| | self.embedding_service_url = getattr( |
| | args, "embedding_service_url", "embedding_service_url" |
| | ) |
| | self.embedding_model = getattr(args, "embedding_model", "embedding_model") |
| |
|
| | self.web_search_service_url = getattr( |
| | args, "web_search_service_url", "web_search_service_url" |
| | ) |
| | self.max_search_results_num = getattr(args, "max_search_results_num", 15) |
| |
|
| | self.qianfan_api_key = os.environ.get("API_KEY") |
| |
|
| | def call_back(self, host_url: str, req_data: dict) -> dict: |
| | """ |
| | Executes an HTTP request to the specified endpoint using the OpenAI client, handles the response |
| | conversion to a compatible dictionary format, and manages any exceptions that may occur during |
| | the request process while logging errors appropriately. |
| | |
| | Args: |
| | host_url (str): The URL to send the request to. |
| | req_data (dict): The data to send in the request body. |
| | |
| | Returns: |
| | dict: Parsed JSON response from the server. Returns empty dict |
| | if request fails or response is invalid. |
| | """ |
| | try: |
| | client = OpenAI(base_url=host_url, api_key=self.api_key) |
| | response = client.chat.completions.create(**req_data) |
| |
|
| | |
| | return response.model_dump() |
| |
|
| | except Exception as e: |
| | self.logger.error(f"Stream request failed: {e}") |
| | raise |
| |
|
| | def call_back_stream(self, host_url: str, req_data: dict) -> dict: |
| | """ |
| | Makes a streaming HTTP request to the specified host URL using the OpenAI client and yields response chunks |
| | in real-time while handling any exceptions that may occur during the streaming process. |
| | |
| | Args: |
| | host_url (str): The URL to send the request to. |
| | req_data (dict): The data to send in the request body. |
| | |
| | Returns: |
| | generator: Generator that yields parsed JSON responses from the server. |
| | """ |
| | try: |
| | client = OpenAI(base_url=host_url, api_key=self.api_key) |
| | response = client.chat.completions.create( |
| | **req_data, |
| | stream=True, |
| | ) |
| | for chunk in response: |
| | if not chunk.choices: |
| | continue |
| |
|
| | |
| | yield chunk.model_dump() |
| |
|
| | except Exception as e: |
| | self.logger.error(f"Stream request failed: {e}") |
| | raise |
| |
|
| | def process( |
| | self, |
| | model_name: str, |
| | req_data: dict, |
| | max_tokens: int = 2048, |
| | temperature: float = 1.0, |
| | top_p: float = 0.7, |
| | ) -> dict: |
| | """ |
| | Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters |
| | including token limits and sampling settings, truncating messages to fit character limits, making API calls |
| | with built-in retry mechanism, and logging the full request/response cycle for debugging purposes. |
| | |
| | Args: |
| | model_name (str): Name of the model, used to look up the model URL from model_map. |
| | req_data (dict): Dictionary containing request data, including information to be processed. |
| | max_tokens (int): Maximum number of tokens to generate. |
| | temperature (float): Sampling temperature to control the diversity of generated text. |
| | top_p (float): Cumulative probability threshold to control the diversity of generated text. |
| | |
| | Returns: |
| | dict: Dictionary containing the model's processing results. |
| | """ |
| | model_url = self.model_map[model_name] |
| |
|
| | req_data["model"] = model_name |
| | req_data["max_tokens"] = max_tokens |
| | req_data["temperature"] = temperature |
| | req_data["top_p"] = top_p |
| | req_data["messages"] = self.truncate_messages(req_data["messages"]) |
| | for _ in range(self.max_retry_num): |
| | try: |
| | self.logger.info(f"[MODEL] {model_url}") |
| | self.logger.info("[req_data]====>") |
| | self.logger.info(json.dumps(req_data, ensure_ascii=False)) |
| | res = self.call_back(model_url, req_data) |
| | self.logger.info("model response") |
| | self.logger.info(res) |
| | self.logger.info("-" * 30) |
| | except Exception as e: |
| | self.logger.info(e) |
| | self.logger.info(traceback.format_exc()) |
| | res = {} |
| | if len(res) != 0 and "error" not in res: |
| | break |
| |
|
| | return res |
| |
|
| | def process_stream( |
| | self, |
| | model_name: str, |
| | req_data: dict, |
| | max_tokens: int = 2048, |
| | temperature: float = 1.0, |
| | top_p: float = 0.7, |
| | ) -> dict: |
| | """ |
| | Processes streaming requests by mapping the model name to its endpoint, configuring request parameters, |
| | implementing a retry mechanism with logging, and streaming back response chunks in real-time while |
| | handling any errors that may occur during the streaming session. |
| | |
| | Args: |
| | model_name (str): Name of the model, used to look up the model URL from model_map. |
| | req_data (dict): Dictionary containing request data, including information to be processed. |
| | max_tokens (int): Maximum number of tokens to generate. |
| | temperature (float): Sampling temperature to control the diversity of generated text. |
| | top_p (float): Cumulative probability threshold to control the diversity of generated text. |
| | |
| | Yields: |
| | dict: Dictionary containing the model's processing results. |
| | """ |
| | model_url = self.model_map[model_name] |
| | req_data["model"] = model_name |
| | req_data["max_tokens"] = max_tokens |
| | req_data["temperature"] = temperature |
| | req_data["top_p"] = top_p |
| | req_data["messages"] = self.truncate_messages(req_data["messages"]) |
| |
|
| | last_error = None |
| | for _ in range(self.max_retry_num): |
| | try: |
| | self.logger.info(f"[MODEL] {model_url}") |
| | self.logger.info("[req_data]====>") |
| | self.logger.info(json.dumps(req_data, ensure_ascii=False)) |
| |
|
| | yield from self.call_back_stream(model_url, req_data) |
| | return |
| |
|
| | except Exception as e: |
| | last_error = e |
| | self.logger.error( |
| | f"Stream request failed (attempt {_ + 1}/{self.max_retry_num}): {e}" |
| | ) |
| |
|
| | self.logger.error("All retry attempts failed for stream request") |
| | yield {"error": str(last_error)} |
| |
|
| | def cut_chinese_english(self, text: str) -> list: |
| | """ |
| | Segments mixed Chinese and English text into individual components using Jieba for Chinese words |
| | while preserving English words as whole units, with special handling for Unicode character ranges |
| | to distinguish between the two languages. |
| | |
| | Args: |
| | text (str): Input string to be segmented. |
| | |
| | Returns: |
| | list: A list of segments, where each segment is either a letter or a word. |
| | """ |
| | words = jieba.lcut(text) |
| | en_ch_words = [] |
| |
|
| | for word in words: |
| | if word.isalpha() and not any( |
| | "\u4e00" <= char <= "\u9fff" for char in word |
| | ): |
| | en_ch_words.append(word) |
| | else: |
| | en_ch_words.extend(list(word)) |
| | return en_ch_words |
| |
|
| | def truncate_messages(self, messages: list[dict]) -> list: |
| | """ |
| | Truncates conversation messages to fit within the maximum character limit (self.max_char) |
| | by intelligently removing content while preserving message structure. The truncation follows |
| | a prioritized order: historical messages first, then system message, and finally the last message. |
| | |
| | Args: |
| | messages (list[dict]): List of messages to be truncated. |
| | |
| | Returns: |
| | list[dict]: Modified list of messages after truncation. |
| | """ |
| | if not messages: |
| | return messages |
| |
|
| | processed = [] |
| | total_units = 0 |
| |
|
| | for msg in messages: |
| | |
| | if isinstance(msg["content"], str): |
| | text_content = msg["content"] |
| | elif isinstance(msg["content"], list): |
| | text_content = msg["content"][1]["text"] |
| | else: |
| | text_content = "" |
| |
|
| | |
| | units = self.cut_chinese_english(text_content) |
| | unit_count = len(units) |
| |
|
| | processed.append( |
| | { |
| | "role": msg["role"], |
| | "original_content": msg["content"], |
| | "text_content": text_content, |
| | "units": units, |
| | "unit_count": unit_count, |
| | } |
| | ) |
| | total_units += unit_count |
| |
|
| | if total_units <= self.max_char: |
| | return messages |
| |
|
| | |
| | to_remove = total_units - self.max_char |
| |
|
| | |
| | for i in range(len(processed) - 1, 1): |
| | if to_remove <= 0: |
| | break |
| |
|
| | |
| | if processed[i]["unit_count"] <= to_remove: |
| | processed[i]["text_content"] = "" |
| | to_remove -= processed[i]["unit_count"] |
| | if isinstance(processed[i]["original_content"], str): |
| | processed[i]["original_content"] = "" |
| | elif isinstance(processed[i]["original_content"], list): |
| | processed[i]["original_content"][1]["text"] = "" |
| | else: |
| | kept_units = processed[i]["units"][:-to_remove] |
| | new_text = "".join(kept_units) |
| | processed[i]["text_content"] = new_text |
| | if isinstance(processed[i]["original_content"], str): |
| | processed[i]["original_content"] = new_text |
| | elif isinstance(processed[i]["original_content"], list): |
| | processed[i]["original_content"][1]["text"] = new_text |
| | to_remove = 0 |
| |
|
| | |
| | if to_remove > 0: |
| | system_msg = processed[0] |
| | if system_msg["unit_count"] <= to_remove: |
| | processed[0]["text_content"] = "" |
| | to_remove -= system_msg["unit_count"] |
| | if isinstance(processed[0]["original_content"], str): |
| | processed[0]["original_content"] = "" |
| | elif isinstance(processed[0]["original_content"], list): |
| | processed[0]["original_content"][1]["text"] = "" |
| | else: |
| | kept_units = system_msg["units"][:-to_remove] |
| | new_text = "".join(kept_units) |
| | processed[0]["text_content"] = new_text |
| | if isinstance(processed[0]["original_content"], str): |
| | processed[0]["original_content"] = new_text |
| | elif isinstance(processed[0]["original_content"], list): |
| | processed[0]["original_content"][1]["text"] = new_text |
| | to_remove = 0 |
| |
|
| | |
| | if to_remove > 0 and len(processed) > 1: |
| | last_msg = processed[-1] |
| | if last_msg["unit_count"] > to_remove: |
| | kept_units = last_msg["units"][:-to_remove] |
| | new_text = "".join(kept_units) |
| | last_msg["text_content"] = new_text |
| | if isinstance(last_msg["original_content"], str): |
| | last_msg["original_content"] = new_text |
| | elif isinstance(last_msg["original_content"], list): |
| | last_msg["original_content"][1]["text"] = new_text |
| | else: |
| | last_msg["text_content"] = "" |
| | if isinstance(last_msg["original_content"], str): |
| | last_msg["original_content"] = "" |
| | elif isinstance(last_msg["original_content"], list): |
| | last_msg["original_content"][1]["text"] = "" |
| |
|
| | result = [] |
| | for msg in processed: |
| | if msg["text_content"]: |
| | result.append({"role": msg["role"], "content": msg["original_content"]}) |
| |
|
| | return result |
| |
|
| | def embed_fn(self, text: str) -> list: |
| | """ |
| | Generate an embedding for the given text using the QianFan API. |
| | |
| | Args: |
| | text (str): The input text to be embedded. |
| | |
| | Returns: |
| | list: A list of floats representing the embedding. |
| | """ |
| | client = OpenAI( |
| | base_url=self.embedding_service_url, api_key=self.qianfan_api_key |
| | ) |
| | response = client.embeddings.create(input=[text], model=self.embedding_model) |
| | return response.data[0].embedding |
| |
|
| | def get_web_search_res(self, query_list: list) -> list: |
| | """ |
| | Send a request to the AI Search service using the provided API key and service URL. |
| | |
| | Args: |
| | query_list (list): List of queries to send to the AI Search service. |
| | |
| | Returns: |
| | list: List of responses from the AI Search service. |
| | """ |
| | headers = { |
| | "Authorization": "Bearer " + self.qianfan_api_key, |
| | "Content-Type": "application/json", |
| | } |
| |
|
| | results = [] |
| | top_k = self.max_search_results_num // len(query_list) |
| | for query in query_list: |
| | payload = { |
| | "messages": [{"role": "user", "content": query}], |
| | "resource_type_filter": [{"type": "web", "top_k": top_k}], |
| | } |
| | response = requests.post( |
| | self.web_search_service_url, headers=headers, json=payload |
| | ) |
| |
|
| | if response.status_code == 200: |
| | response = response.json() |
| | self.logger.info(response) |
| | results.append(response["references"]) |
| | else: |
| | self.logger.info(f"请求失败,状态码: {response.status_code}") |
| | self.logger.info(response.text) |
| | return results |
| |
|