196 lines
5.8 KiB
Python
196 lines
5.8 KiB
Python
|
# This is an example that uses the websockets api to know when a prompt execution is done
|
||
|
# Once the prompt execution is done it downloads the images using the /history endpoint
|
||
|
|
||
|
import io
|
||
|
import json
|
||
|
import urllib.parse
|
||
|
import urllib.request
|
||
|
import uuid
|
||
|
from logging import getLogger
|
||
|
from os import environ, path
|
||
|
from random import choice, randint
|
||
|
from typing import List
|
||
|
|
||
|
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||
|
from PIL import Image
|
||
|
|
||
|
logger = getLogger(__name__)
|
||
|
|
||
|
server_address = environ["COMFY_API"]
|
||
|
client_id = str(uuid.uuid4())
|
||
|
|
||
|
|
||
|
def generate_cfg():
|
||
|
return randint(5, 8)
|
||
|
|
||
|
|
||
|
def generate_steps():
|
||
|
return 30
|
||
|
|
||
|
|
||
|
def generate_batches(
|
||
|
count: int,
|
||
|
batch_size: int = 3,
|
||
|
) -> List[int]:
|
||
|
"""
|
||
|
Generate count images in batches of at most batch_size.
|
||
|
"""
|
||
|
|
||
|
batches = []
|
||
|
for i in range(0, count, batch_size):
|
||
|
batches.append(min(count - i, batch_size))
|
||
|
|
||
|
return batches
|
||
|
|
||
|
|
||
|
def queue_prompt(prompt):
|
||
|
p = {"prompt": prompt, "client_id": client_id}
|
||
|
data = json.dumps(p).encode("utf-8")
|
||
|
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
||
|
return json.loads(urllib.request.urlopen(req).read())
|
||
|
|
||
|
|
||
|
def get_image(filename, subfolder, folder_type):
|
||
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||
|
url_values = urllib.parse.urlencode(data)
|
||
|
with urllib.request.urlopen(
|
||
|
"http://{}/view?{}".format(server_address, url_values)
|
||
|
) as response:
|
||
|
return response.read()
|
||
|
|
||
|
|
||
|
def get_history(prompt_id):
|
||
|
with urllib.request.urlopen(
|
||
|
"http://{}/history/{}".format(server_address, prompt_id)
|
||
|
) as response:
|
||
|
return json.loads(response.read())
|
||
|
|
||
|
|
||
|
def get_images(ws, prompt):
|
||
|
prompt_id = queue_prompt(prompt)["prompt_id"]
|
||
|
output_images = {}
|
||
|
while True:
|
||
|
out = ws.recv()
|
||
|
if isinstance(out, str):
|
||
|
message = json.loads(out)
|
||
|
if message["type"] == "executing":
|
||
|
data = message["data"]
|
||
|
if data["node"] is None and data["prompt_id"] == prompt_id:
|
||
|
break # Execution is done
|
||
|
else:
|
||
|
continue # previews are binary data
|
||
|
|
||
|
history = get_history(prompt_id)[prompt_id]
|
||
|
for o in history["outputs"]:
|
||
|
for node_id in history["outputs"]:
|
||
|
node_output = history["outputs"][node_id]
|
||
|
if "images" in node_output:
|
||
|
images_output = []
|
||
|
for image in node_output["images"]:
|
||
|
image_data = get_image(
|
||
|
image["filename"], image["subfolder"], image["type"]
|
||
|
)
|
||
|
images_output.append(image_data)
|
||
|
output_images[node_id] = images_output
|
||
|
|
||
|
return output_images
|
||
|
|
||
|
|
||
|
def generate_image_tool(prompt, count, size="landscape"):
|
||
|
output_paths = []
|
||
|
for i, count in enumerate(generate_batches(count)):
|
||
|
results = generate_images(prompt, count, size, prefix=f"output-{i}")
|
||
|
output_paths.extend(results)
|
||
|
|
||
|
return output_paths
|
||
|
|
||
|
|
||
|
sizes = {
|
||
|
"landscape": (1024, 768),
|
||
|
"portrait": (768, 1024),
|
||
|
"square": (768, 768),
|
||
|
}
|
||
|
|
||
|
|
||
|
def generate_images(
|
||
|
prompt: str, count: int, size="landscape", prefix="output"
|
||
|
) -> List[str]:
|
||
|
cfg = generate_cfg()
|
||
|
width, height = sizes.get(size, (512, 512))
|
||
|
steps = generate_steps()
|
||
|
seed = randint(0, 10000000)
|
||
|
checkpoint = choice(["diffusion-sdxl-dynavision-0-5-5-7.safetensors"])
|
||
|
logger.info(
|
||
|
"generating %s images at %s by %s with prompt: %s", count, width, height, prompt
|
||
|
)
|
||
|
|
||
|
# parsing here helps ensure the template emits valid JSON
|
||
|
prompt_workflow = {
|
||
|
"3": {
|
||
|
"class_type": "KSampler",
|
||
|
"inputs": {
|
||
|
"cfg": cfg,
|
||
|
"denoise": 1,
|
||
|
"latent_image": ["5", 0],
|
||
|
"model": ["4", 0],
|
||
|
"negative": ["7", 0],
|
||
|
"positive": ["6", 0],
|
||
|
"sampler_name": "euler_ancestral",
|
||
|
"scheduler": "normal",
|
||
|
"seed": seed,
|
||
|
"steps": steps,
|
||
|
},
|
||
|
},
|
||
|
"4": {
|
||
|
"class_type": "CheckpointLoaderSimple",
|
||
|
"inputs": {"ckpt_name": checkpoint},
|
||
|
},
|
||
|
"5": {
|
||
|
"class_type": "EmptyLatentImage",
|
||
|
"inputs": {"batch_size": count, "height": height, "width": width},
|
||
|
},
|
||
|
"6": {
|
||
|
"class_type": "CLIPTextEncode",
|
||
|
"inputs": {"text": prompt, "clip": ["4", 1]},
|
||
|
},
|
||
|
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "", "clip": ["4", 1]}},
|
||
|
"8": {
|
||
|
"class_type": "VAEDecode",
|
||
|
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
|
||
|
},
|
||
|
"9": {
|
||
|
"class_type": "SaveImage",
|
||
|
"inputs": {"filename_prefix": prefix, "images": ["8", 0]},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
logger.debug("Connecting to Comfy API at %s", server_address)
|
||
|
ws = websocket.WebSocket()
|
||
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||
|
images = get_images(ws, prompt_workflow)
|
||
|
|
||
|
results = []
|
||
|
for node_id in images:
|
||
|
for image_data in images[node_id]:
|
||
|
image = Image.open(io.BytesIO(image_data))
|
||
|
results.append(image)
|
||
|
|
||
|
paths: List[str] = []
|
||
|
for j, image in enumerate(results):
|
||
|
image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png")
|
||
|
with open(image_path, "wb") as f:
|
||
|
image_bytes = io.BytesIO()
|
||
|
image.save(image_bytes, format="PNG")
|
||
|
f.write(image_bytes.getvalue())
|
||
|
|
||
|
paths.append(image_path)
|
||
|
|
||
|
return paths
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
paths = generate_images(
|
||
|
"A painting of a beautiful sunset over a calm lake", 3, "landscape"
|
||
|
)
|
||
|
logger.info("Generated %d images: %s", len(paths), paths)
|