enhance
This commit is contained in:
wvivi2023 2023-11-23 12:38:31 +08:00
parent a59767711b
commit 9ba2120129
9 changed files with 142 additions and 33 deletions

BIN
.DS_Store vendored

Binary file not shown.

BIN
server/.DS_Store vendored

Binary file not shown.

View File

@ -10,7 +10,7 @@ from langchain.prompts.chat import ChatPromptTemplate
from typing import List
from server.chat.utils import History
from server.utils import get_prompt_template
from langchain.prompts import PromptTemplate
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
@ -41,11 +41,34 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks=[callback],
)
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# augment_prompt_template = get_prompt_template("data_augment", "default")
# input_msg = History(role="user", content=augment_prompt_template).to_msg_template(False)
# chat_prompt = ChatPromptTemplate.from_messages(
# [i.to_msg_template() for i in history] + [input_msg])
# chain = LLMChain(prompt=chat_prompt, llm=model)
# print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt}")
# # Begin a task that runs in the background.
# task = asyncio.create_task(wrap_done(
# chain.acall({ "question": query}),
# callback.done),
# )
prompt = ChatPromptTemplate.from_template(" 你是一个非常聪明的语义转换专家总能找到同一个语义不同的表达方式,请简洁生成一个与三单引号里的原句子语气语调完全一致,并且语义最相似的新句子,注意不是回答三单引号里的原句子,同时新句子直接使用简体中文给出,而不用重三单引号里的原句子。如果无法给出满足条件的新句子,直接给出三单引号里的原句子,而不是给出三单引号里的原句子的答案 '''{{input}}''' ")
# prompt_template = get_prompt_template("llm_chat", prompt_name)
# input_msg = History(role="user", content=prompt_template).to_msg_template(False)
# chat_prompt = ChatPromptTemplate.from_messages(
# [i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=prompt, llm=model)
name = "John"
age = 30
text = "你是一个非常聪明的语义转换专家总能找到同一个语义不同的表达方式,请简洁生成一个与三单引号里的原句子语气语调完全一致,并且语义最相似的新句子,注意不是回答三单引号里的原句子,同时新句子直接使用简体中文给出,而不用重三单引号里的原句子。如果无法给出满足条件的新句子,直接给出三单引号里的原句子,而不是给出三单引号里的原句子的答案\n ''' "
text = text + query
text = text + " ''' "
result = ", ".join(["Name: ", name, ", Age: ", str(age)])
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(

View File

@ -14,12 +14,12 @@ import json
import os
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
from langchain.prompts import PromptTemplate
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=2),
history: List[History] = Body([],
description="历史对话",
examples=[[
@ -39,17 +39,30 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history]
#weiwei
print(f"server/chat/knowledge_base_chat function, history:{history}")
async def knowledge_base_chat_iterator(query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
#weiwei
print(f"knowledge_base_chat_iterator,query:{query},model_name:{model_name},prompt_name:{prompt_name}")
model1 = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[],
)
# augment_prompt_template = get_prompt_template("data_augment", "default")
# input_msg1 = History(role="user", content=augment_prompt_template).to_msg_template(False)
# chat_prompt1 = ChatPromptTemplate.from_messages(
# [i.to_msg_template() for i in history] + [input_msg1])
# chain1 = LLMChain(prompt=chat_prompt1, llm=model1)
# print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt1}")
# result = chain1._call({ "question": query})
# print(f"chain1._call, result:{result}")
callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
@ -58,23 +71,36 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
max_tokens=max_tokens,
callbacks=[callback],
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
#augment_prompt_template = get_prompt_template("data_augment", "default")
#input_msg = History(role="user", content=augment_prompt_template).to_msg_template(False)
#chat_prompt = ChatPromptTemplate.from_messages(
# [i.to_msg_template() for i in history] + [input_msg])
#chain = LLMChain(prompt=chat_prompt, llm=model)
#print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt}")
##chain = LLMChain(prompt=PromptTemplate.from_template(augment_prompt_template), llm=model)
##print(f"knowledge_base_chat_iterator,prompt_template:{augment_prompt_template}")
#task = asyncio.create_task(wrap_done(
# chain.acall({ "question": query}),
# callback.done),
#)
#prompt_template = "请找出和{question}最相似的一句话"
#llm_chain = LLMChain(prompt=PromptTemplate.from_template(prompt_template), llm=model)
#result = llm_chain(query)
#print(f"请找出和question 最相似的一句话:{result}")
docs = search_docs(query, knowledge_base_name, top_k, score_threshold, model1)
context = "\n".join([doc.page_content for doc in docs])
#weiwei
print(f"knowledge_base_chat_iterator,search docs context:{context}")
#print(f"knowledge_base_chat_iterator,search docs context:{context}")
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
#weiwei
print(f"knowledge_base_chat_iterator,input_msg:{input_msg}")
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
#weiwei
print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}")
#print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}")
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
@ -82,10 +108,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
chain.acall({"context": context, "question": query}),
callback.done),
)
#weiwei
print(f"task call end")
print(f"task call end")
source_documents = []
for inum, doc in enumerate(docs):
filename = os.path.split(doc.metadata["source"])[-1]
@ -93,6 +117,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
url = f"/knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
print(f"knowledge_base_chat_iterator, stream:{stream}")
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response

View File

