Merge branch 'master' into branch-0.2.9

This commit is contained in:
wvivi2023 2024-01-17 11:05:10 +08:00
commit c3321a6a34
18 changed files with 148 additions and 132 deletions

View File

@ -42,7 +42,7 @@
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v11` 版本所使用代码已更新至本项目 `v0.2.7` 版本。 🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v13` 版本所使用代码已更新至本项目 `v0.2.9` 版本。
🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.6) 已经更新到 ```0.2.7``` 版本。 🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.6) 已经更新到 ```0.2.7``` 版本。
@ -67,10 +67,10 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
### 1. 环境配置 ### 1. 环境配置
+ 首先,确保你的机器安装了 Python 3.8 - 3.10 + 首先,确保你的机器安装了 Python 3.8 - 3.11
``` ```
$ python --version $ python --version
Python 3.10.12 Python 3.11.7
``` ```
接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖 接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖
```shell ```shell
@ -88,6 +88,7 @@ $ pip install -r requirements_webui.txt
# 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。 # 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
``` ```
请注意LangChain-Chatchat `0.2.x` 系列是针对 Langchain `0.0.x` 系列版本的,如果你使用的是 Langchain `0.1.x` 系列版本,需要降级。
### 2 模型下载 ### 2 模型下载
如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding 模型可以从 [HuggingFace](https://huggingface.co/models) 下载。 如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding 模型可以从 [HuggingFace](https://huggingface.co/models) 下载。
@ -148,7 +149,7 @@ $ python startup.py -a
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9) [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
### 项目交流群 ### 项目交流群
<img src="img/qr_code_83.jpg" alt="二维码" width="300" /> <img src="img/qr_code_86.jpg" alt="二维码" width="300" />
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

View File

@ -55,10 +55,10 @@ The main process analysis from the aspect of document process:
🚩 The training or fine-tuning are not involved in the project, but still, one always can improve performance by do 🚩 The training or fine-tuning are not involved in the project, but still, one always can improve performance by do
these. these.
🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v9 the codes are update 🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v13 the codes are update
to v0.2.5. to v0.2.9.
🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) 🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7)
## Pain Points Addressed ## Pain Points Addressed
@ -98,6 +98,7 @@ $ pip install -r requirements_webui.txt
# 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。 # 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
``` ```
Please note that the LangChain-Chachat `0.2.x` series is for the Langchain `0.0.x` series version. If you are using the Langchain `0.1.x` series version, you need to downgrade.
### Model Download ### Model Download

Binary file not shown.

Before

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 218 KiB

BIN
img/qr_code_85.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 272 KiB

BIN
img/qr_code_86.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 195 KiB

View File

@ -6,7 +6,6 @@ from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
import nltk import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from datetime import datetime from datetime import datetime
import sys
if __name__ == "__main__": if __name__ == "__main__":
@ -50,11 +49,11 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"-i", "-i",
"--increament", "--increment",
action="store_true", action="store_true",
help=(''' help=('''
update vector store for files exist in local folder and not exist in database. update vector store for files exist in local folder and not exist in database.
use this option if you want to create vectors increamentally. use this option if you want to create vectors incrementally.
''' '''
) )
) )
@ -100,7 +99,7 @@ if __name__ == "__main__":
if args.clear_tables: if args.clear_tables:
reset_tables() reset_tables()
print("database talbes reseted") print("database tables reset")
if args.recreate_vs: if args.recreate_vs:
create_tables() create_tables()
@ -110,8 +109,8 @@ if __name__ == "__main__":
import_from_db(args.import_db) import_from_db(args.import_db)
elif args.update_in_db: elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
elif args.increament: elif args.increment:
folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model) folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model)
elif args.prune_db: elif args.prune_db:
prune_db_docs(args.kb_name) prune_db_docs(args.kb_name)
elif args.prune_folder: elif args.prune_folder:

View File

