1
0
Fork 0
taleweave-ai/adventure/render_comfy.py

303 lines
8.7 KiB
Python

# This is an example that uses the websockets api to know when a prompt execution is done
# Once the prompt execution is done it downloads the images using the /history endpoint
import io
import json
import urllib.parse
import urllib.request
import uuid
from logging import getLogger
from os import environ, path
from queue import Queue
from random import choice, randint
from threading import Thread
from typing import List
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
from PIL import Image
from adventure.context import broadcast
from adventure.models.event import (
ActionEvent,
GameEvent,
RenderEvent,
ReplyEvent,
ResultEvent,
StatusEvent,
)
logger = getLogger(__name__)
server_address = environ["COMFY_API"]
client_id = str(uuid.uuid4())
def generate_cfg():
return randint(5, 8)
def generate_steps():
return 30
def generate_batches(
count: int,
batch_size: int = 3,
) -> List[int]:
"""
Generate count images in batches of at most batch_size.
"""
batches = []
for i in range(0, count, batch_size):
batches.append(min(count - i, batch_size))
return batches
def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(
"http://{}/view?{}".format(server_address, url_values)
) as response:
return response.read()
def get_history(prompt_id):
with urllib.request.urlopen(
"http://{}/history/{}".format(server_address, prompt_id)
) as response:
return json.loads(response.read())
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)["prompt_id"]
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = get_history(prompt_id)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
images_output = []
for image in node_output["images"]:
image_data = get_image(
image["filename"], image["subfolder"], image["type"]
)
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
def generate_image_tool(prompt, count, size="landscape"):
output_paths = []
for i, count in enumerate(generate_batches(count)):
results = generate_images(prompt, count, size, prefix=f"output-{i}")
output_paths.extend(results)
return output_paths
sizes = {
"landscape": (1024, 768),
"portrait": (768, 1024),
"square": (768, 768),
}
def generate_images(
prompt: str, count: int, size="landscape", prefix="output"
) -> List[str]:
cfg = generate_cfg()
width, height = sizes.get(size, (512, 512))
steps = generate_steps()
seed = randint(0, 10000000)
checkpoint = choice(["diffusion-sdxl-dynavision-0-5-5-7.safetensors"])
logger.info(
"generating %s images at %s by %s with prompt: %s", count, width, height, prompt
)
# parsing here helps ensure the template emits valid JSON
prompt_workflow = {
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": cfg,
"denoise": 1,
"latent_image": ["5", 0],
"model": ["4", 0],
"negative": ["7", 0],
"positive": ["6", 0],
"sampler_name": "euler_ancestral",
"scheduler": "normal",
"seed": seed,
"steps": steps,
},
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {"ckpt_name": checkpoint},
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {"batch_size": count, "height": height, "width": width},
},
"6": {
"class_type": "smZ CLIPTextEncode",
"inputs": {
"text": prompt,
"parser": "compel",
"mean_normalization": True,
"multi_conditioning": True,
"use_old_emphasis_implementation": False,
"with_SDXL": False,
"ascore": 6,
"width": width,
"height": height,
"crop_w": 0,
"crop_h": 0,
"target_width": width,
"target_height": height,
"text_g": "",
"text_l": "",
"smZ_steps": 1,
"clip": ["4", 1],
},
},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "", "clip": ["4", 1]}},
"8": {
"class_type": "VAEDecode",
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
},
"9": {
"class_type": "SaveImage",
"inputs": {"filename_prefix": prefix, "images": ["8", 0]},
},
}
logger.debug("Connecting to Comfy API at %s", server_address)
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt_workflow)
results = []
for node_id in images:
for image_data in images[node_id]:
image = Image.open(io.BytesIO(image_data))
results.append(image)
paths: List[str] = []
for j, image in enumerate(results):
# TODO: replace with environment variable
image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png")
with open(image_path, "wb") as f:
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
f.write(image_bytes.getvalue())
paths.append(image_path)
return paths
def prompt_from_event(event: GameEvent) -> str | None:
if isinstance(event, ActionEvent):
if event.item:
return f"{event.actor.name} uses the {event.item.name}. {event.item.description}. {event.actor.description}. {event.room.description}."
return f"{event.actor.name} {event.action}. {event.actor.description}. {event.room.description}."
if isinstance(event, ReplyEvent):
return event.text
if isinstance(event, ResultEvent):
return f"{event.result}. {event.actor.description}. {event.room.description}."
if isinstance(event, StatusEvent):
if event.room:
if event.actor:
return f"{event.text}. {event.actor.description}. {event.room.description}."
return f"{event.text}. {event.room.description}."
return event.text
return None
def prefix_from_event(event: GameEvent) -> str:
if isinstance(event, ActionEvent):
return (
f"{event.actor.name}-{event.action}-{event.item.name if event.item else ''}"
)
if isinstance(event, ReplyEvent):
return f"{event.actor.name}-reply"
if isinstance(event, ResultEvent):
return f"{event.actor.name}-result"
if isinstance(event, StatusEvent):
return "status"
return "unknown"
# requests to generate images for game events
render_queue: Queue[GameEvent] = Queue()
def render_loop():
while True:
event = render_queue.get()
prompt = prompt_from_event(event)
if prompt:
logger.info("rendering prompt for event %s: %s", event, prompt)
prefix = prefix_from_event(event)
image_paths = generate_images(prompt, 2, prefix=prefix)
broadcast(RenderEvent(paths=image_paths, source=event))
else:
logger.warning("no prompt for event %s", event)
def render_event(event: GameEvent):
render_queue.put(event)
render_thread = None
def launch_render():
global render_thread
render_thread = Thread(target=render_loop, daemon=True)
render_thread.start()
return [render_thread]
if __name__ == "__main__":
paths = generate_images(
"A painting of a beautiful sunset over a calm lake", 3, "landscape"
)
logger.info("Generated %d images: %s", len(paths), paths)