升级到langchain==0.0.287,fschat==0.0.28;处理langchain.Milvus bug(#10492) (#1454)

* 修复milvus_kb_service中一些bug,添加文档后将数据同步到数据库
* 升级到langchain==0.0.287,fschat==0.0.28;处理langchain.Milvus bug(#10492)
* 修复切换模型BUG: 从在线API切换模型时出错
This commit is contained in:
liunux4odoo 2023-09-13 08:43:11 +08:00 committed by GitHub
parent efd6d4a251
commit a03b8d330d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 30 additions and 36 deletions

View File

@ -59,6 +59,7 @@ FSCHAT_MODEL_WORKERS = {
# "limit_worker_concurrency": 5,
# "stream_interval": 2,
# "no_register": False,
# "embed_in_truncate": False,
},
"baichuan-7b": { # 使用default中的IP和端口
"device": "cpu",

View File

@ -1,8 +1,8 @@
langchain==0.0.266
langchain==0.0.287
fschat[model_worker]==0.2.28
openai
zhipuai
sentence_transformers
fschat==0.2.24
transformers>=4.31.0
torch~=2.0.0
fastapi~=0.99.1

View File

@ -1,7 +1,7 @@
langchain==0.0.266
langchain==0.0.287
fschat[model_worker]==0.2.28
openai
sentence_transformers
fschat==0.2.24
transformers>=4.31.0
torch~=2.0.0
fastapi~=0.99.1

View File

@ -68,10 +68,14 @@ class MilvusKBService(KBService):
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
# TODO: workaround for bug #10492 in langchain==0.0.286
# TODO: workaround for bug #10492 in langchain
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.milvus.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.milvus._text_field, None)
doc.metadata.pop(self.milvus._vector_field, None)
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]

View File

@ -82,16 +82,16 @@ SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
'''
langchain的JSONLoader需要jq在win上使用不便进行替代
langchain的JSONLoader需要jq在win上使用不便进行替代针对langchain==0.0.286
'''
def __init__(
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
):
"""Initialize the JSONLoader.
@ -113,21 +113,6 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
self._text_content = text_content
self._json_lines = json_lines
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
def load(self) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
if self._json_lines:
with self.file_path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self._parse(line, docs)
else:
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
return docs
def _parse(self, content: str, docs: List[Document]) -> None:
"""Convert given content to documents."""
data = json.loads(content)
@ -137,13 +122,14 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
# and prevent the user from getting a cryptic error later on.
if self._content_key is not None:
self._validate_content_key(data)
if self._metadata_func is not None:
self._validate_metadata_func(data)
for i, sample in enumerate(data, len(docs) + 1):
metadata = dict(
source=str(self.file_path),
seq_num=i,
text = self._get_text(sample=sample)
metadata = self._get_metadata(
sample=sample, source=str(self.file_path), seq_num=i
)
text = self._get_text(sample=sample, metadata=metadata)
docs.append(Document(page_content=text, metadata=metadata))

View File

@ -88,6 +88,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.limit_worker_concurrency = 5
args.stream_interval = 2
args.no_register = False
args.embed_in_truncate = False
for k, v in kwargs.items():
setattr(args, k, v)
@ -148,6 +149,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config

View File

@ -14,7 +14,7 @@ from pprint import pprint
api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=True)
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False)
kb = "kb_for_api_test"

View File

@ -7,7 +7,7 @@ root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict
from server.utils import api_address
from server.utils import api_address, get_model_worker_config
from pprint import pprint
import random
@ -65,7 +65,8 @@ def test_change_model(api="/llm_model/change"):
assert len(availabel_new_models) > 0
print(availabel_new_models)
model_name = random.choice(running_models)
local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")]
model_name = random.choice(local_models)
new_model_name = random.choice(availabel_new_models)
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})

View File

@ -80,7 +80,7 @@ def dialogue_page(api: ApiRequest):
if x in config_models:
config_models.remove(x)
llm_models = running_models + config_models
cur_model = st.session_state.get("prev_llm_model", LLM_MODEL)
cur_model = st.session_state.get("cur_llm_model", LLM_MODEL)
index = llm_models.index(cur_model)
llm_model = st.selectbox("选择LLM模型",
llm_models,
@ -93,7 +93,7 @@ def dialogue_page(api: ApiRequest):
and not get_model_worker_config(llm_model).get("online_api")):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = llm_model
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)