diff --git a/.gitignore b/.gitignore
index fe1385d..73ebd21 100644
--- a/.gitignore
+++ b/.gitignore
@@ -175,3 +175,6 @@ embedding/*
pyrightconfig.json
loader/tmp_files
flagged/*
+ptuning-v2/*.json
+ptuning-v2/*.bin
+
diff --git a/README.md b/README.md
index 7369db3..5ee0ade 100644
--- a/README.md
+++ b/README.md
@@ -255,7 +255,7 @@ Web UI 可以实现如下功能:
- [x] VUE 前端
## 项目交流群
-
+
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
diff --git a/README_en.md b/README_en.md
index 79edb91..249962d 100644
--- a/README_en.md
+++ b/README_en.md
@@ -87,6 +87,10 @@ $ conda create -p /your_path/env_name python=3.8
# Activate the environment
$ source activate /your_path/env_name
+# or, do not specify an env path, note that /your_path/env_name is to be replaced with env_name below
+$ conda create -n env_name python=3.8
+$ conda activate env_name # Activate the environment
+
# Deactivate the environment
$ source deactivate /your_path/env_name
diff --git a/api.py b/api.py
index d6e48d6..452313e 100644
--- a/api.py
+++ b/api.py
@@ -9,8 +9,9 @@ import asyncio
import nltk
import pydantic
import uvicorn
-from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
+from fastapi import Body, Request, FastAPI, File, Form, Query, UploadFile, WebSocket
from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing_extensions import Annotated
from starlette.responses import RedirectResponse
@@ -55,7 +56,7 @@ class ListDocsResponse(BaseResponse):
class ChatMessage(BaseModel):
question: str = pydantic.Field(..., description="Question text")
response: str = pydantic.Field(..., description="Response text")
- history: List[List[str]] = pydantic.Field(..., description="History text")
+ history: List[List[Optional[str]]] = pydantic.Field(..., description="History text")
source_documents: List[str] = pydantic.Field(
..., description="List of source documents and their scores"
)
@@ -303,7 +304,8 @@ async def update_doc(
async def local_doc_chat(
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
question: str = Body(..., description="Question", example="工伤保险是什么?"),
- history: List[List[str]] = Body(
+ streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
+ history: List[List[Optional[str]]] = Body(
[],
description="History of previous questions and answers",
example=[
@@ -324,27 +326,39 @@ async def local_doc_chat(
source_documents=[],
)
else:
- for resp, history in local_doc_qa.get_knowledge_based_answer(
- query=question, vs_path=vs_path, chat_history=history, streaming=True
- ):
- pass
- source_documents = [
- f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
- f"""相关度:{doc.metadata['score']}\n\n"""
- for inum, doc in enumerate(resp["source_documents"])
- ]
+ if (streaming):
+ def generate_answer ():
+ last_print_len = 0
+ for resp, next_history in local_doc_qa.get_knowledge_based_answer(
+ query=question, vs_path=vs_path, chat_history=history, streaming=True
+ ):
+ yield resp["result"][last_print_len:]
+ last_print_len=len(resp["result"])
- return ChatMessage(
- question=question,
- response=resp["result"],
- history=history,
- source_documents=source_documents,
- )
+ return StreamingResponse(generate_answer())
+ else:
+ for resp, history in local_doc_qa.get_knowledge_based_answer(
+ query=question, vs_path=vs_path, chat_history=history, streaming=True
+ ):
+ pass
+
+ source_documents = [
+ f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
+ f"""相关度:{doc.metadata['score']}\n\n"""
+ for inum, doc in enumerate(resp["source_documents"])
+ ]
+
+ return ChatMessage(
+ question=question,
+ response=resp["result"],
+ history=history,
+ source_documents=source_documents,
+ )
async def bing_search_chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"),
- history: Optional[List[List[str]]] = Body(
+ history: Optional[List[List[Optional[str]]]] = Body(
[],
description="History of previous questions and answers",
example=[
@@ -374,7 +388,8 @@ async def bing_search_chat(
async def chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"),
- history: List[List[str]] = Body(
+ streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
+ history: List[List[Optional[str]]] = Body(
[],
description="History of previous questions and answers",
example=[
@@ -385,6 +400,30 @@ async def chat(
],
),
):
+ if (streaming):
+ def generate_answer ():
+ last_print_len = 0
+ answer_result_stream_result = local_doc_qa.llm_model_chain(
+ {"prompt": question, "history": history, "streaming": True})
+ for answer_result in answer_result_stream_result['answer_result_stream']:
+ yield answer_result.llm_output["answer"][last_print_len:]
+ last_print_len = len(answer_result.llm_output["answer"])
+
+ return StreamingResponse(generate_answer())
+ else:
+ answer_result_stream_result = local_doc_qa.llm_model_chain(
+ {"prompt": question, "history": history, "streaming": True})
+ for answer_result in answer_result_stream_result['answer_result_stream']:
+ resp = answer_result.llm_output["answer"]
+ history = answer_result.history
+ pass
+
+ return ChatMessage(
+ question=question,
+ response=resp,
+ history=history,
+ source_documents=[],
+ )
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})
@@ -392,7 +431,6 @@ async def chat(
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
-
return ChatMessage(
question=question,
response=resp,
@@ -545,7 +583,7 @@ if __name__ == "__main__":
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
- args = None
+
args = parser.parse_args()
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py
index 2f7d8da..6085dfc 100644
--- a/chains/local_doc_qa.py
+++ b/chains/local_doc_qa.py
@@ -8,7 +8,6 @@ from typing import List
from utils import torch_gc
from tqdm import tqdm
from pypinyin import lazy_pinyin
-from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
from models.base import (BaseAnswer,
AnswerResult)
from models.loader.args import parser
@@ -59,6 +58,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE):
+
if filepath.lower().endswith(".md"):
loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load()
@@ -67,10 +67,14 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter)
elif filepath.lower().endswith(".pdf"):
+ # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
+ from loader import UnstructuredPaddlePDFLoader
loader = UnstructuredPaddlePDFLoader(filepath)
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter)
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
+ # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
+ from loader import UnstructuredPaddleImageLoader
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter)
diff --git a/configs/model_config.py b/configs/model_config.py
index c6fb964..a95d5bf 100644
--- a/configs/model_config.py
+++ b/configs/model_config.py
@@ -31,7 +31,7 @@ EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backe
# llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例
# 在以下字典中修改属性值,以指定本地 LLM 模型存储位置
# 如将 "chatglm-6b" 的 "local_model_path" 由 None 修改为 "User/Downloads/chatglm-6b"
-# 此处请写绝对路径
+# 此处请写绝对路径,且路径中必须包含repo-id的模型名称,因为FastChat是以模型名匹配的
llm_model_dict = {
"chatglm-6b-int4-qe": {
"name": "chatglm-6b-int4-qe",
@@ -154,6 +154,15 @@ llm_model_dict = {
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
+ },
+ # 通过 fastchat 调用的模型请参考如下格式
+ "fastchat-chatglm-6b-int4": {
+ "name": "chatglm-6b-int4", # "name"修改为fastchat服务中的"model_name"
+ "pretrained_model_name": "chatglm-6b-int4",
+ "local_model_path": None,
+ "provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
+ "api_base_url": "http://localhost:8001/v1", # "name"修改为fastchat服务中的"api_base_url"
+ "api_key": "EMPTY"
},
"fastchat-chatglm2-6b": {
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
@@ -175,11 +184,13 @@ llm_model_dict = {
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
+ # 如果依然报urllib3.exceptions.MaxRetryError: HTTPSConnectionPool,则将https改为http
+ # 参考https://zhuanlan.zhihu.com/p/350015032
# 如果报出:raise NewConnectionError(
# urllib3.exceptions.NewConnectionError: :
# Failed to establish a new connection: [WinError 10060]
- # 则是因为内地和香港的IP都被OPENAI封了,需要挂切换为日本、新加坡等地
+ # 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地
"openai-chatgpt-3.5": {
"name": "gpt-3.5-turbo",
"pretrained_model_name": "gpt-3.5-turbo",
@@ -200,20 +211,17 @@ BF16 = False
# 本地lora存放的位置
LORA_DIR = "loras/"
-# LLM lora path,默认为空,如果有请直接指定文件夹路径
-LLM_LORA_PATH = ""
-USE_LORA = True if LLM_LORA_PATH else False
+# LORA的名称,如有请指定为列表
+
+LORA_NAME = ""
+USE_LORA = True if LORA_NAME else False
# LLM streaming reponse
STREAMING = True
# Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False
-PTUNING_DIR='./ptuing-v2'
-<<<<<<< HEAD
-=======
-
->>>>>>> f68d347c25b4bdd07f293c65a6e44a673a11f614
+PTUNING_DIR='./ptuning-v2'
# LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
diff --git a/docs/FAQ.md b/docs/FAQ.md
index 2162671..ccc0f25 100644
--- a/docs/FAQ.md
+++ b/docs/FAQ.md
@@ -180,9 +180,9 @@ Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a
---
-Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出`RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`
+Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`
-疑为chatglm的quantization的问题或torch版本差异问题,针对已经变为Parameter的torch.zeros矩阵也执行Parameter操作,从而抛出`RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在chatglm-项目的原始文件中的quantization.py文件374行改为:
+疑为chatglm的quantization的问题或torch版本差异问题,针对已经变为Parameter的torch.zeros矩阵也执行Parameter操作,从而抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在chatglm-项目的原始文件中的quantization.py文件374行改为:
```
try:
@@ -191,6 +191,8 @@ Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出`RuntimeError: Only Tensors of
pass
```
- 注:虽然模型可以顺利加载但在cpu上仍存在推理失败的可能:即针对每个问题,模型一直输出gugugugu。
+ 如果上述方式不起作用,则在.cache/hugggingface/modules/目录下针对chatglm项目的原始文件中的quantization.py文件执行上述操作,若软链接不止一个,按照错误提示选择正确的路径。
- 因此,最好不要试图用cpu加载量化模型,原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。
\ No newline at end of file
+注:虽然模型可以顺利加载但在cpu上仍存在推理失败的可能:即针对每个问题,模型一直输出gugugugu。
+
+ 因此,最好不要试图用cpu加载量化模型,原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。
diff --git a/docs/INSTALL.md b/docs/INSTALL.md
index 6602973..cd3c51d 100644
--- a/docs/INSTALL.md
+++ b/docs/INSTALL.md
@@ -12,6 +12,12 @@ $ conda create -p /your_path/env_name python=3.8
# 激活环境
$ source activate /your_path/env_name
+
+# 或,conda安装,不指定路径, 注意以下,都将/your_path/env_name替换为env_name
+$ conda create -n env_name python=3.8
+$ conda activate env_name # Activate the environment
+
+# 更新py库
$ pip3 install --upgrade pip
# 关闭环境
@@ -49,7 +55,7 @@ $ python loader/image_loader.py
## llama-cpp模型调用的说明
-1. 首先从huggingface hub中下载对应的模型,如https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/的[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。
+1. 首先从huggingface hub中下载对应的模型,如 [https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/) 的 [ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。
2. 将下载的模型重命名。通过huggingface_hub下载的模型会被重命名为随机序列,因此需要重命名为原始文件名,如[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)。
3. 基于下载模型的ggml的加载时间,推测对应的llama-cpp版本,下载对应的llama-cpp-python库的wheel文件,实测[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)与llama-cpp-python库兼容,然后手动安装wheel文件。
4. 将下载的模型信息写入configs/model_config.py文件里 `llm_model_dict`中,注意保证参数的兼容性,一些参数组合可能会报错.
diff --git a/img/qr_code_43.jpg b/img/qr_code_43.jpg
deleted file mode 100644
index 8bbdcbd..0000000
Binary files a/img/qr_code_43.jpg and /dev/null differ
diff --git a/img/qr_code_45.jpg b/img/qr_code_45.jpg
new file mode 100644
index 0000000..ad253c8
Binary files /dev/null and b/img/qr_code_45.jpg differ
diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py
index 81878ce..0d19ee6 100644
--- a/models/chatglm_llm.py
+++ b/models/chatglm_llm.py
@@ -2,14 +2,14 @@ from abc import ABC
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator
from langchain.callbacks.manager import CallbackManagerForChainRun
-from transformers.generation.logits_process import LogitsProcessor
-from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
+# from transformers.generation.logits_process import LogitsProcessor
+# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
-import torch
+# import torch
import transformers
diff --git a/models/fastchat_openai_llm.py b/models/fastchat_openai_llm.py
index d0972f7..217910a 100644
--- a/models/fastchat_openai_llm.py
+++ b/models/fastchat_openai_llm.py
@@ -245,7 +245,7 @@ if __name__ == "__main__":
chain = FastChatOpenAILLMChain()
- chain.set_api_key("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")
+ chain.set_api_key("EMPTY")
# chain.set_api_base_url("https://api.openai.com/v1")
# chain.call_model_name("gpt-3.5-turbo")
diff --git a/models/loader/args.py b/models/loader/args.py
index 59668b0..02f4926 100644
--- a/models/loader/args.py
+++ b/models/loader/args.py
@@ -42,9 +42,10 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
'model to add the ` '
'--no-remote-model`')
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
-parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
+parser.add_argument("--use-lora",type=bool,default=USE_LORA,help="use lora or not")
+parser.add_argument('--lora', type=str, default=LORA_NAME,help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
-parser.add_argument('--use-ptuning-v2',type=str,default=False,help="whether use ptuning-v2 checkpoint")
+parser.add_argument('--use-ptuning-v2',action=USE_PTUNING_V2,help="whether use ptuning-v2 checkpoint")
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
# Accelerate/transformers
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
diff --git a/models/loader/loader.py b/models/loader/loader.py
index ae4b408..c5c80fd 100644
--- a/models/loader/loader.py
+++ b/models/loader/loader.py
@@ -148,7 +148,7 @@ class LoaderCheckPoint:
trust_remote_code=True).half()
# 可传入device_map自定义每张卡的部署情况
if self.device_map is None:
- if 'chatglm' in self.model_name.lower():
+ if 'chatglm' in self.model_name.lower() and not "chatglm2" in self.model_name.lower():
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
elif 'moss' in self.model_name.lower():
self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
@@ -164,13 +164,6 @@ class LoaderCheckPoint:
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
max_memory=max_memory,
no_split_module_classes=model._no_split_modules)
- # 对于chaglm和moss意外的模型应使用自动指定,而非调用chatglm的配置方式
- # 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
- # 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
- # 实测在bloom模型上如此
- # self.device_map = infer_auto_device_map(model,
- # dtype=torch.int8,
- # no_split_module_classes=model._no_split_modules)
model = dispatch_model(model, device_map=self.device_map)
else:
@@ -184,6 +177,13 @@ class LoaderCheckPoint:
)
elif self.is_llamacpp:
+
+ # 要调用llama-cpp模型,如vicuma-13b量化模型需要安装llama-cpp-python库
+ # but!!! 实测pip install 不好使,需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载
+ # 而且注意不同时期的ggml格式并不!兼!容!!!因此需要安装的llama-cpp-python版本也不一致,需要手动测试才能确定
+ # 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容
+ # 不过!!!本项目模型加载的方式控制的比较严格,与llama-cpp-python的兼容性较差,很多参数设定不能使用,
+ # 建议如非必要还是不要使用llama-cpp
try:
from llama_cpp import Llama
@@ -448,7 +448,7 @@ class LoaderCheckPoint:
if self.use_ptuning_v2:
try:
- prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
+ prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
prefix_encoder_config = json.loads(prefix_encoder_file.read())
prefix_encoder_file.close()
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
@@ -464,13 +464,14 @@ class LoaderCheckPoint:
if self.use_ptuning_v2:
try:
- prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
+ prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float()
+ print("加载ptuning检查点成功!")
except Exception as e:
print(e)
print("加载PrefixEncoder模型参数失败")
diff --git a/requirements.txt b/requirements.txt
index 7f97f67..4c981b2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -23,13 +23,6 @@ openai
#accelerate~=0.18.0
#peft~=0.3.0
#bitsandbytes; platform_system != "Windows"
-
-# 要调用llama-cpp模型,如vicuma-13b量化模型需要安装llama-cpp-python库
-# but!!! 实测pip install 不好使,需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载
-# 而且注意不同时期的ggml格式并不!兼!容!!!因此需要安装的llama-cpp-python版本也不一致,需要手动测试才能确定
-# 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容
-# 不过!!!本项目模型加载的方式控制的比较严格,与llama-cpp-python的兼容性较差,很多参数设定不能使用,
-# 建议如非必要还是不要使用llama-cpp
torch~=2.0.0
pydantic~=1.10.7
starlette~=0.26.1
diff --git a/webui.py b/webui.py
index b7ff9fd..c7e7880 100644
--- a/webui.py
+++ b/webui.py
@@ -104,6 +104,7 @@ def init_model():
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
+ llm_model_ins.history_len = LLM_HISTORY_LEN
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins)
answer_result_stream_result = local_doc_qa.llm_model_chain(
diff --git a/webui_st.py b/webui_st.py
index 541851f..8309210 100644
--- a/webui_st.py
+++ b/webui_st.py
@@ -1,6 +1,8 @@
import streamlit as st
from streamlit_chatbox import st_chatbox
import tempfile
+from pathlib import Path
+
###### 从webui借用的代码 #####
###### 做了少量修改 #####
import os
@@ -23,6 +25,7 @@ def get_vs_list():
if not os.path.exists(KB_ROOT_PATH):
return lst_default
lst = os.listdir(KB_ROOT_PATH)
+ lst = [x for x in lst if os.path.isdir(os.path.join(KB_ROOT_PATH, x))]
if not lst:
return lst_default
lst.sort()
@@ -100,23 +103,23 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
- vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
- filelist = []
- if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
- os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
+ vs_path = Path(KB_ROOT_PATH) / vs_id / "vector_store"
+ con_path = Path(KB_ROOT_PATH) / vs_id / "content"
+ con_path.mkdir(parents=True, exist_ok=True)
+
qa = st.session_state.local_doc_qa
if qa.llm_model_chain and qa.embeddings:
+ filelist = []
if isinstance(files, list):
for file in files:
filename = os.path.split(file.name)[-1]
- shutil.move(file.name, os.path.join(
- KB_ROOT_PATH, vs_id, "content", filename))
- filelist.append(os.path.join(
- KB_ROOT_PATH, vs_id, "content", filename))
+ target = con_path / filename
+ shutil.move(file.name, target)
+ filelist.append(str(target))
vs_path, loaded_files = qa.init_knowledge_vector_store(
- filelist, vs_path, sentence_size)
+ filelist, str(vs_path), sentence_size)
else:
- vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
+ vs_path, loaded_files = qa.one_knowledge_add(str(vs_path), files, one_conent, one_content_segmentation,
sentence_size)
if len(loaded_files):
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
@@ -283,20 +286,24 @@ with st.sidebar:
def on_new_kb():
name = st.session_state.kb_name
- if name in vs_list:
- st.error(f'名为“{name}”的知识库已存在。')
+ if not name:
+ st.sidebar.error(f'新建知识库名称不能为空!')
+ elif name in vs_list:
+ st.sidebar.error(f'名为“{name}”的知识库已存在。')
else:
- vs_list.append(name)
st.session_state.vs_path = name
+ st.session_state.kb_name = ''
+ new_kb_dir = os.path.join(KB_ROOT_PATH, name)
+ if not os.path.exists(new_kb_dir):
+ os.makedirs(new_kb_dir)
+ st.sidebar.success(f'名为“{name}”的知识库创建成功,您可以开始添加文件。')
def on_vs_change():
chat_box.robot_say(f'已加载知识库: {st.session_state.vs_path}')
with st.expander('知识库配置', True):
cols = st.columns([12, 10])
kb_name = cols[0].text_input(
- '新知识库名称', placeholder='新知识库名称', label_visibility='collapsed')
- if 'kb_name' not in st.session_state:
- st.session_state.kb_name = kb_name
+ '新知识库名称', placeholder='新知识库名称', label_visibility='collapsed', key='kb_name')
cols[1].button('新建知识库', on_click=on_new_kb)
index = 0
try:
@@ -317,7 +324,8 @@ with st.sidebar:
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
files = st.file_uploader('上传知识文件',
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
- accept_multiple_files=True)
+ accept_multiple_files=True,
+ )
if st.button('添加文件到知识库'):
temp_dir = tempfile.mkdtemp()
file_list = []
@@ -326,8 +334,8 @@ with st.sidebar:
with open(file, 'wb') as fp:
fp.write(f.getvalue())
file_list.append(TempFile(file))
- _, _, history = get_vector_store(
- vs_path, file_list, sentence_size, [], None, None)
+ _, _, history = get_vector_store(
+ vs_path, file_list, sentence_size, [], None, None)
st.session_state.files = []