diff --git a/.gitignore b/.gitignore index 7ee7991..73ebd21 100644 --- a/.gitignore +++ b/.gitignore @@ -174,8 +174,7 @@ embedding/* pyrightconfig.json loader/tmp_files -<<<<<<< HEAD flagged/* -======= -flagged/* ->>>>>>> 19147c3 (Update .gitignore) +ptuning-v2/*.json +ptuning-v2/*.bin + diff --git a/README.md b/README.md index 39c230f..5ee0ade 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,10 @@ Web UI 可以实现如下功能: - [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe) - [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) - [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft) + - [x] [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1) + - [x] [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b) + - [x] [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) + - [x] [lmsys/vicuna-13b-delta-v1.1](https://huggingface.co/lmsys/vicuna-13b-delta-v1.1) - [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm - [x] 增加更多 Embedding 模型支持 - [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh) @@ -251,7 +255,7 @@ Web UI 可以实现如下功能: - [x] VUE 前端 ## 项目交流群 -二维码 +二维码 🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 diff --git a/api.py b/api.py index 70dccc8..5c348a9 100644 --- a/api.py +++ b/api.py @@ -1,3 +1,4 @@ +#encoding:utf-8 import argparse import json import os @@ -373,7 +374,7 @@ async def bing_search_chat( async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), - history: List[List[str]] = Body( + history: Optional[List[List[str]]] = Body( [], description="History of previous questions and answers", example=[ @@ -391,7 +392,6 @@ async def chat( resp = answer_result.llm_output["answer"] history = answer_result.history pass - return ChatMessage( question=question, response=resp, diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 2f7d8da..6085dfc 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -8,7 +8,6 @@ from typing import List from utils import torch_gc from tqdm import tqdm from pypinyin import lazy_pinyin -from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader from models.base import (BaseAnswer, AnswerResult) from models.loader.args import parser @@ -59,6 +58,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None): def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE): + if filepath.lower().endswith(".md"): loader = UnstructuredFileLoader(filepath, mode="elements") docs = loader.load() @@ -67,10 +67,14 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = loader.load_and_split(textsplitter) elif filepath.lower().endswith(".pdf"): + # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x + from loader import UnstructuredPaddlePDFLoader loader = UnstructuredPaddlePDFLoader(filepath) textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) docs = loader.load_and_split(textsplitter) elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"): + # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x + from loader import UnstructuredPaddleImageLoader loader = UnstructuredPaddleImageLoader(filepath, mode="elements") textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = loader.load_and_split(text_splitter=textsplitter) diff --git a/chains/modules/embeddings.py b/chains/modules/embeddings.py deleted file mode 100644 index 3abeddf..0000000 --- a/chains/modules/embeddings.py +++ /dev/null @@ -1,34 +0,0 @@ -from langchain.embeddings.huggingface import HuggingFaceEmbeddings - -from typing import Any, List - - -class MyEmbeddings(HuggingFaceEmbeddings): - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Compute doc embeddings using a HuggingFace transformer model. - - Args: - texts: The list of texts to embed. - - Returns: - List of embeddings, one for each text. - """ - texts = list(map(lambda x: x.replace("\n", " "), texts)) - embeddings = self.client.encode(texts, normalize_embeddings=True) - return embeddings.tolist() - - def embed_query(self, text: str) -> List[float]: - """Compute query embeddings using a HuggingFace transformer model. - - Args: - text: The text to embed. - - Returns: - Embeddings for the text. - """ - text = text.replace("\n", " ") - embedding = self.client.encode(text, normalize_embeddings=True) - return embedding.tolist() diff --git a/chains/modules/vectorstores.py b/chains/modules/vectorstores.py deleted file mode 100644 index da89775..0000000 --- a/chains/modules/vectorstores.py +++ /dev/null @@ -1,121 +0,0 @@ -from langchain.vectorstores import FAISS -from typing import Any, Callable, List, Optional, Tuple, Dict -from langchain.docstore.document import Document -from langchain.docstore.base import Docstore - -from langchain.vectorstores.utils import maximal_marginal_relevance -from langchain.embeddings.base import Embeddings -import uuid -from langchain.docstore.in_memory import InMemoryDocstore - -import numpy as np - -def dependable_faiss_import() -> Any: - """Import faiss if available, otherwise raise error.""" - try: - import faiss - except ImportError: - raise ValueError( - "Could not import faiss python package. " - "Please install it with `pip install faiss` " - "or `pip install faiss-cpu` (depending on Python version)." - ) - return faiss - -class FAISSVS(FAISS): - def __init__(self, - embedding_function: Callable[..., Any], - index: Any, - docstore: Docstore, - index_to_docstore_id: Dict[int, str]): - super().__init__(embedding_function, index, docstore, index_to_docstore_id) - - def max_marginal_relevance_search_by_vector( - self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - - Returns: - List of Documents with scores selected by maximal marginal relevance. - """ - scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k) - # -1 happens when not enough docs are returned. - embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] - mmr_selected = maximal_marginal_relevance( - np.array([embedding], dtype=np.float32), embeddings, k=k - ) - selected_indices = [indices[0][i] for i in mmr_selected] - selected_scores = [scores[0][i] for i in mmr_selected] - docs = [] - for i, score in zip(selected_indices, selected_scores): - if i == -1: - # This happens when not enough docs are returned. - continue - _id = self.index_to_docstore_id[i] - doc = self.docstore.search(_id) - if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {_id}, got {doc}") - docs.append((doc, score)) - return docs - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - - Returns: - List of Documents with scores selected by maximal marginal relevance. - """ - embedding = self.embedding_function(query) - docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k) - return docs - - @classmethod - def __from( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> FAISS: - faiss = dependable_faiss_import() - index = faiss.IndexFlatIP(len(embeddings[0])) - index.add(np.array(embeddings, dtype=np.float32)) - - # # my code, for speeding up search - # quantizer = faiss.IndexFlatL2(len(embeddings[0])) - # index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100) - # index.train(np.array(embeddings, dtype=np.float32)) - # index.add(np.array(embeddings, dtype=np.float32)) - - documents = [] - for i, text in enumerate(texts): - metadata = metadatas[i] if metadatas else {} - documents.append(Document(page_content=text, metadata=metadata)) - index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))} - docstore = InMemoryDocstore( - {index_to_id[i]: doc for i, doc in enumerate(documents)} - ) - return cls(embedding.embed_query, index, docstore, index_to_id) - diff --git a/configs/model_config.py b/configs/model_config.py index ce883a6..60f26fd 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -94,6 +94,12 @@ llm_model_dict = { "local_model_path": None, "provides": "MOSSLLMChain" }, + "moss-int4": { + "name": "moss", + "pretrained_model_name": "fnlp/moss-moon-003-sft-int4", + "local_model_path": None, + "provides": "MOSSLLM" + }, "vicuna-13b-hf": { "name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf", @@ -155,6 +161,15 @@ llm_model_dict = { "provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain" "api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url" "api_key": "EMPTY" + }, + # 通过 fastchat 调用的模型请参考如下格式 + "fastchat-chatglm-6b-int4": { + "name": "chatglm-6b-int4", # "name"修改为fastchat服务中的"model_name" + "pretrained_model_name": "chatglm-6b-int4", + "local_model_path": None, + "provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain" + "api_base_url": "http://localhost:8001/v1", # "name"修改为fastchat服务中的"api_base_url" + "api_key": "EMPTY" }, "fastchat-chatglm2-6b": { "name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name" @@ -176,11 +191,13 @@ llm_model_dict = { # 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443): # Max retries exceeded with url: /v1/chat/completions # 则需要将urllib3版本修改为1.25.11 + # 如果依然报urllib3.exceptions.MaxRetryError: HTTPSConnectionPool,则将https改为http + # 参考https://zhuanlan.zhihu.com/p/350015032 # 如果报出:raise NewConnectionError( # urllib3.exceptions.NewConnectionError: : # Failed to establish a new connection: [WinError 10060] - # 则是因为内地和香港的IP都被OPENAI封了,需要挂切换为日本、新加坡等地 + # 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地 "openai-chatgpt-3.5": { "name": "gpt-3.5-turbo", "pretrained_model_name": "gpt-3.5-turbo", @@ -210,7 +227,7 @@ STREAMING = True # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False - +PTUNING_DIR='./ptuning-v2' # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" @@ -238,8 +255,8 @@ LLM_HISTORY_LEN = 3 # 知识库检索时返回的匹配内容条数 VECTOR_SEARCH_TOP_K = 5 -# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准 -VECTOR_SEARCH_SCORE_THRESHOLD = 390 +# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,建议设置为500左右,经测试设置为小于500时,匹配结果更精准 +VECTOR_SEARCH_SCORE_THRESHOLD = 500 NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") diff --git a/docs/FAQ.md b/docs/FAQ.md index f712477..ccc0f25 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -177,3 +177,22 @@ download_with_progressbar(url, tmp_path) Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a new connection: [Errno 110] Connection timed out` 这是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG--! + +--- + +Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients` + +疑为chatglm的quantization的问题或torch版本差异问题,针对已经变为Parameter的torch.zeros矩阵也执行Parameter操作,从而抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在chatglm-项目的原始文件中的quantization.py文件374行改为: + +``` + try: + self.weight =Parameter(self.weight.to(kwargs["device"]), requires_grad=False) + except Exception as e: + pass +``` + + 如果上述方式不起作用,则在.cache/hugggingface/modules/目录下针对chatglm项目的原始文件中的quantization.py文件执行上述操作,若软链接不止一个,按照错误提示选择正确的路径。 + +注:虽然模型可以顺利加载但在cpu上仍存在推理失败的可能:即针对每个问题,模型一直输出gugugugu。 + + 因此,最好不要试图用cpu加载量化模型,原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。 diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 6602973..2682c7b 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -49,7 +49,7 @@ $ python loader/image_loader.py ## llama-cpp模型调用的说明 -1. 首先从huggingface hub中下载对应的模型,如https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/的[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。 +1. 首先从huggingface hub中下载对应的模型,如 [https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/) 的 [ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。 2. 将下载的模型重命名。通过huggingface_hub下载的模型会被重命名为随机序列,因此需要重命名为原始文件名,如[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)。 3. 基于下载模型的ggml的加载时间,推测对应的llama-cpp版本,下载对应的llama-cpp-python库的wheel文件,实测[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)与llama-cpp-python库兼容,然后手动安装wheel文件。 4. 将下载的模型信息写入configs/model_config.py文件里 `llm_model_dict`中,注意保证参数的兼容性,一些参数组合可能会报错. diff --git a/img/qr_code_42.jpg b/img/qr_code_42.jpg deleted file mode 100644 index 146b873..0000000 Binary files a/img/qr_code_42.jpg and /dev/null differ diff --git a/img/qr_code_45.jpg b/img/qr_code_45.jpg new file mode 100644 index 0000000..ad253c8 Binary files /dev/null and b/img/qr_code_45.jpg differ diff --git a/models/base/base.py b/models/base/base.py index 1b65b21..c6674c9 100644 --- a/models/base/base.py +++ b/models/base/base.py @@ -6,6 +6,7 @@ from queue import Queue from threading import Thread from langchain.callbacks.manager import CallbackManagerForChainRun from models.loader import LoaderCheckPoint +from pydantic import BaseModel import torch import transformers @@ -23,13 +24,12 @@ class ListenerToken: self._scores = _scores -class AnswerResult: +class AnswerResult(BaseModel): """ 消息实体 """ history: List[List[str]] = [] llm_output: Optional[dict] = None - listenerToken: ListenerToken = None class AnswerResultStream: @@ -167,8 +167,6 @@ class BaseAnswer(ABC): with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator: for answerResult in generator: - if answerResult.listenerToken: - output = answerResult.listenerToken.input_ids yield answerResult @abstractmethod diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index c45cf3b..0d19ee6 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -2,14 +2,14 @@ from abc import ABC from langchain.chains.base import Chain from typing import Any, Dict, List, Optional, Generator from langchain.callbacks.manager import CallbackManagerForChainRun -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList +# from transformers.generation.logits_process import LogitsProcessor +# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, AnswerResult, AnswerResultStream, AnswerResultQueueSentinelTokenListenerQueue) -import torch +# import torch import transformers @@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC): answer_result = AnswerResult() answer_result.history = history answer_result.llm_output = {"answer": stream_resp} - if listenerQueue.listenerQueue.__len__() > 0: - answer_result.listenerToken = listenerQueue.listenerQueue.pop() generate_with_callback(answer_result) self.checkPoint.clear_torch_cache() else: @@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC): answer_result = AnswerResult() answer_result.history = history answer_result.llm_output = {"answer": response} - if listenerQueue.listenerQueue.__len__() > 0: - answer_result.listenerToken = listenerQueue.listenerQueue.pop() generate_with_callback(answer_result) diff --git a/models/fastchat_openai_llm.py b/models/fastchat_openai_llm.py index 398364b..217910a 100644 --- a/models/fastchat_openai_llm.py +++ b/models/fastchat_openai_llm.py @@ -1,6 +1,11 @@ from abc import ABC from langchain.chains.base import Chain -from typing import Any, Dict, List, Optional, Generator, Collection +from typing import ( + Any, Dict, List, Optional, Generator, Collection, Set, + Callable, + Tuple, + Union) + from models.loader import LoaderCheckPoint from langchain.callbacks.manager import CallbackManagerForChainRun from models.base import (BaseAnswer, @@ -8,9 +13,26 @@ from models.base import (BaseAnswer, AnswerResult, AnswerResultStream, AnswerResultQueueSentinelTokenListenerQueue) +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) +from pydantic import Extra, Field, root_validator + +from openai import ( + ChatCompletion +) + +import openai +import logging import torch import transformers +logger = logging.getLogger(__name__) + def _build_message_template() -> Dict[str, str]: """ @@ -25,15 +47,26 @@ def _build_message_template() -> Dict[str, str]: # 将历史对话数组转换为文本格式 def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]: build_messages: Collection[Dict[str, str]] = [] - for i, (old_query, response) in enumerate(history): - user_build_message = _build_message_template() - user_build_message['role'] = 'user' - user_build_message['content'] = old_query - system_build_message = _build_message_template() - system_build_message['role'] = 'system' - system_build_message['content'] = response - build_messages.append(user_build_message) - build_messages.append(system_build_message) + + system_build_message = _build_message_template() + system_build_message['role'] = 'system' + system_build_message['content'] = "You are a helpful assistant." + build_messages.append(system_build_message) + if history: + for i, (user, assistant) in enumerate(history): + if user: + + user_build_message = _build_message_template() + user_build_message['role'] = 'user' + user_build_message['content'] = user + build_messages.append(user_build_message) + + if not assistant: + raise RuntimeError("历史数据结构不正确") + system_build_message = _build_message_template() + system_build_message['role'] = 'assistant' + system_build_message['content'] = assistant + build_messages.append(system_build_message) user_build_message = _build_message_template() user_build_message['role'] = 'user' @@ -43,6 +76,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): + client: Any + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" + max_retries: int = 6 api_base_url: str = "http://localhost:8000/v1" model_name: str = "chatglm-6b" max_token: int = 10000 @@ -108,6 +144,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): def call_model_name(self, model_name): self.model_name = model_name + def _create_retry_decorator(self) -> Callable[[Any], Any]: + min_seconds = 1 + max_seconds = 60 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + def completion_with_retry(self, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = self._create_retry_decorator() + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return self.client.create(**kwargs) + + return _completion_with_retry(**kwargs) + def _call( self, inputs: Dict[str, Any], @@ -121,32 +186,74 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): run_manager: Optional[CallbackManagerForChainRun] = None, generate_with_callback: AnswerResultStream = None) -> None: - history = inputs[self.history_key] - streaming = inputs[self.streaming_key] + history = inputs.get(self.history_key, []) + streaming = inputs.get(self.streaming_key, False) prompt = inputs[self.prompt_key] + stop = inputs.get("stop", "stop") print(f"__call:{prompt}") try: - import openai # Not support yet # openai.api_key = "EMPTY" openai.api_key = self.api_key openai.api_base = self.api_base_url - except ImportError: + self.client = openai.ChatCompletion + except AttributeError: raise ValueError( - "Could not import openai python package. " - "Please install it with `pip install openai`." + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." ) - # create a chat completion - completion = openai.ChatCompletion.create( - model=self.model_name, - messages=build_message_list(prompt) - ) - print(f"response:{completion.choices[0].message.content}") - print(f"+++++++++++++++++++++++++++++++++++") + msg = build_message_list(prompt, history=history) - history += [[prompt, completion.choices[0].message.content]] - answer_result = AnswerResult() - answer_result.history = history - answer_result.llm_output = {"answer": completion.choices[0].message.content} - generate_with_callback(answer_result) + if streaming: + params = {"stream": streaming, + "model": self.model_name, + "stop": stop} + out_str = "" + for stream_resp in self.completion_with_retry( + messages=msg, + **params + ): + role = stream_resp["choices"][0]["delta"].get("role", "") + token = stream_resp["choices"][0]["delta"].get("content", "") + out_str += token + history[-1] = [prompt, out_str] + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": out_str} + generate_with_callback(answer_result) + else: + + params = {"stream": streaming, + "model": self.model_name, + "stop": stop} + response = self.completion_with_retry( + messages=msg, + **params + ) + role = response["choices"][0]["message"].get("role", "") + content = response["choices"][0]["message"].get("content", "") + history += [[prompt, content]] + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": content} + generate_with_callback(answer_result) + + +if __name__ == "__main__": + + chain = FastChatOpenAILLMChain() + + chain.set_api_key("EMPTY") + # chain.set_api_base_url("https://api.openai.com/v1") + # chain.call_model_name("gpt-3.5-turbo") + + answer_result_stream_result = chain({"streaming": True, + "prompt": "你好", + "history": [] + }) + + for answer_result in answer_result_stream_result['answer_result_stream']: + resp = answer_result.llm_output["answer"] + print(resp) diff --git a/models/llama_llm.py b/models/llama_llm.py index 89d21ac..014fd81 100644 --- a/models/llama_llm.py +++ b/models/llama_llm.py @@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC): answer_result = AnswerResult() history += [[prompt, reply]] answer_result.history = history - if listenerQueue.listenerQueue.__len__() > 0: - answer_result.listenerToken = listenerQueue.listenerQueue.pop() answer_result.llm_output = {"answer": reply} generate_with_callback(answer_result) diff --git a/models/loader/args.py b/models/loader/args.py index b15ad5e..cd3e78b 100644 --- a/models/loader/args.py +++ b/models/loader/args.py @@ -1,3 +1,4 @@ + import argparse import os from configs.model_config import * @@ -43,7 +44,8 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") - +parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint") +parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint") # Accelerate/transformers parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT, help='Load the model with 8-bit precision.') diff --git a/models/loader/loader.py b/models/loader/loader.py index ab014ed..eb92f90 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -149,7 +149,7 @@ class LoaderCheckPoint: trust_remote_code=True).half() # 可传入device_map自定义每张卡的部署情况 if self.device_map is None: - if 'chatglm' in self.model_name.lower(): + if 'chatglm' in self.model_name.lower() and not "chatglm2" in self.model_name.lower(): self.device_map = self.chatglm_auto_configure_device_map(num_gpus) elif 'moss' in self.model_name.lower(): self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint) @@ -165,13 +165,6 @@ class LoaderCheckPoint: dtype=torch.float16 if not self.load_in_8bit else torch.int8, max_memory=max_memory, no_split_module_classes=model._no_split_modules) - # 对于chaglm和moss意外的模型应使用自动指定,而非调用chatglm的配置方式 - # 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map - # 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错 - # 实测在bloom模型上如此 - # self.device_map = infer_auto_device_map(model, - # dtype=torch.int8, - # no_split_module_classes=model._no_split_modules) model = dispatch_model(model, device_map=self.device_map) else: @@ -472,12 +465,13 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') + prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r') prefix_encoder_config = json.loads(prefix_encoder_file.read()) prefix_encoder_file.close() self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] self.model_config.prefix_projection = prefix_encoder_config['prefix_projection'] except Exception as e: + print(e) print("加载PrefixEncoder config.json失败") self.model, self.tokenizer = self._load_model() @@ -487,14 +481,16 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) + prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin')) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.float() + print("加载ptuning检查点成功!") except Exception as e: + print(e) print("加载PrefixEncoder模型参数失败") # llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法 if not self.is_llamacpp and not self.is_chatgmlcpp: diff --git a/requirements.txt b/requirements.txt index bffd30c..7f97f67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ beautifulsoup4 icetk cpm_kernels faiss-cpu -gradio==3.28.3 +gradio==3.37.0 fastapi~=0.95.0 uvicorn~=0.21.1 pypinyin~=0.48.0 diff --git a/webui.py b/webui.py index b7ff9fd..c7e7880 100644 --- a/webui.py +++ b/webui.py @@ -104,6 +104,7 @@ def init_model(): args_dict = vars(args) shared.loaderCheckPoint = LoaderCheckPoint(args_dict) llm_model_ins = shared.loaderLLM() + llm_model_ins.history_len = LLM_HISTORY_LEN try: local_doc_qa.init_cfg(llm_model=llm_model_ins) answer_result_stream_result = local_doc_qa.llm_model_chain( diff --git a/webui_st.py b/webui_st.py index 4300662..1584a55 100644 --- a/webui_st.py +++ b/webui_st.py @@ -1,5 +1,5 @@ import streamlit as st -# from st_btn_select import st_btn_select +from streamlit_chatbox import st_chatbox import tempfile ###### 从webui借用的代码 ##### ###### 做了少量修改 ##### @@ -23,6 +23,7 @@ def get_vs_list(): if not os.path.exists(KB_ROOT_PATH): return lst_default lst = os.listdir(KB_ROOT_PATH) + lst = [x for x in lst if os.path.isdir(os.path.join(KB_ROOT_PATH, x))] if not lst: return lst_default lst.sort() @@ -31,7 +32,6 @@ def get_vs_list(): embedding_model_dict_list = list(embedding_model_dict.keys()) llm_model_dict_list = list(llm_model_dict.keys()) -# flag_csv_logger = gr.CSVLogger() def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, @@ -50,6 +50,9 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR history[-1][-1] += source yield history, "" elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path): + local_doc_qa.top_k = vector_search_top_k + local_doc_qa.chunk_conent = chunk_conent + local_doc_qa.chunk_size = chunk_size for resp, history in local_doc_qa.get_knowledge_based_answer( query=query, vs_path=vs_path, chat_history=history, streaming=streaming): source = "\n\n" @@ -95,17 +98,101 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") yield history, "" logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}") - # flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME) -def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'): +def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): + vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") + filelist = [] + if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): + os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content")) + qa = st.session_state.local_doc_qa + if qa.llm_model_chain and qa.embeddings: + if isinstance(files, list): + for file in files: + filename = os.path.split(file.name)[-1] + shutil.move(file.name, os.path.join( + KB_ROOT_PATH, vs_id, "content", filename)) + filelist.append(os.path.join( + KB_ROOT_PATH, vs_id, "content", filename)) + vs_path, loaded_files = qa.init_knowledge_vector_store( + filelist, vs_path, sentence_size) + else: + vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, + sentence_size) + if len(loaded_files): + file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" + else: + file_status = "文件未成功加载,请重新上传文件" + else: + file_status = "模型未完成加载,请先在加载模型后再导入文件" + vs_path = None + logger.info(file_status) + return vs_path, None, history + [[None, file_status]] + + +knowledge_base_test_mode_info = ("【注意】\n\n" + "1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询," + "并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n" + "2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。" + """3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n""" + "4. 单条内容长度建议设置在100-150左右。") + + +webui_title = """ +# 🎉langchain-ChatGLM WebUI🎉 +👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM) +""" +###### ##### + + +###### todo ##### +# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。 +# 目前已经实现了local_doc_qa和shared.loaderCheckPoint的全局化。 +# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。 +# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。 +# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。 +###### ##### + + +###### 配置项 ##### +class ST_CONFIG: + default_mode = "知识库问答" + default_kb = "" +###### ##### + + +class TempFile: + ''' + 为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式 + ''' + + def __init__(self, path): + self.name = path + + +@st.cache_resource(show_spinner=False, max_entries=1) +def load_model( + llm_model: str = LLM_MODEL, + embedding_model: str = EMBEDDING_MODEL, + use_ptuning_v2: bool = USE_PTUNING_V2, +): + ''' + 对应init_model,利用streamlit cache避免模型重复加载 + ''' local_doc_qa = LocalDocQA() # 初始化消息 args = parser.parse_args() args_dict = vars(args) args_dict.update(model=llm_model) - shared.loaderCheckPoint = LoaderCheckPoint(args_dict) - llm_model_ins = shared.loaderLLM() + if shared.loaderCheckPoint is None: # avoid checkpoint reloading when reinit model + shared.loaderCheckPoint = LoaderCheckPoint(args_dict) + # shared.loaderCheckPoint.model_name is different by no_remote_model. + # if it is not set properly error occurs when reinit llm model(issue#473). + # as no_remote_model is removed from model_config, need workaround to set it automaticlly. + local_model_path = llm_model_dict.get(llm_model, {}).get('local_model_path') or '' + no_remote_model = os.path.isdir(local_model_path) + llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) + llm_model_ins.history_len = LLM_HISTORY_LEN try: local_doc_qa.init_cfg(llm_model=llm_model_ins, @@ -128,235 +215,9 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec' return local_doc_qa -# 暂未使用到,先保留 -# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history): -# try: -# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) -# llm_model_ins.history_len = llm_history_len -# local_doc_qa.init_cfg(llm_model=llm_model_ins, -# embedding_model=embedding_model, -# top_k=top_k) -# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" -# logger.info(model_status) -# except Exception as e: -# logger.error(e) -# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮""" -# logger.info(model_status) -# return history + [[None, model_status]] - - -def get_vector_store(local_doc_qa, vs_id, files, sentence_size, history, one_conent, one_content_segmentation): - vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") - filelist = [] - if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): - os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content")) - if local_doc_qa.llm and local_doc_qa.embeddings: - if isinstance(files, list): - for file in files: - filename = os.path.split(file.name)[-1] - shutil.move(file.name, os.path.join( - KB_ROOT_PATH, vs_id, "content", filename)) - filelist.append(os.path.join( - KB_ROOT_PATH, vs_id, "content", filename)) - vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store( - filelist, vs_path, sentence_size) - else: - vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, - sentence_size) - if len(loaded_files): - file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" - else: - file_status = "文件未成功加载,请重新上传文件" - else: - file_status = "模型未完成加载,请先在加载模型后再导入文件" - vs_path = None - logger.info(file_status) - return vs_path, None, history + [[None, file_status]] - - -knowledge_base_test_mode_info = ("【注意】\n\n" - "1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询," - "并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n" - "2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。" - """3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n""" - "4. 单条内容长度建议设置在100-150左右。\n\n" - "5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中," - "本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。" - "相关参数将在后续版本中支持本界面直接修改。") - - -webui_title = """ -# 🎉langchain-ChatGLM WebUI🎉 -👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM) -""" -###### ##### - - -###### todo ##### -# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。 -# 目前已经实现了local_doc_qa的全局化,后面要考虑shared。 -# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。 -# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。 -# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。 -###### ##### - - -###### 配置项 ##### -class ST_CONFIG: - user_bg_color = '#77ff77' - user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7' - robot_bg_color = '#ccccee' - robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0' - default_mode = '知识库问答' - defalut_kb = '' -###### ##### - - -class MsgType: - ''' - 目前仅支持文本类型的输入输出,为以后多模态模型预留图像、视频、音频支持。 - ''' - TEXT = 1 - IMAGE = 2 - VIDEO = 3 - AUDIO = 4 - - -class TempFile: - ''' - 为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式 - ''' - - def __init__(self, path): - self.name = path - - -def init_session(): - st.session_state.setdefault('history', []) - - -# def get_query_params(): -# ''' -# 可以用url参数传递配置参数:llm_model, embedding_model, kb, mode。 -# 该参数将覆盖model_config中的配置。处于安全考虑,目前只支持kb和mode -# 方便将固定的配置分享给特定的人。 -# ''' -# params = st.experimental_get_query_params() -# return {k: v[0] for k, v in params.items() if v} - - -def robot_say(msg, kb=''): - st.session_state['history'].append( - {'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb}) - - -def user_say(msg): - st.session_state['history'].append( - {'is_user': True, 'type': MsgType.TEXT, 'content': msg}) - - -def format_md(msg, is_user=False, bg_color='', margin='10%'): - ''' - 将文本消息格式化为markdown文本 - ''' - if is_user: - bg_color = bg_color or ST_CONFIG.user_bg_color - text = f''' -
- {msg} -
- ''' - else: - bg_color = bg_color or ST_CONFIG.robot_bg_color - text = f''' -
- {msg} -
- ''' - return text - - -def message(msg, - is_user=False, - msg_type=MsgType.TEXT, - icon='', - bg_color='', - margin='10%', - kb='', - ): - ''' - 渲染单条消息。目前仅支持文本 - ''' - cols = st.columns([1, 10, 1]) - empty = cols[1].empty() - if is_user: - icon = icon or ST_CONFIG.user_icon - bg_color = bg_color or ST_CONFIG.user_bg_color - cols[2].image(icon, width=40) - if msg_type == MsgType.TEXT: - text = format_md(msg, is_user, bg_color, margin) - empty.markdown(text, unsafe_allow_html=True) - else: - raise RuntimeError('only support text message now.') - else: - icon = icon or ST_CONFIG.robot_icon - bg_color = bg_color or ST_CONFIG.robot_bg_color - cols[0].image(icon, width=40) - if kb: - cols[0].write(f'({kb})') - if msg_type == MsgType.TEXT: - text = format_md(msg, is_user, bg_color, margin) - empty.markdown(text, unsafe_allow_html=True) - else: - raise RuntimeError('only support text message now.') - return empty - - -def output_messages( - user_bg_color='', - robot_bg_color='', - user_icon='', - robot_icon='', -): - with chat_box.container(): - last_response = None - for msg in st.session_state['history']: - bg_color = user_bg_color if msg['is_user'] else robot_bg_color - icon = user_icon if msg['is_user'] else robot_icon - empty = message(msg['content'], - is_user=msg['is_user'], - icon=icon, - msg_type=msg['type'], - bg_color=bg_color, - kb=msg.get('kb', '') - ) - if not msg['is_user']: - last_response = empty - return last_response - - -@st.cache_resource(show_spinner=False, max_entries=1) -def load_model(llm_model: str, embedding_model: str): - ''' - 对应init_model,利用streamlit cache避免模型重复加载 - ''' - local_doc_qa = init_model(llm_model, embedding_model) - robot_say('模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。\n请尽量不要刷新页面,以免模型出错或重复加载。') - return local_doc_qa - - # @st.cache_data def answer(query, vs_path='', history=[], mode='', score_threshold=0, - vector_search_top_k=5, chunk_conent=True, chunk_size=100, qa=None + vector_search_top_k=5, chunk_conent=True, chunk_size=100 ): ''' 对应get_answer,--利用streamlit cache缓存相同问题的答案-- @@ -365,48 +226,24 @@ def answer(query, vs_path='', history=[], mode='', score_threshold=0, vector_search_top_k, chunk_conent, chunk_size) -def load_vector_store( - vs_id, - files, - sentence_size=100, - history=[], - one_conent=None, - one_content_segmentation=None, -): - return get_vector_store( - local_doc_qa, - vs_id, - files, - sentence_size, - history, - one_conent, - one_content_segmentation, - ) +def use_kb_mode(m): + return m in ["知识库问答", "知识库测试"] # main ui st.set_page_config(webui_title, layout='wide') -init_session() -# params = get_query_params() -# llm_model = params.get('llm_model', LLM_MODEL) -# embedding_model = params.get('embedding_model', EMBEDDING_MODEL) - -with st.spinner(f'正在加载模型({LLM_MODEL} + {EMBEDDING_MODEL}),请耐心等候...'): - local_doc_qa = load_model(LLM_MODEL, EMBEDDING_MODEL) - - -def use_kb_mode(m): - return m in ['知识库问答', '知识库测试'] +chat_box = st_chatbox(greetings=["模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。"]) +# 使用 help(st_chatbox) 查看自定义参数 # sidebar modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试'] with st.sidebar: def on_mode_change(): m = st.session_state.mode - robot_say(f'已切换到"{m}"模式') + chat_box.robot_say(f'已切换到"{m}"模式') if m == '知识库测试': - robot_say(knowledge_base_test_mode_info) + chat_box.robot_say(knowledge_base_test_mode_info) index = 0 try: @@ -416,7 +253,7 @@ with st.sidebar: mode = st.selectbox('对话模式', modes, index, on_change=on_mode_change, key='mode') - with st.expander('模型配置', '知识' not in mode): + with st.expander('模型配置', not use_kb_mode(mode)): with st.form('model_config'): index = 0 try: @@ -425,9 +262,8 @@ with st.sidebar: pass llm_model = st.selectbox('LLM模型', llm_model_dict_list, index) - no_remote_model = st.checkbox('加载本地模型', False) use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False) - use_lora = st.checkbox('使用lora微调的权重', False) + try: index = embedding_model_dict_list.index(EMBEDDING_MODEL) except: @@ -437,44 +273,52 @@ with st.sidebar: btn_load_model = st.form_submit_button('重新加载模型') if btn_load_model: - local_doc_qa = load_model(llm_model, embedding_model) + local_doc_qa = load_model(llm_model, embedding_model, use_ptuning_v2) - if mode in ['知识库问答', '知识库测试']: + history_len = st.slider( + "LLM对话轮数", 1, 50, LLM_HISTORY_LEN) + + if use_kb_mode(mode): vs_list = get_vs_list() vs_list.remove('新建知识库') def on_new_kb(): name = st.session_state.kb_name - if name in vs_list: - st.error(f'名为“{name}”的知识库已存在。') + if not name: + st.sidebar.error(f'新建知识库名称不能为空!') + elif name in vs_list: + st.sidebar.error(f'名为“{name}”的知识库已存在。') else: - vs_list.append(name) st.session_state.vs_path = name + st.session_state.kb_name = '' + new_kb_dir = os.path.join(KB_ROOT_PATH, name) + if not os.path.exists(new_kb_dir): + os.makedirs(new_kb_dir) + st.sidebar.success(f'名为“{name}”的知识库创建成功,您可以开始添加文件。') def on_vs_change(): - robot_say(f'已加载知识库: {st.session_state.vs_path}') + chat_box.robot_say(f'已加载知识库: {st.session_state.vs_path}') with st.expander('知识库配置', True): cols = st.columns([12, 10]) kb_name = cols[0].text_input( - '新知识库名称', placeholder='新知识库名称', label_visibility='collapsed') - if 'kb_name' not in st.session_state: - st.session_state.kb_name = kb_name + '新知识库名称', placeholder='新知识库名称', label_visibility='collapsed', key='kb_name') cols[1].button('新建知识库', on_click=on_new_kb) + index = 0 + try: + index = vs_list.index(ST_CONFIG.default_kb) + except: + pass vs_path = st.selectbox( - '选择知识库', vs_list, on_change=on_vs_change, key='vs_path') + '选择知识库', vs_list, index, on_change=on_vs_change, key='vs_path') st.text('') score_threshold = st.slider( '知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD) top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K) - history_len = st.slider( - 'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置 - # local_doc_qa.llm.set_history_len(history_len) chunk_conent = st.checkbox('启用上下文关联', False) - st.text('') - # chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库 chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE) + st.text('') sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE) files = st.file_uploader('上传知识文件', ['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'], @@ -487,56 +331,61 @@ with st.sidebar: with open(file, 'wb') as fp: fp.write(f.getvalue()) file_list.append(TempFile(file)) - _, _, history = load_vector_store( + _, _, history = get_vector_store( vs_path, file_list, sentence_size, [], None, None) st.session_state.files = [] -# main body -chat_box = st.empty() +# load model after params rendered +with st.spinner(f"正在加载模型({llm_model} + {embedding_model}),请耐心等候..."): + local_doc_qa = load_model( + llm_model, + embedding_model, + use_ptuning_v2, + ) + local_doc_qa.llm_model_chain.history_len = history_len + if use_kb_mode(mode): + local_doc_qa.chunk_conent = chunk_conent + local_doc_qa.chunk_size = chunk_size + # local_doc_qa.llm_model_chain.temperature = temperature # 这样设置temperature似乎不起作用 + st.session_state.local_doc_qa = local_doc_qa -with st.form('my_form', clear_on_submit=True): +# input form +with st.form("my_form", clear_on_submit=True): cols = st.columns([8, 1]) - question = cols[0].text_input( + question = cols[0].text_area( 'temp', key='input_question', label_visibility='collapsed') - def on_send(): - q = st.session_state.input_question - if q: - user_say(q) + if cols[1].form_submit_button("发送"): + chat_box.user_say(question) + history = [] + if mode == "LLM 对话": + chat_box.robot_say("正在思考...") + chat_box.output_messages() + for history, _ in answer(question, + history=[], + mode=mode): + chat_box.update_last_box_text(history[-1][-1]) + elif use_kb_mode(mode): + chat_box.robot_say(f"正在查询 [{vs_path}] ...") + chat_box.output_messages() + for history, _ in answer(question, + vs_path=os.path.join( + KB_ROOT_PATH, vs_path, 'vector_store'), + history=[], + mode=mode, + score_threshold=score_threshold, + vector_search_top_k=top_k, + chunk_conent=chunk_conent, + chunk_size=chunk_size): + chat_box.update_last_box_text(history[-1][-1]) + else: + chat_box.robot_say(f"正在执行Bing搜索...") + chat_box.output_messages() + for history, _ in answer(question, + history=[], + mode=mode): + chat_box.update_last_box_text(history[-1][-1]) - if mode == 'LLM 对话': - robot_say('正在思考...') - last_response = output_messages() - for history, _ in answer(q, - history=[], - mode=mode): - last_response.markdown( - format_md(history[-1][-1], False), - unsafe_allow_html=True - ) - elif use_kb_mode(mode): - robot_say('正在思考...', vs_path) - last_response = output_messages() - for history, _ in answer(q, - vs_path=os.path.join( - KB_ROOT_PATH, vs_path, "vector_store"), - history=[], - mode=mode, - score_threshold=score_threshold, - vector_search_top_k=top_k, - chunk_conent=chunk_conent, - chunk_size=chunk_size): - last_response.markdown( - format_md(history[-1][-1], False, 'ligreen'), - unsafe_allow_html=True - ) - else: - robot_say('正在思考...') - last_response = output_messages() - st.session_state['history'][-1]['content'] = history[-1][-1] - submit = cols[1].form_submit_button('发送', on_click=on_send) - -output_messages() - -# st.write(st.session_state['history']) +# st.write(chat_box.history) +chat_box.output_messages()