from fastapi import Body from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) from server.chat.utils import wrap_done from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts import PromptTemplate from langchain.vectorstores import FAISS from langchain.embeddings.huggingface import HuggingFaceEmbeddings from server.knowledge_base.utils import get_vs_path def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"), knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), ): async def knowledge_base_chat_iterator(query: str, knowledge_base_name: str, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( streaming=True, verbose=True, callbacks=[callback], openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], model_name=LLM_MODEL ) vs_path = get_vs_path(knowledge_base_name) embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL], model_kwargs={'device': EMBEDDING_DEVICE}) search_index = FAISS.load_local(vs_path, embeddings) docs = search_index.similarity_search(query, k=4) context = "\n".join([doc.page_content for doc in docs]) prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) chain = LLMChain(prompt=prompt, llm=model) # Begin a task that runs in the background. task = asyncio.create_task(wrap_done( chain.acall({"context": context, "question": query}), callback.done), ) async for token in callback.aiter(): # Use server-sent-events to stream the response yield token await task return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name), media_type="text/event-stream")