From a03b8d330d09942161d8f80a6f18ee772520f3a2 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Wed, 13 Sep 2023 08:43:11 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8D=87=E7=BA=A7=E5=88=B0langchain=3D=3D0.0.2?= =?UTF-8?q?87,fschat=3D=3D0.0.28;=E5=A4=84=E7=90=86langchain.Milvus=20bug(?= =?UTF-8?q?#10492)=20(#1454)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复milvus_kb_service中一些bug,添加文档后将数据同步到数据库 * 升级到langchain==0.0.287,fschat==0.0.28;处理langchain.Milvus bug(#10492) * 修复切换模型BUG: 从在线API切换模型时出错 --- configs/server_config.py.example | 1 + requirements.txt | 4 +- requirements_api.txt | 4 +- .../kb_service/milvus_kb_service.py | 6 ++- server/knowledge_base/utils.py | 38 ++++++------------- startup.py | 2 + tests/api/test_kb_api_request.py | 2 +- tests/api/test_llm_api.py | 5 ++- webui_pages/dialogue/dialogue.py | 4 +- 9 files changed, 30 insertions(+), 36 deletions(-) diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 2157071..0fdd492 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -59,6 +59,7 @@ FSCHAT_MODEL_WORKERS = { # "limit_worker_concurrency": 5, # "stream_interval": 2, # "no_register": False, + # "embed_in_truncate": False, }, "baichuan-7b": { # 使用default中的IP和端口 "device": "cpu", diff --git a/requirements.txt b/requirements.txt index 7897232..8badee0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -langchain==0.0.266 +langchain==0.0.287 +fschat[model_worker]==0.2.28 openai zhipuai sentence_transformers -fschat==0.2.24 transformers>=4.31.0 torch~=2.0.0 fastapi~=0.99.1 diff --git a/requirements_api.txt b/requirements_api.txt index bdecf3c..fa4cf1e 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,7 +1,7 @@ -langchain==0.0.266 +langchain==0.0.287 +fschat[model_worker]==0.2.28 openai sentence_transformers -fschat==0.2.24 transformers>=4.31.0 torch~=2.0.0 fastapi~=0.99.1 diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index f45a8d9..5ca425b 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -68,10 +68,14 @@ class MilvusKBService(KBService): return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: - # TODO: workaround for bug #10492 in langchain==0.0.286 + # TODO: workaround for bug #10492 in langchain for doc in docs: + for k, v in doc.metadata.items(): + doc.metadata[k] = str(v) for field in self.milvus.fields: doc.metadata.setdefault(field, "") + doc.metadata.pop(self.milvus._text_field, None) + doc.metadata.pop(self.milvus._vector_field, None) ids = self.milvus.add_documents(docs) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 906da4d..f822035 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -82,16 +82,16 @@ SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] class CustomJSONLoader(langchain.document_loaders.JSONLoader): ''' - langchain的JSONLoader需要jq,在win上使用不便,进行替代。 + langchain的JSONLoader需要jq,在win上使用不便,进行替代。针对langchain==0.0.286 ''' def __init__( - self, - file_path: Union[str, Path], - content_key: Optional[str] = None, - metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, - text_content: bool = True, - json_lines: bool = False, + self, + file_path: Union[str, Path], + content_key: Optional[str] = None, + metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, + text_content: bool = True, + json_lines: bool = False, ): """Initialize the JSONLoader. @@ -113,21 +113,6 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader): self._text_content = text_content self._json_lines = json_lines - # TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows. - # This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded. - def load(self) -> List[Document]: - """Load and return documents from the JSON file.""" - docs: List[Document] = [] - if self._json_lines: - with self.file_path.open(encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - self._parse(line, docs) - else: - self._parse(self.file_path.read_text(encoding="utf-8"), docs) - return docs - def _parse(self, content: str, docs: List[Document]) -> None: """Convert given content to documents.""" data = json.loads(content) @@ -137,13 +122,14 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader): # and prevent the user from getting a cryptic error later on. if self._content_key is not None: self._validate_content_key(data) + if self._metadata_func is not None: + self._validate_metadata_func(data) for i, sample in enumerate(data, len(docs) + 1): - metadata = dict( - source=str(self.file_path), - seq_num=i, + text = self._get_text(sample=sample) + metadata = self._get_metadata( + sample=sample, source=str(self.file_path), seq_num=i ) - text = self._get_text(sample=sample, metadata=metadata) docs.append(Document(page_content=text, metadata=metadata)) diff --git a/startup.py b/startup.py index ad7cc4a..d045849 100644 --- a/startup.py +++ b/startup.py @@ -88,6 +88,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: args.limit_worker_concurrency = 5 args.stream_interval = 2 args.no_register = False + args.embed_in_truncate = False for k, v in kwargs.items(): setattr(args, k, v) @@ -148,6 +149,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: awq_config=awq_config, stream_interval=args.stream_interval, conv_template=args.conv_template, + embed_in_truncate=args.embed_in_truncate, ) sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config diff --git a/tests/api/test_kb_api_request.py b/tests/api/test_kb_api_request.py index 3b067b3..8645528 100644 --- a/tests/api/test_kb_api_request.py +++ b/tests/api/test_kb_api_request.py @@ -14,7 +14,7 @@ from pprint import pprint api_base_url = api_address() -api: ApiRequest = ApiRequest(api_base_url, no_remote_api=True) +api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False) kb = "kb_for_api_test" diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py index 2c5d039..af5ced8 100644 --- a/tests/api/test_llm_api.py +++ b/tests/api/test_llm_api.py @@ -7,7 +7,7 @@ root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from configs.server_config import FSCHAT_MODEL_WORKERS from configs.model_config import LLM_MODEL, llm_model_dict -from server.utils import api_address +from server.utils import api_address, get_model_worker_config from pprint import pprint import random @@ -65,7 +65,8 @@ def test_change_model(api="/llm_model/change"): assert len(availabel_new_models) > 0 print(availabel_new_models) - model_name = random.choice(running_models) + local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")] + model_name = random.choice(local_models) new_model_name = random.choice(availabel_new_models) print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}") r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name}) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 7d81d14..ec31093 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -80,7 +80,7 @@ def dialogue_page(api: ApiRequest): if x in config_models: config_models.remove(x) llm_models = running_models + config_models - cur_model = st.session_state.get("prev_llm_model", LLM_MODEL) + cur_model = st.session_state.get("cur_llm_model", LLM_MODEL) index = llm_models.index(cur_model) llm_model = st.selectbox("选择LLM模型:", llm_models, @@ -93,7 +93,7 @@ def dialogue_page(api: ApiRequest): and not get_model_worker_config(llm_model).get("online_api")): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model) - st.session_state["prev_llm_model"] = llm_model + st.session_state["cur_llm_model"] = llm_model history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)