@ -1,6 +1,5 @@
# API requirements # API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/
torch~=2.1.2 torch~=2.1.2
torchvision~=0.16.2 torchvision~=0.16.2
torchaudio~=2.1.2 torchaudio~=2.1.2
@ -8,30 +7,30 @@ xformers==0.0.23.post1
transformers==4.36.2 transformers==4.36.2
sentence_transformers==2.2.2 sentence_transformers==2.2.2
langchain==0.0.352 langchain==0.0.354
langchain-experimental==0.0.47 langchain-experimental==0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.34 fschat==0.2.34
openai~=1.6.0 openai~=1.7.1
fastapi>=0.105 fastapi~=0.108.0
sse_starlette sse_starlette==1.8.2
nltk>=3.8.1 nltk>=3.8.1
uvicorn>=0.24.0.post1 uvicorn>=0.24.0.post1
starlette~=0.27.0 starlette~=0.32.0
unstructured[all-docs]==0.11.0 unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
accelerate==0.24.1 accelerate~=0.24.1
spacy~=3.7.2 spacy~=3.7.2
PyMuPDF~=1.23.8 PyMuPDF~=1.23.8
rapidocr_onnxruntime==1.3.8 rapidocr_onnxruntime==1.3.8
requests>=2.31.0 requests~=2.31.0
pathlib>=1.0.1 pathlib~=1.0.1
pytest>=7.4.3 pytest~=7.4.3
numexpr>=2.8.6 # max version for py38 numexpr~=2.8.6 # max version for py38
strsimpy>=0.2.1 strsimpy~=0.2.1
markdownify>=0.11.6 markdownify~=0.11.6
tiktoken~=0.5.2 tiktoken~=0.5.2
tqdm>=4.66.1 tqdm>=4.66.1
websockets>=12.0 websockets>=12.0
@ -46,15 +45,14 @@ llama-index
# optional document loaders # optional document loaders
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files # rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
beautifulsoup4~=4.12.2 # for .mhtml files beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2 pysrt~=1.1.2
# Online api libs dependencies # Online api libs dependencies
zhipuai>=1.0.7, <=2.0.0 # zhipu zhipuai==1.0.7 # zhipu
dashscope>=1.13.6 # qwen dashscope==1.13.6 # qwen
# volcengine>=1.0.119 # fangzhou # volcengine>=1.0.119 # fangzhou
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store
@ -64,14 +62,14 @@ dashscope>=1.13.6 # qwen
# Agent and Search Tools # Agent and Search Tools
arxiv>=2.0.0 arxiv~=2.1.0
youtube-search>=2.1.2 youtube-search~=2.1.2
duckduckgo-search>=3.9.9 duckduckgo-search~=3.9.9
metaphor-python>=0.1.23 metaphor-python~=0.1.23
# WebUI requirements # WebUI requirements
streamlit~=1.29.0 # do remember to add streamlit to environment variables if you use windows streamlit~=1.29.0
streamlit-option-menu>=0.3.6 streamlit-option-menu>=0.3.6
streamlit-chatbox==1.1.11 streamlit-chatbox==1.1.11
streamlit-modal>=0.1.0 streamlit-modal>=0.1.0

View File

