wire up pipeline parameter, apply lint
This commit is contained in:
parent
2af1530a7e
commit
4aabf1ee27
|
@ -2,11 +2,9 @@ from logging import getLogger
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
|
@ -47,7 +45,7 @@ def blend_controlnet(
|
|||
image=source,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=params.strength, # TODO: ControlNet strength
|
||||
strength=params.strength, # TODO: ControlNet strength
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Optional
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
|
@ -33,7 +32,7 @@ def blend_img2img(
|
|||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params.pipeline,
|
||||
"img2img",
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Callable, Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import get_latents_from_seed, load_pipeline
|
||||
|
@ -59,6 +58,14 @@ def blend_inpaint(
|
|||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-noise.png", noise)
|
||||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
"inpaint",
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
)
|
||||
|
||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
size = Size(*tile_source.size)
|
||||
|
@ -69,14 +76,6 @@ def blend_inpaint(
|
|||
save_image(server, "tile-mask.png", tile_mask)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
)
|
||||
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
|||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Optional
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import get_latents_from_seed, load_pipeline
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Callable, Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from ..diffusers.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
||||
|
|
|
@ -7,7 +7,6 @@ import torch
|
|||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
|
@ -41,9 +40,11 @@ except ImportError:
|
|||
from ..constants import ONNX_MODEL
|
||||
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
||||
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
||||
from ..diffusers.lpw_stable_diffusion_onnx import (
|
||||
OnnxStableDiffusionLongPromptWeightingPipeline,
|
||||
)
|
||||
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
|
||||
from ..diffusers.lpw_stable_diffusion_onnx import OnnxStableDiffusionLongPromptWeightingPipeline
|
||||
from ..diffusers.utils import expand_prompt
|
||||
from ..params import DeviceParams, Size
|
||||
from ..server import ServerContext
|
||||
|
@ -150,7 +151,6 @@ def load_pipeline(
|
|||
model,
|
||||
device.device,
|
||||
device.provider,
|
||||
lpw,
|
||||
control_key,
|
||||
inversions,
|
||||
loras,
|
||||
|
@ -294,7 +294,7 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
# ControlNet component
|
||||
if control is not None:
|
||||
if pipeline == "controlnet" and control is not None:
|
||||
components["controlnet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
path.join(server.model_path, "control", f"{control.name}.onnx"),
|
||||
|
|
|
@ -3,13 +3,11 @@ from typing import Any, List
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..chain import blend_mask, upscale_outpaint
|
||||
from ..chain.base import ChainProgress
|
||||
from ..chain.utils import process_tile_order
|
||||
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from ..output import save_image, save_params
|
||||
from ..params import (
|
||||
Border,
|
||||
|
@ -231,7 +229,7 @@ def run_img2img_pipeline(
|
|||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
"img2img",
|
||||
params.pipeline, # this is one of the only places this can actually vary between different pipelines
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
|
|
|
@ -145,7 +145,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(server)
|
||||
device, params, size = pipeline_from_request(server, "img2img")
|
||||
upscale = upscale_from_request()
|
||||
|
||||
strength = get_and_clamp_float(
|
||||
|
@ -177,7 +177,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
|
||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||
device, params, size = pipeline_from_request(server)
|
||||
device, params, size = pipeline_from_request(server, "txt2img")
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
|
||||
|
@ -212,7 +212,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(server)
|
||||
device, params, size = pipeline_from_request(server, "inpaint")
|
||||
expand = border_from_request()
|
||||
upscale = upscale_from_request()
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Tuple
|
|||
import numpy as np
|
||||
from flask import request
|
||||
|
||||
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers, pipeline_schedulers
|
||||
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
||||
from ..params import (
|
||||
Border,
|
||||
DeviceParams,
|
||||
|
@ -30,7 +30,7 @@ logger = getLogger(__name__)
|
|||
|
||||
def pipeline_from_request(
|
||||
server: ServerContext,
|
||||
default_pipeline: str,
|
||||
default_pipeline: str = "txt2img",
|
||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
|
@ -49,9 +49,10 @@ def pipeline_from_request(
|
|||
control = get_from_list(request.args, "control", get_network_models())
|
||||
|
||||
# pipeline stuff
|
||||
pipeline = get_from_list(request.args, "pipeline", get_available_pipelines(), default_pipeline)
|
||||
scheduler = get_from_list(
|
||||
request.args, "scheduler", get_pipeline_schedulers())
|
||||
pipeline = get_from_list(
|
||||
request.args, "pipeline", get_available_pipelines(), default_pipeline
|
||||
)
|
||||
scheduler = get_from_list(request.args, "scheduler", get_pipeline_schedulers())
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = get_config_value("scheduler")
|
||||
|
@ -135,7 +136,6 @@ def pipeline_from_request(
|
|||
steps,
|
||||
seed,
|
||||
eta=eta,
|
||||
lpw=lpw,
|
||||
negative_prompt=negative_prompt,
|
||||
batch=batch,
|
||||
control=control,
|
||||
|
|
Loading…
Reference in New Issue