from fastapi import Body from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, VECTOR_SEARCH_TOP_K) from server.chat.utils import wrap_done from server.utils import BaseResponse from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts.chat import ChatPromptTemplate from typing import List, Optional from server.chat.utils import History from server.knowledge_base.kb_service.base import KBService, KBServiceFactory import json def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), history: List[History] = Body([], description="历史对话", examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑"}]] ), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") history = [History(**h) if isinstance(h, dict) else h for h in history] async def knowledge_base_chat_iterator(query: str, kb: KBService, top_k: int, history: Optional[List[History]], ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( streaming=True, verbose=True, callbacks=[callback], openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], model_name=LLM_MODEL ) docs = kb.search_docs(query, top_k) context = "\n".join([doc.page_content for doc in docs]) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. task = asyncio.create_task(wrap_done( chain.acall({"context": context, "question": query}), callback.done), ) source_documents = [ f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" for inum, doc in enumerate(docs) ] async for token in callback.aiter(): # Use server-sent-events to stream the response yield json.dumps({"answer": token, "docs": source_documents}, ensure_ascii=False) await task return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history), media_type="text/event-stream")