* 修复milvus_kb_service中一些bug,添加文档后将数据同步到数据库 * 升级到langchain==0.0.287,fschat==0.0.28;处理langchain.Milvus bug(#10492) * 修复切换模型BUG: 从在线API切换模型时出错
This commit is contained in:
parent
efd6d4a251
commit
a03b8d330d
|
|
@ -59,6 +59,7 @@ FSCHAT_MODEL_WORKERS = {
|
|||
# "limit_worker_concurrency": 5,
|
||||
# "stream_interval": 2,
|
||||
# "no_register": False,
|
||||
# "embed_in_truncate": False,
|
||||
},
|
||||
"baichuan-7b": { # 使用default中的IP和端口
|
||||
"device": "cpu",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
langchain==0.0.266
|
||||
langchain==0.0.287
|
||||
fschat[model_worker]==0.2.28
|
||||
openai
|
||||
zhipuai
|
||||
sentence_transformers
|
||||
fschat==0.2.24
|
||||
transformers>=4.31.0
|
||||
torch~=2.0.0
|
||||
fastapi~=0.99.1
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
langchain==0.0.266
|
||||
langchain==0.0.287
|
||||
fschat[model_worker]==0.2.28
|
||||
openai
|
||||
sentence_transformers
|
||||
fschat==0.2.24
|
||||
transformers>=4.31.0
|
||||
torch~=2.0.0
|
||||
fastapi~=0.99.1
|
||||
|
|
|
|||
|
|
@ -68,10 +68,14 @@ class MilvusKBService(KBService):
|
|||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
# TODO: workaround for bug #10492 in langchain==0.0.286
|
||||
# TODO: workaround for bug #10492 in langchain
|
||||
for doc in docs:
|
||||
for k, v in doc.metadata.items():
|
||||
doc.metadata[k] = str(v)
|
||||
for field in self.milvus.fields:
|
||||
doc.metadata.setdefault(field, "")
|
||||
doc.metadata.pop(self.milvus._text_field, None)
|
||||
doc.metadata.pop(self.milvus._vector_field, None)
|
||||
|
||||
ids = self.milvus.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
|
|
|
|||
|
|
@ -82,16 +82,16 @@ SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
|||
|
||||
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||
'''
|
||||
langchain的JSONLoader需要jq,在win上使用不便,进行替代。
|
||||
langchain的JSONLoader需要jq,在win上使用不便,进行替代。针对langchain==0.0.286
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
json_lines: bool = False,
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
json_lines: bool = False,
|
||||
):
|
||||
"""Initialize the JSONLoader.
|
||||
|
||||
|
|
@ -113,21 +113,6 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
|||
self._text_content = text_content
|
||||
self._json_lines = json_lines
|
||||
|
||||
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
|
||||
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
|
||||
def load(self) -> List[Document]:
|
||||
"""Load and return documents from the JSON file."""
|
||||
docs: List[Document] = []
|
||||
if self._json_lines:
|
||||
with self.file_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._parse(line, docs)
|
||||
else:
|
||||
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
|
||||
return docs
|
||||
|
||||
def _parse(self, content: str, docs: List[Document]) -> None:
|
||||
"""Convert given content to documents."""
|
||||
data = json.loads(content)
|
||||
|
|
@ -137,13 +122,14 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
|||
# and prevent the user from getting a cryptic error later on.
|
||||
if self._content_key is not None:
|
||||
self._validate_content_key(data)
|
||||
if self._metadata_func is not None:
|
||||
self._validate_metadata_func(data)
|
||||
|
||||
for i, sample in enumerate(data, len(docs) + 1):
|
||||
metadata = dict(
|
||||
source=str(self.file_path),
|
||||
seq_num=i,
|
||||
text = self._get_text(sample=sample)
|
||||
metadata = self._get_metadata(
|
||||
sample=sample, source=str(self.file_path), seq_num=i
|
||||
)
|
||||
text = self._get_text(sample=sample, metadata=metadata)
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
args.limit_worker_concurrency = 5
|
||||
args.stream_interval = 2
|
||||
args.no_register = False
|
||||
args.embed_in_truncate = False
|
||||
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
|
|
@ -148,6 +149,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
embed_in_truncate=args.embed_in_truncate,
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from pprint import pprint
|
|||
|
||||
|
||||
api_base_url = api_address()
|
||||
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=True)
|
||||
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False)
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ root_path = Path(__file__).parent.parent.parent
|
|||
sys.path.append(str(root_path))
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from configs.model_config import LLM_MODEL, llm_model_dict
|
||||
from server.utils import api_address
|
||||
from server.utils import api_address, get_model_worker_config
|
||||
|
||||
from pprint import pprint
|
||||
import random
|
||||
|
|
@ -65,7 +65,8 @@ def test_change_model(api="/llm_model/change"):
|
|||
assert len(availabel_new_models) > 0
|
||||
print(availabel_new_models)
|
||||
|
||||
model_name = random.choice(running_models)
|
||||
local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")]
|
||||
model_name = random.choice(local_models)
|
||||
new_model_name = random.choice(availabel_new_models)
|
||||
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
|
||||
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ def dialogue_page(api: ApiRequest):
|
|||
if x in config_models:
|
||||
config_models.remove(x)
|
||||
llm_models = running_models + config_models
|
||||
cur_model = st.session_state.get("prev_llm_model", LLM_MODEL)
|
||||
cur_model = st.session_state.get("cur_llm_model", LLM_MODEL)
|
||||
index = llm_models.index(cur_model)
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
llm_models,
|
||||
|
|
@ -93,7 +93,7 @@ def dialogue_page(api: ApiRequest):
|
|||
and not get_model_worker_config(llm_model).get("online_api")):
|
||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
st.session_state["cur_llm_model"] = llm_model
|
||||
|
||||
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue