From b002a2879be3d49378455ded718a5714791493c1 Mon Sep 17 00:00:00 2001 From: "liqiankun.1111" Date: Sun, 22 Oct 2023 15:34:02 +0800 Subject: [PATCH] feat: add completion api --- configs/prompt_config.py.example | 4 ++ server/api.py | 6 ++- server/chat/__init__.py | 1 + server/chat/completion.py | 64 +++++++++++++++++++++++++++ server/utils.py | 75 ++++++++++++++++++++++++++++++++ 5 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 server/chat/completion.py diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index a52b1f7..7949a1f 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -19,6 +19,10 @@ PROMPT_TEMPLATES = {} +PROMPT_TEMPLATES["completion"] = { + "default": "{input}" +} + PROMPT_TEMPLATES["llm_chat"] = { "default": "{{ input }}", diff --git a/server/api.py b/server/api.py index 370f7ed..35f9f7c 100644 --- a/server/api.py +++ b/server/api.py @@ -11,7 +11,7 @@ import argparse import uvicorn from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse -from server.chat import (chat, knowledge_base_chat, openai_chat, +from server.chat import (completion,chat, knowledge_base_chat, openai_chat, search_engine_chat, agent_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, @@ -52,6 +52,10 @@ def create_app(): response_model=BaseResponse, summary="swagger 文档")(document) + app.post("/completion", + tags=["Completion"], + summary="要求llm模型补全(通过LLMChain)")(completion) + # Tag: Chat app.post("/chat/fastchat", tags=["Chat"], diff --git a/server/chat/__init__.py b/server/chat/__init__.py index 62fe430..294aeab 100644 --- a/server/chat/__init__.py +++ b/server/chat/__init__.py @@ -1,4 +1,5 @@ from .chat import chat +from .completion import completion from .knowledge_base_chat import knowledge_base_chat from .openai_chat import openai_chat from .search_engine_chat import search_engine_chat diff --git a/server/chat/completion.py b/server/chat/completion.py new file mode 100644 index 0000000..587c9d4 --- /dev/null +++ b/server/chat/completion.py @@ -0,0 +1,64 @@ +from fastapi import Body +from fastapi.responses import StreamingResponse +from configs import LLM_MODEL, TEMPERATURE +from server.utils import wrap_done, get_OpenAI +from langchain.chains import LLMChain +from langchain.callbacks import AsyncIteratorCallbackHandler +from typing import AsyncIterable +import asyncio +from langchain.prompts.chat import PromptTemplate +from server.utils import get_prompt_template + + +async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), + stream: bool = Body(False, description="流式输出"), + echo: bool = Body(False, description="除了输出之外,还回显输入"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: int = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), + # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), + prompt_name: str = Body("default", + description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + ): + + #todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 + async def completion_iterator(query: str, + model_name: str = LLM_MODEL, + prompt_name: str = prompt_name, + echo: bool = echo, + ) -> AsyncIterable[str]: + callback = AsyncIteratorCallbackHandler() + model = get_OpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + callbacks=[callback], + echo=echo + ) + + prompt_template = get_prompt_template("completion", prompt_name) + prompt = PromptTemplate.from_template(prompt_template) + chain = LLMChain(prompt=prompt, llm=model) + + # Begin a task that runs in the background. + task = asyncio.create_task(wrap_done( + chain.acall({"input": query}), + callback.done), + ) + + if stream: + async for token in callback.aiter(): + # Use server-sent-events to stream the response + yield token + else: + answer = "" + async for token in callback.aiter(): + answer += token + yield answer + + await task + + return StreamingResponse(completion_iterator(query=query, + model_name=model_name, + prompt_name=prompt_name), + media_type="text/event-stream") diff --git a/server/utils.py b/server/utils.py index 2069dc0..8e74a9c 100644 --- a/server/utils.py +++ b/server/utils.py @@ -10,6 +10,7 @@ from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, import os from concurrent.futures import ThreadPoolExecutor, as_completed from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic +from langchain.llms import OpenAI, AzureOpenAI, Anthropic import httpx from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union @@ -100,6 +101,80 @@ def get_ChatOpenAI( return model +def get_OpenAI( + model_name: str, + temperature: float, + max_tokens: int = None, + streaming: bool = True, + echo: bool = True, + callbacks: List[Callable] = [], + verbose: bool = True, + **kwargs: Any, +) -> OpenAI: + ## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装 + config_models = list_config_llm_models() + if model_name in config_models.get("langchain", {}): + config = config_models["langchain"][model_name] + if model_name == "Azure-OpenAI": + model = AzureOpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + deployment_name=config.get("deployment_name"), + model_version=config.get("model_version"), + openai_api_type=config.get("openai_api_type"), + openai_api_base=config.get("api_base_url"), + openai_api_version=config.get("api_version"), + openai_api_key=config.get("api_key"), + openai_proxy=config.get("openai_proxy"), + temperature=temperature, + max_tokens=max_tokens, + echo=echo, + ) + + elif model_name == "OpenAI": + model = OpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + model_name=config.get("model_name"), + openai_api_base=config.get("api_base_url"), + openai_api_key=config.get("api_key"), + openai_proxy=config.get("openai_proxy"), + temperature=temperature, + max_tokens=max_tokens, + echo=echo, + ) + elif model_name == "Anthropic": + model = Anthropic( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + model_name=config.get("model_name"), + anthropic_api_key=config.get("api_key"), + echo=echo, + ) + ## TODO 支持其他的Langchain原生支持的模型 + else: + ## 非Langchain原生支持的模型,走Fschat封装 + config = get_model_worker_config(model_name) + model = OpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + openai_api_key=config.get("api_key", "EMPTY"), + openai_api_base=config.get("api_base_url", fschat_openai_api_address()), + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + openai_proxy=config.get("openai_proxy"), + echo=echo, + **kwargs + ) + + return model + + class BaseResponse(BaseModel): code: int = pydantic.Field(200, description="API status code") msg: str = pydantic.Field("success", description="API status message")