使用自定义LLM和Embedding模型部署Vanna:基于RAG的Text-to-SQL生成
说明:
- 首次发表日期:2024-07-12
- Vanna Github地址: https://github.com/vanna-ai/vanna
- Vanna官方文档: https://vanna.ai/
部署Vanna时我们可以选择使用什么大模型和向量数据库,比如OPEN AI和ChromaDB等这些官方支持的。
但是存在一个问题,为了保证数据不存在泄露风险,部署自己的大模型服务比较安全。
Vanna官方文档中说明可以使用自定义大模型的,不过没有给出具体的例子,本文提供一个例子以供参考。
继承VannaBase
,并调用自己的大模型实现接口
一般我们的大模型服务,不过是第三方的还是自己部署的,大多都有提供和OPEN AI兼容的接口;所以,我们只需要复制一下Vanna提供的OpenAI_Chat类,进行少量修改,使其可以调用自定义模型即可,代码如下:
class OpenAICompatibleLLM(VannaBase):def __init__(self, client=None, config=None):VannaBase.__init__(self, config=config)# default parameters - can be overrided using configself.temperature = 0.5self.max_tokens = 500if "temperature" in config:self.temperature = config["temperature"]if "max_tokens" in config:self.max_tokens = config["max_tokens"]if "api_type" in config:raise Exception("Passing api_type is now deprecated. Please pass an OpenAI client instead.")if "api_version" in config:raise Exception("Passing api_version is now deprecated. Please pass an OpenAI client instead.")if client is not None:self.client = clientreturnif "api_base" not in config:raise Exception("Please passing api_base")if "api_key" not in config:raise Exception("Please passing api_key")self.client = OpenAI(api_key=config["api_key"], base_url=config["api_base"])def system_message(self, message: str) -> any:return {"role": "system", "content": message}def user_message(self, message: str) -> any:return {"role": "user", "content": message}def assistant_message(self, message: str) -> any:return {"role": "assistant", "content": message}def submit_prompt(self, prompt, **kwargs) -> str:if prompt is None:raise Exception("Prompt is None")if len(prompt) == 0:raise Exception("Prompt is empty")num_tokens = 0for message in prompt:num_tokens += len(message["content"]) / 4if kwargs.get("model", None) is not None:model = kwargs.get("model", None)print(f"Using model {model} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(model=model,messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)elif kwargs.get("engine", None) is not None:engine = kwargs.get("engine", None)print(f"Using model {engine} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(engine=engine,messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)elif self.config is not None and "engine" in self.config:print(f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(engine=self.config["engine"],messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)elif self.config is not None and "model" in self.config:print(f"Using model {self.config['model']} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(model=self.config["model"],messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)else:if num_tokens > 3500:model = "kimi"else:model = "doubao"print(f"Using model {model} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(model=model,messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)for choice in response.choices:if "text" in choice:return choice.textreturn response.choices[0].message.content
继承Qdrant_VectorStore
类并使用自己的Embedding服务
class CustomQdrant_VectorStore(Qdrant_VectorStore):def __init__(self,config={}):self.embedding_model_name = config.get("embedding_model_name", "beg-m3")self.embedding_api_base = config.get("embedding_api_base", "https://xxxxxxxxxxx.com")self.embedding_api_key = config.get("embedding_api_key", "sk-xxxxxxxxxxxxxxx")super().__init__(config)def generate_embedding(self, data: str, **kwargs) -> List[float]:def _get_error_string(response: requests.Response) -> str:try:if response.content:return response.json()["detail"]except Exception:passtry:response.raise_for_status()except requests.HTTPError as e:return str(e)return "Unknown error"request_body = {"model": self.embedding_model_name,"input": data,}request_body.update(kwargs)response = requests.post(url=f"{self.embedding_api_base}/v1/embeddings",json=request_body,headers={"Authorization": f"Bearer {self.embedding_api_key}"},)if response.status_code != 200:raise RuntimeError(f"Failed to create the embeddings, detail: {_get_error_string(response)}")result = response.json()embeddings = [d["embedding"] for d in result["data"]]return embeddings[0]
启动服务
- 定义一个
CustomVanna
类,继承CustomQdrant_VectorStore
和OpenAICompatibleLLM
类 - 构建一个
CustomVanna
,在其中指定自己的大模型服务和Embedding服务的参数 - 链接数据库,比如mysql
- 启动服务
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):def __init__(self, llm_config=None, vector_store_config=None):CustomQdrant_VectorStore.__init__(self, config=vector_store_config)OpenAICompatibleLLM.__init__(self, config=llm_config)vn = CustomVanna(vector_store_config={"client": QdrantClient(host="xxxxx", port=6333)},llm_config={"api_key": "sk-xxxxxxxxxxxx","api_base": "https://xxxxxxxxxxxxxxxxxx/v1","model": "xxxxxxx",},
)vn.connect_to_mysql(host='xxxxx', dbname='xxx', user='xxx', password='xxx', port=3306)from vanna.flask import VannaFlaskApp
app = VannaFlaskApp(vn)
app.run()