Merge remote-tracking branch 'origin/dev' into dev
This commit is contained in:
commit
ead2e26da1
|
|
@ -183,7 +183,7 @@ async def recreate_vector_store(
|
|||
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
||||
'''
|
||||
|
||||
async def output():
|
||||
def output():
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,11 @@ from configs.model_config import (
|
|||
from functools import lru_cache
|
||||
import importlib
|
||||
from text_splitter import zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from pathlib import Path
|
||||
import json
|
||||
from typing import List, Union, Callable, Dict, Optional
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
|
|
@ -57,15 +62,88 @@ def load_embeddings(model: str, device: str):
|
|||
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
||||
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||
"UnstructuredMarkdownLoader": ['.md'],
|
||||
"CustomJSONLoader": [".json"],
|
||||
"CSVLoader": [".csv"],
|
||||
"PyPDFLoader": [".pdf"],
|
||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.doc', '.docx', '.epub', '.odt',
|
||||
'.ppt', '.pptx', '.tsv'], # '.xlsx'
|
||||
}
|
||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||
|
||||
|
||||
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||
'''
|
||||
langchain的JSONLoader需要jq,在win上使用不便,进行替代。
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
json_lines: bool = False,
|
||||
):
|
||||
"""Initialize the JSONLoader.
|
||||
|
||||
Args:
|
||||
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
|
||||
content_key (str): The key to use to extract the content from the JSON if
|
||||
results to a list of objects (dict).
|
||||
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
|
||||
object extracted by the jq_schema and the default metadata and returns
|
||||
a dict of the updated metadata.
|
||||
text_content (bool): Boolean flag to indicate whether the content is in
|
||||
string format, default to True.
|
||||
json_lines (bool): Boolean flag to indicate whether the input is in
|
||||
JSON Lines format.
|
||||
"""
|
||||
self.file_path = Path(file_path).resolve()
|
||||
self._content_key = content_key
|
||||
self._metadata_func = metadata_func
|
||||
self._text_content = text_content
|
||||
self._json_lines = json_lines
|
||||
|
||||
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
|
||||
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
|
||||
def load(self) -> List[Document]:
|
||||
"""Load and return documents from the JSON file."""
|
||||
docs: List[Document] = []
|
||||
if self._json_lines:
|
||||
with self.file_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._parse(line, docs)
|
||||
else:
|
||||
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
|
||||
return docs
|
||||
|
||||
def _parse(self, content: str, docs: List[Document]) -> None:
|
||||
"""Convert given content to documents."""
|
||||
data = json.loads(content)
|
||||
|
||||
# Perform some validation
|
||||
# This is not a perfect validation, but it should catch most cases
|
||||
# and prevent the user from getting a cryptic error later on.
|
||||
if self._content_key is not None:
|
||||
self._validate_content_key(data)
|
||||
|
||||
for i, sample in enumerate(data, len(docs) + 1):
|
||||
metadata = dict(
|
||||
source=str(self.file_path),
|
||||
seq_num=i,
|
||||
)
|
||||
text = self._get_text(sample=sample, metadata=metadata)
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
|
||||
|
||||
|
||||
def get_LoaderClass(file_extension):
|
||||
for LoaderClass, extensions in LOADER_DICT.items():
|
||||
if file_extension in extensions:
|
||||
|
|
@ -101,6 +179,16 @@ class KnowledgeFile:
|
|||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
if self.document_loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||
elif self.document_loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
||||
elif self.document_loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
||||
elif self.document_loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(self.filepath, text_content=False)
|
||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
|
||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
|
||||
# 上传文件
|
||||
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||
files = st.file_uploader("上传知识文件",
|
||||
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
|
||||
[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True,
|
||||
)
|
||||
|
|
@ -244,7 +244,6 @@ def knowledge_base_page(api: ApiRequest):
|
|||
|
||||
cols = st.columns(3)
|
||||
|
||||
# todo: freezed
|
||||
if cols[0].button(
|
||||
"依据源文件重建向量库",
|
||||
# help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
|
||||
|
|
@ -258,7 +257,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
if msg := check_error_msg(d):
|
||||
st.toast(msg)
|
||||
else:
|
||||
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
|
||||
empty.progress(d["finished"] / d["total"], d["msg"])
|
||||
st.experimental_rerun()
|
||||
|
||||
if cols[2].button(
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from server.utils import run_async, iter_over_async
|
|||
from configs.model_config import NLTK_DATA_PATH
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
|
|
@ -224,9 +225,17 @@ class ApiRequest:
|
|||
try:
|
||||
with response as r:
|
||||
for chunk in r.iter_text(None):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
if not chunk: # fastchat api yield empty bytes on start and end
|
||||
continue
|
||||
if as_json:
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
pprint(data, depth=1)
|
||||
yield data
|
||||
except Exception as e:
|
||||
logger.error(f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。")
|
||||
else:
|
||||
print(chunk, end="", flush=True)
|
||||
yield chunk
|
||||
except httpx.ConnectError as e:
|
||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
||||
|
|
@ -274,6 +283,9 @@ class ApiRequest:
|
|||
return self._fastapi_stream2generator(response)
|
||||
else:
|
||||
data = msg.dict(exclude_unset=True, exclude_none=True)
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/fastchat",
|
||||
json=data,
|
||||
|
|
@ -300,6 +312,9 @@ class ApiRequest:
|
|||
"stream": stream,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
|
||||
if no_remote_api:
|
||||
from server.chat.chat import chat
|
||||
response = chat(**data)
|
||||
|
|
@ -334,6 +349,9 @@ class ApiRequest:
|
|||
"local_doc_url": no_remote_api,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
|
||||
if no_remote_api:
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
response = knowledge_base_chat(**data)
|
||||
|
|
@ -367,6 +385,9 @@ class ApiRequest:
|
|||
"stream": stream,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
|
||||
if no_remote_api:
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
response = search_engine_chat(**data)
|
||||
|
|
|
|||
Loading…
Reference in New Issue