parent
a59767711b
commit
9ba2120129
Binary file not shown.
|
|
@ -10,7 +10,7 @@ from langchain.prompts.chat import ChatPromptTemplate
|
|||
from typing import List
|
||||
from server.chat.utils import History
|
||||
from server.utils import get_prompt_template
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
history: List[History] = Body([],
|
||||
|
|
@ -41,11 +41,34 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
callbacks=[callback],
|
||||
)
|
||||
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
# augment_prompt_template = get_prompt_template("data_augment", "default")
|
||||
# input_msg = History(role="user", content=augment_prompt_template).to_msg_template(False)
|
||||
# chat_prompt = ChatPromptTemplate.from_messages(
|
||||
# [i.to_msg_template() for i in history] + [input_msg])
|
||||
# chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
# print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt}")
|
||||
|
||||
# # Begin a task that runs in the background.
|
||||
# task = asyncio.create_task(wrap_done(
|
||||
# chain.acall({ "question": query}),
|
||||
# callback.done),
|
||||
# )
|
||||
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(" 你是一个非常聪明的语义转换专家总能找到同一个语义不同的表达方式,请简洁生成一个与三单引号里的原句子语气语调完全一致,并且语义最相似的新句子,注意不是回答三单引号里的原句子,同时新句子直接使用简体中文给出,而不用重三单引号里的原句子。如果无法给出满足条件的新句子,直接给出三单引号里的原句子,而不是给出三单引号里的原句子的答案 '''{{input}}''' ")
|
||||
|
||||
# prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
# input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
# chat_prompt = ChatPromptTemplate.from_messages(
|
||||
# [i.to_msg_template() for i in history] + [input_msg])
|
||||
chain = LLMChain(prompt=prompt, llm=model)
|
||||
|
||||
name = "John"
|
||||
age = 30
|
||||
text = "你是一个非常聪明的语义转换专家总能找到同一个语义不同的表达方式,请简洁生成一个与三单引号里的原句子语气语调完全一致,并且语义最相似的新句子,注意不是回答三单引号里的原句子,同时新句子直接使用简体中文给出,而不用重三单引号里的原句子。如果无法给出满足条件的新句子,直接给出三单引号里的原句子,而不是给出三单引号里的原句子的答案\n ''' "
|
||||
text = text + query
|
||||
text = text + " ''' "
|
||||
result = ", ".join(["Name: ", name, ", Age: ", str(age)])
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
|
|
|
|||
|
|
@ -14,12 +14,12 @@ import json
|
|||
import os
|
||||
from urllib.parse import urlencode
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
|
|
@ -39,17 +39,30 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
history = [History.from_data(h) for h in history]
|
||||
#weiwei
|
||||
print(f"server/chat/knowledge_base_chat function, history:{history}")
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODEL,
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
#weiwei
|
||||
print(f"knowledge_base_chat_iterator,query:{query},model_name:{model_name},prompt_name:{prompt_name}")
|
||||
|
||||
model1 = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# augment_prompt_template = get_prompt_template("data_augment", "default")
|
||||
# input_msg1 = History(role="user", content=augment_prompt_template).to_msg_template(False)
|
||||
# chat_prompt1 = ChatPromptTemplate.from_messages(
|
||||
# [i.to_msg_template() for i in history] + [input_msg1])
|
||||
# chain1 = LLMChain(prompt=chat_prompt1, llm=model1)
|
||||
# print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt1}")
|
||||
# result = chain1._call({ "question": query})
|
||||
# print(f"chain1._call, result:{result}")
|
||||
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = get_ChatOpenAI(
|
||||
|
|
@ -58,23 +71,36 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
|
||||
#augment_prompt_template = get_prompt_template("data_augment", "default")
|
||||
#input_msg = History(role="user", content=augment_prompt_template).to_msg_template(False)
|
||||
#chat_prompt = ChatPromptTemplate.from_messages(
|
||||
# [i.to_msg_template() for i in history] + [input_msg])
|
||||
#chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
#print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt}")
|
||||
##chain = LLMChain(prompt=PromptTemplate.from_template(augment_prompt_template), llm=model)
|
||||
##print(f"knowledge_base_chat_iterator,prompt_template:{augment_prompt_template}")
|
||||
#task = asyncio.create_task(wrap_done(
|
||||
# chain.acall({ "question": query}),
|
||||
# callback.done),
|
||||
#)
|
||||
#prompt_template = "请找出和{question}最相似的一句话"
|
||||
#llm_chain = LLMChain(prompt=PromptTemplate.from_template(prompt_template), llm=model)
|
||||
#result = llm_chain(query)
|
||||
#print(f"请找出和question 最相似的一句话:{result}")
|
||||
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold, model1)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
#weiwei
|
||||
print(f"knowledge_base_chat_iterator,search docs context:{context}")
|
||||
|
||||
#print(f"knowledge_base_chat_iterator,search docs context:{context}")
|
||||
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
|
||||
#weiwei
|
||||
print(f"knowledge_base_chat_iterator,input_msg:{input_msg}")
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
#weiwei
|
||||
print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}")
|
||||
|
||||
#print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}")
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
|
|
@ -82,10 +108,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
chain.acall({"context": context, "question": query}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
#weiwei
|
||||
print(f"task call end")
|
||||
|
||||
print(f"task call end")
|
||||
source_documents = []
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = os.path.split(doc.metadata["source"])[-1]
|
||||
|
|
@ -93,6 +117,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
url = f"/knowledge_base/download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
print(f"knowledge_base_chat_iterator, stream:{stream}")
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
|
|
|
|||
|
|
@ -15,7 +15,13 @@ from server.knowledge_base.kb_service.base import KBServiceFactory
|
|||
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from typing import List, Dict
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.chains import LLMChain
|
||||
from server.chat.utils import History
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
|
||||
class DocumentWithScore(Document):
|
||||
score: float = None
|
||||
|
|
@ -25,32 +31,82 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
|||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
model:ChatOpenAI = Body(...,description="大语言模型"),
|
||||
) -> List[DocumentWithScore]:
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return []
|
||||
# query = "根据国网安徽信通公司安全准入实施要求," + query
|
||||
|
||||
#history = {}
|
||||
#augment_prompt_template = get_prompt_template("data_augment", "default")
|
||||
#input_msg1 = History(role="user", content=augment_prompt_template).to_msg_template(False)
|
||||
#chat_prompt1 = ChatPromptTemplate.from_messages(
|
||||
# [i.to_msg_template() for i in history] + [input_msg1])
|
||||
#chain1 = LLMChain(prompt=chat_prompt1, llm=model)
|
||||
#print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt1}")
|
||||
#result = chain1._call({ "question": query})
|
||||
#query1 = result["text"]
|
||||
#
|
||||
#print(f"相似的问法:{query1}")
|
||||
#docs = kb.search_docs(query, top_k, score_threshold)
|
||||
#print(f"{query}的相似文档块有{docs}")
|
||||
#data = []
|
||||
#if query1 != query:
|
||||
# docs1 = kb.search_docs(query1, top_k, score_threshold)
|
||||
# print(f"{query1}的相似文档块有{docs1}")
|
||||
# rerank_docs = rerank(docs1,docs,top_k)
|
||||
# print(f"精排后的相似文档块有{rerank_docs}")
|
||||
# data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in rerank_docs]
|
||||
#else:
|
||||
# data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||
|
||||
#print(f"chain1._call, result:{result},similiarit text:{query1}")
|
||||
|
||||
|
||||
pre_doc = kb.search_docs(query, 1, None)
|
||||
print(f"len(pre_doc):{len(pre_doc)}")
|
||||
if len(pre_doc) > 0:
|
||||
print(f"search_docs, len(pre_doc):{len(pre_doc)}")
|
||||
print(f"search_docs, pre_doc:{pre_doc}")
|
||||
filpath = pre_doc[0][0].metadata['source']
|
||||
file_name = os.path.basename(filpath)
|
||||
file_name, file_extension = os.path.splitext(file_name)
|
||||
query = "根据" +file_name + ","+ query
|
||||
|
||||
print(f"search_docs, query:{query}")
|
||||
print(f"search_docs, query:{query}")
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
print(f"search_docs, docs:{docs}")
|
||||
if len(pre_doc) > 0:
|
||||
if docs is not None:
|
||||
docs.append(pre_doc[0])
|
||||
else:
|
||||
docs = pre_doc[0]
|
||||
|
||||
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||
|
||||
|
||||
return data
|
||||
|
||||
def rerank(query_docs:List[Document] = Body(...,description="源query查询相似文档", examples=[]),
|
||||
augment_docs:List[Document] = Body(...,description="增强query查询相似文档", examples=[]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
)-> List[Document]:
|
||||
all_documents = query_docs + augment_docs
|
||||
# unique_documents_dict = {doc[0].page_content: doc for doc in all_documents}
|
||||
|
||||
unique_documents_dict = {}
|
||||
for doc in all_documents:
|
||||
if doc[0].page_content not in unique_documents_dict or doc[1] < unique_documents_dict[doc[0].page_content][1]:
|
||||
unique_documents_dict[doc[0].page_content] = doc
|
||||
|
||||
# 得到去重后的文档列表
|
||||
unique_documents = list(unique_documents_dict.values())
|
||||
|
||||
sorted_documents = sorted(unique_documents_dict.values(),key=lambda doc: doc[1],reverse=False)
|
||||
min_documents = sorted_documents[:top_k]
|
||||
# 打印结果
|
||||
for doc in min_documents:
|
||||
print(f"{doc[0].page_content},doc[1]")
|
||||
|
||||
return min_documents
|
||||
|
||||
def list_files(
|
||||
knowledge_base_name: str
|
||||
|
|
@ -158,6 +214,8 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件,
|
|||
failed_files = {}
|
||||
file_names = list(docs.keys())
|
||||
|
||||
print(f"upload_docs, file_names:{file_names}")
|
||||
|
||||
# 先将上传的文件保存到磁盘
|
||||
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
filename = result["data"]["file_name"]
|
||||
|
|
@ -167,7 +225,9 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件,
|
|||
if filename not in file_names:
|
||||
file_names.append(filename)
|
||||
|
||||
|
||||
# 对保存的文件进行向量化
|
||||
print(f"upload_docs, to_vector_store:{to_vector_store}")
|
||||
if to_vector_store:
|
||||
result = update_docs(
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
|
|
|
|||
|
|
@ -334,7 +334,7 @@ class KnowledgeFile:
|
|||
with open(outputfile, 'w') as file:
|
||||
for doc in docs:
|
||||
print(f"**********切分段{i}:{doc}")
|
||||
file.write(f"分段{i}")
|
||||
file.write(f"\n**********切分段{i}")
|
||||
file.write(doc.page_content)
|
||||
i = i+1
|
||||
|
||||
|
|
|
|||
|
|
@ -54,12 +54,12 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
|
|||
# Get appropriate separator to use
|
||||
separator = separators[-1]
|
||||
new_separators = [SPLIT_SEPARATOE]
|
||||
#text = re.sub(r'(\n+[a-zA-Z1-9]+\s*(\.\s*[a-zA-Z1-9]+\s*)+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过章和节来分块
|
||||
text = re.sub(r'(\n+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s*(?!\.|[a-zA-Z1-9]))', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过1.2这样的章和节来分块
|
||||
text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\.[A-Za-z0-9]+)+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过表 A.4.a
|
||||
text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\.[A-Za-z0-9]+)+\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过表 A.4.a
|
||||
text = re.sub(r'(\n+第\s*\S+\s*条\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 条
|
||||
text = re.sub(r'(\n+第\s*\S+\s*章\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 条
|
||||
text = re.sub(r'(\n+(一、|二、|三、|四、|五、|六、|七、|八、|九、|十、|十一、|十二、|十三、|十四、|十五、|十六、|十七、|十八、|十九、|二十、))', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 条
|
||||
text = re.sub(r'(\s+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 再通过 1.2 来分块
|
||||
text = text.rstrip() # 段尾如果有多余的\n就去掉它
|
||||
for i, _s in enumerate(separators):
|
||||
_separator = _s if self._is_separator_regex else re.escape(_s)
|
||||
|
|
@ -88,7 +88,7 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
|
|||
if not new_separators:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
text = re.sub(r'(\n+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s*(\.\s*[a-zA-Z1-9]+\s*)+)', r"\n\n\n\n\n\n\n\n\n\n\1", s) # 再通过1.2.3来分块
|
||||
text = re.sub(r'(\s+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", s) # 再通过 1.2.3 来分块
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -173,7 +173,7 @@ def dialogue_page(api: ApiRequest):
|
|||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
|
||||
## Bge 模型会超过1
|
||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
|
||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
search_engine_list = api.list_search_engines()
|
||||
|
|
|
|||
Loading…
Reference in New Issue