commit some debug log
This commit is contained in:
parent
e50a06fd00
commit
68bd4ffef1
|
|
@ -40,7 +40,7 @@ LLM_MODEL_CONFIG:
|
|||
prompt_name: default
|
||||
callbacks: false
|
||||
llm_model:
|
||||
model: ''
|
||||
model: 'qwen2-instruct'
|
||||
temperature: 0.9
|
||||
max_tokens: 4096
|
||||
history_len: 10
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ llm_model:
|
|||
The AI is talkative and provides lots of specific details from its context.\n
|
||||
If the AI does not know the answer to a question, it truthfully says it does not
|
||||
know.\n\nCurrent conversation:\n{{history}}\nHuman: {{input}}\nAI:"
|
||||
intention: "你是一个意图识别专家,你主要的任务是根据用户的输入提取出用户的意图,意图主要有两类,第一类:打开页面,第二类:切换页面,请提取出以下这种格式的数据表明是打开还是切换,及具体的模块名{'action':'打开','module':'模块名'} 注意当用户只说一个名词时,默认是切换页面,名词为模块名"
|
||||
|
||||
# RAG 用模板,可用于知识库问答、文件对话、搜索引擎对话
|
||||
rag:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ search_local_knowledgebase:
|
|||
# 搜索引擎工具配置项。推荐自己部署 searx 搜索引擎,国内使用最方便。
|
||||
search_internet:
|
||||
use: false
|
||||
search_engine_name: duckduckgo
|
||||
search_engine_name: searx
|
||||
search_engine_config:
|
||||
bing:
|
||||
bing_search_url: https://api.bing.microsoft.com/v7.0/search
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from .tools_registry import BaseToolOutput, regist_tool, format_context
|
|||
|
||||
|
||||
def searx_search(text ,config, top_k: int):
|
||||
print(f"searx_search: text: {text},config:{config},top_k:{top_k}")
|
||||
search = SearxSearchWrapper(
|
||||
searx_host=config["host"],
|
||||
engines=config["engines"],
|
||||
|
|
@ -111,6 +112,7 @@ def search_engine(query: str, top_k:int=0, engine_name: str="", config: dict={})
|
|||
if top_k <= 0:
|
||||
top_k = config.get("top_k", Settings.kb_settings.SEARCH_ENGINE_TOP_K)
|
||||
engine_name = engine_name or config.get("search_engine_name")
|
||||
print(f"search_engine: query: {query},engine_name:{engine_name},top_k:{top_k}")
|
||||
search_engine_use = SEARCH_ENGINES[engine_name]
|
||||
results = search_engine_use(
|
||||
text=query, config=config["search_engine_config"][engine_name], top_k=top_k
|
||||
|
|
@ -122,4 +124,5 @@ def search_engine(query: str, top_k:int=0, engine_name: str="", config: dict={})
|
|||
@regist_tool(title="互联网搜索")
|
||||
def search_internet(query: str = Field(description="query for Internet search")):
|
||||
"""Use this tool to use bing search engine to search the internet and get information."""
|
||||
print(f"search_internet: query: {query}")
|
||||
return BaseToolOutput(search_engine(query=query), format=format_context)
|
||||
|
|
|
|||
|
|
@ -11,13 +11,18 @@ from chatchat.server.knowledge_base.kb_doc_api import search_docs
|
|||
from chatchat.server.pydantic_v1 import Field
|
||||
from chatchat.server.utils import get_tool_config
|
||||
|
||||
# template = (
|
||||
# "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on "
|
||||
# "this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
||||
# )
|
||||
template = (
|
||||
"Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on "
|
||||
"this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
||||
"使用本地知识库里\n{KB_info}\n去查与大模型技术栈相关的问题时,只有当用户的问题在本地知识库里时才使用这个工具查询"
|
||||
"'database' 应该是上面的 [{key}] 之一."
|
||||
)
|
||||
KB_info_str = "\n".join([f"{key}: {value}" for key, value in Settings.kb_settings.KB_INFO.items()])
|
||||
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
||||
|
||||
print(f"template_knowledge 模版:{template_knowledge}")
|
||||
|
||||
def search_knowledgebase(query: str, database: str, config: dict):
|
||||
docs = search_docs(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
import requests
|
||||
import json
|
||||
url = f"https://api.seniverse.com/v3/weather/now.json?key=SCXuHjCcMZfxWOy-R&location=合肥&language=zh-Hans&unit=c"
|
||||
print(f"url:{url}")
|
||||
response = requests.post(url,proxies={"http": None, "https": None})
|
||||
|
||||
print(f"response.text:{response.text}")
|
||||
# 检查响应状态码
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
response.encoding = 'utf-8'
|
||||
# 尝试将响应内容解析为 JSON
|
||||
data = response.json() # 转换为 Python 字典
|
||||
print(f"JSON Response:{data}")
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
print("The response is not a valid JSON.")
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(f"Raw Response: {response.text}")
|
||||
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
"""
|
||||
简单的单参数输入工具实现,用于查询现在天气的情况
|
||||
"""
|
||||
import requests
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import warnings
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from chatchat.server.utils import get_tool_config
|
||||
|
||||
|
|
@ -15,17 +17,22 @@ def weather_check(
|
|||
):
|
||||
"""Use this tool to check the weather at a specific city"""
|
||||
|
||||
print(f"weather_check,city{city}")
|
||||
# warnings.filterwarnings("ignore", message="Unverified HTTPS request")
|
||||
print(f"weather_check tool内部调用,city{city}")
|
||||
tool_config = get_tool_config("weather_check")
|
||||
api_key = tool_config.get("api_key")
|
||||
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={city}&language=zh-Hans&unit=c"
|
||||
url = f"http://api.seniverse.com/v3/weather/now.json?key={api_key}&location={city}&language=zh-Hans&unit=c"
|
||||
logging.info(f"url:{url}")
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logging.info(f"response.json():{data}")
|
||||
weather = {
|
||||
"temperature": data["results"][0]["now"]["temperature"],
|
||||
"description": data["results"][0]["now"]["text"],
|
||||
}
|
||||
return BaseToolOutput(weather)
|
||||
else:
|
||||
logging.error(f"Failed to retrieve weather: {response.status_code}")
|
||||
raise Exception(f"Failed to retrieve weather: {response.status_code}")
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from .tools_registry import BaseToolOutput, regist_tool
|
|||
@regist_tool(title="维基百科搜索")
|
||||
def wikipedia_search(query: str = Field(description="The search query")):
|
||||
""" A wrapper that uses Wikipedia to search."""
|
||||
|
||||
print(f"wikipedia_search tool内部调用,str:{query}")
|
||||
api_wrapper = WikipediaAPIWrapper(lang="zh")
|
||||
tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
return BaseToolOutput(tool.run(tool_input=query))
|
||||
|
|
|
|||
|
|
@ -41,6 +41,8 @@ chat_router.post(
|
|||
|
||||
chat_router.post("/kb_chat", summary="知识库对话")(kb_chat)
|
||||
chat_router.post("/file_chat", summary="文件对话")(file_chat)
|
||||
#本来想增加存LLM指定模版的对话,后来发现框架已经提供
|
||||
# chat_router.post("/llm_chat", summary="llm对话")(completion)
|
||||
|
||||
|
||||
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
|
||||
|
|
@ -62,28 +64,36 @@ async def chat_completions(
|
|||
"""
|
||||
# import rich
|
||||
# rich.print(body)
|
||||
|
||||
# 当调用本接口且 body 中没有传入 "max_tokens" 参数时, 默认使用配置中定义的值
|
||||
logger.info(f"body.model_config:{body.model_config},body.tools: {body.tools},body.messages:{body.messages}")
|
||||
if body.max_tokens in [None, 0]:
|
||||
body.max_tokens = Settings.model_settings.MAX_TOKENS
|
||||
|
||||
import time
|
||||
current_time = time.time()
|
||||
logger.info(f"current time:{time.time()}")
|
||||
client = get_OpenAIClient(model_name=body.model, is_async=True)
|
||||
logger.info(f"")
|
||||
end_time = time.time()
|
||||
logger.info(f"get_OpenAIClient takes{end_time - current_time}")
|
||||
extra = {**body.model_extra} or {}
|
||||
for key in list(extra):
|
||||
delattr(body, key)
|
||||
|
||||
# check tools & tool_choice in request body
|
||||
if isinstance(body.tool_choice, str):
|
||||
logger.info(f"tool_choice")
|
||||
logger.info(f"isinstance(body.tool_choice, str)")
|
||||
if t := get_tool(body.tool_choice):
|
||||
logger.info(f"function name: {t.name}")
|
||||
body.tool_choice = {"function": {"name": t.name}, "type": "function"}
|
||||
if isinstance(body.tools, list):
|
||||
logger.info(f"tools")
|
||||
logger.info(f"isinstance(body.tools, list)")
|
||||
logger.info(f"body.tools:{body.tools}")
|
||||
for i in range(len(body.tools)):
|
||||
if isinstance(body.tools[i], str):
|
||||
if t := get_tool(body.tools[i]):
|
||||
logger.info(f"function name: {t.name}")
|
||||
logger.info(f"parameters: {t.args}")
|
||||
body.tools[i] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
|
@ -93,12 +103,17 @@ async def chat_completions(
|
|||
},
|
||||
}
|
||||
|
||||
# "14c68cfa65ab4d7b8f5dba05f20f4eec"#
|
||||
conversation_id = extra.get("conversation_id")
|
||||
|
||||
logger.info(f"conversation_id:{conversation_id}")
|
||||
|
||||
# chat based on result from one choiced tool
|
||||
if body.tool_choice:
|
||||
logger.info(f"if body.tool_choice:{body.tool_choice:}")
|
||||
tool = get_tool(body.tool_choice["function"]["name"])
|
||||
if not body.tools:
|
||||
logger.info(f"if not body.tools")
|
||||
body.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
|
|
@ -110,6 +125,7 @@ async def chat_completions(
|
|||
}
|
||||
]
|
||||
if tool_input := extra.get("tool_input"):
|
||||
logger.info(f"if tool_input := extra.get")
|
||||
try:
|
||||
message_id = (
|
||||
add_message_to_db(
|
||||
|
|
@ -125,6 +141,7 @@ async def chat_completions(
|
|||
message_id = None
|
||||
|
||||
tool_result = await tool.ainvoke(tool_input)
|
||||
logger.info(f"tool result: {tool_result}")
|
||||
prompt_template = PromptTemplate.from_template(
|
||||
get_prompt_template("rag", "default"), template_format="jinja2"
|
||||
)
|
||||
|
|
@ -161,7 +178,10 @@ async def chat_completions(
|
|||
|
||||
# agent chat with tool calls
|
||||
if body.tools:
|
||||
logger.info(f"if body.tools:{body.tools:}")
|
||||
try:
|
||||
message_str = body.messages[-1]["content"]
|
||||
logger.info(f"message_str:{message_str}")
|
||||
message_id = (
|
||||
add_message_to_db(
|
||||
chat_type="agent_chat",
|
||||
|
|
@ -171,6 +191,7 @@ async def chat_completions(
|
|||
if conversation_id
|
||||
else None
|
||||
)
|
||||
logger.info(f"message_id:{message_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"failed to add message to db: {e}")
|
||||
message_id = None
|
||||
|
|
@ -190,9 +211,12 @@ async def chat_completions(
|
|||
tool_config=extra.get("tool_config", tool_config),
|
||||
max_tokens=body.max_tokens,
|
||||
)
|
||||
logger.info(f"chat result: {result}")
|
||||
return result
|
||||
else: # LLM chat directly
|
||||
try: # query is complex object that unable add to db when using qwen-vl-chat
|
||||
logger.info(f"LLM chat directly")
|
||||
try: # query is complex object that unable add to db when using qwen-vl-chat
|
||||
logger.info(f"conversation_id:{conversation_id}")
|
||||
message_id = (
|
||||
add_message_to_db(
|
||||
chat_type="llm_chat",
|
||||
|
|
@ -202,14 +226,25 @@ async def chat_completions(
|
|||
if conversation_id
|
||||
else None
|
||||
)
|
||||
logger.info(f"message_id:{message_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"failed to add message to db: {e}")
|
||||
message_id = None
|
||||
|
||||
logger.info(f"**message_id:{message_id}")
|
||||
extra_json = {
|
||||
"message_id": message_id,
|
||||
"status": None,
|
||||
}
|
||||
return await openai_request(
|
||||
client.chat.completions.create, body, extra_json=extra_json
|
||||
)
|
||||
logger.info(f"extra_json:{extra_json}")
|
||||
|
||||
if message_id is not None:
|
||||
logger.info(f"message_id is not None")
|
||||
return await openai_request(
|
||||
client.chat.completions.create, body, extra_json=extra_json
|
||||
)
|
||||
else:
|
||||
logger.info(f"message_id is None")
|
||||
return await openai_request(
|
||||
client.chat.completions.create, body)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,31 +72,46 @@ async def openai_request(
|
|||
|
||||
async def generator():
|
||||
try:
|
||||
# logger.info(f"extra_json is:{extra_json},body:{body}")
|
||||
# logger.info(f"****openai_request")
|
||||
for x in header:
|
||||
# logger.info(f"****for x in header")
|
||||
if isinstance(x, str):
|
||||
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
|
||||
elif isinstance(x, dict):
|
||||
x = OpenAIChatOutput.model_validate(x)
|
||||
else:
|
||||
raise RuntimeError(f"unsupported value: {header}")
|
||||
for k, v in extra_json.items():
|
||||
setattr(x, k, v)
|
||||
if extra_json is not None:
|
||||
for k, v in extra_json.items():
|
||||
setattr(x, k, v)
|
||||
yield x.model_dump_json()
|
||||
|
||||
# logger.info(f"****async for chunk in await method(**params) before")
|
||||
# logger.info(f"**params:{params}")
|
||||
async for chunk in await method(**params):
|
||||
for k, v in extra_json.items():
|
||||
setattr(chunk, k, v)
|
||||
# logger.info(f"****async for chunk in await method(**params) after")
|
||||
if extra_json is not None:
|
||||
for k, v in extra_json.items():
|
||||
setattr(chunk, k, v)
|
||||
yield chunk.model_dump_json()
|
||||
|
||||
# logger.info(f"****for x in tail")
|
||||
for x in tail:
|
||||
# logger.info(f"****for x in tail")
|
||||
if isinstance(x, str):
|
||||
# logger.info(f"11111for x in tail")
|
||||
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
|
||||
elif isinstance(x, dict):
|
||||
# logger.info(f"22222for x in tail")
|
||||
x = OpenAIChatOutput.model_validate(x)
|
||||
else:
|
||||
# logger.info(f"33333for x in tail")
|
||||
raise RuntimeError(f"unsupported value: {tail}")
|
||||
for k, v in extra_json.items():
|
||||
setattr(x, k, v)
|
||||
if extra_json is not None:
|
||||
# logger.info(f"extra_json is not None")
|
||||
for k, v in extra_json.items():
|
||||
setattr(x, k, v)
|
||||
yield x.model_dump_json()
|
||||
except asyncio.exceptions.CancelledError:
|
||||
logger.warning("streaming progress has been interrupted by user.")
|
||||
|
|
@ -110,11 +125,14 @@ async def openai_request(
|
|||
params["max_tokens"] = Settings.model_settings.MAX_TOKENS
|
||||
|
||||
if hasattr(body, "stream") and body.stream:
|
||||
# logger.info(f"*** body.stream")
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
# logger.info(f"*** not body.stream")
|
||||
result = await method(**params)
|
||||
for k, v in extra_json.items():
|
||||
setattr(result, k, v)
|
||||
if extra_json is not None:
|
||||
for k, v in extra_json.items():
|
||||
setattr(result, k, v)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
|
|
@ -148,6 +166,7 @@ async def list_models() -> Dict:
|
|||
async def create_chat_completions(
|
||||
body: OpenAIChatInput,
|
||||
):
|
||||
logger.info(f"*****/chat/completions")
|
||||
async with get_model_client(body.model) as client:
|
||||
result = await openai_request(client.chat.completions.create, body)
|
||||
return result
|
||||
|
|
@ -158,6 +177,7 @@ async def create_completions(
|
|||
request: Request,
|
||||
body: OpenAIChatInput,
|
||||
):
|
||||
logger.info(f"*****/completions")
|
||||
async with get_model_client(body.model) as client:
|
||||
return await openai_request(client.completions.create, body)
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ def create_models_from_config(configs, callbacks, stream, max_tokens):
|
|||
prompt_name = params.get("prompt_name", "default")
|
||||
prompt_template = get_prompt_template(type=model_type, name=prompt_name)
|
||||
prompts[model_type] = prompt_template
|
||||
logger.info(f"model_type:{model_type}, model_name:{model_name}, prompt_name:{prompt_name}")
|
||||
return models, prompts
|
||||
|
||||
|
||||
|
|
@ -72,17 +73,20 @@ def create_models_chains(
|
|||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg]
|
||||
)
|
||||
logger.info(f"history_len:{history_len},chat_prompt:{chat_prompt}")
|
||||
elif conversation_id and history_len > 0:
|
||||
memory = ConversationBufferDBMemory(
|
||||
conversation_id=conversation_id,
|
||||
llm=models["llm_model"],
|
||||
message_limit=history_len,
|
||||
)
|
||||
logger.info(f"conversation_id:{conversation_id}")
|
||||
else:
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(
|
||||
False
|
||||
)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
logger.info(f"chat_prompt:{chat_prompt}")
|
||||
|
||||
if "action_model" in models and tools:
|
||||
llm = models["action_model"]
|
||||
|
|
@ -122,6 +126,9 @@ async def chat(
|
|||
):
|
||||
"""Agent 对话"""
|
||||
|
||||
# logger.info(f"query:{query},metadata:{metadata}, conversation_id:{conversation_id},message_id:{message_id},stream:{stream},chat_model_config:{chat_model_config}"
|
||||
# f"tool_config:{tool_config}, max_tokens:{max_tokens}")
|
||||
|
||||
async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]:
|
||||
try:
|
||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||
|
|
@ -139,11 +146,14 @@ async def chat(
|
|||
|
||||
langfuse_handler = CallbackHandler()
|
||||
callbacks.append(langfuse_handler)
|
||||
logger.info(f"langfuse_secret_key:{langfuse_secret_key}, langfuse_public_key:{langfuse_public_key},langfuse_host:{langfuse_host}")
|
||||
|
||||
models, prompts = create_models_from_config(
|
||||
callbacks=callbacks, configs=chat_model_config, stream=stream, max_tokens=max_tokens
|
||||
)
|
||||
all_tools = get_tool().values()
|
||||
logger.info(f"all_tools:{all_tools}")
|
||||
logger.info(f"metadata:{metadata}")
|
||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
||||
full_chain = create_models_chains(
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ async def completion(
|
|||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
logger.info(f"model_name:{model_name, prompt_name:{prompt_name}, echo:{echo}}")
|
||||
model = get_OpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
|
|
@ -51,6 +52,7 @@ async def completion(
|
|||
|
||||
prompt_template = get_prompt_template("llm_model", prompt_name)
|
||||
prompt = PromptTemplate.from_template(prompt_template, template_format="jinja2")
|
||||
logger.info(f"prompt_template:{prompt}")
|
||||
chain = LLMChain(prompt=prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
|
|
|
|||
|
|
@ -373,6 +373,7 @@ def get_OpenAIClient(
|
|||
f"cannot find configured platform for model: {model_name}"
|
||||
)
|
||||
platform_name = platform_info.get("platform_name")
|
||||
logger.info(f"platform_name:{platform_name}")
|
||||
platform_info = get_config_platforms().get(platform_name)
|
||||
assert platform_info, f"cannot find configured platform: {platform_name}"
|
||||
params = {
|
||||
|
|
@ -863,7 +864,8 @@ def update_search_local_knowledgebase_tool():
|
|||
from chatchat.server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||
|
||||
kbs = list_kbs_from_db()
|
||||
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
||||
#template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
||||
template = "使用本地知识库里\n{KB_info}\n去查与大模型技术栈相关的问题时,只有当用户的问题在本地知识库里时才使用这个工具查询.'database' 应该是上面的 [{key}] 之一"
|
||||
KB_info_str = "\n".join([f"{kb.kb_name}: {kb.kb_info}" for kb in kbs])
|
||||
KB_name_info_str = "\n".join([f"{kb.kb_name}" for kb in kbs])
|
||||
template_knowledge = template.format(KB_info=KB_info_str, key=KB_name_info_str)
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ if __name__ == "__main__":
|
|||
sac.MenuItem("多功能对话", icon="chat"),
|
||||
sac.MenuItem("RAG 对话", icon="database"),
|
||||
sac.MenuItem("知识库管理", icon="hdd-stack"),
|
||||
sac.MenuItem("LLM对话", icon="chat"),
|
||||
],
|
||||
key="selected_page",
|
||||
open_index=0,
|
||||
|
|
@ -70,4 +71,9 @@ if __name__ == "__main__":
|
|||
elif selected_page == "RAG 对话":
|
||||
kb_chat(api=api)
|
||||
else:
|
||||
#多功能对话
|
||||
dialogue_page(api=api, is_lite=is_lite)
|
||||
#本来想增加LLM可选模版的对话,后来发现多功能对话已经包含
|
||||
# else:
|
||||
# #LLM对话
|
||||
# llm_dialogue_page(api=api, is_lite=is_lite)
|
||||
|
|
|
|||
|
|
@ -268,6 +268,41 @@ class ApiRequest:
|
|||
response = self.post("/server/get_prompt_template", json=data, **kwargs)
|
||||
return self._get_response_value(response, value_func=lambda r: r.text)
|
||||
|
||||
#LLM对话
|
||||
def chat_completion(
|
||||
self,
|
||||
query: str,
|
||||
conversation_id: str = None,
|
||||
history_len: int = -1,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = Settings.model_settings.DEFAULT_LLM_MODEL,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int = None,
|
||||
prompt_name: str = "default",
|
||||
**kwargs,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/completion接口
|
||||
'''
|
||||
data = {
|
||||
"query": query,
|
||||
"conversation_id": conversation_id,
|
||||
"history_len": history_len,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
logger.info(f"chat_completion:/chat/llm_chat")
|
||||
response = self.post("/chat/llm_chat", json=data, stream=True, **kwargs)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
# 对话相关操作
|
||||
def chat_chat(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue