Merge branch 'dev' of github.com:chatchat-space/Langchain-Chatchat into dev
This commit is contained in:
commit
84b491b8b2
|
|
@ -5,7 +5,6 @@ fschat==0.2.20
|
||||||
transformers
|
transformers
|
||||||
torch~=2.0.0
|
torch~=2.0.0
|
||||||
fastapi~=0.99.1
|
fastapi~=0.99.1
|
||||||
fastapi-offline
|
|
||||||
nltk~=3.8.1
|
nltk~=3.8.1
|
||||||
uvicorn~=0.23.1
|
uvicorn~=0.23.1
|
||||||
starlette~=0.27.0
|
starlette~=0.27.0
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ fschat==0.2.20
|
||||||
transformers
|
transformers
|
||||||
torch~=2.0.0
|
torch~=2.0.0
|
||||||
fastapi~=0.99.1
|
fastapi~=0.99.1
|
||||||
fastapi-offline
|
|
||||||
nltk~=3.8.1
|
nltk~=3.8.1
|
||||||
uvicorn~=0.23.1
|
uvicorn~=0.23.1
|
||||||
starlette~=0.27.0
|
starlette~=0.27.0
|
||||||
|
|
|
||||||
|
|
@ -7,15 +7,17 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
||||||
import argparse
|
import argparse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from server.utils import FastAPIOffline as FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||||
search_engine_chat)
|
search_engine_chat)
|
||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||||
update_doc, download_doc, recreate_vector_store)
|
update_doc, download_doc, recreate_vector_store,
|
||||||
from server.utils import BaseResponse, ListResponse
|
search_docs, DocumentWithScore)
|
||||||
|
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
|
|
@ -25,7 +27,8 @@ async def document():
|
||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
app = FastAPI()
|
app = FastAPI(title="Langchain-Chatchat API Server")
|
||||||
|
MakeFastAPIOffline(app)
|
||||||
# Add CORS middleware to allow all origins
|
# Add CORS middleware to allow all origins
|
||||||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||||||
|
|
@ -83,6 +86,12 @@ def create_app():
|
||||||
summary="获取知识库内的文件列表"
|
summary="获取知识库内的文件列表"
|
||||||
)(list_docs)
|
)(list_docs)
|
||||||
|
|
||||||
|
app.post("/knowledge_base/search_docs",
|
||||||
|
tags=["Knowledge Base Management"],
|
||||||
|
response_model=List[DocumentWithScore],
|
||||||
|
summary="搜索知识库"
|
||||||
|
)(search_docs)
|
||||||
|
|
||||||
app.post("/knowledge_base/upload_doc",
|
app.post("/knowledge_base/upload_doc",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,27 @@
|
||||||
from fastapi import Body, Request
|
from fastapi import Body, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||||
VECTOR_SEARCH_TOP_K)
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||||
from server.chat.utils import wrap_done
|
from server.chat.utils import wrap_done
|
||||||
from server.utils import BaseResponse
|
from server.utils import BaseResponse
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable, List, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from typing import List, Optional
|
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
from server.knowledge_base.kb_doc_api import search_docs
|
||||||
|
|
||||||
|
|
||||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
|
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
examples=[[
|
examples=[[
|
||||||
|
|
@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
model_name=LLM_MODEL
|
model_name=LLM_MODEL
|
||||||
)
|
)
|
||||||
docs = kb.search_docs(query, top_k)
|
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,32 @@
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
from fastapi import File, Form, Body, Query, UploadFile
|
from fastapi import File, Form, Body, Query, UploadFile
|
||||||
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL
|
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
|
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentWithScore(Document):
|
||||||
|
score: float = None
|
||||||
|
|
||||||
|
|
||||||
|
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
|
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||||
|
) -> List[DocumentWithScore]:
|
||||||
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
|
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
|
||||||
|
docs = kb.search_docs(query, top_k, score_threshold)
|
||||||
|
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def list_docs(
|
async def list_docs(
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import (
|
||||||
list_docs_from_db, get_file_detail, delete_file_from_db
|
list_docs_from_db, get_file_detail, delete_file_from_db
|
||||||
)
|
)
|
||||||
|
|
||||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
|
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||||
from server.knowledge_base.utils import (
|
from server.knowledge_base.utils import (
|
||||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||||
|
|
@ -112,9 +112,10 @@ class KBService(ABC):
|
||||||
def search_docs(self,
|
def search_docs(self,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
):
|
):
|
||||||
embeddings = self._load_embeddings()
|
embeddings = self._load_embeddings()
|
||||||
docs = self.do_search(query, top_k, embeddings)
|
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -81,12 +81,13 @@ class FaissKBService(KBService):
|
||||||
def do_search(self,
|
def do_search(self,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
embeddings: Embeddings,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
|
embeddings: Embeddings = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
search_index = load_vector_store(self.kb_name,
|
search_index = load_vector_store(self.kb_name,
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||||
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD)
|
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,8 @@ class MilvusKBService(KBService):
|
||||||
def do_drop_kb(self):
|
def do_drop_kb(self):
|
||||||
self.milvus.col.drop()
|
self.milvus.col.drop()
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
|
||||||
|
# todo: support score threshold
|
||||||
self._load_milvus(embeddings=embeddings)
|
self._load_milvus(embeddings=embeddings)
|
||||||
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
|
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,8 @@ class PGKBService(KBService):
|
||||||
'''))
|
'''))
|
||||||
connect.commit()
|
connect.commit()
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
|
||||||
|
# todo: support score threshold
|
||||||
self._load_pg_vector(embeddings=embeddings)
|
self._load_pg_vector(embeddings=embeddings)
|
||||||
return self.pg_vector.similarity_search(query, top_k)
|
return self.pg_vector.similarity_search(query, top_k)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ import os
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
||||||
|
from server.utils import MakeFastAPIOffline
|
||||||
|
|
||||||
|
|
||||||
host_ip = "0.0.0.0"
|
host_ip = "0.0.0.0"
|
||||||
controller_port = 20001
|
controller_port = 20001
|
||||||
|
|
@ -30,6 +32,8 @@ def create_controller_app(
|
||||||
controller = Controller(dispatch_method)
|
controller = Controller(dispatch_method)
|
||||||
sys.modules["fastchat.serve.controller"].controller = controller
|
sys.modules["fastchat.serve.controller"].controller = controller
|
||||||
|
|
||||||
|
MakeFastAPIOffline(app)
|
||||||
|
app.title = "FastChat Controller"
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -55,7 +59,6 @@ def create_model_worker_app(
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
|
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
|
||||||
from fastchat.serve import model_worker
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
@ -117,6 +120,8 @@ def create_model_worker_app(
|
||||||
sys.modules["fastchat.serve.model_worker"].args = args
|
sys.modules["fastchat.serve.model_worker"].args = args
|
||||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||||
|
|
||||||
|
MakeFastAPIOffline(app)
|
||||||
|
app.title = f"FastChat LLM Server ({LLM_MODEL})"
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -141,6 +146,8 @@ def create_openai_api_app(
|
||||||
app_settings.controller_address = controller_address
|
app_settings.controller_address = controller_address
|
||||||
app_settings.api_keys = api_keys
|
app_settings.api_keys = api_keys
|
||||||
|
|
||||||
|
MakeFastAPIOffline(app)
|
||||||
|
app.title = "FastChat OpeanAI API Server"
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 7.1 KiB |
|
|
@ -2,14 +2,10 @@ import pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
from fastapi_offline import FastAPIOffline
|
from fastapi import FastAPI
|
||||||
import fastapi_offline
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
# patch fastapi_offline to use local static assests
|
|
||||||
fastapi_offline.core._STATIC_PATH = Path(__file__).parent / "static"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
class BaseResponse(BaseModel):
|
||||||
|
|
@ -112,3 +108,81 @@ def iter_over_async(ait, loop):
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
yield obj
|
yield obj
|
||||||
|
|
||||||
|
|
||||||
|
def MakeFastAPIOffline(
|
||||||
|
app: FastAPI,
|
||||||
|
static_dir = Path(__file__).parent / "static",
|
||||||
|
static_url = "/static-offline-docs",
|
||||||
|
docs_url: Optional[str] = "/docs",
|
||||||
|
redoc_url: Optional[str] = "/redoc",
|
||||||
|
) -> None:
|
||||||
|
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.openapi.docs import (
|
||||||
|
get_redoc_html,
|
||||||
|
get_swagger_ui_html,
|
||||||
|
get_swagger_ui_oauth2_redirect_html,
|
||||||
|
)
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from starlette.responses import HTMLResponse
|
||||||
|
|
||||||
|
openapi_url = app.openapi_url
|
||||||
|
swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
|
||||||
|
|
||||||
|
def remove_route(url: str) -> None:
|
||||||
|
'''
|
||||||
|
remove original route from app
|
||||||
|
'''
|
||||||
|
index = None
|
||||||
|
for i, r in enumerate(app.routes):
|
||||||
|
if r.path.lower() == url.lower():
|
||||||
|
index = i
|
||||||
|
break
|
||||||
|
if isinstance(index, int):
|
||||||
|
app.routes.pop(i)
|
||||||
|
|
||||||
|
# Set up static file mount
|
||||||
|
app.mount(
|
||||||
|
static_url,
|
||||||
|
StaticFiles(directory=Path(static_dir).as_posix()),
|
||||||
|
name="static-offline-docs",
|
||||||
|
)
|
||||||
|
|
||||||
|
if docs_url is not None:
|
||||||
|
remove_route(docs_url)
|
||||||
|
remove_route(swagger_ui_oauth2_redirect_url)
|
||||||
|
|
||||||
|
# Define the doc and redoc pages, pointing at the right files
|
||||||
|
@app.get(docs_url, include_in_schema=False)
|
||||||
|
async def custom_swagger_ui_html(request: Request) -> HTMLResponse:
|
||||||
|
root = request.scope.get("root_path")
|
||||||
|
favicon = f"{root}{static_url}/favicon.png"
|
||||||
|
return get_swagger_ui_html(
|
||||||
|
openapi_url=f"{root}{openapi_url}",
|
||||||
|
title=app.title + " - Swagger UI",
|
||||||
|
oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
|
||||||
|
swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js",
|
||||||
|
swagger_css_url=f"{root}{static_url}/swagger-ui.css",
|
||||||
|
swagger_favicon_url=favicon,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False)
|
||||||
|
async def swagger_ui_redirect() -> HTMLResponse:
|
||||||
|
return get_swagger_ui_oauth2_redirect_html()
|
||||||
|
|
||||||
|
if redoc_url is not None:
|
||||||
|
remove_route(redoc_url)
|
||||||
|
|
||||||
|
@app.get(redoc_url, include_in_schema=False)
|
||||||
|
async def redoc_html(request: Request) -> HTMLResponse:
|
||||||
|
root = request.scope.get("root_path")
|
||||||
|
favicon = f"{root}{static_url}/favicon.png"
|
||||||
|
|
||||||
|
return get_redoc_html(
|
||||||
|
openapi_url=f"{root}{openapi_url}",
|
||||||
|
title=app.title + " - ReDoc",
|
||||||
|
redoc_js_url=f"{root}{static_url}/redoc.standalone.js",
|
||||||
|
with_google_fonts=False,
|
||||||
|
redoc_favicon_url=favicon,
|
||||||
|
)
|
||||||
|
|
|
||||||
6
webui.py
6
webui.py
|
|
@ -13,7 +13,11 @@ import os
|
||||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
st.set_page_config("Langchain-Chatchat WebUI", initial_sidebar_state="expanded")
|
st.set_page_config(
|
||||||
|
"Langchain-Chatchat WebUI",
|
||||||
|
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
|
||||||
|
initial_sidebar_state="expanded",
|
||||||
|
)
|
||||||
|
|
||||||
if not chat_box.chat_inited:
|
if not chat_box.chat_inited:
|
||||||
st.toast(
|
st.toast(
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
key="selected_kb",
|
key="selected_kb",
|
||||||
)
|
)
|
||||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
|
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
|
||||||
# score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True)
|
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
|
||||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
|
|
@ -111,8 +111,8 @@ def dialogue_page(api: ApiRequest):
|
||||||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
|
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
|
||||||
if error_msg := check_error_msg(t): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
text += d["answer"]
|
text += d["answer"]
|
||||||
chat_box.update_msg(text, 0)
|
chat_box.update_msg(text, 0)
|
||||||
|
|
@ -125,7 +125,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
||||||
if error_msg := check_error_msg(t): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
text += d["answer"]
|
text += d["answer"]
|
||||||
chat_box.update_msg(text, 0)
|
chat_box.update_msg(text, 0)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from configs.model_config import (
|
||||||
DEFAULT_VS_TYPE,
|
DEFAULT_VS_TYPE,
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
|
SCORE_THRESHOLD,
|
||||||
VECTOR_SEARCH_TOP_K,
|
VECTOR_SEARCH_TOP_K,
|
||||||
SEARCH_ENGINE_TOP_K,
|
SEARCH_ENGINE_TOP_K,
|
||||||
logger,
|
logger,
|
||||||
|
|
@ -312,6 +313,7 @@ class ApiRequest:
|
||||||
query: str,
|
query: str,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
|
|
@ -326,6 +328,7 @@ class ApiRequest:
|
||||||
"query": query,
|
"query": query,
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
|
"score_threshold": score_threshold,
|
||||||
"history": history,
|
"history": history,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"local_doc_url": no_remote_api,
|
"local_doc_url": no_remote_api,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue