feat: add completion api

This commit is contained in:
liqiankun.1111 2023-10-22 15:34:02 +08:00
parent 38a48bc831
commit b002a2879b
5 changed files with 149 additions and 1 deletions

View File

@ -19,6 +19,10 @@
PROMPT_TEMPLATES = {}
PROMPT_TEMPLATES["completion"] = {
"default": "{input}"
}
PROMPT_TEMPLATES["llm_chat"] = {
"default": "{{ input }}",

View File

@ -11,7 +11,7 @@ import argparse
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
from server.chat import (completion,chat, knowledge_base_chat, openai_chat,
search_engine_chat, agent_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
@ -52,6 +52,10 @@ def create_app():
response_model=BaseResponse,
summary="swagger 文档")(document)
app.post("/completion",
tags=["Completion"],
summary="要求llm模型补全(通过LLMChain)")(completion)
# Tag: Chat
app.post("/chat/fastchat",
tags=["Chat"],

View File

@ -1,4 +1,5 @@
from .chat import chat
from .completion import completion
from .knowledge_base_chat import knowledge_base_chat
from .openai_chat import openai_chat
from .search_engine_chat import search_engine_chat

64
server/chat/completion.py Normal file
View File

@ -0,0 +1,64 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs import LLM_MODEL, TEMPERATURE
from server.utils import wrap_done, get_OpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import PromptTemplate
from server.utils import get_prompt_template
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
stream: bool = Body(False, description="流式输出"),
echo: bool = Body(False, description="除了输出之外,还回显输入"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
#todo 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str,
model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
echo: bool = echo,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = get_OpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
echo=echo
)
prompt_template = get_prompt_template("completion", prompt_name)
prompt = PromptTemplate.from_template(prompt_template)
chain = LLMChain(prompt=prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"input": query}),
callback.done),
)
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield token
else:
answer = ""
async for token in callback.aiter():
answer += token
yield answer
await task
return StreamingResponse(completion_iterator(query=query,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")

View File

@ -10,6 +10,7 @@ from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
@ -100,6 +101,80 @@ def get_ChatOpenAI(
return model
def get_OpenAI(
model_name: str,
temperature: float,
max_tokens: int = None,
streaming: bool = True,
echo: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
**kwargs: Any,
) -> OpenAI:
## 以下模型是Langchain原生支持的模型这些模型不会走Fschat封装
config_models = list_config_llm_models()
if model_name in config_models.get("langchain", {}):
config = config_models["langchain"][model_name]
if model_name == "Azure-OpenAI":
model = AzureOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
deployment_name=config.get("deployment_name"),
model_version=config.get("model_version"),
openai_api_type=config.get("openai_api_type"),
openai_api_base=config.get("api_base_url"),
openai_api_version=config.get("api_version"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
echo=echo,
)
elif model_name == "OpenAI":
model = OpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
openai_api_base=config.get("api_base_url"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
echo=echo,
)
elif model_name == "Anthropic":
model = Anthropic(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
anthropic_api_key=config.get("api_key"),
echo=echo,
)
## TODO 支持其他的Langchain原生支持的模型
else:
## 非Langchain原生支持的模型走Fschat封装
config = get_model_worker_config(model_name)
model = OpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
echo=echo,
**kwargs
)
return model
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")