diff --git a/.gitignore b/.gitignore index fe1385d..73ebd21 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,6 @@ embedding/* pyrightconfig.json loader/tmp_files flagged/* +ptuning-v2/*.json +ptuning-v2/*.bin + diff --git a/README.md b/README.md index 7369db3..5ee0ade 100644 --- a/README.md +++ b/README.md @@ -255,7 +255,7 @@ Web UI 可以实现如下功能: - [x] VUE 前端 ## 项目交流群 -二维码 +二维码 🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 diff --git a/README_en.md b/README_en.md index 79edb91..249962d 100644 --- a/README_en.md +++ b/README_en.md @@ -87,6 +87,10 @@ $ conda create -p /your_path/env_name python=3.8 # Activate the environment $ source activate /your_path/env_name +# or, do not specify an env path, note that /your_path/env_name is to be replaced with env_name below +$ conda create -n env_name python=3.8 +$ conda activate env_name # Activate the environment + # Deactivate the environment $ source deactivate /your_path/env_name diff --git a/api.py b/api.py index d6e48d6..452313e 100644 --- a/api.py +++ b/api.py @@ -9,8 +9,9 @@ import asyncio import nltk import pydantic import uvicorn -from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket +from fastapi import Body, Request, FastAPI, File, Form, Query, UploadFile, WebSocket from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing_extensions import Annotated from starlette.responses import RedirectResponse @@ -55,7 +56,7 @@ class ListDocsResponse(BaseResponse): class ChatMessage(BaseModel): question: str = pydantic.Field(..., description="Question text") response: str = pydantic.Field(..., description="Response text") - history: List[List[str]] = pydantic.Field(..., description="History text") + history: List[List[Optional[str]]] = pydantic.Field(..., description="History text") source_documents: List[str] = pydantic.Field( ..., description="List of source documents and their scores" ) @@ -303,7 +304,8 @@ async def update_doc( async def local_doc_chat( knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), question: str = Body(..., description="Question", example="工伤保险是什么?"), - history: List[List[str]] = Body( + streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"), + history: List[List[Optional[str]]] = Body( [], description="History of previous questions and answers", example=[ @@ -324,27 +326,39 @@ async def local_doc_chat( source_documents=[], ) else: - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=question, vs_path=vs_path, chat_history=history, streaming=True - ): - pass - source_documents = [ - f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" - f"""相关度:{doc.metadata['score']}\n\n""" - for inum, doc in enumerate(resp["source_documents"]) - ] + if (streaming): + def generate_answer (): + last_print_len = 0 + for resp, next_history in local_doc_qa.get_knowledge_based_answer( + query=question, vs_path=vs_path, chat_history=history, streaming=True + ): + yield resp["result"][last_print_len:] + last_print_len=len(resp["result"]) - return ChatMessage( - question=question, - response=resp["result"], - history=history, - source_documents=source_documents, - ) + return StreamingResponse(generate_answer()) + else: + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=question, vs_path=vs_path, chat_history=history, streaming=True + ): + pass + + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + + return ChatMessage( + question=question, + response=resp["result"], + history=history, + source_documents=source_documents, + ) async def bing_search_chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), - history: Optional[List[List[str]]] = Body( + history: Optional[List[List[Optional[str]]]] = Body( [], description="History of previous questions and answers", example=[ @@ -374,7 +388,8 @@ async def bing_search_chat( async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), - history: List[List[str]] = Body( + streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"), + history: List[List[Optional[str]]] = Body( [], description="History of previous questions and answers", example=[ @@ -385,6 +400,30 @@ async def chat( ], ), ): + if (streaming): + def generate_answer (): + last_print_len = 0 + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": question, "history": history, "streaming": True}) + for answer_result in answer_result_stream_result['answer_result_stream']: + yield answer_result.llm_output["answer"][last_print_len:] + last_print_len = len(answer_result.llm_output["answer"]) + + return StreamingResponse(generate_answer()) + else: + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": question, "history": history, "streaming": True}) + for answer_result in answer_result_stream_result['answer_result_stream']: + resp = answer_result.llm_output["answer"] + history = answer_result.history + pass + + return ChatMessage( + question=question, + response=resp, + history=history, + source_documents=[], + ) answer_result_stream_result = local_doc_qa.llm_model_chain( {"prompt": question, "history": history, "streaming": True}) @@ -392,7 +431,6 @@ async def chat( resp = answer_result.llm_output["answer"] history = answer_result.history pass - return ChatMessage( question=question, response=resp, @@ -545,7 +583,7 @@ if __name__ == "__main__": parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_certfile", type=str) # 初始化消息 - args = None + args = parser.parse_args() args_dict = vars(args) shared.loaderCheckPoint = LoaderCheckPoint(args_dict) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 2f7d8da..6085dfc 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -8,7 +8,6 @@ from typing import List from utils import torch_gc from tqdm import tqdm from pypinyin import lazy_pinyin -from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader from models.base import (BaseAnswer, AnswerResult) from models.loader.args import parser @@ -59,6 +58,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None): def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE): + if filepath.lower().endswith(".md"): loader = UnstructuredFileLoader(filepath, mode="elements") docs = loader.load() @@ -67,10 +67,14 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = loader.load_and_split(textsplitter) elif filepath.lower().endswith(".pdf"): + # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x + from loader import UnstructuredPaddlePDFLoader loader = UnstructuredPaddlePDFLoader(filepath) textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) docs = loader.load_and_split(textsplitter) elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"): + # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x + from loader import UnstructuredPaddleImageLoader loader = UnstructuredPaddleImageLoader(filepath, mode="elements") textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = loader.load_and_split(text_splitter=textsplitter) diff --git a/configs/model_config.py b/configs/model_config.py index c6fb964..a95d5bf 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -31,7 +31,7 @@ EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backe # llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例 # 在以下字典中修改属性值,以指定本地 LLM 模型存储位置 # 如将 "chatglm-6b" 的 "local_model_path" 由 None 修改为 "User/Downloads/chatglm-6b" -# 此处请写绝对路径 +# 此处请写绝对路径,且路径中必须包含repo-id的模型名称,因为FastChat是以模型名匹配的 llm_model_dict = { "chatglm-6b-int4-qe": { "name": "chatglm-6b-int4-qe", @@ -154,6 +154,15 @@ llm_model_dict = { "provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain" "api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url" "api_key": "EMPTY" + }, + # 通过 fastchat 调用的模型请参考如下格式 + "fastchat-chatglm-6b-int4": { + "name": "chatglm-6b-int4", # "name"修改为fastchat服务中的"model_name" + "pretrained_model_name": "chatglm-6b-int4", + "local_model_path": None, + "provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain" + "api_base_url": "http://localhost:8001/v1", # "name"修改为fastchat服务中的"api_base_url" + "api_key": "EMPTY" }, "fastchat-chatglm2-6b": { "name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name" @@ -175,11 +184,13 @@ llm_model_dict = { # 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443): # Max retries exceeded with url: /v1/chat/completions # 则需要将urllib3版本修改为1.25.11 + # 如果依然报urllib3.exceptions.MaxRetryError: HTTPSConnectionPool,则将https改为http + # 参考https://zhuanlan.zhihu.com/p/350015032 # 如果报出:raise NewConnectionError( # urllib3.exceptions.NewConnectionError: : # Failed to establish a new connection: [WinError 10060] - # 则是因为内地和香港的IP都被OPENAI封了,需要挂切换为日本、新加坡等地 + # 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地 "openai-chatgpt-3.5": { "name": "gpt-3.5-turbo", "pretrained_model_name": "gpt-3.5-turbo", @@ -200,20 +211,17 @@ BF16 = False # 本地lora存放的位置 LORA_DIR = "loras/" -# LLM lora path,默认为空,如果有请直接指定文件夹路径 -LLM_LORA_PATH = "" -USE_LORA = True if LLM_LORA_PATH else False +# LORA的名称,如有请指定为列表 + +LORA_NAME = "" +USE_LORA = True if LORA_NAME else False # LLM streaming reponse STREAMING = True # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False -PTUNING_DIR='./ptuing-v2' -<<<<<<< HEAD -======= - ->>>>>>> f68d347c25b4bdd07f293c65a6e44a673a11f614 +PTUNING_DIR='./ptuning-v2' # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/docs/FAQ.md b/docs/FAQ.md index 2162671..ccc0f25 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -180,9 +180,9 @@ Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a --- -Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出`RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients` +Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients` -疑为chatglm的quantization的问题或torch版本差异问题,针对已经变为Parameter的torch.zeros矩阵也执行Parameter操作,从而抛出`RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在chatglm-项目的原始文件中的quantization.py文件374行改为: +疑为chatglm的quantization的问题或torch版本差异问题,针对已经变为Parameter的torch.zeros矩阵也执行Parameter操作,从而抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在chatglm-项目的原始文件中的quantization.py文件374行改为: ``` try: @@ -191,6 +191,8 @@ Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出`RuntimeError: Only Tensors of pass ``` - 注:虽然模型可以顺利加载但在cpu上仍存在推理失败的可能:即针对每个问题,模型一直输出gugugugu。 + 如果上述方式不起作用,则在.cache/hugggingface/modules/目录下针对chatglm项目的原始文件中的quantization.py文件执行上述操作,若软链接不止一个,按照错误提示选择正确的路径。 - 因此,最好不要试图用cpu加载量化模型,原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。 \ No newline at end of file +注:虽然模型可以顺利加载但在cpu上仍存在推理失败的可能:即针对每个问题,模型一直输出gugugugu。 + + 因此,最好不要试图用cpu加载量化模型,原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。 diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 6602973..cd3c51d 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -12,6 +12,12 @@ $ conda create -p /your_path/env_name python=3.8 # 激活环境 $ source activate /your_path/env_name + +# 或,conda安装,不指定路径, 注意以下,都将/your_path/env_name替换为env_name +$ conda create -n env_name python=3.8 +$ conda activate env_name # Activate the environment + +# 更新py库 $ pip3 install --upgrade pip # 关闭环境 @@ -49,7 +55,7 @@ $ python loader/image_loader.py ## llama-cpp模型调用的说明 -1. 首先从huggingface hub中下载对应的模型,如https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/的[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。 +1. 首先从huggingface hub中下载对应的模型,如 [https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/) 的 [ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。 2. 将下载的模型重命名。通过huggingface_hub下载的模型会被重命名为随机序列,因此需要重命名为原始文件名,如[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)。 3. 基于下载模型的ggml的加载时间,推测对应的llama-cpp版本,下载对应的llama-cpp-python库的wheel文件,实测[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)与llama-cpp-python库兼容,然后手动安装wheel文件。 4. 将下载的模型信息写入configs/model_config.py文件里 `llm_model_dict`中,注意保证参数的兼容性,一些参数组合可能会报错. diff --git a/img/qr_code_43.jpg b/img/qr_code_43.jpg deleted file mode 100644 index 8bbdcbd..0000000 Binary files a/img/qr_code_43.jpg and /dev/null differ diff --git a/img/qr_code_45.jpg b/img/qr_code_45.jpg new file mode 100644 index 0000000..ad253c8 Binary files /dev/null and b/img/qr_code_45.jpg differ diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 81878ce..0d19ee6 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -2,14 +2,14 @@ from abc import ABC from langchain.chains.base import Chain from typing import Any, Dict, List, Optional, Generator from langchain.callbacks.manager import CallbackManagerForChainRun -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList +# from transformers.generation.logits_process import LogitsProcessor +# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, AnswerResult, AnswerResultStream, AnswerResultQueueSentinelTokenListenerQueue) -import torch +# import torch import transformers diff --git a/models/fastchat_openai_llm.py b/models/fastchat_openai_llm.py index d0972f7..217910a 100644 --- a/models/fastchat_openai_llm.py +++ b/models/fastchat_openai_llm.py @@ -245,7 +245,7 @@ if __name__ == "__main__": chain = FastChatOpenAILLMChain() - chain.set_api_key("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") + chain.set_api_key("EMPTY") # chain.set_api_base_url("https://api.openai.com/v1") # chain.call_model_name("gpt-3.5-turbo") diff --git a/models/loader/args.py b/models/loader/args.py index 59668b0..02f4926 100644 --- a/models/loader/args.py +++ b/models/loader/args.py @@ -42,9 +42,10 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th 'model to add the ` ' '--no-remote-model`') parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') -parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') +parser.add_argument("--use-lora",type=bool,default=USE_LORA,help="use lora or not") +parser.add_argument('--lora', type=str, default=LORA_NAME,help='Name of the LoRA to apply to the model by default.') parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") -parser.add_argument('--use-ptuning-v2',type=str,default=False,help="whether use ptuning-v2 checkpoint") +parser.add_argument('--use-ptuning-v2',action=USE_PTUNING_V2,help="whether use ptuning-v2 checkpoint") parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint") # Accelerate/transformers parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT, diff --git a/models/loader/loader.py b/models/loader/loader.py index ae4b408..c5c80fd 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -148,7 +148,7 @@ class LoaderCheckPoint: trust_remote_code=True).half() # 可传入device_map自定义每张卡的部署情况 if self.device_map is None: - if 'chatglm' in self.model_name.lower(): + if 'chatglm' in self.model_name.lower() and not "chatglm2" in self.model_name.lower(): self.device_map = self.chatglm_auto_configure_device_map(num_gpus) elif 'moss' in self.model_name.lower(): self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint) @@ -164,13 +164,6 @@ class LoaderCheckPoint: dtype=torch.float16 if not self.load_in_8bit else torch.int8, max_memory=max_memory, no_split_module_classes=model._no_split_modules) - # 对于chaglm和moss意外的模型应使用自动指定,而非调用chatglm的配置方式 - # 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map - # 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错 - # 实测在bloom模型上如此 - # self.device_map = infer_auto_device_map(model, - # dtype=torch.int8, - # no_split_module_classes=model._no_split_modules) model = dispatch_model(model, device_map=self.device_map) else: @@ -184,6 +177,13 @@ class LoaderCheckPoint: ) elif self.is_llamacpp: + + # 要调用llama-cpp模型,如vicuma-13b量化模型需要安装llama-cpp-python库 + # but!!! 实测pip install 不好使,需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载 + # 而且注意不同时期的ggml格式并不!兼!容!!!因此需要安装的llama-cpp-python版本也不一致,需要手动测试才能确定 + # 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容 + # 不过!!!本项目模型加载的方式控制的比较严格,与llama-cpp-python的兼容性较差,很多参数设定不能使用, + # 建议如非必要还是不要使用llama-cpp try: from llama_cpp import Llama @@ -448,7 +448,7 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') + prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r') prefix_encoder_config = json.loads(prefix_encoder_file.read()) prefix_encoder_file.close() self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] @@ -464,13 +464,14 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) + prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin')) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.float() + print("加载ptuning检查点成功!") except Exception as e: print(e) print("加载PrefixEncoder模型参数失败") diff --git a/requirements.txt b/requirements.txt index 7f97f67..4c981b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,13 +23,6 @@ openai #accelerate~=0.18.0 #peft~=0.3.0 #bitsandbytes; platform_system != "Windows" - -# 要调用llama-cpp模型,如vicuma-13b量化模型需要安装llama-cpp-python库 -# but!!! 实测pip install 不好使,需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载 -# 而且注意不同时期的ggml格式并不!兼!容!!!因此需要安装的llama-cpp-python版本也不一致,需要手动测试才能确定 -# 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容 -# 不过!!!本项目模型加载的方式控制的比较严格,与llama-cpp-python的兼容性较差,很多参数设定不能使用, -# 建议如非必要还是不要使用llama-cpp torch~=2.0.0 pydantic~=1.10.7 starlette~=0.26.1 diff --git a/webui.py b/webui.py index b7ff9fd..c7e7880 100644 --- a/webui.py +++ b/webui.py @@ -104,6 +104,7 @@ def init_model(): args_dict = vars(args) shared.loaderCheckPoint = LoaderCheckPoint(args_dict) llm_model_ins = shared.loaderLLM() + llm_model_ins.history_len = LLM_HISTORY_LEN try: local_doc_qa.init_cfg(llm_model=llm_model_ins) answer_result_stream_result = local_doc_qa.llm_model_chain( diff --git a/webui_st.py b/webui_st.py index 541851f..8309210 100644 --- a/webui_st.py +++ b/webui_st.py @@ -1,6 +1,8 @@ import streamlit as st from streamlit_chatbox import st_chatbox import tempfile +from pathlib import Path + ###### 从webui借用的代码 ##### ###### 做了少量修改 ##### import os @@ -23,6 +25,7 @@ def get_vs_list(): if not os.path.exists(KB_ROOT_PATH): return lst_default lst = os.listdir(KB_ROOT_PATH) + lst = [x for x in lst if os.path.isdir(os.path.join(KB_ROOT_PATH, x))] if not lst: return lst_default lst.sort() @@ -100,23 +103,23 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): - vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") - filelist = [] - if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): - os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content")) + vs_path = Path(KB_ROOT_PATH) / vs_id / "vector_store" + con_path = Path(KB_ROOT_PATH) / vs_id / "content" + con_path.mkdir(parents=True, exist_ok=True) + qa = st.session_state.local_doc_qa if qa.llm_model_chain and qa.embeddings: + filelist = [] if isinstance(files, list): for file in files: filename = os.path.split(file.name)[-1] - shutil.move(file.name, os.path.join( - KB_ROOT_PATH, vs_id, "content", filename)) - filelist.append(os.path.join( - KB_ROOT_PATH, vs_id, "content", filename)) + target = con_path / filename + shutil.move(file.name, target) + filelist.append(str(target)) vs_path, loaded_files = qa.init_knowledge_vector_store( - filelist, vs_path, sentence_size) + filelist, str(vs_path), sentence_size) else: - vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, + vs_path, loaded_files = qa.one_knowledge_add(str(vs_path), files, one_conent, one_content_segmentation, sentence_size) if len(loaded_files): file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" @@ -283,20 +286,24 @@ with st.sidebar: def on_new_kb(): name = st.session_state.kb_name - if name in vs_list: - st.error(f'名为“{name}”的知识库已存在。') + if not name: + st.sidebar.error(f'新建知识库名称不能为空!') + elif name in vs_list: + st.sidebar.error(f'名为“{name}”的知识库已存在。') else: - vs_list.append(name) st.session_state.vs_path = name + st.session_state.kb_name = '' + new_kb_dir = os.path.join(KB_ROOT_PATH, name) + if not os.path.exists(new_kb_dir): + os.makedirs(new_kb_dir) + st.sidebar.success(f'名为“{name}”的知识库创建成功,您可以开始添加文件。') def on_vs_change(): chat_box.robot_say(f'已加载知识库: {st.session_state.vs_path}') with st.expander('知识库配置', True): cols = st.columns([12, 10]) kb_name = cols[0].text_input( - '新知识库名称', placeholder='新知识库名称', label_visibility='collapsed') - if 'kb_name' not in st.session_state: - st.session_state.kb_name = kb_name + '新知识库名称', placeholder='新知识库名称', label_visibility='collapsed', key='kb_name') cols[1].button('新建知识库', on_click=on_new_kb) index = 0 try: @@ -317,7 +324,8 @@ with st.sidebar: sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE) files = st.file_uploader('上传知识文件', ['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'], - accept_multiple_files=True) + accept_multiple_files=True, + ) if st.button('添加文件到知识库'): temp_dir = tempfile.mkdtemp() file_list = [] @@ -326,8 +334,8 @@ with st.sidebar: with open(file, 'wb') as fp: fp.write(f.getvalue()) file_list.append(TempFile(file)) - _, _, history = get_vector_store( - vs_path, file_list, sentence_size, [], None, None) + _, _, history = get_vector_store( + vs_path, file_list, sentence_size, [], None, None) st.session_state.files = []