* 修复milvus_kb_service中一些bug,添加文档后将数据同步到数据库 * 升级到langchain==0.0.287,fschat==0.0.28;处理langchain.Milvus bug(#10492) * 修复切换模型BUG: 从在线API切换模型时出错
This commit is contained in:
parent
efd6d4a251
commit
a03b8d330d
|
|
@ -59,6 +59,7 @@ FSCHAT_MODEL_WORKERS = {
|
||||||
# "limit_worker_concurrency": 5,
|
# "limit_worker_concurrency": 5,
|
||||||
# "stream_interval": 2,
|
# "stream_interval": 2,
|
||||||
# "no_register": False,
|
# "no_register": False,
|
||||||
|
# "embed_in_truncate": False,
|
||||||
},
|
},
|
||||||
"baichuan-7b": { # 使用default中的IP和端口
|
"baichuan-7b": { # 使用default中的IP和端口
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
langchain==0.0.266
|
langchain==0.0.287
|
||||||
|
fschat[model_worker]==0.2.28
|
||||||
openai
|
openai
|
||||||
zhipuai
|
zhipuai
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
fschat==0.2.24
|
|
||||||
transformers>=4.31.0
|
transformers>=4.31.0
|
||||||
torch~=2.0.0
|
torch~=2.0.0
|
||||||
fastapi~=0.99.1
|
fastapi~=0.99.1
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
langchain==0.0.266
|
langchain==0.0.287
|
||||||
|
fschat[model_worker]==0.2.28
|
||||||
openai
|
openai
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
fschat==0.2.24
|
|
||||||
transformers>=4.31.0
|
transformers>=4.31.0
|
||||||
torch~=2.0.0
|
torch~=2.0.0
|
||||||
fastapi~=0.99.1
|
fastapi~=0.99.1
|
||||||
|
|
|
||||||
|
|
@ -68,10 +68,14 @@ class MilvusKBService(KBService):
|
||||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
# TODO: workaround for bug #10492 in langchain==0.0.286
|
# TODO: workaround for bug #10492 in langchain
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
for k, v in doc.metadata.items():
|
||||||
|
doc.metadata[k] = str(v)
|
||||||
for field in self.milvus.fields:
|
for field in self.milvus.fields:
|
||||||
doc.metadata.setdefault(field, "")
|
doc.metadata.setdefault(field, "")
|
||||||
|
doc.metadata.pop(self.milvus._text_field, None)
|
||||||
|
doc.metadata.pop(self.milvus._vector_field, None)
|
||||||
|
|
||||||
ids = self.milvus.add_documents(docs)
|
ids = self.milvus.add_documents(docs)
|
||||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||||
|
|
|
||||||
|
|
@ -82,16 +82,16 @@ SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||||
|
|
||||||
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||||
'''
|
'''
|
||||||
langchain的JSONLoader需要jq,在win上使用不便,进行替代。
|
langchain的JSONLoader需要jq,在win上使用不便,进行替代。针对langchain==0.0.286
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
file_path: Union[str, Path],
|
file_path: Union[str, Path],
|
||||||
content_key: Optional[str] = None,
|
content_key: Optional[str] = None,
|
||||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||||
text_content: bool = True,
|
text_content: bool = True,
|
||||||
json_lines: bool = False,
|
json_lines: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize the JSONLoader.
|
"""Initialize the JSONLoader.
|
||||||
|
|
||||||
|
|
@ -113,21 +113,6 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||||
self._text_content = text_content
|
self._text_content = text_content
|
||||||
self._json_lines = json_lines
|
self._json_lines = json_lines
|
||||||
|
|
||||||
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
|
|
||||||
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
|
|
||||||
def load(self) -> List[Document]:
|
|
||||||
"""Load and return documents from the JSON file."""
|
|
||||||
docs: List[Document] = []
|
|
||||||
if self._json_lines:
|
|
||||||
with self.file_path.open(encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if line:
|
|
||||||
self._parse(line, docs)
|
|
||||||
else:
|
|
||||||
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def _parse(self, content: str, docs: List[Document]) -> None:
|
def _parse(self, content: str, docs: List[Document]) -> None:
|
||||||
"""Convert given content to documents."""
|
"""Convert given content to documents."""
|
||||||
data = json.loads(content)
|
data = json.loads(content)
|
||||||
|
|
@ -137,13 +122,14 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||||
# and prevent the user from getting a cryptic error later on.
|
# and prevent the user from getting a cryptic error later on.
|
||||||
if self._content_key is not None:
|
if self._content_key is not None:
|
||||||
self._validate_content_key(data)
|
self._validate_content_key(data)
|
||||||
|
if self._metadata_func is not None:
|
||||||
|
self._validate_metadata_func(data)
|
||||||
|
|
||||||
for i, sample in enumerate(data, len(docs) + 1):
|
for i, sample in enumerate(data, len(docs) + 1):
|
||||||
metadata = dict(
|
text = self._get_text(sample=sample)
|
||||||
source=str(self.file_path),
|
metadata = self._get_metadata(
|
||||||
seq_num=i,
|
sample=sample, source=str(self.file_path), seq_num=i
|
||||||
)
|
)
|
||||||
text = self._get_text(sample=sample, metadata=metadata)
|
|
||||||
docs.append(Document(page_content=text, metadata=metadata))
|
docs.append(Document(page_content=text, metadata=metadata))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||||
args.limit_worker_concurrency = 5
|
args.limit_worker_concurrency = 5
|
||||||
args.stream_interval = 2
|
args.stream_interval = 2
|
||||||
args.no_register = False
|
args.no_register = False
|
||||||
|
args.embed_in_truncate = False
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
|
|
@ -148,6 +149,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||||
awq_config=awq_config,
|
awq_config=awq_config,
|
||||||
stream_interval=args.stream_interval,
|
stream_interval=args.stream_interval,
|
||||||
conv_template=args.conv_template,
|
conv_template=args.conv_template,
|
||||||
|
embed_in_truncate=args.embed_in_truncate,
|
||||||
)
|
)
|
||||||
sys.modules["fastchat.serve.model_worker"].args = args
|
sys.modules["fastchat.serve.model_worker"].args = args
|
||||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
api_base_url = api_address()
|
api_base_url = api_address()
|
||||||
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=True)
|
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False)
|
||||||
|
|
||||||
|
|
||||||
kb = "kb_for_api_test"
|
kb = "kb_for_api_test"
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ root_path = Path(__file__).parent.parent.parent
|
||||||
sys.path.append(str(root_path))
|
sys.path.append(str(root_path))
|
||||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||||
from configs.model_config import LLM_MODEL, llm_model_dict
|
from configs.model_config import LLM_MODEL, llm_model_dict
|
||||||
from server.utils import api_address
|
from server.utils import api_address, get_model_worker_config
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
import random
|
import random
|
||||||
|
|
@ -65,7 +65,8 @@ def test_change_model(api="/llm_model/change"):
|
||||||
assert len(availabel_new_models) > 0
|
assert len(availabel_new_models) > 0
|
||||||
print(availabel_new_models)
|
print(availabel_new_models)
|
||||||
|
|
||||||
model_name = random.choice(running_models)
|
local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")]
|
||||||
|
model_name = random.choice(local_models)
|
||||||
new_model_name = random.choice(availabel_new_models)
|
new_model_name = random.choice(availabel_new_models)
|
||||||
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
|
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
|
||||||
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
|
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
if x in config_models:
|
if x in config_models:
|
||||||
config_models.remove(x)
|
config_models.remove(x)
|
||||||
llm_models = running_models + config_models
|
llm_models = running_models + config_models
|
||||||
cur_model = st.session_state.get("prev_llm_model", LLM_MODEL)
|
cur_model = st.session_state.get("cur_llm_model", LLM_MODEL)
|
||||||
index = llm_models.index(cur_model)
|
index = llm_models.index(cur_model)
|
||||||
llm_model = st.selectbox("选择LLM模型:",
|
llm_model = st.selectbox("选择LLM模型:",
|
||||||
llm_models,
|
llm_models,
|
||||||
|
|
@ -93,7 +93,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
and not get_model_worker_config(llm_model).get("online_api")):
|
and not get_model_worker_config(llm_model).get("online_api")):
|
||||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||||
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
st.session_state["cur_llm_model"] = llm_model
|
||||||
|
|
||||||
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
|
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue