1
0
Fork 0

wire up pipeline parameter, apply lint

This commit is contained in:
Sean Sube 2023-04-12 23:11:53 -05:00
parent 2af1530a7e
commit 4aabf1ee27
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 24 additions and 33 deletions

View File

@ -2,11 +2,9 @@ from logging import getLogger
from typing import Optional
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
@ -47,7 +45,7 @@ def blend_controlnet(
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength, # TODO: ControlNet strength
strength=params.strength, # TODO: ControlNet strength
callback=callback,
)

View File

@ -3,7 +3,6 @@ from typing import Optional
import numpy as np
import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image
from ..diffusers.load import load_pipeline
@ -33,7 +32,7 @@ def blend_img2img(
pipe = load_pipeline(
server,
params.pipeline,
"img2img",
params.model,
params.scheduler,
job.get_device(),

View File

@ -3,7 +3,6 @@ from typing import Callable, Optional, Tuple
import numpy as np
import torch
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image
from ..diffusers.load import get_latents_from_seed, load_pipeline
@ -59,6 +58,14 @@ def blend_inpaint(
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
pipe = load_pipeline(
server,
"inpaint",
params.model,
params.scheduler,
job.get_device(),
)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
@ -69,14 +76,6 @@ def blend_inpaint(
save_image(server, "tile-mask.png", tile_mask)
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
server,
OnnxStableDiffusionInpaintPipeline,
params.model,
params.scheduler,
job.get_device(),
)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)

View File

@ -6,7 +6,6 @@ import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext

View File

@ -3,7 +3,6 @@ from typing import Optional
import numpy as np
import torch
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image
from ..diffusers.load import get_latents_from_seed, load_pipeline

View File

@ -3,7 +3,6 @@ from typing import Callable, Optional, Tuple
import numpy as np
import torch
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
from ..diffusers.load import get_latents_from_seed, get_tile_latents, load_pipeline

View File

@ -7,7 +7,6 @@ import torch
from diffusers import (
DDIMScheduler,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
@ -41,9 +40,11 @@ except ImportError:
from ..constants import ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.lpw_stable_diffusion_onnx import (
OnnxStableDiffusionLongPromptWeightingPipeline,
)
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from ..diffusers.lpw_stable_diffusion_onnx import OnnxStableDiffusionLongPromptWeightingPipeline
from ..diffusers.utils import expand_prompt
from ..params import DeviceParams, Size
from ..server import ServerContext
@ -150,7 +151,6 @@ def load_pipeline(
model,
device.device,
device.provider,
lpw,
control_key,
inversions,
loras,
@ -294,7 +294,7 @@ def load_pipeline(
)
# ControlNet component
if control is not None:
if pipeline == "controlnet" and control is not None:
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
path.join(server.model_path, "control", f"{control.name}.onnx"),

View File

@ -3,13 +3,11 @@ from typing import Any, List
import numpy as np
import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image
from ..chain import blend_mask, upscale_outpaint
from ..chain.base import ChainProgress
from ..chain.utils import process_tile_order
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ..output import save_image, save_params
from ..params import (
Border,
@ -231,7 +229,7 @@ def run_img2img_pipeline(
pipe = load_pipeline(
server,
"img2img",
params.pipeline, # this is one of the only places this can actually vary between different pipelines
params.model,
params.scheduler,
job.get_device(),

View File

@ -145,7 +145,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(server)
device, params, size = pipeline_from_request(server, "img2img")
upscale = upscale_from_request()
strength = get_and_clamp_float(
@ -177,7 +177,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server)
device, params, size = pipeline_from_request(server, "txt2img")
upscale = upscale_from_request()
highres = highres_from_request()
@ -212,7 +212,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request(server)
device, params, size = pipeline_from_request(server, "inpaint")
expand = border_from_request()
upscale = upscale_from_request()

View File

@ -4,7 +4,7 @@ from typing import Tuple
import numpy as np
from flask import request
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers, pipeline_schedulers
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
from ..params import (
Border,
DeviceParams,
@ -30,7 +30,7 @@ logger = getLogger(__name__)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str,
default_pipeline: str = "txt2img",
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
@ -49,9 +49,10 @@ def pipeline_from_request(
control = get_from_list(request.args, "control", get_network_models())
# pipeline stuff
pipeline = get_from_list(request.args, "pipeline", get_available_pipelines(), default_pipeline)
scheduler = get_from_list(
request.args, "scheduler", get_pipeline_schedulers())
pipeline = get_from_list(
request.args, "pipeline", get_available_pipelines(), default_pipeline
)
scheduler = get_from_list(request.args, "scheduler", get_pipeline_schedulers())
if scheduler is None:
scheduler = get_config_value("scheduler")
@ -135,7 +136,6 @@ def pipeline_from_request(
steps,
seed,
eta=eta,
lpw=lpw,
negative_prompt=negative_prompt,
batch=batch,
control=control,