feat: add completion api
This commit is contained in:
parent
38a48bc831
commit
b002a2879b
|
|
@ -19,6 +19,10 @@
|
|||
|
||||
PROMPT_TEMPLATES = {}
|
||||
|
||||
PROMPT_TEMPLATES["completion"] = {
|
||||
"default": "{input}"
|
||||
}
|
||||
|
||||
PROMPT_TEMPLATES["llm_chat"] = {
|
||||
"default": "{{ input }}",
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import argparse
|
|||
import uvicorn
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
from server.chat import (completion,chat, knowledge_base_chat, openai_chat,
|
||||
search_engine_chat, agent_chat)
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
|
|
@ -52,6 +52,10 @@ def create_app():
|
|||
response_model=BaseResponse,
|
||||
summary="swagger 文档")(document)
|
||||
|
||||
app.post("/completion",
|
||||
tags=["Completion"],
|
||||
summary="要求llm模型补全(通过LLMChain)")(completion)
|
||||
|
||||
# Tag: Chat
|
||||
app.post("/chat/fastchat",
|
||||
tags=["Chat"],
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .chat import chat
|
||||
from .completion import completion
|
||||
from .knowledge_base_chat import knowledge_base_chat
|
||||
from .openai_chat import openai_chat
|
||||
from .search_engine_chat import search_engine_chat
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import LLM_MODEL, TEMPERATURE
|
||||
from server.utils import wrap_done, get_OpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts.chat import PromptTemplate
|
||||
from server.utils import get_prompt_template
|
||||
|
||||
|
||||
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("default",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
|
||||
#todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||
async def completion_iterator(query: str,
|
||||
model_name: str = LLM_MODEL,
|
||||
prompt_name: str = prompt_name,
|
||||
echo: bool = echo,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = get_OpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
echo=echo
|
||||
)
|
||||
|
||||
prompt_template = get_prompt_template("completion", prompt_name)
|
||||
prompt = PromptTemplate.from_template(prompt_template)
|
||||
chain = LLMChain(prompt=prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"input": query}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield token
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield answer
|
||||
|
||||
await task
|
||||
|
||||
return StreamingResponse(completion_iterator(query=query,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
media_type="text/event-stream")
|
||||
|
|
@ -10,6 +10,7 @@ from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
|||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
|
||||
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
|
||||
import httpx
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
||||
|
||||
|
|
@ -100,6 +101,80 @@ def get_ChatOpenAI(
|
|||
return model
|
||||
|
||||
|
||||
def get_OpenAI(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int = None,
|
||||
streaming: bool = True,
|
||||
echo: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> OpenAI:
|
||||
## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
|
||||
config_models = list_config_llm_models()
|
||||
if model_name in config_models.get("langchain", {}):
|
||||
config = config_models["langchain"][model_name]
|
||||
if model_name == "Azure-OpenAI":
|
||||
model = AzureOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
deployment_name=config.get("deployment_name"),
|
||||
model_version=config.get("model_version"),
|
||||
openai_api_type=config.get("openai_api_type"),
|
||||
openai_api_base=config.get("api_base_url"),
|
||||
openai_api_version=config.get("api_version"),
|
||||
openai_api_key=config.get("api_key"),
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
echo=echo,
|
||||
)
|
||||
|
||||
elif model_name == "OpenAI":
|
||||
model = OpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=config.get("model_name"),
|
||||
openai_api_base=config.get("api_base_url"),
|
||||
openai_api_key=config.get("api_key"),
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
echo=echo,
|
||||
)
|
||||
elif model_name == "Anthropic":
|
||||
model = Anthropic(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=config.get("model_name"),
|
||||
anthropic_api_key=config.get("api_key"),
|
||||
echo=echo,
|
||||
)
|
||||
## TODO 支持其他的Langchain原生支持的模型
|
||||
else:
|
||||
## 非Langchain原生支持的模型,走Fschat封装
|
||||
config = get_model_worker_config(model_name)
|
||||
model = OpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
echo=echo,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="API status code")
|
||||
msg: str = pydantic.Field("success", description="API status message")
|
||||
|
|
|
|||
Loading…
Reference in New Issue