diff --git a/README.md b/README.md
index 39c230f..1e9602f 100644
--- a/README.md
+++ b/README.md
@@ -226,6 +226,10 @@ Web UI 可以实现如下功能:
- [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
+ - [x] [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1)
+ - [x] [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b)
+ - [x] [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
+ - [x] [lmsys/vicuna-13b-delta-v1.1](https://huggingface.co/lmsys/vicuna-13b-delta-v1.1)
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
- [x] 增加更多 Embedding 模型支持
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
@@ -251,7 +255,7 @@ Web UI 可以实现如下功能:
- [x] VUE 前端
## 项目交流群
-
+
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
diff --git a/api.py b/api.py
index 70dccc8..d6e48d6 100644
--- a/api.py
+++ b/api.py
@@ -1,3 +1,4 @@
+#encoding:utf-8
import argparse
import json
import os
diff --git a/chains/modules/embeddings.py b/chains/modules/embeddings.py
deleted file mode 100644
index 3abeddf..0000000
--- a/chains/modules/embeddings.py
+++ /dev/null
@@ -1,34 +0,0 @@
-from langchain.embeddings.huggingface import HuggingFaceEmbeddings
-
-from typing import Any, List
-
-
-class MyEmbeddings(HuggingFaceEmbeddings):
- def __init__(self, **kwargs: Any):
- super().__init__(**kwargs)
-
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
- """Compute doc embeddings using a HuggingFace transformer model.
-
- Args:
- texts: The list of texts to embed.
-
- Returns:
- List of embeddings, one for each text.
- """
- texts = list(map(lambda x: x.replace("\n", " "), texts))
- embeddings = self.client.encode(texts, normalize_embeddings=True)
- return embeddings.tolist()
-
- def embed_query(self, text: str) -> List[float]:
- """Compute query embeddings using a HuggingFace transformer model.
-
- Args:
- text: The text to embed.
-
- Returns:
- Embeddings for the text.
- """
- text = text.replace("\n", " ")
- embedding = self.client.encode(text, normalize_embeddings=True)
- return embedding.tolist()
diff --git a/chains/modules/vectorstores.py b/chains/modules/vectorstores.py
deleted file mode 100644
index da89775..0000000
--- a/chains/modules/vectorstores.py
+++ /dev/null
@@ -1,121 +0,0 @@
-from langchain.vectorstores import FAISS
-from typing import Any, Callable, List, Optional, Tuple, Dict
-from langchain.docstore.document import Document
-from langchain.docstore.base import Docstore
-
-from langchain.vectorstores.utils import maximal_marginal_relevance
-from langchain.embeddings.base import Embeddings
-import uuid
-from langchain.docstore.in_memory import InMemoryDocstore
-
-import numpy as np
-
-def dependable_faiss_import() -> Any:
- """Import faiss if available, otherwise raise error."""
- try:
- import faiss
- except ImportError:
- raise ValueError(
- "Could not import faiss python package. "
- "Please install it with `pip install faiss` "
- "or `pip install faiss-cpu` (depending on Python version)."
- )
- return faiss
-
-class FAISSVS(FAISS):
- def __init__(self,
- embedding_function: Callable[..., Any],
- index: Any,
- docstore: Docstore,
- index_to_docstore_id: Dict[int, str]):
- super().__init__(embedding_function, index, docstore, index_to_docstore_id)
-
- def max_marginal_relevance_search_by_vector(
- self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
- ) -> List[Tuple[Document, float]]:
- """Return docs selected using the maximal marginal relevance.
-
- Maximal marginal relevance optimizes for similarity to query AND diversity
- among selected documents.
-
- Args:
- embedding: Embedding to look up documents similar to.
- k: Number of Documents to return. Defaults to 4.
- fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-
- Returns:
- List of Documents with scores selected by maximal marginal relevance.
- """
- scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
- # -1 happens when not enough docs are returned.
- embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
- mmr_selected = maximal_marginal_relevance(
- np.array([embedding], dtype=np.float32), embeddings, k=k
- )
- selected_indices = [indices[0][i] for i in mmr_selected]
- selected_scores = [scores[0][i] for i in mmr_selected]
- docs = []
- for i, score in zip(selected_indices, selected_scores):
- if i == -1:
- # This happens when not enough docs are returned.
- continue
- _id = self.index_to_docstore_id[i]
- doc = self.docstore.search(_id)
- if not isinstance(doc, Document):
- raise ValueError(f"Could not find document for id {_id}, got {doc}")
- docs.append((doc, score))
- return docs
-
- def max_marginal_relevance_search(
- self,
- query: str,
- k: int = 4,
- fetch_k: int = 20,
- **kwargs: Any,
- ) -> List[Tuple[Document, float]]:
- """Return docs selected using the maximal marginal relevance.
-
- Maximal marginal relevance optimizes for similarity to query AND diversity
- among selected documents.
-
- Args:
- query: Text to look up documents similar to.
- k: Number of Documents to return. Defaults to 4.
- fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-
- Returns:
- List of Documents with scores selected by maximal marginal relevance.
- """
- embedding = self.embedding_function(query)
- docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
- return docs
-
- @classmethod
- def __from(
- cls,
- texts: List[str],
- embeddings: List[List[float]],
- embedding: Embeddings,
- metadatas: Optional[List[dict]] = None,
- **kwargs: Any,
- ) -> FAISS:
- faiss = dependable_faiss_import()
- index = faiss.IndexFlatIP(len(embeddings[0]))
- index.add(np.array(embeddings, dtype=np.float32))
-
- # # my code, for speeding up search
- # quantizer = faiss.IndexFlatL2(len(embeddings[0]))
- # index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
- # index.train(np.array(embeddings, dtype=np.float32))
- # index.add(np.array(embeddings, dtype=np.float32))
-
- documents = []
- for i, text in enumerate(texts):
- metadata = metadatas[i] if metadatas else {}
- documents.append(Document(page_content=text, metadata=metadata))
- index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
- docstore = InMemoryDocstore(
- {index_to_id[i]: doc for i, doc in enumerate(documents)}
- )
- return cls(embedding.embed_query, index, docstore, index_to_id)
-
diff --git a/configs/model_config.py b/configs/model_config.py
index 0b3c7fb..d95e37e 100644
--- a/configs/model_config.py
+++ b/configs/model_config.py
@@ -246,8 +246,8 @@ LLM_HISTORY_LEN = 3
# 知识库检索时返回的匹配内容条数
VECTOR_SEARCH_TOP_K = 5
-# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
-VECTOR_SEARCH_SCORE_THRESHOLD = 390
+# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,建议设置为500左右,经测试设置为小于500时,匹配结果更精准
+VECTOR_SEARCH_SCORE_THRESHOLD = 500
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
diff --git a/img/qr_code_42.jpg b/img/qr_code_42.jpg
deleted file mode 100644
index 146b873..0000000
Binary files a/img/qr_code_42.jpg and /dev/null differ
diff --git a/img/qr_code_43.jpg b/img/qr_code_43.jpg
new file mode 100644
index 0000000..8bbdcbd
Binary files /dev/null and b/img/qr_code_43.jpg differ
diff --git a/img/qr_code_44.jpg b/img/qr_code_44.jpg
new file mode 100644
index 0000000..58ccd6c
Binary files /dev/null and b/img/qr_code_44.jpg differ
diff --git a/models/base/base.py b/models/base/base.py
index 1b65b21..c6674c9 100644
--- a/models/base/base.py
+++ b/models/base/base.py
@@ -6,6 +6,7 @@ from queue import Queue
from threading import Thread
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.loader import LoaderCheckPoint
+from pydantic import BaseModel
import torch
import transformers
@@ -23,13 +24,12 @@ class ListenerToken:
self._scores = _scores
-class AnswerResult:
+class AnswerResult(BaseModel):
"""
消息实体
"""
history: List[List[str]] = []
llm_output: Optional[dict] = None
- listenerToken: ListenerToken = None
class AnswerResultStream:
@@ -167,8 +167,6 @@ class BaseAnswer(ABC):
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
for answerResult in generator:
- if answerResult.listenerToken:
- output = answerResult.listenerToken.input_ids
yield answerResult
@abstractmethod
diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py
index c45cf3b..81878ce 100644
--- a/models/chatglm_llm.py
+++ b/models/chatglm_llm.py
@@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": stream_resp}
- if listenerQueue.listenerQueue.__len__() > 0:
- answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
self.checkPoint.clear_torch_cache()
else:
@@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
- if listenerQueue.listenerQueue.__len__() > 0:
- answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
diff --git a/models/fastchat_openai_llm.py b/models/fastchat_openai_llm.py
index 4a7f342..d0972f7 100644
--- a/models/fastchat_openai_llm.py
+++ b/models/fastchat_openai_llm.py
@@ -1,6 +1,11 @@
from abc import ABC
from langchain.chains.base import Chain
-from typing import Any, Dict, List, Optional, Generator, Collection
+from typing import (
+ Any, Dict, List, Optional, Generator, Collection, Set,
+ Callable,
+ Tuple,
+ Union)
+
from models.loader import LoaderCheckPoint
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.base import (BaseAnswer,
@@ -8,9 +13,26 @@ from models.base import (BaseAnswer,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+from pydantic import Extra, Field, root_validator
+
+from openai import (
+ ChatCompletion
+)
+
+import openai
+import logging
import torch
import transformers
+logger = logging.getLogger(__name__)
+
def _build_message_template() -> Dict[str, str]:
"""
@@ -25,15 +47,26 @@ def _build_message_template() -> Dict[str, str]:
# 将历史对话数组转换为文本格式
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
build_messages: Collection[Dict[str, str]] = []
- for i, (old_query, response) in enumerate(history):
- user_build_message = _build_message_template()
- user_build_message['role'] = 'user'
- user_build_message['content'] = old_query
- system_build_message = _build_message_template()
- system_build_message['role'] = 'system'
- system_build_message['content'] = response
- build_messages.append(user_build_message)
- build_messages.append(system_build_message)
+
+ system_build_message = _build_message_template()
+ system_build_message['role'] = 'system'
+ system_build_message['content'] = "You are a helpful assistant."
+ build_messages.append(system_build_message)
+ if history:
+ for i, (user, assistant) in enumerate(history):
+ if user:
+
+ user_build_message = _build_message_template()
+ user_build_message['role'] = 'user'
+ user_build_message['content'] = user
+ build_messages.append(user_build_message)
+
+ if not assistant:
+ raise RuntimeError("历史数据结构不正确")
+ system_build_message = _build_message_template()
+ system_build_message['role'] = 'assistant'
+ system_build_message['content'] = assistant
+ build_messages.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
@@ -43,6 +76,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
+ client: Any
+ """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
+ max_retries: int = 6
api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b"
max_token: int = 10000
@@ -108,6 +144,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
def call_model_name(self, model_name):
self.model_name = model_name
+ def _create_retry_decorator(self) -> Callable[[Any], Any]:
+ min_seconds = 1
+ max_seconds = 60
+ # Wait 2^x * 1 second between each retry starting with
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
+ return retry(
+ reraise=True,
+ stop=stop_after_attempt(self.max_retries),
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
+ retry=(
+ retry_if_exception_type(openai.error.Timeout)
+ | retry_if_exception_type(openai.error.APIError)
+ | retry_if_exception_type(openai.error.APIConnectionError)
+ | retry_if_exception_type(openai.error.RateLimitError)
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ ),
+ before_sleep=before_sleep_log(logger, logging.WARNING),
+ )
+
+ def completion_with_retry(self, **kwargs: Any) -> Any:
+ """Use tenacity to retry the completion call."""
+ retry_decorator = self._create_retry_decorator()
+
+ @retry_decorator
+ def _completion_with_retry(**kwargs: Any) -> Any:
+ return self.client.create(**kwargs)
+
+ return _completion_with_retry(**kwargs)
+
def _call(
self,
inputs: Dict[str, Any],
@@ -121,32 +186,74 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
- history = inputs[self.history_key]
- streaming = inputs[self.streaming_key]
+ history = inputs.get(self.history_key, [])
+ streaming = inputs.get(self.streaming_key, False)
prompt = inputs[self.prompt_key]
+ stop = inputs.get("stop", "stop")
print(f"__call:{prompt}")
try:
- import openai
# Not support yet
# openai.api_key = "EMPTY"
openai.api_key = self.api_key
openai.api_base = self.api_base_url
- except ImportError:
+ self.client = openai.ChatCompletion
+ except AttributeError:
raise ValueError(
- "Could not import openai python package. "
- "Please install it with `pip install openai`."
+ "`openai` has no `ChatCompletion` attribute, this is likely "
+ "due to an old version of the openai package. Try upgrading it "
+ "with `pip install --upgrade openai`."
)
- # create a chat completion
- completion = openai.ChatCompletion.create(
- model=self.model_name,
- messages=build_message_list(prompt,history=history)
- )
- print(f"response:{completion.choices[0].message.content}")
- print(f"+++++++++++++++++++++++++++++++++++")
+ msg = build_message_list(prompt, history=history)
- history += [[prompt, completion.choices[0].message.content]]
- answer_result = AnswerResult()
- answer_result.history = history
- answer_result.llm_output = {"answer": completion.choices[0].message.content}
- generate_with_callback(answer_result)
+ if streaming:
+ params = {"stream": streaming,
+ "model": self.model_name,
+ "stop": stop}
+ out_str = ""
+ for stream_resp in self.completion_with_retry(
+ messages=msg,
+ **params
+ ):
+ role = stream_resp["choices"][0]["delta"].get("role", "")
+ token = stream_resp["choices"][0]["delta"].get("content", "")
+ out_str += token
+ history[-1] = [prompt, out_str]
+ answer_result = AnswerResult()
+ answer_result.history = history
+ answer_result.llm_output = {"answer": out_str}
+ generate_with_callback(answer_result)
+ else:
+
+ params = {"stream": streaming,
+ "model": self.model_name,
+ "stop": stop}
+ response = self.completion_with_retry(
+ messages=msg,
+ **params
+ )
+ role = response["choices"][0]["message"].get("role", "")
+ content = response["choices"][0]["message"].get("content", "")
+ history += [[prompt, content]]
+ answer_result = AnswerResult()
+ answer_result.history = history
+ answer_result.llm_output = {"answer": content}
+ generate_with_callback(answer_result)
+
+
+if __name__ == "__main__":
+
+ chain = FastChatOpenAILLMChain()
+
+ chain.set_api_key("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")
+ # chain.set_api_base_url("https://api.openai.com/v1")
+ # chain.call_model_name("gpt-3.5-turbo")
+
+ answer_result_stream_result = chain({"streaming": True,
+ "prompt": "你好",
+ "history": []
+ })
+
+ for answer_result in answer_result_stream_result['answer_result_stream']:
+ resp = answer_result.llm_output["answer"]
+ print(resp)
diff --git a/models/llama_llm.py b/models/llama_llm.py
index 89d21ac..014fd81 100644
--- a/models/llama_llm.py
+++ b/models/llama_llm.py
@@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC):
answer_result = AnswerResult()
history += [[prompt, reply]]
answer_result.history = history
- if listenerQueue.listenerQueue.__len__() > 0:
- answer_result.listenerToken = listenerQueue.listenerQueue.pop()
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)
diff --git a/requirements.txt b/requirements.txt
index bffd30c..7f97f67 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,7 +11,7 @@ beautifulsoup4
icetk
cpm_kernels
faiss-cpu
-gradio==3.28.3
+gradio==3.37.0
fastapi~=0.95.0
uvicorn~=0.21.1
pypinyin~=0.48.0
diff --git a/webui_st.py b/webui_st.py
index 4300662..1584a55 100644
--- a/webui_st.py
+++ b/webui_st.py
@@ -1,5 +1,5 @@
import streamlit as st
-# from st_btn_select import st_btn_select
+from streamlit_chatbox import st_chatbox
import tempfile
###### 从webui借用的代码 #####
###### 做了少量修改 #####
@@ -23,6 +23,7 @@ def get_vs_list():
if not os.path.exists(KB_ROOT_PATH):
return lst_default
lst = os.listdir(KB_ROOT_PATH)
+ lst = [x for x in lst if os.path.isdir(os.path.join(KB_ROOT_PATH, x))]
if not lst:
return lst_default
lst.sort()
@@ -31,7 +32,6 @@ def get_vs_list():
embedding_model_dict_list = list(embedding_model_dict.keys())
llm_model_dict_list = list(llm_model_dict.keys())
-# flag_csv_logger = gr.CSVLogger()
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
@@ -50,6 +50,9 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
history[-1][-1] += source
yield history, ""
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
+ local_doc_qa.top_k = vector_search_top_k
+ local_doc_qa.chunk_conent = chunk_conent
+ local_doc_qa.chunk_size = chunk_size
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
source = "\n\n"
@@ -95,17 +98,101 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, ""
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
- # flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
-def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'):
+def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
+ vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
+ filelist = []
+ if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
+ os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
+ qa = st.session_state.local_doc_qa
+ if qa.llm_model_chain and qa.embeddings:
+ if isinstance(files, list):
+ for file in files:
+ filename = os.path.split(file.name)[-1]
+ shutil.move(file.name, os.path.join(
+ KB_ROOT_PATH, vs_id, "content", filename))
+ filelist.append(os.path.join(
+ KB_ROOT_PATH, vs_id, "content", filename))
+ vs_path, loaded_files = qa.init_knowledge_vector_store(
+ filelist, vs_path, sentence_size)
+ else:
+ vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
+ sentence_size)
+ if len(loaded_files):
+ file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
+ else:
+ file_status = "文件未成功加载,请重新上传文件"
+ else:
+ file_status = "模型未完成加载,请先在加载模型后再导入文件"
+ vs_path = None
+ logger.info(file_status)
+ return vs_path, None, history + [[None, file_status]]
+
+
+knowledge_base_test_mode_info = ("【注意】\n\n"
+ "1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询,"
+ "并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
+ "2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
+ """3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
+ "4. 单条内容长度建议设置在100-150左右。")
+
+
+webui_title = """
+# 🎉langchain-ChatGLM WebUI🎉
+👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
+"""
+###### #####
+
+
+###### todo #####
+# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
+# 目前已经实现了local_doc_qa和shared.loaderCheckPoint的全局化。
+# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
+# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
+# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
+###### #####
+
+
+###### 配置项 #####
+class ST_CONFIG:
+ default_mode = "知识库问答"
+ default_kb = ""
+###### #####
+
+
+class TempFile:
+ '''
+ 为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
+ '''
+
+ def __init__(self, path):
+ self.name = path
+
+
+@st.cache_resource(show_spinner=False, max_entries=1)
+def load_model(
+ llm_model: str = LLM_MODEL,
+ embedding_model: str = EMBEDDING_MODEL,
+ use_ptuning_v2: bool = USE_PTUNING_V2,
+):
+ '''
+ 对应init_model,利用streamlit cache避免模型重复加载
+ '''
local_doc_qa = LocalDocQA()
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
args_dict.update(model=llm_model)
- shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
- llm_model_ins = shared.loaderLLM()
+ if shared.loaderCheckPoint is None: # avoid checkpoint reloading when reinit model
+ shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
+ # shared.loaderCheckPoint.model_name is different by no_remote_model.
+ # if it is not set properly error occurs when reinit llm model(issue#473).
+ # as no_remote_model is removed from model_config, need workaround to set it automaticlly.
+ local_model_path = llm_model_dict.get(llm_model, {}).get('local_model_path') or ''
+ no_remote_model = os.path.isdir(local_model_path)
+ llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
+ llm_model_ins.history_len = LLM_HISTORY_LEN
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins,
@@ -128,235 +215,9 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
return local_doc_qa
-# 暂未使用到,先保留
-# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
-# try:
-# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
-# llm_model_ins.history_len = llm_history_len
-# local_doc_qa.init_cfg(llm_model=llm_model_ins,
-# embedding_model=embedding_model,
-# top_k=top_k)
-# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
-# logger.info(model_status)
-# except Exception as e:
-# logger.error(e)
-# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
-# logger.info(model_status)
-# return history + [[None, model_status]]
-
-
-def get_vector_store(local_doc_qa, vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
- vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
- filelist = []
- if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
- os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
- if local_doc_qa.llm and local_doc_qa.embeddings:
- if isinstance(files, list):
- for file in files:
- filename = os.path.split(file.name)[-1]
- shutil.move(file.name, os.path.join(
- KB_ROOT_PATH, vs_id, "content", filename))
- filelist.append(os.path.join(
- KB_ROOT_PATH, vs_id, "content", filename))
- vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(
- filelist, vs_path, sentence_size)
- else:
- vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
- sentence_size)
- if len(loaded_files):
- file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
- else:
- file_status = "文件未成功加载,请重新上传文件"
- else:
- file_status = "模型未完成加载,请先在加载模型后再导入文件"
- vs_path = None
- logger.info(file_status)
- return vs_path, None, history + [[None, file_status]]
-
-
-knowledge_base_test_mode_info = ("【注意】\n\n"
- "1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询,"
- "并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
- "2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
- """3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
- "4. 单条内容长度建议设置在100-150左右。\n\n"
- "5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中,"
- "本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。"
- "相关参数将在后续版本中支持本界面直接修改。")
-
-
-webui_title = """
-# 🎉langchain-ChatGLM WebUI🎉
-👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
-"""
-###### #####
-
-
-###### todo #####
-# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
-# 目前已经实现了local_doc_qa的全局化,后面要考虑shared。
-# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
-# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
-# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
-###### #####
-
-
-###### 配置项 #####
-class ST_CONFIG:
- user_bg_color = '#77ff77'
- user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7'
- robot_bg_color = '#ccccee'
- robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0'
- default_mode = '知识库问答'
- defalut_kb = ''
-###### #####
-
-
-class MsgType:
- '''
- 目前仅支持文本类型的输入输出,为以后多模态模型预留图像、视频、音频支持。
- '''
- TEXT = 1
- IMAGE = 2
- VIDEO = 3
- AUDIO = 4
-
-
-class TempFile:
- '''
- 为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
- '''
-
- def __init__(self, path):
- self.name = path
-
-
-def init_session():
- st.session_state.setdefault('history', [])
-
-
-# def get_query_params():
-# '''
-# 可以用url参数传递配置参数:llm_model, embedding_model, kb, mode。
-# 该参数将覆盖model_config中的配置。处于安全考虑,目前只支持kb和mode
-# 方便将固定的配置分享给特定的人。
-# '''
-# params = st.experimental_get_query_params()
-# return {k: v[0] for k, v in params.items() if v}
-
-
-def robot_say(msg, kb=''):
- st.session_state['history'].append(
- {'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb})
-
-
-def user_say(msg):
- st.session_state['history'].append(
- {'is_user': True, 'type': MsgType.TEXT, 'content': msg})
-
-
-def format_md(msg, is_user=False, bg_color='', margin='10%'):
- '''
- 将文本消息格式化为markdown文本
- '''
- if is_user:
- bg_color = bg_color or ST_CONFIG.user_bg_color
- text = f'''
-