merge master
This commit is contained in:
commit
fbdc62d910
|
|
@ -178,6 +178,6 @@ Web UI 可以实现如下功能:
|
||||||
- [ ] 实现调用 API 的 Web UI Demo
|
- [ ] 实现调用 API 的 Web UI Demo
|
||||||
|
|
||||||
## 项目交流群
|
## 项目交流群
|
||||||

|

|
||||||
|
|
||||||
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||||
|
|
|
||||||
|
|
@ -12,43 +12,33 @@ from utils import torch_gc
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pypinyin import lazy_pinyin
|
from pypinyin import lazy_pinyin
|
||||||
|
|
||||||
|
|
||||||
DEVICE_ = EMBEDDING_DEVICE
|
DEVICE_ = EMBEDDING_DEVICE
|
||||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
||||||
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
||||||
|
|
||||||
|
|
||||||
def load_file(filepath):
|
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||||
if filepath.lower().endswith(".md"):
|
if filepath.lower().endswith(".md"):
|
||||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
elif filepath.lower().endswith(".pdf"):
|
elif filepath.lower().endswith(".pdf"):
|
||||||
loader = UnstructuredFileLoader(filepath)
|
loader = UnstructuredFileLoader(filepath)
|
||||||
textsplitter = ChineseTextSplitter(pdf=True)
|
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(textsplitter)
|
docs = loader.load_and_split(textsplitter)
|
||||||
else:
|
else:
|
||||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||||
textsplitter = ChineseTextSplitter(pdf=False)
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(text_splitter=textsplitter)
|
docs = loader.load_and_split(text_splitter=textsplitter)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt(related_docs: List[str],
|
def generate_prompt(related_docs: List[str], query: str,
|
||||||
query: str,
|
|
||||||
prompt_template=PROMPT_TEMPLATE) -> str:
|
prompt_template=PROMPT_TEMPLATE) -> str:
|
||||||
context = "\n".join([doc.page_content for doc in related_docs])
|
context = "\n".join([doc.page_content for doc in related_docs])
|
||||||
prompt = prompt_template.replace("{question}", query).replace("{context}", context)
|
prompt = prompt_template.replace("{question}", query).replace("{context}", context)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def get_docs_with_score(docs_with_score):
|
|
||||||
docs = []
|
|
||||||
for doc, score in docs_with_score:
|
|
||||||
doc.metadata["score"] = score
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def seperate_list(ls: List[int]) -> List[List[int]]:
|
def seperate_list(ls: List[int]) -> List[List[int]]:
|
||||||
lists = []
|
lists = []
|
||||||
ls1 = [ls[0]]
|
ls1 = [ls[0]]
|
||||||
|
|
@ -63,18 +53,24 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
|
||||||
|
|
||||||
|
|
||||||
def similarity_search_with_score_by_vector(
|
def similarity_search_with_score_by_vector(
|
||||||
self, embedding: List[float], k: int = 4,
|
self, embedding: List[float], k: int = 4
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||||
docs = []
|
docs = []
|
||||||
id_set = set()
|
id_set = set()
|
||||||
store_len = len(self.index_to_docstore_id)
|
store_len = len(self.index_to_docstore_id)
|
||||||
for j, i in enumerate(indices[0]):
|
for j, i in enumerate(indices[0]):
|
||||||
if i == -1:
|
if i == -1 or 0 < self.score_threshold < scores[0][j]:
|
||||||
# This happens when not enough docs are returned.
|
# This happens when not enough docs are returned.
|
||||||
continue
|
continue
|
||||||
_id = self.index_to_docstore_id[i]
|
_id = self.index_to_docstore_id[i]
|
||||||
doc = self.docstore.search(_id)
|
doc = self.docstore.search(_id)
|
||||||
|
if not self.chunk_conent:
|
||||||
|
if not isinstance(doc, Document):
|
||||||
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||||
|
doc.metadata["score"] = int(scores[0][j])
|
||||||
|
docs.append(doc)
|
||||||
|
continue
|
||||||
id_set.add(i)
|
id_set.add(i)
|
||||||
docs_len = len(doc.page_content)
|
docs_len = len(doc.page_content)
|
||||||
for k in range(1, max(i, store_len - i)):
|
for k in range(1, max(i, store_len - i)):
|
||||||
|
|
@ -91,6 +87,10 @@ def similarity_search_with_score_by_vector(
|
||||||
id_set.add(l)
|
id_set.add(l)
|
||||||
if break_flag:
|
if break_flag:
|
||||||
break
|
break
|
||||||
|
if not self.chunk_conent:
|
||||||
|
return docs
|
||||||
|
if len(id_set) == 0 and self.score_threshold > 0:
|
||||||
|
return []
|
||||||
id_list = sorted(list(id_set))
|
id_list = sorted(list(id_set))
|
||||||
id_lists = seperate_list(id_list)
|
id_lists = seperate_list(id_list)
|
||||||
for id_seq in id_lists:
|
for id_seq in id_lists:
|
||||||
|
|
@ -105,7 +105,8 @@ def similarity_search_with_score_by_vector(
|
||||||
if not isinstance(doc, Document):
|
if not isinstance(doc, Document):
|
||||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||||
doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])
|
doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])
|
||||||
docs.append((doc, doc_score))
|
doc.metadata["score"] = int(doc_score)
|
||||||
|
docs.append(doc)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
@ -115,6 +116,8 @@ class LocalDocQA:
|
||||||
embeddings: object = None
|
embeddings: object = None
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K
|
top_k: int = VECTOR_SEARCH_TOP_K
|
||||||
chunk_size: int = CHUNK_SIZE
|
chunk_size: int = CHUNK_SIZE
|
||||||
|
chunk_conent: bool = True
|
||||||
|
score_threshold: int = VECTOR_SEARCH_SCORE_THRESHOLD
|
||||||
|
|
||||||
def init_cfg(self,
|
def init_cfg(self,
|
||||||
embedding_model: str = EMBEDDING_MODEL,
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
|
|
@ -137,7 +140,8 @@ class LocalDocQA:
|
||||||
|
|
||||||
def init_knowledge_vector_store(self,
|
def init_knowledge_vector_store(self,
|
||||||
filepath: str or List[str],
|
filepath: str or List[str],
|
||||||
vs_path: str or os.PathLike = None):
|
vs_path: str or os.PathLike = None,
|
||||||
|
sentence_size=SENTENCE_SIZE):
|
||||||
loaded_files = []
|
loaded_files = []
|
||||||
failed_files = []
|
failed_files = []
|
||||||
if isinstance(filepath, str):
|
if isinstance(filepath, str):
|
||||||
|
|
@ -147,40 +151,41 @@ class LocalDocQA:
|
||||||
elif os.path.isfile(filepath):
|
elif os.path.isfile(filepath):
|
||||||
file = os.path.split(filepath)[-1]
|
file = os.path.split(filepath)[-1]
|
||||||
try:
|
try:
|
||||||
docs = load_file(filepath)
|
docs = load_file(filepath, sentence_size)
|
||||||
print(f"{file} 已成功加载")
|
logger.info(f"{file} 已成功加载")
|
||||||
loaded_files.append(filepath)
|
loaded_files.append(filepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logger.error(e)
|
||||||
print(f"{file} 未能成功加载")
|
logger.info(f"{file} 未能成功加载")
|
||||||
return None
|
return None
|
||||||
elif os.path.isdir(filepath):
|
elif os.path.isdir(filepath):
|
||||||
docs = []
|
docs = []
|
||||||
for file in tqdm(os.listdir(filepath), desc="加载文件"):
|
for file in tqdm(os.listdir(filepath), desc="加载文件"):
|
||||||
fullfilepath = os.path.join(filepath, file)
|
fullfilepath = os.path.join(filepath, file)
|
||||||
try:
|
try:
|
||||||
docs += load_file(fullfilepath)
|
docs += load_file(fullfilepath, sentence_size)
|
||||||
loaded_files.append(fullfilepath)
|
loaded_files.append(fullfilepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
failed_files.append(file)
|
failed_files.append(file)
|
||||||
|
|
||||||
if len(failed_files) > 0:
|
if len(failed_files) > 0:
|
||||||
print("以下文件未能成功加载:")
|
logger.info("以下文件未能成功加载:")
|
||||||
for file in failed_files:
|
for file in failed_files:
|
||||||
print(file, end="\n")
|
logger.info(file, end="\n")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
docs = []
|
docs = []
|
||||||
for file in filepath:
|
for file in filepath:
|
||||||
try:
|
try:
|
||||||
docs += load_file(file)
|
docs += load_file(file)
|
||||||
print(f"{file} 已成功加载")
|
logger.info(f"{file} 已成功加载")
|
||||||
loaded_files.append(file)
|
loaded_files.append(file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logger.error(e)
|
||||||
print(f"{file} 未能成功加载")
|
logger.info(f"{file} 未能成功加载")
|
||||||
if len(docs) > 0:
|
if len(docs) > 0:
|
||||||
print("文件加载完毕,正在生成向量库")
|
logger.info("文件加载完毕,正在生成向量库")
|
||||||
if vs_path and os.path.isdir(vs_path):
|
if vs_path and os.path.isdir(vs_path):
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
vector_store.add_documents(docs)
|
vector_store.add_documents(docs)
|
||||||
|
|
@ -189,38 +194,46 @@ class LocalDocQA:
|
||||||
if not vs_path:
|
if not vs_path:
|
||||||
vs_path = os.path.join(VS_ROOT_PATH,
|
vs_path = os.path.join(VS_ROOT_PATH,
|
||||||
f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
||||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
vector_store = FAISS.from_documents(docs, self.embeddings) # docs 为Document列表
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
vector_store.save_local(vs_path)
|
vector_store.save_local(vs_path)
|
||||||
return vs_path, loaded_files
|
return vs_path, loaded_files
|
||||||
else:
|
else:
|
||||||
print("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。")
|
logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。")
|
||||||
return None, loaded_files
|
return None, loaded_files
|
||||||
|
|
||||||
def get_knowledge_based_answer(self,
|
def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size):
|
||||||
query,
|
try:
|
||||||
vs_path,
|
if not vs_path or not one_title or not one_conent:
|
||||||
chat_history=[],
|
logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!")
|
||||||
streaming: bool = STREAMING):
|
return None, [one_title]
|
||||||
|
docs = [Document(page_content=one_conent+"\n", metadata={"source": one_title})]
|
||||||
|
if not one_content_segmentation:
|
||||||
|
text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||||
|
docs = text_splitter.split_documents(docs)
|
||||||
|
if os.path.isdir(vs_path):
|
||||||
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
|
vector_store.add_documents(docs)
|
||||||
|
else:
|
||||||
|
vector_store = FAISS.from_documents(docs, self.embeddings) ##docs 为Document列表
|
||||||
|
torch_gc()
|
||||||
|
vector_store.save_local(vs_path)
|
||||||
|
return vs_path, [one_title]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return None, [one_title]
|
||||||
|
|
||||||
|
def get_knowledge_based_answer(self, query, vs_path, chat_history=[], streaming: bool = STREAMING):
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
|
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
|
||||||
vector_store.chunk_size = self.chunk_size
|
vector_store.chunk_size = self.chunk_size
|
||||||
related_docs_with_score = vector_store.similarity_search_with_score(query,
|
vector_store.chunk_conent = self.chunk_conent
|
||||||
k=self.top_k)
|
vector_store.score_threshold = self.score_threshold
|
||||||
related_docs = get_docs_with_score(related_docs_with_score)
|
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
prompt = generate_prompt(related_docs, query)
|
prompt = generate_prompt(related_docs_with_score, query)
|
||||||
|
|
||||||
# if streaming:
|
|
||||||
# for result, history in self.llm._stream_call(prompt=prompt,
|
|
||||||
# history=chat_history):
|
|
||||||
# history[-1][0] = query
|
|
||||||
# response = {"query": query,
|
|
||||||
# "result": result,
|
|
||||||
# "source_documents": related_docs}
|
|
||||||
# yield response, history
|
|
||||||
# else:
|
|
||||||
for result, history in self.llm._call(prompt=prompt,
|
for result, history in self.llm._call(prompt=prompt,
|
||||||
history=chat_history,
|
history=chat_history,
|
||||||
streaming=streaming):
|
streaming=streaming):
|
||||||
|
|
@ -228,10 +241,35 @@ class LocalDocQA:
|
||||||
history[-1][0] = query
|
history[-1][0] = query
|
||||||
response = {"query": query,
|
response = {"query": query,
|
||||||
"result": result,
|
"result": result,
|
||||||
"source_documents": related_docs}
|
"source_documents": related_docs_with_score}
|
||||||
yield response, history
|
yield response, history
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
# query 查询内容
|
||||||
|
# vs_path 知识库路径
|
||||||
|
# chunk_conent 是否启用上下文关联
|
||||||
|
# score_threshold 搜索匹配score阈值
|
||||||
|
# vector_search_top_k 搜索知识库内容条数,默认搜索5条结果
|
||||||
|
# chunk_sizes 匹配单段内容的连接上下文长度
|
||||||
|
def get_knowledge_based_conent_test(self, query, vs_path, chunk_conent,
|
||||||
|
score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
|
||||||
|
vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE):
|
||||||
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
|
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
|
||||||
|
vector_store.chunk_conent = chunk_conent
|
||||||
|
vector_store.score_threshold = score_threshold
|
||||||
|
vector_store.chunk_size = chunk_size
|
||||||
|
related_docs_with_score = vector_store.similarity_search_with_score(query, k=vector_search_top_k)
|
||||||
|
if not related_docs_with_score:
|
||||||
|
response = {"query": query,
|
||||||
|
"source_documents": []}
|
||||||
|
return response, ""
|
||||||
|
torch_gc()
|
||||||
|
prompt = "\n".join([doc.page_content for doc in related_docs_with_score])
|
||||||
|
response = {"query": query,
|
||||||
|
"source_documents": related_docs_with_score}
|
||||||
|
return response, prompt
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
|
|
@ -243,11 +281,11 @@ if __name__ == "__main__":
|
||||||
vs_path=vs_path,
|
vs_path=vs_path,
|
||||||
chat_history=[],
|
chat_history=[],
|
||||||
streaming=True):
|
streaming=True):
|
||||||
print(resp["result"][last_print_len:], end="", flush=True)
|
logger.info(resp["result"][last_print_len:], end="", flush=True)
|
||||||
last_print_len = len(resp["result"])
|
last_print_len = len(resp["result"])
|
||||||
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||||
# f"""相关度:{doc.metadata['score']}\n\n"""
|
# f"""相关度:{doc.metadata['score']}\n\n"""
|
||||||
for inum, doc in
|
for inum, doc in
|
||||||
enumerate(resp["source_documents"])]
|
enumerate(resp["source_documents"])]
|
||||||
print("\n\n" + "\n\n".join(source_text))
|
logger.info("\n\n" + "\n\n".join(source_text))
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
||||||
chat_history=history,
|
chat_history=history,
|
||||||
streaming=STREAMING):
|
streaming=STREAMING):
|
||||||
if STREAMING:
|
if STREAMING:
|
||||||
logger.info(resp["result"][last_print_len:], end="", flush=True)
|
logger.info(resp["result"][last_print_len:])
|
||||||
last_print_len = len(resp["result"])
|
last_print_len = len(resp["result"])
|
||||||
else:
|
else:
|
||||||
logger.info(resp["result"])
|
logger.info(resp["result"])
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,9 @@ LLM_HISTORY_LEN = 3
|
||||||
# return top-k text chunk from vector store
|
# return top-k text chunk from vector store
|
||||||
VECTOR_SEARCH_TOP_K = 5
|
VECTOR_SEARCH_TOP_K = 5
|
||||||
|
|
||||||
|
# 如果为0,则不生效,经测试小于500值的结果更精准
|
||||||
|
VECTOR_SEARCH_SCORE_THRESHOLD = 0
|
||||||
|
|
||||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||||
|
|
||||||
FLAG_USER_NAME = uuid.uuid4().hex
|
FLAG_USER_NAME = uuid.uuid4().hex
|
||||||
|
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 270 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 268 KiB |
|
|
@ -5,9 +5,10 @@ from configs.model_config import SENTENCE_SIZE
|
||||||
|
|
||||||
|
|
||||||
class ChineseTextSplitter(CharacterTextSplitter):
|
class ChineseTextSplitter(CharacterTextSplitter):
|
||||||
def __init__(self, pdf: bool = False, **kwargs):
|
def __init__(self, pdf: bool = False, sentence_size: int = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.pdf = pdf
|
self.pdf = pdf
|
||||||
|
self.sentence_size = sentence_size
|
||||||
|
|
||||||
def split_text1(self, text: str) -> List[str]:
|
def split_text1(self, text: str) -> List[str]:
|
||||||
if self.pdf:
|
if self.pdf:
|
||||||
|
|
@ -23,7 +24,7 @@ class ChineseTextSplitter(CharacterTextSplitter):
|
||||||
sent_list.append(ele)
|
sent_list.append(ele)
|
||||||
return sent_list
|
return sent_list
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
|
||||||
if self.pdf:
|
if self.pdf:
|
||||||
text = re.sub(r"\n{3,}", r"\n", text)
|
text = re.sub(r"\n{3,}", r"\n", text)
|
||||||
text = re.sub('\s', " ", text)
|
text = re.sub('\s', " ", text)
|
||||||
|
|
@ -38,15 +39,15 @@ class ChineseTextSplitter(CharacterTextSplitter):
|
||||||
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
|
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
|
||||||
ls = [i for i in text.split("\n") if i]
|
ls = [i for i in text.split("\n") if i]
|
||||||
for ele in ls:
|
for ele in ls:
|
||||||
if len(ele) > SENTENCE_SIZE:
|
if len(ele) > self.sentence_size:
|
||||||
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
|
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
|
||||||
ele1_ls = ele1.split("\n")
|
ele1_ls = ele1.split("\n")
|
||||||
for ele_ele1 in ele1_ls:
|
for ele_ele1 in ele1_ls:
|
||||||
if len(ele_ele1) > SENTENCE_SIZE:
|
if len(ele_ele1) > self.sentence_size:
|
||||||
ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
|
ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
|
||||||
ele2_ls = ele_ele2.split("\n")
|
ele2_ls = ele_ele2.split("\n")
|
||||||
for ele_ele2 in ele2_ls:
|
for ele_ele2 in ele2_ls:
|
||||||
if len(ele_ele2) > SENTENCE_SIZE:
|
if len(ele_ele2) > self.sentence_size:
|
||||||
ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
|
ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
|
||||||
ele2_id = ele2_ls.index(ele_ele2)
|
ele2_id = ele2_ls.index(ele_ele2)
|
||||||
ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
|
ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
|
||||||
|
|
|
||||||
228
webui.py
228
webui.py
|
|
@ -7,6 +7,7 @@ import nltk
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
|
|
||||||
def get_vs_list():
|
def get_vs_list():
|
||||||
lst_default = ["新建知识库"]
|
lst_default = ["新建知识库"]
|
||||||
if not os.path.exists(VS_ROOT_PATH):
|
if not os.path.exists(VS_ROOT_PATH):
|
||||||
|
|
@ -28,14 +29,13 @@ local_doc_qa = LocalDocQA()
|
||||||
|
|
||||||
flag_csv_logger = gr.CSVLogger()
|
flag_csv_logger = gr.CSVLogger()
|
||||||
|
|
||||||
def get_answer(query, vs_path, history, mode,
|
|
||||||
streaming: bool = STREAMING):
|
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
|
||||||
if mode == "知识库问答" and vs_path:
|
vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_conent: bool = True,
|
||||||
|
chunk_size=CHUNK_SIZE, streaming: bool = STREAMING):
|
||||||
|
if mode == "知识库问答" and os.path.exists(vs_path):
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||||
query=query,
|
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
|
||||||
vs_path=vs_path,
|
|
||||||
chat_history=history,
|
|
||||||
streaming=streaming):
|
|
||||||
source = "\n\n"
|
source = "\n\n"
|
||||||
source += "".join(
|
source += "".join(
|
||||||
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
|
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
|
||||||
|
|
@ -45,15 +45,34 @@ def get_answer(query, vs_path, history, mode,
|
||||||
enumerate(resp["source_documents"])])
|
enumerate(resp["source_documents"])])
|
||||||
history[-1][-1] += source
|
history[-1][-1] += source
|
||||||
yield history, ""
|
yield history, ""
|
||||||
|
elif mode == "知识库测试" and os.path.exists(vs_path):
|
||||||
|
resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
vector_search_top_k=vector_search_top_k,
|
||||||
|
chunk_conent=chunk_conent,
|
||||||
|
chunk_size=chunk_size)
|
||||||
|
if not resp["source_documents"]:
|
||||||
|
yield history + [[query,
|
||||||
|
"根据您的设定,没有匹配到任何内容,请确认您设置的score阈值是否过小或其他参数是否正确!"]], ""
|
||||||
else:
|
else:
|
||||||
for resp, history in local_doc_qa.llm._call(query, history,
|
source = "".join(
|
||||||
streaming=streaming):
|
[
|
||||||
|
f"""<details> <summary>[score值]:{doc.metadata["score"]} - ({i + 1})[出处]: {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
|
||||||
|
f"""{doc.page_content}\n"""
|
||||||
|
f"""</details>"""
|
||||||
|
for i, doc in
|
||||||
|
enumerate(resp["source_documents"])])
|
||||||
|
history.append([query, prompt + source])
|
||||||
|
yield history, ""
|
||||||
|
else:
|
||||||
|
for resp, history in local_doc_qa.llm._call(query, history, streaming=streaming):
|
||||||
history[-1][-1] = resp + (
|
history[-1][-1] = resp + (
|
||||||
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||||
yield history, ""
|
yield history, ""
|
||||||
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
|
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
|
||||||
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
||||||
|
|
||||||
|
|
||||||
def init_model():
|
def init_model():
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg()
|
local_doc_qa.init_cfg()
|
||||||
|
|
@ -89,19 +108,23 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, us
|
||||||
return history + [[None, model_status]]
|
return history + [[None, model_status]]
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store(vs_id, files, history):
|
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||||
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||||
filelist = []
|
filelist = []
|
||||||
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
|
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
|
||||||
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
|
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
|
||||||
|
if local_doc_qa.llm and local_doc_qa.embeddings:
|
||||||
|
if isinstance(files, list) and one_conent is None:
|
||||||
for file in files:
|
for file in files:
|
||||||
filename = os.path.split(file.name)[-1]
|
filename = os.path.split(file.name)[-1]
|
||||||
shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
||||||
filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
||||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path, sentence_size)
|
||||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
|
else:
|
||||||
|
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
||||||
|
sentence_size)
|
||||||
if len(loaded_files):
|
if len(loaded_files):
|
||||||
file_status = f"已上传 {'、'.join([os.path.split(i)[-1] for i in loaded_files])} 至知识库,并已加载知识库,请开始提问"
|
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files])} 内容至知识库,并已加载知识库,请开始提问"
|
||||||
else:
|
else:
|
||||||
file_status = "文件未成功加载,请重新上传文件"
|
file_status = "文件未成功加载,请重新上传文件"
|
||||||
else:
|
else:
|
||||||
|
|
@ -111,7 +134,6 @@ def get_vector_store(vs_id, files, history):
|
||||||
return vs_path, None, history + [[None, file_status]]
|
return vs_path, None, history + [[None, file_status]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def change_vs_name_input(vs_id, history):
|
def change_vs_name_input(vs_id, history):
|
||||||
if vs_id == "新建知识库":
|
if vs_id == "新建知识库":
|
||||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
|
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
|
||||||
|
|
@ -122,22 +144,42 @@ def change_vs_name_input(vs_id, history):
|
||||||
[None, file_status]]
|
[None, file_status]]
|
||||||
|
|
||||||
|
|
||||||
def change_mode(mode):
|
def change_mode(mode, history):
|
||||||
if mode == "知识库问答":
|
if mode == "知识库问答":
|
||||||
return gr.update(visible=True)
|
return gr.update(visible=True), gr.update(visible=False), history + [[None,
|
||||||
|
"【注意】:现在是知识库问答模式,您输入的任何查询都将进行知识库查询,然后会自动整理知识库关联内容进入模型查询!!!"]]
|
||||||
|
elif mode == "知识库测试":
|
||||||
|
return gr.update(visible=True), gr.update(visible=True), [[None,
|
||||||
|
"【注意】:现在是知识库测试模式,您输入的任何查询都将进行知识库查询,并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询!!!如果单条内容入库,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。单条内容长度在100-150左右较为合理。"]]
|
||||||
else:
|
else:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False), gr.update(visible=False), history
|
||||||
|
|
||||||
|
|
||||||
|
def change_chunk_conent(mode, label_conent, history):
|
||||||
|
conent = ""
|
||||||
|
if "chunk_conent" in label_conent:
|
||||||
|
conent = "搜索结果上下文关联"
|
||||||
|
elif "one_content_segmentation" in label_conent: # 这里没用上,可以先留着
|
||||||
|
conent = "内容分段入库"
|
||||||
|
|
||||||
|
if mode:
|
||||||
|
return gr.update(visible=True), history + [[None, f"【已开启{conent}】"]]
|
||||||
|
else:
|
||||||
|
return gr.update(visible=False), history + [[None, f"[已关闭{conent}]"]]
|
||||||
|
|
||||||
|
|
||||||
def add_vs_name(vs_name, vs_list, chatbot):
|
def add_vs_name(vs_name, vs_list, chatbot):
|
||||||
if vs_name in vs_list:
|
if vs_name in vs_list:
|
||||||
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
|
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
|
||||||
chatbot = chatbot + [[None, vs_status]]
|
chatbot = chatbot + [[None, vs_status]]
|
||||||
return gr.update(visible=True), vs_list,gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), chatbot
|
return gr.update(visible=True), vs_list, gr.update(visible=True), gr.update(visible=True), gr.update(
|
||||||
|
visible=False), chatbot
|
||||||
else:
|
else:
|
||||||
vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
|
vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
|
||||||
chatbot = chatbot + [[None, vs_status]]
|
chatbot = chatbot + [[None, vs_status]]
|
||||||
return gr.update(visible=True, choices= [vs_name] + vs_list, value=vs_name), [vs_name]+vs_list, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True),chatbot
|
return gr.update(visible=True, choices=[vs_name] + vs_list, value=vs_name), [vs_name] + vs_list, gr.update(
|
||||||
|
visible=False), gr.update(visible=False), gr.update(visible=True), chatbot
|
||||||
|
|
||||||
|
|
||||||
block_css = """.importantButton {
|
block_css = """.importantButton {
|
||||||
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
|
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
|
||||||
|
|
@ -175,17 +217,18 @@ with gr.Blocks(css=block_css) as demo:
|
||||||
with gr.Column(scale=10):
|
with gr.Column(scale=10):
|
||||||
chatbot = gr.Chatbot([[None, init_message], [None, model_status.value]],
|
chatbot = gr.Chatbot([[None, init_message], [None, model_status.value]],
|
||||||
elem_id="chat-box",
|
elem_id="chat-box",
|
||||||
show_label=False).style(height=750)
|
show_label=False).style(height=650)
|
||||||
query = gr.Textbox(show_label=False,
|
query = gr.Textbox(show_label=False,
|
||||||
placeholder="请输入提问内容,按回车进行提交").style(container=False)
|
placeholder="请输入提问内容,按回车进行提交").style(container=False)
|
||||||
with gr.Column(scale=5):
|
with gr.Column(scale=5):
|
||||||
mode = gr.Radio(["LLM 对话", "知识库问答"],
|
mode = gr.Radio(["LLM 对话", "知识库问答"],
|
||||||
label="请选择使用模式",
|
label="请选择使用模式",
|
||||||
value="知识库问答", )
|
value="知识库问答", )
|
||||||
|
knowledge_set = gr.Accordion("知识库设定", visible=False)
|
||||||
vs_setting = gr.Accordion("配置知识库")
|
vs_setting = gr.Accordion("配置知识库")
|
||||||
mode.change(fn=change_mode,
|
mode.change(fn=change_mode,
|
||||||
inputs=mode,
|
inputs=[mode, chatbot],
|
||||||
outputs=vs_setting)
|
outputs=[vs_setting, knowledge_set, chatbot])
|
||||||
with vs_setting:
|
with vs_setting:
|
||||||
select_vs = gr.Dropdown(vs_list.value,
|
select_vs = gr.Dropdown(vs_list.value,
|
||||||
label="请选择要加载的知识库",
|
label="请选择要加载的知识库",
|
||||||
|
|
@ -195,12 +238,96 @@ with gr.Blocks(css=block_css) as demo:
|
||||||
vs_name = gr.Textbox(label="请输入新建知识库名称",
|
vs_name = gr.Textbox(label="请输入新建知识库名称",
|
||||||
lines=1,
|
lines=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
visible=True if default_path=="" else False)
|
visible=True)
|
||||||
vs_add = gr.Button(value="添加至知识库选项", visible=True if default_path=="" else False)
|
vs_add = gr.Button(value="添加至知识库选项", visible=True)
|
||||||
file2vs = gr.Column(visible=False if default_path=="" else True)
|
file2vs = gr.Column(visible=False)
|
||||||
with file2vs:
|
with file2vs:
|
||||||
# load_vs = gr.Button("加载知识库")
|
# load_vs = gr.Button("加载知识库")
|
||||||
gr.Markdown("向知识库中添加文件")
|
gr.Markdown("向知识库中添加文件")
|
||||||
|
sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0,
|
||||||
|
label="文本入库分句长度限制",
|
||||||
|
interactive=True, visible=True)
|
||||||
|
with gr.Tab("上传文件"):
|
||||||
|
files = gr.File(label="添加文件",
|
||||||
|
file_types=['.txt', '.md', '.docx', '.pdf'],
|
||||||
|
file_count="multiple",
|
||||||
|
show_label=False)
|
||||||
|
load_file_button = gr.Button("上传文件并加载知识库")
|
||||||
|
with gr.Tab("上传文件夹"):
|
||||||
|
folder_files = gr.File(label="添加文件",
|
||||||
|
# file_types=['.txt', '.md', '.docx', '.pdf'],
|
||||||
|
file_count="directory",
|
||||||
|
show_label=False)
|
||||||
|
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
||||||
|
vs_add.click(fn=add_vs_name,
|
||||||
|
inputs=[vs_name, vs_list, chatbot],
|
||||||
|
outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot])
|
||||||
|
select_vs.change(fn=change_vs_name_input,
|
||||||
|
inputs=[select_vs, chatbot],
|
||||||
|
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
||||||
|
load_file_button.click(get_vector_store,
|
||||||
|
show_progress=True,
|
||||||
|
inputs=[select_vs, files, sentence_size, chatbot, vs_setting, file2vs],
|
||||||
|
outputs=[vs_path, files, chatbot], )
|
||||||
|
load_folder_button.click(get_vector_store,
|
||||||
|
show_progress=True,
|
||||||
|
inputs=[select_vs, folder_files, sentence_size, chatbot, vs_setting,
|
||||||
|
file2vs],
|
||||||
|
outputs=[vs_path, folder_files, chatbot], )
|
||||||
|
flag_csv_logger.setup([query, vs_path, chatbot, mode], "flagged")
|
||||||
|
query.submit(get_answer,
|
||||||
|
[query, vs_path, chatbot, mode],
|
||||||
|
[chatbot, query])
|
||||||
|
with gr.Tab("知识库测试"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=10):
|
||||||
|
chatbot = gr.Chatbot([[None,
|
||||||
|
"【注意】:现在是知识库测试模式,您输入的任何查询都将进行知识库查询,并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询!!!如果单条内容入库,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。单条内容长度在100-150左右较为合理。"]],
|
||||||
|
elem_id="chat-box",
|
||||||
|
show_label=False).style(height=750)
|
||||||
|
query = gr.Textbox(show_label=False,
|
||||||
|
placeholder="请输入提问内容,按回车进行提交").style(container=False)
|
||||||
|
with gr.Column(scale=5):
|
||||||
|
mode = gr.Radio(["知识库问答", "知识库测试"],
|
||||||
|
label="请选择使用模式",
|
||||||
|
value="知识库测试", )
|
||||||
|
knowledge_set = gr.Accordion("知识库设定", visible=True)
|
||||||
|
vs_setting = gr.Accordion("配置知识库", visible=True)
|
||||||
|
mode.change(fn=change_mode,
|
||||||
|
inputs=[mode, chatbot],
|
||||||
|
outputs=[vs_setting, knowledge_set, chatbot])
|
||||||
|
with knowledge_set:
|
||||||
|
score_threshold = gr.Number(value=VECTOR_SEARCH_SCORE_THRESHOLD,
|
||||||
|
label="score阈值,分值越低匹配度越高",
|
||||||
|
precision=0, interactive=True)
|
||||||
|
vector_search_top_k = gr.Number(value=VECTOR_SEARCH_TOP_K, precision=0,
|
||||||
|
label="获取知识库内容条数", interactive=True)
|
||||||
|
chunk_conent = gr.Checkbox(value=False,
|
||||||
|
label="是否启用上下文关联",
|
||||||
|
interactive=True)
|
||||||
|
chunk_sizes = gr.Number(value=CHUNK_SIZE, precision=0,
|
||||||
|
label="匹配单段内容的连接上下文长度",
|
||||||
|
interactive=True, visible=False)
|
||||||
|
chunk_conent.change(fn=change_chunk_conent,
|
||||||
|
inputs=[chunk_conent, gr.Textbox(value="chunk_conent", visible=False), chatbot],
|
||||||
|
outputs=[chunk_sizes, chatbot])
|
||||||
|
with vs_setting:
|
||||||
|
select_vs = gr.Dropdown(vs_list.value,
|
||||||
|
label="请选择要加载的知识库",
|
||||||
|
interactive=True,
|
||||||
|
value=vs_list.value[0] if len(vs_list.value) > 0 else None)
|
||||||
|
vs_name = gr.Textbox(label="请输入新建知识库名称",
|
||||||
|
lines=1,
|
||||||
|
interactive=True,
|
||||||
|
visible=True)
|
||||||
|
vs_add = gr.Button(value="添加至知识库选项", visible=True)
|
||||||
|
file2vs = gr.Column(visible=False)
|
||||||
|
with file2vs:
|
||||||
|
# load_vs = gr.Button("加载知识库")
|
||||||
|
gr.Markdown("向知识库中添加单条内容或文件")
|
||||||
|
sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0,
|
||||||
|
label="文本入库分句长度限制",
|
||||||
|
interactive=True, visible=True)
|
||||||
with gr.Tab("上传文件"):
|
with gr.Tab("上传文件"):
|
||||||
files = gr.File(label="添加文件",
|
files = gr.File(label="添加文件",
|
||||||
file_types=['.txt', '.md', '.docx', '.pdf'],
|
file_types=['.txt', '.md', '.docx', '.pdf'],
|
||||||
|
|
@ -212,38 +339,46 @@ with gr.Blocks(css=block_css) as demo:
|
||||||
folder_files = gr.File(label="添加文件",
|
folder_files = gr.File(label="添加文件",
|
||||||
# file_types=['.txt', '.md', '.docx', '.pdf'],
|
# file_types=['.txt', '.md', '.docx', '.pdf'],
|
||||||
file_count="directory",
|
file_count="directory",
|
||||||
show_label=False
|
show_label=False)
|
||||||
)
|
|
||||||
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
||||||
# load_vs.click(fn=)
|
with gr.Tab("添加单条内容"):
|
||||||
|
one_title = gr.Textbox(label="标题", placeholder="请输入要添加单条段落的标题", lines=1)
|
||||||
|
one_conent = gr.Textbox(label="内容", placeholder="请输入要添加单条段落的内容", lines=5)
|
||||||
|
one_content_segmentation = gr.Checkbox(value=True, label="禁止内容分句入库",
|
||||||
|
interactive=True)
|
||||||
|
load_conent_button = gr.Button("添加内容并加载知识库")
|
||||||
|
# 将上传的文件保存到content文件夹下,并更新下拉框
|
||||||
vs_add.click(fn=add_vs_name,
|
vs_add.click(fn=add_vs_name,
|
||||||
inputs=[vs_name, vs_list, chatbot],
|
inputs=[vs_name, vs_list, chatbot],
|
||||||
outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot])
|
outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot])
|
||||||
select_vs.change(fn=change_vs_name_input,
|
select_vs.change(fn=change_vs_name_input,
|
||||||
inputs=[select_vs, chatbot],
|
inputs=[select_vs, chatbot],
|
||||||
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
||||||
# 将上传的文件保存到content文件夹下,并更新下拉框
|
|
||||||
load_file_button.click(get_vector_store,
|
load_file_button.click(get_vector_store,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
inputs=[select_vs, files, chatbot],
|
inputs=[select_vs, files, sentence_size, chatbot, vs_setting, file2vs],
|
||||||
outputs=[vs_path, files, chatbot],
|
outputs=[vs_path, files, chatbot], )
|
||||||
)
|
|
||||||
load_folder_button.click(get_vector_store,
|
load_folder_button.click(get_vector_store,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
inputs=[select_vs, folder_files, chatbot],
|
inputs=[select_vs, folder_files, sentence_size, chatbot, vs_setting,
|
||||||
outputs=[vs_path, folder_files, chatbot],
|
file2vs],
|
||||||
)
|
outputs=[vs_path, folder_files, chatbot], )
|
||||||
|
load_conent_button.click(get_vector_store,
|
||||||
|
show_progress=True,
|
||||||
|
inputs=[select_vs, one_title, sentence_size, chatbot,
|
||||||
|
one_conent, one_content_segmentation],
|
||||||
|
outputs=[vs_path, files, chatbot], )
|
||||||
flag_csv_logger.setup([query, vs_path, chatbot, mode], "flagged")
|
flag_csv_logger.setup([query, vs_path, chatbot, mode], "flagged")
|
||||||
query.submit(get_answer,
|
query.submit(get_answer,
|
||||||
[query, vs_path, chatbot, mode],
|
[query, vs_path, chatbot, mode, score_threshold, vector_search_top_k, chunk_conent,
|
||||||
|
chunk_sizes],
|
||||||
[chatbot, query])
|
[chatbot, query])
|
||||||
with gr.Tab("模型配置"):
|
with gr.Tab("模型配置"):
|
||||||
llm_model = gr.Radio(llm_model_dict_list,
|
llm_model = gr.Radio(llm_model_dict_list,
|
||||||
label="LLM 模型",
|
label="LLM 模型",
|
||||||
value=LLM_MODEL,
|
value=LLM_MODEL,
|
||||||
interactive=True)
|
interactive=True)
|
||||||
llm_history_len = gr.Slider(0,
|
llm_history_len = gr.Slider(0, 10,
|
||||||
10,
|
|
||||||
value=LLM_HISTORY_LEN,
|
value=LLM_HISTORY_LEN,
|
||||||
step=1,
|
step=1,
|
||||||
label="LLM 对话轮数",
|
label="LLM 对话轮数",
|
||||||
|
|
@ -258,19 +393,12 @@ with gr.Blocks(css=block_css) as demo:
|
||||||
label="Embedding 模型",
|
label="Embedding 模型",
|
||||||
value=EMBEDDING_MODEL,
|
value=EMBEDDING_MODEL,
|
||||||
interactive=True)
|
interactive=True)
|
||||||
top_k = gr.Slider(1,
|
top_k = gr.Slider(1, 20, value=VECTOR_SEARCH_TOP_K, step=1,
|
||||||
20,
|
label="向量匹配 top k", interactive=True)
|
||||||
value=VECTOR_SEARCH_TOP_K,
|
|
||||||
step=1,
|
|
||||||
label="向量匹配 top k",
|
|
||||||
interactive=True)
|
|
||||||
load_model_button = gr.Button("重新加载模型")
|
load_model_button = gr.Button("重新加载模型")
|
||||||
load_model_button.click(reinit_model,
|
load_model_button.click(reinit_model, show_progress=True,
|
||||||
show_progress=True,
|
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora,
|
||||||
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k,
|
top_k, chatbot], outputs=chatbot)
|
||||||
chatbot],
|
|
||||||
outputs=chatbot
|
|
||||||
)
|
|
||||||
|
|
||||||
(demo
|
(demo
|
||||||
.queue(concurrency_count=3)
|
.queue(concurrency_count=3)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue