diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index ae027c1..8a40058 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -183,7 +183,7 @@ async def recreate_vector_store( set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. ''' - async def output(): + def output(): kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 34f2083..f3205af 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -12,6 +12,11 @@ from configs.model_config import ( from functools import lru_cache import importlib from text_splitter import zh_title_enhance +import langchain.document_loaders +from langchain.docstore.document import Document +from pathlib import Path +import json +from typing import List, Union, Callable, Dict, Optional def validate_kb_name(knowledge_base_id: str) -> bool: @@ -57,15 +62,88 @@ def load_embeddings(model: str, device: str): -LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst', - '.rtf', '.txt', '.xml', - '.doc', '.docx', '.epub', '.odt', '.pdf', - '.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv' +LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], + "UnstructuredMarkdownLoader": ['.md'], + "CustomJSONLoader": [".json"], "CSVLoader": [".csv"], "PyPDFLoader": [".pdf"], + "UnstructuredFileLoader": ['.eml', '.msg', '.rst', + '.rtf', '.txt', '.xml', + '.doc', '.docx', '.epub', '.odt', + '.ppt', '.pptx', '.tsv'], # '.xlsx' } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] + +class CustomJSONLoader(langchain.document_loaders.JSONLoader): + ''' + langchain的JSONLoader需要jq,在win上使用不便,进行替代。 + ''' + + def __init__( + self, + file_path: Union[str, Path], + content_key: Optional[str] = None, + metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, + text_content: bool = True, + json_lines: bool = False, + ): + """Initialize the JSONLoader. + + Args: + file_path (Union[str, Path]): The path to the JSON or JSON Lines file. + content_key (str): The key to use to extract the content from the JSON if + results to a list of objects (dict). + metadata_func (Callable[Dict, Dict]): A function that takes in the JSON + object extracted by the jq_schema and the default metadata and returns + a dict of the updated metadata. + text_content (bool): Boolean flag to indicate whether the content is in + string format, default to True. + json_lines (bool): Boolean flag to indicate whether the input is in + JSON Lines format. + """ + self.file_path = Path(file_path).resolve() + self._content_key = content_key + self._metadata_func = metadata_func + self._text_content = text_content + self._json_lines = json_lines + + # TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows. + # This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded. + def load(self) -> List[Document]: + """Load and return documents from the JSON file.""" + docs: List[Document] = [] + if self._json_lines: + with self.file_path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + self._parse(line, docs) + else: + self._parse(self.file_path.read_text(encoding="utf-8"), docs) + return docs + + def _parse(self, content: str, docs: List[Document]) -> None: + """Convert given content to documents.""" + data = json.loads(content) + + # Perform some validation + # This is not a perfect validation, but it should catch most cases + # and prevent the user from getting a cryptic error later on. + if self._content_key is not None: + self._validate_content_key(data) + + for i, sample in enumerate(data, len(docs) + 1): + metadata = dict( + source=str(self.file_path), + seq_num=i, + ) + text = self._get_text(sample=sample, metadata=metadata) + docs.append(Document(page_content=text, metadata=metadata)) + +langchain.document_loaders.CustomJSONLoader = CustomJSONLoader + + def get_LoaderClass(file_extension): for LoaderClass, extensions in LOADER_DICT.items(): if file_extension in extensions: @@ -101,6 +179,16 @@ class KnowledgeFile: DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") if self.document_loader_name == "UnstructuredFileLoader": loader = DocumentLoader(self.filepath, autodetect_encoding=True) + elif self.document_loader_name == "CSVLoader": + loader = DocumentLoader(self.filepath, encoding="utf-8") + elif self.document_loader_name == "JSONLoader": + loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False) + elif self.document_loader_name == "CustomJSONLoader": + loader = DocumentLoader(self.filepath, text_content=False) + elif self.document_loader_name == "UnstructuredMarkdownLoader": + loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single` + elif self.document_loader_name == "UnstructuredHTMLLoader": + loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single` else: loader = DocumentLoader(self.filepath) diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 4351e95..7b6b35e 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest): # 上传文件 # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) - files = st.file_uploader("上传知识文件", + files = st.file_uploader("上传知识文件(暂不支持扫描PDF)", [i for ls in LOADER_DICT.values() for i in ls], accept_multiple_files=True, ) @@ -244,7 +244,6 @@ def knowledge_base_page(api: ApiRequest): cols = st.columns(3) - # todo: freezed if cols[0].button( "依据源文件重建向量库", # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", @@ -258,7 +257,7 @@ def knowledge_base_page(api: ApiRequest): if msg := check_error_msg(d): st.toast(msg) else: - empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") + empty.progress(d["finished"] / d["total"], d["msg"]) st.experimental_rerun() if cols[2].button( diff --git a/webui_pages/utils.py b/webui_pages/utils.py index c666d45..042cb6b 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -25,6 +25,7 @@ from server.utils import run_async, iter_over_async from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path +from pprint import pprint def set_httpx_timeout(timeout=60.0): @@ -224,9 +225,17 @@ class ApiRequest: try: with response as r: for chunk in r.iter_text(None): - if as_json and chunk: - yield json.loads(chunk) - elif chunk.strip(): + if not chunk: # fastchat api yield empty bytes on start and end + continue + if as_json: + try: + data = json.loads(chunk) + pprint(data, depth=1) + yield data + except Exception as e: + logger.error(f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。") + else: + print(chunk, end="", flush=True) yield chunk except httpx.ConnectError as e: msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。" @@ -274,6 +283,9 @@ class ApiRequest: return self._fastapi_stream2generator(response) else: data = msg.dict(exclude_unset=True, exclude_none=True) + print(f"received input message:") + pprint(data) + response = self.post( "/chat/fastchat", json=data, @@ -300,6 +312,9 @@ class ApiRequest: "stream": stream, } + print(f"received input message:") + pprint(data) + if no_remote_api: from server.chat.chat import chat response = chat(**data) @@ -334,6 +349,9 @@ class ApiRequest: "local_doc_url": no_remote_api, } + print(f"received input message:") + pprint(data) + if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat response = knowledge_base_chat(**data) @@ -367,6 +385,9 @@ class ApiRequest: "stream": stream, } + print(f"received input message:") + pprint(data) + if no_remote_api: from server.chat.search_engine_chat import search_engine_chat response = search_engine_chat(**data)