feat: add completion api

This commit is contained in:
liqiankun.1111 2023-10-22 15:34:02 +08:00
parent 38a48bc831
commit b002a2879b
5 changed files with 149 additions and 1 deletions

View File

@ -19,6 +19,10 @@
PROMPT_TEMPLATES = {} PROMPT_TEMPLATES = {}
PROMPT_TEMPLATES["completion"] = {
"default": "{input}"
}
PROMPT_TEMPLATES["llm_chat"] = { PROMPT_TEMPLATES["llm_chat"] = {
"default": "{{ input }}", "default": "{{ input }}",

View File

@ -11,7 +11,7 @@ import argparse
import uvicorn import uvicorn
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat, from server.chat import (completion,chat, knowledge_base_chat, openai_chat,
search_engine_chat, agent_chat) search_engine_chat, agent_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
@ -52,6 +52,10 @@ def create_app():
response_model=BaseResponse, response_model=BaseResponse,
summary="swagger 文档")(document) summary="swagger 文档")(document)
app.post("/completion",
tags=["Completion"],
summary="要求llm模型补全(通过LLMChain)")(completion)
# Tag: Chat # Tag: Chat
app.post("/chat/fastchat", app.post("/chat/fastchat",
tags=["Chat"], tags=["Chat"],

View File

@ -1,4 +1,5 @@
from .chat import chat from .chat import chat
from .completion import completion
from .knowledge_base_chat import knowledge_base_chat from .knowledge_base_chat import knowledge_base_chat
from .openai_chat import openai_chat from .openai_chat import openai_chat
from .search_engine_chat import search_engine_chat from .search_engine_chat import search_engine_chat

64
server/chat/completion.py Normal file
View File

@ -0,0 +1,64 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs import LLM_MODEL, TEMPERATURE
from server.utils import wrap_done, get_OpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import PromptTemplate
from server.utils import get_prompt_template
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
stream: bool = Body(False, description="流式输出"),
echo: bool = Body(False, description="除了输出之外,还回显输入"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
#todo 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str,
model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
echo: bool = echo,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = get_OpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
echo=echo
)
prompt_template = get_prompt_template("completion", prompt_name)
prompt = PromptTemplate.from_template(prompt_template)
chain = LLMChain(prompt=prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"input": query}),
callback.done),
)
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield token
else:
answer = ""
async for token in callback.aiter():
answer += token
yield answer
await task
return StreamingResponse(completion_iterator(query=query,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")

View File

@ -10,6 +10,7 @@ from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
@ -100,6 +101,80 @@ def get_ChatOpenAI(
return model return model
def get_OpenAI(
model_name: str,
temperature: float,
max_tokens: int = None,
streaming: bool = True,
echo: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
**kwargs: Any,
) -> OpenAI:
## 以下模型是Langchain原生支持的模型这些模型不会走Fschat封装
config_models = list_config_llm_models()
if model_name in config_models.get("langchain", {}):
config = config_models["langchain"][model_name]
if model_name == "Azure-OpenAI":
model = AzureOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
deployment_name=config.get("deployment_name"),
model_version=config.get("model_version"),
openai_api_type=config.get("openai_api_type"),
openai_api_base=config.get("api_base_url"),
openai_api_version=config.get("api_version"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
echo=echo,
)
elif model_name == "OpenAI":
model = OpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
openai_api_base=config.get("api_base_url"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
echo=echo,
)
elif model_name == "Anthropic":
model = Anthropic(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
anthropic_api_key=config.get("api_key"),
echo=echo,
)
## TODO 支持其他的Langchain原生支持的模型
else:
## 非Langchain原生支持的模型走Fschat封装
config = get_model_worker_config(model_name)
model = OpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
echo=echo,
**kwargs
)
return model
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code") code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message") msg: str = pydantic.Field("success", description="API status message")