Merge remote-tracking branch 'origin/dev' into dev

This commit is contained in:
zqt 2023-08-27 10:32:02 +08:00
commit ead2e26da1
4 changed files with 119 additions and 11 deletions

View File

@ -183,7 +183,7 @@ async def recreate_vector_store(
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
async def output():
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}

View File

@ -12,6 +12,11 @@ from configs.model_config import (
from functools import lru_cache
import importlib
from text_splitter import zh_title_enhance
import langchain.document_loaders
from langchain.docstore.document import Document
from pathlib import Path
import json
from typing import List, Union, Callable, Dict, Optional
def validate_kb_name(knowledge_base_id: str) -> bool:
@ -57,15 +62,88 @@ def load_embeddings(model: str, device: str):
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt', '.pdf',
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"],
"CSVLoader": [".csv"],
"PyPDFLoader": [".pdf"],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt',
'.ppt', '.pptx', '.tsv'], # '.xlsx'
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
'''
langchain的JSONLoader需要jq在win上使用不便进行替代
'''
def __init__(
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
):
"""Initialize the JSONLoader.
Args:
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
content_key (str): The key to use to extract the content from the JSON if
results to a list of objects (dict).
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
object extracted by the jq_schema and the default metadata and returns
a dict of the updated metadata.
text_content (bool): Boolean flag to indicate whether the content is in
string format, default to True.
json_lines (bool): Boolean flag to indicate whether the input is in
JSON Lines format.
"""
self.file_path = Path(file_path).resolve()
self._content_key = content_key
self._metadata_func = metadata_func
self._text_content = text_content
self._json_lines = json_lines
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
def load(self) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
if self._json_lines:
with self.file_path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self._parse(line, docs)
else:
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
return docs
def _parse(self, content: str, docs: List[Document]) -> None:
"""Convert given content to documents."""
data = json.loads(content)
# Perform some validation
# This is not a perfect validation, but it should catch most cases
# and prevent the user from getting a cryptic error later on.
if self._content_key is not None:
self._validate_content_key(data)
for i, sample in enumerate(data, len(docs) + 1):
metadata = dict(
source=str(self.file_path),
seq_num=i,
)
text = self._get_text(sample=sample, metadata=metadata)
docs.append(Document(page_content=text, metadata=metadata))
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER_DICT.items():
if file_extension in extensions:
@ -101,6 +179,16 @@ class KnowledgeFile:
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
elif self.document_loader_name == "CSVLoader":
loader = DocumentLoader(self.filepath, encoding="utf-8")
elif self.document_loader_name == "JSONLoader":
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
elif self.document_loader_name == "CustomJSONLoader":
loader = DocumentLoader(self.filepath, text_content=False)
elif self.document_loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
elif self.document_loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
else:
loader = DocumentLoader(self.filepath)

View File

@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
# 上传文件
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
files = st.file_uploader("上传知识文件",
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)
@ -244,7 +244,6 @@ def knowledge_base_page(api: ApiRequest):
cols = st.columns(3)
# todo: freezed
if cols[0].button(
"依据源文件重建向量库",
# help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。",
@ -258,7 +257,7 @@ def knowledge_base_page(api: ApiRequest):
if msg := check_error_msg(d):
st.toast(msg)
else:
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
empty.progress(d["finished"] / d["total"], d["msg"])
st.experimental_rerun()
if cols[2].button(

View File

@ -25,6 +25,7 @@ from server.utils import run_async, iter_over_async
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from pprint import pprint
def set_httpx_timeout(timeout=60.0):
@ -224,9 +225,17 @@ class ApiRequest:
try:
with response as r:
for chunk in r.iter_text(None):
if as_json and chunk:
yield json.loads(chunk)
elif chunk.strip():
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
data = json.loads(chunk)
pprint(data, depth=1)
yield data
except Exception as e:
logger.error(f"接口返回json错误 {chunk}’。错误信息是:{e}")
else:
print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。"
@ -274,6 +283,9 @@ class ApiRequest:
return self._fastapi_stream2generator(response)
else:
data = msg.dict(exclude_unset=True, exclude_none=True)
print(f"received input message:")
pprint(data)
response = self.post(
"/chat/fastchat",
json=data,
@ -300,6 +312,9 @@ class ApiRequest:
"stream": stream,
}
print(f"received input message:")
pprint(data)
if no_remote_api:
from server.chat.chat import chat
response = chat(**data)
@ -334,6 +349,9 @@ class ApiRequest:
"local_doc_url": no_remote_api,
}
print(f"received input message:")
pprint(data)
if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(**data)
@ -367,6 +385,9 @@ class ApiRequest:
"stream": stream,
}
print(f"received input message:")
pprint(data)
if no_remote_api:
from server.chat.search_engine_chat import search_engine_chat
response = search_engine_chat(**data)