Merge remote-tracking branch 'origin/dev' into dev
This commit is contained in:
commit
ead2e26da1
|
|
@ -183,7 +183,7 @@ async def recreate_vector_store(
|
||||||
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
async def output():
|
def output():
|
||||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||||
if not kb.exists() and not allow_empty_kb:
|
if not kb.exists() and not allow_empty_kb:
|
||||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,11 @@ from configs.model_config import (
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import importlib
|
import importlib
|
||||||
from text_splitter import zh_title_enhance
|
from text_splitter import zh_title_enhance
|
||||||
|
import langchain.document_loaders
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
from typing import List, Union, Callable, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
|
|
@ -57,15 +62,88 @@ def load_embeddings(model: str, device: str):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
|
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||||
'.rtf', '.txt', '.xml',
|
"UnstructuredMarkdownLoader": ['.md'],
|
||||||
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
"CustomJSONLoader": [".json"],
|
||||||
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
|
||||||
"CSVLoader": [".csv"],
|
"CSVLoader": [".csv"],
|
||||||
"PyPDFLoader": [".pdf"],
|
"PyPDFLoader": [".pdf"],
|
||||||
|
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||||
|
'.rtf', '.txt', '.xml',
|
||||||
|
'.doc', '.docx', '.epub', '.odt',
|
||||||
|
'.ppt', '.pptx', '.tsv'], # '.xlsx'
|
||||||
}
|
}
|
||||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||||
|
'''
|
||||||
|
langchain的JSONLoader需要jq,在win上使用不便,进行替代。
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
content_key: Optional[str] = None,
|
||||||
|
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||||
|
text_content: bool = True,
|
||||||
|
json_lines: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the JSONLoader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
|
||||||
|
content_key (str): The key to use to extract the content from the JSON if
|
||||||
|
results to a list of objects (dict).
|
||||||
|
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
|
||||||
|
object extracted by the jq_schema and the default metadata and returns
|
||||||
|
a dict of the updated metadata.
|
||||||
|
text_content (bool): Boolean flag to indicate whether the content is in
|
||||||
|
string format, default to True.
|
||||||
|
json_lines (bool): Boolean flag to indicate whether the input is in
|
||||||
|
JSON Lines format.
|
||||||
|
"""
|
||||||
|
self.file_path = Path(file_path).resolve()
|
||||||
|
self._content_key = content_key
|
||||||
|
self._metadata_func = metadata_func
|
||||||
|
self._text_content = text_content
|
||||||
|
self._json_lines = json_lines
|
||||||
|
|
||||||
|
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
|
||||||
|
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
"""Load and return documents from the JSON file."""
|
||||||
|
docs: List[Document] = []
|
||||||
|
if self._json_lines:
|
||||||
|
with self.file_path.open(encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
self._parse(line, docs)
|
||||||
|
else:
|
||||||
|
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def _parse(self, content: str, docs: List[Document]) -> None:
|
||||||
|
"""Convert given content to documents."""
|
||||||
|
data = json.loads(content)
|
||||||
|
|
||||||
|
# Perform some validation
|
||||||
|
# This is not a perfect validation, but it should catch most cases
|
||||||
|
# and prevent the user from getting a cryptic error later on.
|
||||||
|
if self._content_key is not None:
|
||||||
|
self._validate_content_key(data)
|
||||||
|
|
||||||
|
for i, sample in enumerate(data, len(docs) + 1):
|
||||||
|
metadata = dict(
|
||||||
|
source=str(self.file_path),
|
||||||
|
seq_num=i,
|
||||||
|
)
|
||||||
|
text = self._get_text(sample=sample, metadata=metadata)
|
||||||
|
docs.append(Document(page_content=text, metadata=metadata))
|
||||||
|
|
||||||
|
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
|
||||||
|
|
||||||
|
|
||||||
def get_LoaderClass(file_extension):
|
def get_LoaderClass(file_extension):
|
||||||
for LoaderClass, extensions in LOADER_DICT.items():
|
for LoaderClass, extensions in LOADER_DICT.items():
|
||||||
if file_extension in extensions:
|
if file_extension in extensions:
|
||||||
|
|
@ -101,6 +179,16 @@ class KnowledgeFile:
|
||||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||||
if self.document_loader_name == "UnstructuredFileLoader":
|
if self.document_loader_name == "UnstructuredFileLoader":
|
||||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||||
|
elif self.document_loader_name == "CSVLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
||||||
|
elif self.document_loader_name == "JSONLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
||||||
|
elif self.document_loader_name == "CustomJSONLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, text_content=False)
|
||||||
|
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
|
||||||
|
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
|
||||||
else:
|
else:
|
||||||
loader = DocumentLoader(self.filepath)
|
loader = DocumentLoader(self.filepath)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
|
|
||||||
# 上传文件
|
# 上传文件
|
||||||
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||||
files = st.file_uploader("上传知识文件",
|
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
|
||||||
[i for ls in LOADER_DICT.values() for i in ls],
|
[i for ls in LOADER_DICT.values() for i in ls],
|
||||||
accept_multiple_files=True,
|
accept_multiple_files=True,
|
||||||
)
|
)
|
||||||
|
|
@ -244,7 +244,6 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
|
|
||||||
cols = st.columns(3)
|
cols = st.columns(3)
|
||||||
|
|
||||||
# todo: freezed
|
|
||||||
if cols[0].button(
|
if cols[0].button(
|
||||||
"依据源文件重建向量库",
|
"依据源文件重建向量库",
|
||||||
# help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
|
# help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
|
||||||
|
|
@ -258,7 +257,7 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
if msg := check_error_msg(d):
|
if msg := check_error_msg(d):
|
||||||
st.toast(msg)
|
st.toast(msg)
|
||||||
else:
|
else:
|
||||||
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
|
empty.progress(d["finished"] / d["total"], d["msg"])
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
if cols[2].button(
|
if cols[2].button(
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from server.utils import run_async, iter_over_async
|
||||||
from configs.model_config import NLTK_DATA_PATH
|
from configs.model_config import NLTK_DATA_PATH
|
||||||
import nltk
|
import nltk
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
def set_httpx_timeout(timeout=60.0):
|
def set_httpx_timeout(timeout=60.0):
|
||||||
|
|
@ -224,9 +225,17 @@ class ApiRequest:
|
||||||
try:
|
try:
|
||||||
with response as r:
|
with response as r:
|
||||||
for chunk in r.iter_text(None):
|
for chunk in r.iter_text(None):
|
||||||
if as_json and chunk:
|
if not chunk: # fastchat api yield empty bytes on start and end
|
||||||
yield json.loads(chunk)
|
continue
|
||||||
elif chunk.strip():
|
if as_json:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
pprint(data, depth=1)
|
||||||
|
yield data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。")
|
||||||
|
else:
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
yield chunk
|
yield chunk
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
||||||
|
|
@ -274,6 +283,9 @@ class ApiRequest:
|
||||||
return self._fastapi_stream2generator(response)
|
return self._fastapi_stream2generator(response)
|
||||||
else:
|
else:
|
||||||
data = msg.dict(exclude_unset=True, exclude_none=True)
|
data = msg.dict(exclude_unset=True, exclude_none=True)
|
||||||
|
print(f"received input message:")
|
||||||
|
pprint(data)
|
||||||
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/fastchat",
|
"/chat/fastchat",
|
||||||
json=data,
|
json=data,
|
||||||
|
|
@ -300,6 +312,9 @@ class ApiRequest:
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(f"received input message:")
|
||||||
|
pprint(data)
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.chat import chat
|
from server.chat.chat import chat
|
||||||
response = chat(**data)
|
response = chat(**data)
|
||||||
|
|
@ -334,6 +349,9 @@ class ApiRequest:
|
||||||
"local_doc_url": no_remote_api,
|
"local_doc_url": no_remote_api,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(f"received input message:")
|
||||||
|
pprint(data)
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
response = knowledge_base_chat(**data)
|
response = knowledge_base_chat(**data)
|
||||||
|
|
@ -367,6 +385,9 @@ class ApiRequest:
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(f"received input message:")
|
||||||
|
pprint(data)
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.search_engine_chat import search_engine_chat
|
from server.chat.search_engine_chat import search_engine_chat
|
||||||
response = search_engine_chat(**data)
|
response = search_engine_chat(**data)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue