Langchain-Chatchat/libs/chatchat-server/chatchat/server/agent/tools_factory/text2image.py

75 lines
2.3 KiB
Python
Raw Normal View History

2024-12-20 16:04:03 +08:00
import base64
from datetime import datetime
import os
import uuid
from typing import List, Literal
import openai
from PIL import Image
from chatchat.settings import Settings
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import MsgType, get_tool_config, get_model_info
from .tools_registry import BaseToolOutput, regist_tool
2025-02-14 07:40:25 +08:00
# @regist_tool(title="文生图", return_direct=True)
2024-12-20 16:04:03 +08:00
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
width: Literal[256, 512, 1024] = Field(512, description="生成图片的宽度"),
height: Literal[256, 512, 1024] = Field(512, description="生成图片的高度"),
) -> List[str]:
"""根据用户的描述生成图片"""
tool_config = get_tool_config("text2images")
model_config = get_model_info(tool_config["model"])
assert model_config, "请正确配置文生图模型"
client = openai.Client(
base_url=model_config["api_base_url"],
api_key=model_config["api_key"],
timeout=600,
)
resp = client.images.generate(
prompt=prompt,
n=n,
size=f"{width}*{height}",
response_format="b64_json",
model=model_config["model_name"],
)
images = []
for x in resp.data:
uid = uuid.uuid4().hex
today = datetime.now().strftime("%Y-%m-%d")
path = os.path.join(Settings.basic_settings.MEDIA_PATH, "image", today)
os.makedirs(path, exist_ok=True)
filename = f"image/{today}/{uid}.png"
with open(os.path.join(Settings.basic_settings.MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json))
images.append(filename)
return BaseToolOutput(
{"message_type": MsgType.IMAGE, "images": images}, format="json"
)
if __name__ == "__main__":
import sys
from io import BytesIO
from pathlib import Path
from matplotlib import pyplot as plt
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
prompt = "draw a house with trees and river"
prompt = "画一个带树、草、河流的山中小屋"
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
print(params)
image = text2images.invoke(params)[0]
buffer = BytesIO(base64.b64decode(image))
image = Image.open(buffer)
plt.imshow(image)
plt.show()