@ -1,6 +1,3 @@
# API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/
torch~=2.1.2 torch~=2.1.2
torchvision~=0.16.2 torchvision~=0.16.2
torchaudio~=2.1.2 torchaudio~=2.1.2
@ -8,52 +5,52 @@ xformers==0.0.23.post1
transformers==4.36.2 transformers==4.36.2
sentence_transformers==2.2.2 sentence_transformers==2.2.2
langchain==0.0.352 langchain==0.0.354
langchain-experimental==0.0.47 langchain-experimental==0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.34 fschat==0.2.34
openai~=1.6.0 openai~=1.7.1
fastapi>=0.105 fastapi~=0.108.0
sse_starlette sse_starlette==1.8.2
nltk>=3.8.1 nltk>=3.8.1
uvicorn>=0.24.0.post1 uvicorn>=0.24.0.post1
starlette~=0.27.0 starlette~=0.32.0
unstructured[all-docs]==0.11.0 unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
accelerate==0.24.1 accelerate~=0.24.1
spacy~=3.7.2 spacy~=3.7.2
PyMuPDF~=1.23.8 PyMuPDF~=1.23.8
rapidocr_onnxruntime~=1.3.8 rapidocr_onnxruntime==1.3.8
requests>=2.31.0 requests~=2.31.0
pathlib>=1.0.1 pathlib~=1.0.1
pytest>=7.4.3 pytest~=7.4.3
numexpr>=2.8.6 # max version for py38 numexpr~=2.8.6 # max version for py38
strsimpy>=0.2.1 strsimpy~=0.2.1
markdownify>=0.11.6 markdownify~=0.11.6
tiktoken~=0.5.2 tiktoken~=0.5.2
tqdm>=4.66.1 tqdm>=4.66.1
websockets>=12.0 websockets>=12.0
numpy~=1.26.2 numpy~=1.24.4
pandas~=2.1.4 pandas~=2.0.3
einops>=0.7.0 einops>=0.7.0
transformers_stream_generator==0.0.4 transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux" vllm==0.2.6; sys_platform == "linux"
httpx[brotli,http2,socks]~=0.25.2 httpx[brotli,http2,socks]==0.25.2
llama-index
# optional document loaders # optional document loaders
rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files # rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
beautifulsoup4~=4.12.2 # for .mhtml files beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2 pysrt~=1.1.2
# Online api libs dependencies # Online api libs dependencies
zhipuai>=1.0.7, <=2.0.0 # zhipu zhipuai==1.0.7 # zhipu
dashscope>=1.13.6 # qwen dashscope==1.13.6 # qwen
# volcengine>=1.0.119 # fangzhou # volcengine>=1.0.119 # fangzhou
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store

View File

@ -1,60 +1,44 @@
# API requirements # API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/ langchain==0.0.354
# torch~=2.1.2
# torchvision~=0.16.2
# torchaudio~=2.1.2
# xformers==0.0.23.post1
# transformers==4.36.2
# sentence_transformers==2.2.2
langchain==0.0.352
langchain-experimental==0.0.47 langchain-experimental==0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.34 fschat==0.2.34
openai~=1.6.0 openai~=1.7.1
fastapi>=0.105 fastapi~=0.108.0
sse_starlette sse_starlette==1.8.2
nltk>=3.8.1 nltk>=3.8.1
uvicorn>=0.24.0.post1 uvicorn>=0.24.0.post1
starlette~=0.27.0 starlette~=0.32.0
unstructured[docx,csv]==0.11.0 # add pdf if need unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus faiss-cpu~=1.7.4
# accelerate==0.24.1 requests~=2.31.0
# spacy~=3.7.2 pathlib~=1.0.1
# PyMuPDF~=1.23.8 pytest~=7.4.3
# rapidocr_onnxruntime~=1.3.8 numexpr~=2.8.6 # max version for py38
requests>=2.31.0 strsimpy~=0.2.1
pathlib>=1.0.1 markdownify~=0.11.6
pytest>=7.4.3 tiktoken~=0.5.2
numexpr>=2.8.6 # max version for py38
strsimpy>=0.2.1
markdownify>=0.11.6
# tiktoken~=0.5.2
tqdm>=4.66.1 tqdm>=4.66.1
websockets>=12.0 websockets>=12.0
numpy~=1.26.2 numpy~=1.24.4
pandas~=2.1.4 pandas~=2.0.3
# einops>=0.7.0 einops>=0.7.0
# transformers_stream_generator==0.0.4 transformers_stream_generator==0.0.4
# vllm==0.2.6; sys_platform == "linux" vllm==0.2.6; sys_platform == "linux"
httpx[brotli,http2,socks]~=0.25.2 httpx[brotli,http2,socks]==0.25.2
requests
pathlib
pytest
# optional document loaders
rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2
# Online api libs dependencies # Online api libs dependencies
zhipuai>=1.0.7, <=2.0.0 # zhipu zhipuai==1.0.7
dashscope>=1.13.6 # qwen dashscope==1.13.6
# volcengine>=1.0.119 # fangzhou # volcengine>=1.0.119
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store
# pymilvus>=2.3.4 # pymilvus>=2.3.4
@ -63,16 +47,17 @@ dashscope>=1.13.6 # qwen
# Agent and Search Tools # Agent and Search Tools
arxiv>=2.0.0 arxiv~=2.1.0
youtube-search>=2.1.2 youtube-search~=2.1.2
duckduckgo-search>=3.9.9 duckduckgo-search~=3.9.9
metaphor-python>=0.1.23 metaphor-python~=0.1.23
# WebUI requirements # WebUI requirements
streamlit~=1.29.0 # do remember to add streamlit to environment variables if you use windows streamlit>=1.29.0
streamlit-option-menu>=0.3.6 streamlit-option-menu>=0.3.6
streamlit-chatbox==1.1.11 streamlit-antd-components>=0.3.0
streamlit-chatbox>=1.1.11
streamlit-modal>=0.1.0 streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]>=0.25.2 httpx[brotli,http2,socks]>=0.25.2

View File

@ -1,10 +1,12 @@
# WebUI requirements # WebUI requirements
streamlit~=1.29.0 # do remember to add streamlit to environment variables if you use windows streamlit>=1.29.0
streamlit-option-menu>=0.3.6 streamlit-option-menu>=0.3.6
streamlit-chatbox==1.1.11 streamlit-antd-components>=0.3.0
streamlit-chatbox>=1.1.11
streamlit-modal>=0.1.0 streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]>=0.25.2 httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0 watchdog>=3.0.0
docx2txt docx2txt

View File

@ -53,7 +53,7 @@ class MilvusKBService(KBService):
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model), self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"), connection_args=kbs_config.get("milvus"),
index_params=kbs_config.ge("milvus_kwargs")["index_params"], index_params=kbs_config.get("milvus_kwargs")["index_params"],
search_params=kbs_config.get("milvus_kwargs")["search_params"] search_params=kbs_config.get("milvus_kwargs")["search_params"]
) )

View File

@ -84,7 +84,7 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
def folder2db( def folder2db(
kb_names: List[str], kb_names: List[str],
mode: Literal["recreate_vs", "update_in_db", "increament"], mode: Literal["recreate_vs", "update_in_db", "increment"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
chunk_size: int = CHUNK_SIZE, chunk_size: int = CHUNK_SIZE,
@ -97,7 +97,7 @@ def folder2db(
recreate_vs: recreate all vector store and fill info to database using existed files in local folder recreate_vs: recreate all vector store and fill info to database using existed files in local folder
fill_info_only(disabled): do not create vector store, fill info to db using existed files only fill_info_only(disabled): do not create vector store, fill info to db using existed files only
update_in_db: update vector store and database info using local files that existed in database only update_in_db: update vector store and database info using local files that existed in database only
increament: create vector store and database info for local files that not existed in database only increment: create vector store and database info for local files that not existed in database only
""" """
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]): def files2vs(kb_name: str, kb_files: List[KnowledgeFile]):
@ -142,7 +142,7 @@ def folder2db(
files2vs(kb_name, kb_files) files2vs(kb_name, kb_files)
kb.save_vector_store() kb.save_vector_store()
# 对比本地目录与数据库中的文件列表,进行增量向量化 # 对比本地目录与数据库中的文件列表,进行增量向量化
elif mode == "increament": elif mode == "increment":
db_files = kb.list_files() db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name) folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files)) files = list(set(folder_files) - set(db_files))
@ -150,7 +150,7 @@ def folder2db(
files2vs(kb_name, kb_files) files2vs(kb_name, kb_files)
kb.save_vector_store() kb.save_vector_store()
else: else:
print(f"unspported migrate mode: {mode}") print(f"unsupported migrate mode: {mode}")
def prune_db_docs(kb_names: List[str]): def prune_db_docs(kb_names: List[str]):

View File

@ -272,7 +272,7 @@ def make_text_splitter(
print(e) print(e)
text_splitter_module = importlib.import_module('langchain.text_splitter') text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50) text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter return text_splitter

View File

@ -73,7 +73,7 @@ class MiniMaxWorker(ApiModelWorker):
with response as r: with response as r:
text = "" text = ""
for e in r.iter_text(): for e in r.iter_text():
if not e.startswith("data: "): # 真是优秀的返回 if not e.startswith("data: "):
data = { data = {
"error_code": 500, "error_code": 500,
"text": f"minimax返回错误的结果{e}", "text": f"minimax返回错误的结果{e}",

View File

@ -10,10 +10,11 @@ from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI, AzureOpenAI, Anthropic from langchain.llms import OpenAI
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
import logging import logging
import torch
async def wrap_done(fn: Awaitable, event: asyncio.Event): async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -58,6 +59,7 @@ def get_ChatOpenAI(
) )
return model return model
def get_OpenAI( def get_OpenAI(
model_name: str, model_name: str,
temperature: float, temperature: float,
@ -151,7 +153,6 @@ class ChatMessage(BaseModel):
def torch_gc(): def torch_gc():
try: try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
# with torch.cuda.device(DEVICE): # with torch.cuda.device(DEVICE):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -499,29 +500,58 @@ def set_httpx_config(
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch # 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def is_mps_available():
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def is_cuda_available():
return torch.cuda.is_available()
def detect_device() -> Literal["cuda", "mps", "cpu"]: def detect_device() -> Literal["cuda", "mps", "cpu"]:
try: try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
return "cuda" return "cuda"
if torch.backends.mps.is_available(): if is_mps_available():
return "mps" return "mps"
except: except:
pass pass
return "cpu" return "cpu"
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
device = device or LLM_DEVICE device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu"]: if device not in ["cuda", "mps", "cpu", "xpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device() device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
logging.warning("mps is not available, fallback to cuda")
return "cuda"
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu", "xpu"]:
return detect_device()
return device return device
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
device = device or EMBEDDING_DEVICE device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu"]: if device not in ["cuda", "mps", "cpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device() device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
logging.warning("mps is not available, fallback to cuda")
return "cuda"
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu"]:
return detect_device()
return device return device

View File

@ -14,8 +14,7 @@ from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder
# setup test knowledge base # setup test knowledge base
kb_name = "test_kb_for_migrate" kb_name = "test_kb_for_migrate"
test_files = { test_files = {
"faq.md": str(root_path / "docs" / "faq.md"), "readme.md": str(root_path / "readme.md"),
"install.md": str(root_path / "docs" / "install.md"),
} }
@ -56,13 +55,13 @@ def test_recreate_vs():
assert doc.metadata["source"] == name assert doc.metadata["source"] == name
def test_increament(): def test_increment():
kb = KBServiceFactory.get_service_by_name(kb_name) kb = KBServiceFactory.get_service_by_name(kb_name)
kb.clear_vs() kb.clear_vs()
assert kb.list_files() == [] assert kb.list_files() == []
assert kb.list_docs() == [] assert kb.list_docs() == []
folder2db([kb_name], "increament") folder2db([kb_name], "increment")
files = kb.list_files() files = kb.list_files()
print(files) print(files)

View File

@ -136,6 +136,8 @@ class ApiRequest:
try: try:
if chunk.startswith("data: "): if chunk.startswith("data: "):
data = json.loads(chunk[6:-2]) data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else: else:
data = json.loads(chunk) data = json.loads(chunk)
yield data yield data
@ -170,6 +172,8 @@ class ApiRequest:
try: try:
if chunk.startswith("data: "): if chunk.startswith("data: "):
data = json.loads(chunk[6:-2]) data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else: else:
data = json.loads(chunk) data = json.loads(chunk)
yield data yield data