@ -15,7 +15,13 @@ from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
from typing import List, Dict
from langchain.docstore.document import Document
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from server.chat.utils import History
from server.utils import BaseResponse, get_prompt_template
from langchain.prompts.chat import ChatPromptTemplate
from langchain.callbacks import AsyncIteratorCallbackHandler
class DocumentWithScore(Document):
score: float = None
@ -25,32 +31,82 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
model:ChatOpenAI = Body(...,description="大语言模型"),
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return []
# query = "根据国网安徽信通公司安全准入实施要求," + query
#history = {}
#augment_prompt_template = get_prompt_template("data_augment", "default")
#input_msg1 = History(role="user", content=augment_prompt_template).to_msg_template(False)
#chat_prompt1 = ChatPromptTemplate.from_messages(
# [i.to_msg_template() for i in history] + [input_msg1])
#chain1 = LLMChain(prompt=chat_prompt1, llm=model)
#print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt1}")
#result = chain1._call({ "question": query})
#query1 = result["text"]
#
#print(f"相似的问法:{query1}")
#docs = kb.search_docs(query, top_k, score_threshold)
#print(f"{query}的相似文档块有{docs}")
#data = []
#if query1 != query:
# docs1 = kb.search_docs(query1, top_k, score_threshold)
# print(f"{query1}的相似文档块有{docs1}")
# rerank_docs = rerank(docs1,docs,top_k)
# print(f"精排后的相似文档块有{rerank_docs}")
# data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in rerank_docs]
#else:
# data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
#print(f"chain1._call, result:{result},similiarit text:{query1}")
pre_doc = kb.search_docs(query, 1, None)
print(f"len(pre_doc):{len(pre_doc)}")
if len(pre_doc) > 0:
print(f"search_docs, len(pre_doc):{len(pre_doc)}")
print(f"search_docs, pre_doc:{pre_doc}")
filpath = pre_doc[0][0].metadata['source']
file_name = os.path.basename(filpath)
file_name, file_extension = os.path.splitext(file_name)
query = "根据" +file_name + ""+ query
print(f"search_docs, query:{query}")
print(f"search_docs, query:{query}")
docs = kb.search_docs(query, top_k, score_threshold)
print(f"search_docs, docs:{docs}")
if len(pre_doc) > 0:
if docs is not None:
docs.append(pre_doc[0])
else:
docs = pre_doc[0]
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
return data
def rerank(query_docs:List[Document] = Body(...,description="源query查询相似文档", examples=[]),
augment_docs:List[Document] = Body(...,description="增强query查询相似文档", examples=[]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
)-> List[Document]:
all_documents = query_docs + augment_docs
# unique_documents_dict = {doc[0].page_content: doc for doc in all_documents}
unique_documents_dict = {}
for doc in all_documents:
if doc[0].page_content not in unique_documents_dict or doc[1] < unique_documents_dict[doc[0].page_content][1]:
unique_documents_dict[doc[0].page_content] = doc
# 得到去重后的文档列表
unique_documents = list(unique_documents_dict.values())
sorted_documents = sorted(unique_documents_dict.values(),key=lambda doc: doc[1],reverse=False)
min_documents = sorted_documents[:top_k]
# 打印结果
for doc in min_documents:
print(f"{doc[0].page_content},doc[1]")
return min_documents
def list_files(
knowledge_base_name: str
@ -158,6 +214,8 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件,
failed_files = {}
file_names = list(docs.keys())
print(f"upload_docs, file_names:{file_names}")
# 先将上传的文件保存到磁盘
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
filename = result["data"]["file_name"]
@ -167,7 +225,9 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件,
if filename not in file_names:
file_names.append(filename)
# 对保存的文件进行向量化
print(f"upload_docs, to_vector_store:{to_vector_store}")
if to_vector_store:
result = update_docs(
knowledge_base_name=knowledge_base_name,

View File

@ -334,7 +334,7 @@ class KnowledgeFile:
with open(outputfile, 'w') as file:
for doc in docs:
print(f"**********切分段{i}{doc}")
file.write(f"分段{i}")
file.write(f"\n**********切分段{i}")
file.write(doc.page_content)
i = i+1

View File

@ -54,12 +54,12 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
# Get appropriate separator to use
separator = separators[-1]
new_separators = [SPLIT_SEPARATOE]
#text = re.sub(r'(\n+[a-zA-Z1-9]+\s*(\.\s*[a-zA-Z1-9]+\s*)+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过章和节来分块
text = re.sub(r'(\n+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s*(?!\.|[a-zA-Z1-9]))', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过1.2这样的章和节来分块
text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\.[A-Za-z0-9]+)+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过表 A.4.a
text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\.[A-Za-z0-9]+)+\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过表 A.4.a
text = re.sub(r'(\n+第\s*\S+\s*条\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 条
text = re.sub(r'(\n+第\s*\S+\s*章\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 条
text = re.sub(r'(\n+(一、|二、|三、|四、|五、|六、|七、|八、|九、|十、|十一、|十二、|十三、|十四、|十五、|十六、|十七、|十八、|十九、|二十、))', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 条
text = re.sub(r'(\s+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 再通过 1.2 来分块
text = text.rstrip() # 段尾如果有多余的\n就去掉它
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
@ -88,7 +88,7 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
if not new_separators:
final_chunks.append(s)
else:
text = re.sub(r'(\n+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s*(\.\s*[a-zA-Z1-9]+\s*)+)', r"\n\n\n\n\n\n\n\n\n\n\1", s) # 再通过1.2.3来分块
text = re.sub(r'(\s+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", s) # 再通过 1.2.3 来分块
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:

BIN
webui_pages/.DS_Store vendored Normal file

Binary file not shown.

View File

@ -173,7 +173,7 @@ def dialogue_page(api: ApiRequest):
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines()