Merge pull request #1828 from qiankunli/feat/add-completion-api
增加completion功能的API接口,注意暂不支持online_api模型
This commit is contained in:
commit
be67ea43d8
|
|
@ -19,6 +19,10 @@
|
||||||
|
|
||||||
PROMPT_TEMPLATES = {}
|
PROMPT_TEMPLATES = {}
|
||||||
|
|
||||||
|
PROMPT_TEMPLATES["completion"] = {
|
||||||
|
"default": "{input}"
|
||||||
|
}
|
||||||
|
|
||||||
PROMPT_TEMPLATES["llm_chat"] = {
|
PROMPT_TEMPLATES["llm_chat"] = {
|
||||||
"default": "{{ input }}",
|
"default": "{{ input }}",
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import argparse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
from server.chat import (completion,chat, knowledge_base_chat, openai_chat,
|
||||||
search_engine_chat, agent_chat)
|
search_engine_chat, agent_chat)
|
||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
|
|
@ -52,6 +52,10 @@ def create_app():
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="swagger 文档")(document)
|
summary="swagger 文档")(document)
|
||||||
|
|
||||||
|
app.post("/completion",
|
||||||
|
tags=["Completion"],
|
||||||
|
summary="要求llm模型补全(通过LLMChain)")(completion)
|
||||||
|
|
||||||
# Tag: Chat
|
# Tag: Chat
|
||||||
app.post("/chat/fastchat",
|
app.post("/chat/fastchat",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from .chat import chat
|
from .chat import chat
|
||||||
|
from .completion import completion
|
||||||
from .knowledge_base_chat import knowledge_base_chat
|
from .knowledge_base_chat import knowledge_base_chat
|
||||||
from .openai_chat import openai_chat
|
from .openai_chat import openai_chat
|
||||||
from .search_engine_chat import search_engine_chat
|
from .search_engine_chat import search_engine_chat
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
from fastapi import Body
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from configs import LLM_MODEL, TEMPERATURE
|
||||||
|
from server.utils import wrap_done, get_OpenAI
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
|
from typing import AsyncIterable
|
||||||
|
import asyncio
|
||||||
|
from langchain.prompts.chat import PromptTemplate
|
||||||
|
from server.utils import get_prompt_template
|
||||||
|
|
||||||
|
|
||||||
|
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
|
stream: bool = Body(False, description="流式输出"),
|
||||||
|
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||||
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
|
max_tokens: int = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||||
|
prompt_name: str = Body("default",
|
||||||
|
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
|
):
|
||||||
|
|
||||||
|
#todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||||
|
async def completion_iterator(query: str,
|
||||||
|
model_name: str = LLM_MODEL,
|
||||||
|
prompt_name: str = prompt_name,
|
||||||
|
echo: bool = echo,
|
||||||
|
) -> AsyncIterable[str]:
|
||||||
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
model = get_OpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
callbacks=[callback],
|
||||||
|
echo=echo
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_template = get_prompt_template("completion", prompt_name)
|
||||||
|
prompt = PromptTemplate.from_template(prompt_template)
|
||||||
|
chain = LLMChain(prompt=prompt, llm=model)
|
||||||
|
|
||||||
|
# Begin a task that runs in the background.
|
||||||
|
task = asyncio.create_task(wrap_done(
|
||||||
|
chain.acall({"input": query}),
|
||||||
|
callback.done),
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
async for token in callback.aiter():
|
||||||
|
# Use server-sent-events to stream the response
|
||||||
|
yield token
|
||||||
|
else:
|
||||||
|
answer = ""
|
||||||
|
async for token in callback.aiter():
|
||||||
|
answer += token
|
||||||
|
yield answer
|
||||||
|
|
||||||
|
await task
|
||||||
|
|
||||||
|
return StreamingResponse(completion_iterator(query=query,
|
||||||
|
model_name=model_name,
|
||||||
|
prompt_name=prompt_name),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
|
@ -10,6 +10,7 @@ from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
|
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
|
||||||
|
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
||||||
|
|
||||||
|
|
@ -100,6 +101,80 @@ def get_ChatOpenAI(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_OpenAI(
|
||||||
|
model_name: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int = None,
|
||||||
|
streaming: bool = True,
|
||||||
|
echo: bool = True,
|
||||||
|
callbacks: List[Callable] = [],
|
||||||
|
verbose: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> OpenAI:
|
||||||
|
## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
|
||||||
|
config_models = list_config_llm_models()
|
||||||
|
if model_name in config_models.get("langchain", {}):
|
||||||
|
config = config_models["langchain"][model_name]
|
||||||
|
if model_name == "Azure-OpenAI":
|
||||||
|
model = AzureOpenAI(
|
||||||
|
streaming=streaming,
|
||||||
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
deployment_name=config.get("deployment_name"),
|
||||||
|
model_version=config.get("model_version"),
|
||||||
|
openai_api_type=config.get("openai_api_type"),
|
||||||
|
openai_api_base=config.get("api_base_url"),
|
||||||
|
openai_api_version=config.get("api_version"),
|
||||||
|
openai_api_key=config.get("api_key"),
|
||||||
|
openai_proxy=config.get("openai_proxy"),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
echo=echo,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_name == "OpenAI":
|
||||||
|
model = OpenAI(
|
||||||
|
streaming=streaming,
|
||||||
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
model_name=config.get("model_name"),
|
||||||
|
openai_api_base=config.get("api_base_url"),
|
||||||
|
openai_api_key=config.get("api_key"),
|
||||||
|
openai_proxy=config.get("openai_proxy"),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
echo=echo,
|
||||||
|
)
|
||||||
|
elif model_name == "Anthropic":
|
||||||
|
model = Anthropic(
|
||||||
|
streaming=streaming,
|
||||||
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
model_name=config.get("model_name"),
|
||||||
|
anthropic_api_key=config.get("api_key"),
|
||||||
|
echo=echo,
|
||||||
|
)
|
||||||
|
## TODO 支持其他的Langchain原生支持的模型
|
||||||
|
else:
|
||||||
|
## 非Langchain原生支持的模型,走Fschat封装
|
||||||
|
config = get_model_worker_config(model_name)
|
||||||
|
model = OpenAI(
|
||||||
|
streaming=streaming,
|
||||||
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
openai_api_key=config.get("api_key", "EMPTY"),
|
||||||
|
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
openai_proxy=config.get("openai_proxy"),
|
||||||
|
echo=echo,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
class BaseResponse(BaseModel):
|
||||||
code: int = pydantic.Field(200, description="API status code")
|
code: int = pydantic.Field(200, description="API status code")
|
||||||
msg: str = pydantic.Field("success", description="API status message")
|
msg: str = pydantic.Field("success", description="API status message")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue