1
0
Fork 0
taleweave-ai/adventure/render_comfy.py

214 lines
6.3 KiB
Python

# This is an example that uses the websockets api to know when a prompt execution is done
# Once the prompt execution is done it downloads the images using the /history endpoint
import io
import json
import urllib.parse
import urllib.request
import uuid
from logging import getLogger
from os import environ, path
from random import choice, randint
from typing import List
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
from PIL import Image
logger = getLogger(__name__)
server_address = environ["COMFY_API"]
client_id = str(uuid.uuid4())
def generate_cfg():
return randint(5, 8)
def generate_steps():
return 30
def generate_batches(
count: int,
batch_size: int = 3,
) -> List[int]:
"""
Generate count images in batches of at most batch_size.
"""
batches = []
for i in range(0, count, batch_size):
batches.append(min(count - i, batch_size))
return batches
def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(
"http://{}/view?{}".format(server_address, url_values)
) as response:
return response.read()
def get_history(prompt_id):
with urllib.request.urlopen(
"http://{}/history/{}".format(server_address, prompt_id)
) as response:
return json.loads(response.read())
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)["prompt_id"]
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = get_history(prompt_id)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
images_output = []
for image in node_output["images"]:
image_data = get_image(
image["filename"], image["subfolder"], image["type"]
)
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
def generate_image_tool(prompt, count, size="landscape"):
output_paths = []
for i, count in enumerate(generate_batches(count)):
results = generate_images(prompt, count, size, prefix=f"output-{i}")
output_paths.extend(results)
return output_paths
sizes = {
"landscape": (1024, 768),
"portrait": (768, 1024),
"square": (768, 768),
}
def generate_images(
prompt: str, count: int, size="landscape", prefix="output"
) -> List[str]:
cfg = generate_cfg()
width, height = sizes.get(size, (512, 512))
steps = generate_steps()
seed = randint(0, 10000000)
checkpoint = choice(["diffusion-sdxl-dynavision-0-5-5-7.safetensors"])
logger.info(
"generating %s images at %s by %s with prompt: %s", count, width, height, prompt
)
# parsing here helps ensure the template emits valid JSON
prompt_workflow = {
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": cfg,
"denoise": 1,
"latent_image": ["5", 0],
"model": ["4", 0],
"negative": ["7", 0],
"positive": ["6", 0],
"sampler_name": "euler_ancestral",
"scheduler": "normal",
"seed": seed,
"steps": steps,
},
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {"ckpt_name": checkpoint},
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {"batch_size": count, "height": height, "width": width},
},
"6": {
"class_type": "smZ CLIPTextEncode",
"inputs": {
"text": prompt,
"parser": "compel",
"mean_normalization": True,
"multi_conditioning": True,
"use_old_emphasis_implementation": False,
"with_SDXL": False,
"ascore": 6,
"width": width,
"height": height,
"crop_w": 0,
"crop_h": 0,
"target_width": width,
"target_height": height,
"text_g": "",
"text_l": "",
"smZ_steps": 1,
"clip": ["4", 1],
},
},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "", "clip": ["4", 1]}},
"8": {
"class_type": "VAEDecode",
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
},
"9": {
"class_type": "SaveImage",
"inputs": {"filename_prefix": prefix, "images": ["8", 0]},
},
}
logger.debug("Connecting to Comfy API at %s", server_address)
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt_workflow)
results = []
for node_id in images:
for image_data in images[node_id]:
image = Image.open(io.BytesIO(image_data))
results.append(image)
paths: List[str] = []
for j, image in enumerate(results):
image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png")
with open(image_path, "wb") as f:
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
f.write(image_bytes.getvalue())
paths.append(image_path)
return paths
if __name__ == "__main__":
paths = generate_images(
"A painting of a beautiful sunset over a calm lake", 3, "landscape"
)
logger.info("Generated %d images: %s", len(paths), paths)