2024-05-08 01:42:10 +00:00
|
|
|
import io
|
|
|
|
import json
|
|
|
|
import urllib.parse
|
|
|
|
import urllib.request
|
|
|
|
from logging import getLogger
|
|
|
|
from os import environ, path
|
2024-05-12 05:08:53 +00:00
|
|
|
from queue import Queue
|
2024-05-08 01:42:10 +00:00
|
|
|
from random import choice, randint
|
2024-05-13 04:33:47 +00:00
|
|
|
from re import sub
|
2024-05-12 05:08:53 +00:00
|
|
|
from threading import Thread
|
2024-05-08 01:42:10 +00:00
|
|
|
from typing import List
|
|
|
|
|
|
|
|
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
2024-05-14 01:08:19 +00:00
|
|
|
from fnvhash import fnv1a_32
|
2024-05-12 20:47:18 +00:00
|
|
|
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
2024-05-08 01:42:10 +00:00
|
|
|
from PIL import Image
|
|
|
|
|
2024-05-27 13:10:24 +00:00
|
|
|
from taleweave.context import broadcast
|
2024-05-27 17:03:39 +00:00
|
|
|
from taleweave.models.base import uuid
|
2024-05-27 13:10:24 +00:00
|
|
|
from taleweave.models.config import DEFAULT_CONFIG, RenderConfig
|
|
|
|
from taleweave.models.entity import WorldEntity
|
|
|
|
from taleweave.models.event import (
|
2024-05-12 05:08:53 +00:00
|
|
|
ActionEvent,
|
|
|
|
GameEvent,
|
2024-05-18 21:58:11 +00:00
|
|
|
GenerateEvent,
|
2024-05-12 05:08:53 +00:00
|
|
|
RenderEvent,
|
|
|
|
ReplyEvent,
|
|
|
|
ResultEvent,
|
|
|
|
StatusEvent,
|
|
|
|
)
|
2024-05-27 13:10:24 +00:00
|
|
|
from taleweave.utils.random import resolve_int_range
|
2024-05-19 20:27:56 +00:00
|
|
|
|
|
|
|
from .prompt import prompt_from_entity, prompt_from_event
|
2024-05-12 05:08:53 +00:00
|
|
|
|
2024-05-08 01:42:10 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
server_address = environ["COMFY_API"]
|
2024-05-27 17:03:39 +00:00
|
|
|
client_id = uuid()
|
2024-05-18 21:20:47 +00:00
|
|
|
render_config: RenderConfig = DEFAULT_CONFIG.render
|
2024-05-12 20:47:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
# requests to generate images for game events
|
|
|
|
render_queue: Queue[GameEvent | WorldEntity] = Queue()
|
|
|
|
render_thread: Thread | None = None
|
2024-05-08 01:42:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
def generate_cfg():
|
2024-05-26 22:03:39 +00:00
|
|
|
return resolve_int_range(render_config.cfg)
|
2024-05-08 01:42:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
def generate_steps():
|
2024-05-26 22:03:39 +00:00
|
|
|
return resolve_int_range(render_config.steps)
|
2024-05-08 01:42:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
def generate_batches(
|
|
|
|
count: int,
|
|
|
|
batch_size: int = 3,
|
|
|
|
) -> List[int]:
|
|
|
|
"""
|
|
|
|
Generate count images in batches of at most batch_size.
|
|
|
|
"""
|
|
|
|
|
|
|
|
batches = []
|
|
|
|
for i in range(0, count, batch_size):
|
|
|
|
batches.append(min(count - i, batch_size))
|
|
|
|
|
|
|
|
return batches
|
|
|
|
|
|
|
|
|
|
|
|
def queue_prompt(prompt):
|
|
|
|
p = {"prompt": prompt, "client_id": client_id}
|
|
|
|
data = json.dumps(p).encode("utf-8")
|
|
|
|
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
|
|
|
return json.loads(urllib.request.urlopen(req).read())
|
|
|
|
|
|
|
|
|
|
|
|
def get_image(filename, subfolder, folder_type):
|
|
|
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
|
|
|
url_values = urllib.parse.urlencode(data)
|
|
|
|
with urllib.request.urlopen(
|
|
|
|
"http://{}/view?{}".format(server_address, url_values)
|
|
|
|
) as response:
|
|
|
|
return response.read()
|
|
|
|
|
|
|
|
|
|
|
|
def get_history(prompt_id):
|
|
|
|
with urllib.request.urlopen(
|
|
|
|
"http://{}/history/{}".format(server_address, prompt_id)
|
|
|
|
) as response:
|
|
|
|
return json.loads(response.read())
|
|
|
|
|
|
|
|
|
2024-05-13 04:33:47 +00:00
|
|
|
def get_images(ws_host, prompt, max_retries=3):
|
2024-05-08 01:42:10 +00:00
|
|
|
prompt_id = queue_prompt(prompt)["prompt_id"]
|
|
|
|
output_images = {}
|
2024-05-13 04:33:47 +00:00
|
|
|
retry = 0
|
|
|
|
|
|
|
|
ws = websocket.WebSocket()
|
|
|
|
ws.connect(ws_host, timeout=60)
|
2024-05-08 01:42:10 +00:00
|
|
|
while True:
|
2024-05-13 04:33:47 +00:00
|
|
|
try:
|
|
|
|
out = ws.recv()
|
|
|
|
if isinstance(out, str):
|
|
|
|
message = json.loads(out)
|
|
|
|
if message["type"] == "executing":
|
|
|
|
data = message["data"]
|
|
|
|
if data["node"] is None and data["prompt_id"] == prompt_id:
|
|
|
|
break # Execution is done
|
|
|
|
else:
|
|
|
|
continue # previews are binary data
|
|
|
|
except websocket._exceptions.WebSocketTimeoutException:
|
|
|
|
logger.warning("timeout while waiting for image data")
|
|
|
|
retry += 1
|
|
|
|
if retry >= max_retries:
|
|
|
|
logger.error("max retries exceeded, giving up")
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
# reconnect
|
|
|
|
ws = websocket.WebSocket()
|
|
|
|
ws.connect(ws_host, timeout=60)
|
|
|
|
continue
|
2024-05-08 01:42:10 +00:00
|
|
|
|
|
|
|
history = get_history(prompt_id)[prompt_id]
|
2024-05-12 20:47:18 +00:00
|
|
|
for _ in history["outputs"]:
|
2024-05-08 01:42:10 +00:00
|
|
|
for node_id in history["outputs"]:
|
|
|
|
node_output = history["outputs"][node_id]
|
|
|
|
if "images" in node_output:
|
|
|
|
images_output = []
|
|
|
|
for image in node_output["images"]:
|
|
|
|
image_data = get_image(
|
|
|
|
image["filename"], image["subfolder"], image["type"]
|
|
|
|
)
|
|
|
|
images_output.append(image_data)
|
|
|
|
output_images[node_id] = images_output
|
|
|
|
|
|
|
|
return output_images
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image_tool(prompt, count, size="landscape"):
|
|
|
|
output_paths = []
|
|
|
|
for i, count in enumerate(generate_batches(count)):
|
|
|
|
results = generate_images(prompt, count, size, prefix=f"output-{i}")
|
|
|
|
output_paths.extend(results)
|
|
|
|
|
|
|
|
return output_paths
|
|
|
|
|
|
|
|
|
|
|
|
def generate_images(
|
|
|
|
prompt: str, count: int, size="landscape", prefix="output"
|
|
|
|
) -> List[str]:
|
|
|
|
cfg = generate_cfg()
|
2024-05-12 20:47:18 +00:00
|
|
|
dims = render_config.sizes[size]
|
2024-05-08 01:42:10 +00:00
|
|
|
steps = generate_steps()
|
|
|
|
seed = randint(0, 10000000)
|
2024-05-12 20:47:18 +00:00
|
|
|
checkpoint = choice(render_config.checkpoints)
|
2024-05-08 01:42:10 +00:00
|
|
|
logger.info(
|
2024-05-12 20:47:18 +00:00
|
|
|
"generating %s images at %s by %s with prompt: %s",
|
|
|
|
count,
|
|
|
|
dims.width,
|
|
|
|
dims.height,
|
|
|
|
prompt,
|
|
|
|
)
|
|
|
|
|
|
|
|
env = Environment(
|
2024-05-27 14:22:19 +00:00
|
|
|
loader=FileSystemLoader(["taleweave/templates"]),
|
2024-05-12 20:47:18 +00:00
|
|
|
autoescape=select_autoescape(["json"]),
|
|
|
|
)
|
|
|
|
template = env.get_template("comfy.json.j2")
|
|
|
|
result = template.render(
|
|
|
|
cfg=cfg,
|
|
|
|
height=dims.height,
|
|
|
|
width=dims.width,
|
|
|
|
steps=steps,
|
|
|
|
seed=seed,
|
|
|
|
checkpoint=checkpoint,
|
|
|
|
prompt=prompt.replace("\n", ". "),
|
|
|
|
negative_prompt="",
|
|
|
|
count=count,
|
|
|
|
prefix=prefix,
|
2024-05-08 01:42:10 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# parsing here helps ensure the template emits valid JSON
|
2024-05-12 20:47:18 +00:00
|
|
|
logger.debug("template workflow: %s", result)
|
|
|
|
prompt_workflow = json.loads(result)
|
|
|
|
|
|
|
|
logger.debug("connecting to Comfy API at %s", server_address)
|
2024-05-13 04:33:47 +00:00
|
|
|
images = get_images(
|
|
|
|
"ws://{}/ws?clientId={}".format(server_address, client_id), prompt_workflow
|
|
|
|
)
|
2024-05-08 01:42:10 +00:00
|
|
|
|
|
|
|
results = []
|
|
|
|
for node_id in images:
|
|
|
|
for image_data in images[node_id]:
|
|
|
|
image = Image.open(io.BytesIO(image_data))
|
|
|
|
results.append(image)
|
|
|
|
|
|
|
|
paths: List[str] = []
|
|
|
|
for j, image in enumerate(results):
|
2024-05-12 20:47:18 +00:00
|
|
|
image_path = path.join(render_config.path, f"{prefix}-{j}.png")
|
2024-05-08 01:42:10 +00:00
|
|
|
with open(image_path, "wb") as f:
|
|
|
|
image_bytes = io.BytesIO()
|
|
|
|
image.save(image_bytes, format="PNG")
|
|
|
|
f.write(image_bytes.getvalue())
|
|
|
|
|
|
|
|
paths.append(image_path)
|
|
|
|
|
|
|
|
return paths
|
|
|
|
|
|
|
|
|
2024-05-13 04:33:47 +00:00
|
|
|
def sanitize_name(name: str) -> str:
|
|
|
|
def valid_char(c: str) -> str:
|
|
|
|
if c.isalnum() or c in ["-", "_"]:
|
|
|
|
return c
|
|
|
|
|
|
|
|
return "-"
|
|
|
|
|
|
|
|
valid_name = "".join([valid_char(c) for c in name])
|
|
|
|
valid_name = sub(r"-+", "-", valid_name)
|
|
|
|
valid_name = valid_name.strip("-").strip("_").strip()
|
|
|
|
return valid_name.lower()
|
|
|
|
|
|
|
|
|
2024-05-14 01:08:19 +00:00
|
|
|
def fast_hash(text: str) -> str:
|
|
|
|
return hex(fnv1a_32(text.encode("utf-8")))
|
|
|
|
|
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
def get_image_prefix(event: GameEvent | WorldEntity) -> str:
|
2024-05-12 05:08:53 +00:00
|
|
|
if isinstance(event, ActionEvent):
|
2024-05-27 01:32:03 +00:00
|
|
|
return sanitize_name(f"event-action-{event.character.name}-{event.action}")
|
2024-05-12 05:08:53 +00:00
|
|
|
|
|
|
|
if isinstance(event, ReplyEvent):
|
2024-05-27 01:32:03 +00:00
|
|
|
return sanitize_name(
|
2024-05-29 00:55:32 +00:00
|
|
|
f"event-reply-{event.speaker.name}-{fast_hash(event.text)}"
|
2024-05-27 01:32:03 +00:00
|
|
|
)
|
2024-05-12 05:08:53 +00:00
|
|
|
|
|
|
|
if isinstance(event, ResultEvent):
|
2024-05-14 01:08:19 +00:00
|
|
|
return sanitize_name(
|
2024-05-27 01:32:03 +00:00
|
|
|
f"event-result-{event.character.name}-{fast_hash(event.result)}"
|
2024-05-14 01:08:19 +00:00
|
|
|
)
|
2024-05-12 05:08:53 +00:00
|
|
|
|
|
|
|
if isinstance(event, StatusEvent):
|
2024-05-14 01:08:19 +00:00
|
|
|
return sanitize_name(f"event-status-{fast_hash(event.text)}")
|
2024-05-12 05:08:53 +00:00
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
if isinstance(event, WorldEntity):
|
2024-05-13 04:33:47 +00:00
|
|
|
return sanitize_name(f"entity-{event.__class__.__name__.lower()}-{event.name}")
|
2024-05-12 05:08:53 +00:00
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
return "unknown"
|
2024-05-12 05:08:53 +00:00
|
|
|
|
|
|
|
|
|
|
|
def render_loop():
|
|
|
|
while True:
|
|
|
|
event = render_queue.get()
|
2024-05-12 20:47:18 +00:00
|
|
|
prefix = get_image_prefix(event)
|
|
|
|
|
|
|
|
# check if images already exist
|
|
|
|
image_index = 0
|
|
|
|
image_path = path.join(render_config.path, f"{prefix}-{image_index}.png")
|
|
|
|
existing_images = []
|
|
|
|
while path.exists(image_path):
|
|
|
|
existing_images.append(image_path)
|
|
|
|
image_index += 1
|
|
|
|
image_path = path.join(render_config.path, f"{prefix}-{image_index}.png")
|
|
|
|
|
|
|
|
if existing_images:
|
|
|
|
logger.info(
|
|
|
|
"using existing images for event %s: %s", event, existing_images
|
|
|
|
)
|
2024-05-27 17:03:39 +00:00
|
|
|
|
|
|
|
if isinstance(event, WorldEntity):
|
|
|
|
title = event.name # TODO: generate a real title
|
|
|
|
else:
|
|
|
|
title = event.type
|
|
|
|
|
2024-05-19 20:51:58 +00:00
|
|
|
broadcast(
|
|
|
|
RenderEvent(
|
|
|
|
paths=existing_images,
|
2024-05-27 17:03:39 +00:00
|
|
|
prompt="reusing existing images",
|
2024-05-19 20:51:58 +00:00
|
|
|
source=event,
|
2024-05-27 17:03:39 +00:00
|
|
|
title=title,
|
2024-05-19 20:51:58 +00:00
|
|
|
)
|
|
|
|
)
|
2024-05-12 20:47:18 +00:00
|
|
|
continue
|
|
|
|
|
|
|
|
# generate the prompt
|
|
|
|
if isinstance(event, WorldEntity):
|
2024-05-13 04:33:47 +00:00
|
|
|
logger.info("rendering entity %s", event.name)
|
2024-05-12 20:47:18 +00:00
|
|
|
prompt = prompt_from_entity(event)
|
2024-05-19 20:51:58 +00:00
|
|
|
title = event.name # TODO: generate a real title
|
2024-05-12 20:47:18 +00:00
|
|
|
else:
|
2024-05-13 04:33:47 +00:00
|
|
|
logger.info("rendering event %s", event.id)
|
2024-05-12 20:47:18 +00:00
|
|
|
prompt = prompt_from_event(event)
|
2024-05-19 20:51:58 +00:00
|
|
|
title = event.type # TODO: generate a real title
|
2024-05-12 20:47:18 +00:00
|
|
|
|
|
|
|
# render or not
|
2024-05-12 05:08:53 +00:00
|
|
|
if prompt:
|
2024-05-13 04:33:47 +00:00
|
|
|
logger.debug("rendering prompt for event %s: %s", event, prompt)
|
2024-05-27 17:03:39 +00:00
|
|
|
image_paths = generate_images(prompt, render_config.count, prefix=prefix)
|
2024-05-19 20:51:58 +00:00
|
|
|
broadcast(
|
|
|
|
RenderEvent(paths=image_paths, prompt=prompt, source=event, title=title)
|
|
|
|
)
|
2024-05-12 05:08:53 +00:00
|
|
|
else:
|
|
|
|
logger.warning("no prompt for event %s", event)
|
|
|
|
|
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
def render_entity(entity: WorldEntity):
|
|
|
|
render_queue.put(entity)
|
2024-05-12 05:08:53 +00:00
|
|
|
|
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
def render_event(event: GameEvent):
|
|
|
|
render_queue.put(event)
|
2024-05-12 05:08:53 +00:00
|
|
|
|
|
|
|
|
2024-05-18 21:58:11 +00:00
|
|
|
def render_generated(event: GameEvent):
|
|
|
|
if isinstance(event, GenerateEvent) and event.entity:
|
|
|
|
logger.info("rendering generated entity: %s", event.entity.name)
|
|
|
|
render_entity(event.entity)
|
|
|
|
|
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
def launch_render(config: RenderConfig):
|
|
|
|
global render_config
|
2024-05-12 05:08:53 +00:00
|
|
|
global render_thread
|
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
# update the config
|
|
|
|
logger.info("updating render config: %s", config)
|
|
|
|
render_config = config
|
|
|
|
|
|
|
|
# start the render thread
|
|
|
|
logger.info("launching render thread")
|
2024-05-12 05:08:53 +00:00
|
|
|
render_thread = Thread(target=render_loop, daemon=True)
|
|
|
|
render_thread.start()
|
|
|
|
|
|
|
|
return [render_thread]
|
|
|
|
|
|
|
|
|
2024-05-08 01:42:10 +00:00
|
|
|
if __name__ == "__main__":
|
|
|
|
paths = generate_images(
|
|
|
|
"A painting of a beautiful sunset over a calm lake", 3, "landscape"
|
|
|
|
)
|
2024-05-13 04:33:47 +00:00
|
|
|
logger.info("generated %d images: %s", len(paths), paths)
|