feat(api): add tile size and stride to image parameters
This commit is contained in:
parent
746e33b1f5
commit
95725fff79
|
@ -31,6 +31,7 @@ def blend_controlnet(
|
|||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
"controlnet",
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -33,6 +33,7 @@ def blend_img2img(
|
|||
pipe_type = "lpw" if params.lpw() else "img2img"
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -62,6 +62,7 @@ def blend_inpaint(
|
|||
pipe_type = "lpw" if params.lpw() else "inpaint"
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -34,6 +34,7 @@ def blend_pix2pix(
|
|||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
"pix2pix",
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -40,6 +40,7 @@ def source_txt2img(
|
|||
pipe_type = "lpw" if params.lpw() else "txt2img"
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -82,6 +82,7 @@ def upscale_outpaint(
|
|||
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
@ -125,9 +126,7 @@ def upscale_outpaint(
|
|||
|
||||
if params.pipeline == "panorama":
|
||||
logger.debug("outpainting with one shot panorama, no tiling")
|
||||
return outpaint(
|
||||
source, (source.width, source.height, max(source.width, source.height))
|
||||
)
|
||||
return outpaint(source, (0, 0, max(source.width, source.height)))
|
||||
if overlap == 0:
|
||||
logger.debug("outpainting with 0 margin, using grid tiling")
|
||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
||||
|
|
|
@ -37,6 +37,7 @@ def upscale_stable_diffusion(
|
|||
|
||||
pipeline = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
"upscale",
|
||||
path.join(server.model_path, upscale.upscale_model),
|
||||
params.scheduler,
|
||||
|
|
|
@ -10,8 +10,7 @@ from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
|||
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
||||
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||
from ..diffusers.utils import expand_prompt
|
||||
from ..models.meta import NetworkModel
|
||||
from ..params import DeviceParams
|
||||
from ..params import DeviceParams, ImageParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
from .patches.unet import UNetWrapper
|
||||
|
@ -93,20 +92,20 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
|||
|
||||
def load_pipeline(
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
pipeline: str,
|
||||
model: str,
|
||||
scheduler_name: str,
|
||||
device: DeviceParams,
|
||||
control: Optional[NetworkModel] = None,
|
||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
):
|
||||
inversions = inversions or []
|
||||
loras = loras or []
|
||||
control_key = control.name if control is not None else None
|
||||
model = params.model
|
||||
|
||||
torch_dtype = server.torch_dtype()
|
||||
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
||||
|
||||
control_key = params.control.name if params.control is not None else None
|
||||
pipe_key = (
|
||||
pipeline,
|
||||
model,
|
||||
|
@ -116,8 +115,8 @@ def load_pipeline(
|
|||
inversions,
|
||||
loras,
|
||||
)
|
||||
scheduler_key = (scheduler_name, model)
|
||||
scheduler_type = pipeline_schedulers[scheduler_name]
|
||||
scheduler_key = (params.scheduler, model)
|
||||
scheduler_type = pipeline_schedulers[params.scheduler]
|
||||
|
||||
cache_pipe = server.cache.get("diffusion", pipe_key)
|
||||
|
||||
|
@ -164,8 +163,10 @@ def load_pipeline(
|
|||
unet_type = "unet"
|
||||
|
||||
# ControlNet component
|
||||
if pipeline == "controlnet" and control is not None:
|
||||
cnet_path = path.join(server.model_path, "control", f"{control.name}.onnx")
|
||||
if pipeline == "controlnet" and params.control is not None:
|
||||
cnet_path = path.join(
|
||||
server.model_path, "control", f"{params.control.name}.onnx"
|
||||
)
|
||||
logger.debug("loading ControlNet weights from %s", cnet_path)
|
||||
components["controlnet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
|
@ -317,6 +318,11 @@ def load_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# additional options for panorama pipeline
|
||||
if pipeline == "panorama":
|
||||
components["window"] = params.tiles
|
||||
components["stride"] = params.stride()
|
||||
|
||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||
pipe = pipeline_class.from_pretrained(
|
||||
|
@ -333,12 +339,12 @@ def load_pipeline(
|
|||
|
||||
optimize_pipeline(server, pipe)
|
||||
|
||||
# TODO: CPU VAE, etc
|
||||
# TODO: remove this, not relevant with ONNX
|
||||
if device is not None and hasattr(pipe, "to"):
|
||||
pipe = pipe.to(device.torch_str())
|
||||
|
||||
# monkey-patch pipeline
|
||||
patch_pipeline(server, pipe, pipeline)
|
||||
patch_pipeline(server, pipe, pipeline_class, params)
|
||||
|
||||
server.cache.set("diffusion", pipe_key, pipe)
|
||||
server.cache.set("scheduler", scheduler_key, components["scheduler"])
|
||||
|
@ -402,6 +408,7 @@ def patch_pipeline(
|
|||
server: ServerContext,
|
||||
pipe: StableDiffusionPipeline,
|
||||
pipeline: Any,
|
||||
params: ImageParams,
|
||||
) -> None:
|
||||
logger.debug("patching SD pipeline")
|
||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||
|
@ -411,9 +418,21 @@ def patch_pipeline(
|
|||
|
||||
if hasattr(pipe, "vae_decoder"):
|
||||
original_decoder = pipe.vae_decoder
|
||||
pipe.vae_decoder = VAEWrapper(server, original_decoder, decoder=True)
|
||||
pipe.vae_decoder = VAEWrapper(
|
||||
server,
|
||||
original_decoder,
|
||||
decoder=True,
|
||||
tiles=params.tiles,
|
||||
stride=params.stride(),
|
||||
)
|
||||
original_encoder = pipe.vae_encoder
|
||||
pipe.vae_encoder = VAEWrapper(server, original_encoder, decoder=False)
|
||||
pipe.vae_encoder = VAEWrapper(
|
||||
server,
|
||||
original_encoder,
|
||||
decoder=False,
|
||||
tiles=params.tiles,
|
||||
stride=params.stride(),
|
||||
)
|
||||
elif hasattr(pipe, "vae"):
|
||||
pass # TODO: current wrapper does not work with upscaling VAE
|
||||
else:
|
||||
|
|
|
@ -12,8 +12,6 @@ from ...server import ServerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
LATENT_CHANNELS = 4
|
||||
LATENT_SIZE = 32
|
||||
SAMPLE_SIZE = 256
|
||||
|
||||
# TODO: does this need to change for fp16 modes?
|
||||
timestep_dtype = np.float32
|
||||
|
@ -25,14 +23,23 @@ def set_vae_dtype(dtype):
|
|||
|
||||
|
||||
class VAEWrapper(object):
|
||||
def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool):
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
wrapped: OnnxRuntimeModel,
|
||||
decoder: bool,
|
||||
tiles: int,
|
||||
stride: int,
|
||||
):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
self.decoder = decoder
|
||||
self.tiles = tiles
|
||||
self.stride = stride
|
||||
|
||||
self.tile_sample_min_size = SAMPLE_SIZE
|
||||
self.tile_latent_min_size = LATENT_SIZE
|
||||
self.tile_overlap_factor = 0.25
|
||||
self.tile_latent_min_size = tiles
|
||||
self.tile_sample_min_size = tiles * 8
|
||||
self.tile_overlap_factor = stride / tiles
|
||||
|
||||
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
||||
global timestep_dtype
|
||||
|
@ -52,10 +59,16 @@ class VAEWrapper(object):
|
|||
logger.debug("converting VAE sample dtype")
|
||||
sample = sample.astype(timestep_dtype)
|
||||
|
||||
if self.decoder:
|
||||
return self.tiled_decode(latent_sample, **kwargs)
|
||||
if self.tiles is not None and self.stride is not None:
|
||||
if self.decoder:
|
||||
return self.tiled_decode(latent_sample, **kwargs)
|
||||
else:
|
||||
return self.tiled_encode(sample, **kwargs)
|
||||
else:
|
||||
return self.tiled_encode(sample, **kwargs)
|
||||
if self.decoder:
|
||||
return self.wrapped(latent_sample=latent_sample)
|
||||
else:
|
||||
return self.wrapped(sample=sample)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
|
|
@ -33,6 +33,9 @@ logger = logging.get_logger(__name__)
|
|||
NUM_UNET_INPUT_CHANNELS = 9
|
||||
NUM_LATENT_CHANNELS = 4
|
||||
|
||||
DEFAULT_WINDOW = 32
|
||||
DEFAULT_STRIDE = 8
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
|
@ -105,9 +108,14 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
window: Optional[int] = None,
|
||||
stride: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.window = window or DEFAULT_WINDOW
|
||||
self.stride = stride or DEFAULT_STRIDE
|
||||
|
||||
if (
|
||||
hasattr(scheduler.config, "steps_offset")
|
||||
and scheduler.config.steps_offset != 1
|
||||
|
@ -338,7 +346,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def get_views(self, panorama_height, panorama_width, window_size=32, stride=8):
|
||||
def get_views(self, panorama_height, panorama_width, window_size, stride):
|
||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||
panorama_height /= 8
|
||||
panorama_width /= 8
|
||||
|
@ -514,7 +522,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# panorama additions
|
||||
views = self.get_views(height, width)
|
||||
views = self.get_views(height, width, self.window, self.stride)
|
||||
count = np.zeros_like(latents)
|
||||
value = np.zeros_like(latents)
|
||||
|
||||
|
@ -816,7 +824,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# panorama additions
|
||||
views = self.get_views(height, width)
|
||||
views = self.get_views(height, width, self.window, self.stride)
|
||||
count = np.zeros_like(latents)
|
||||
value = np.zeros_like(latents)
|
||||
|
||||
|
@ -1124,7 +1132,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# panorama additions
|
||||
views = self.get_views(height, width)
|
||||
views = self.get_views(height, width, self.window, self.stride)
|
||||
count = np.zeros_like(latents)
|
||||
value = np.zeros_like(latents)
|
||||
|
||||
|
|
|
@ -55,9 +55,8 @@ def run_loopback(
|
|||
|
||||
pipe = pipeline or load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
|
@ -140,9 +139,8 @@ def run_highres(
|
|||
|
||||
highres_pipe = pipeline or load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
|
@ -242,9 +240,8 @@ def run_txt2img_pipeline(
|
|||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
|
@ -350,11 +347,9 @@ def run_img2img_pipeline(
|
|||
pipe_type = params.get_valid_pipeline("img2img")
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
control=params.control,
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
|
|
@ -191,6 +191,9 @@ class ImageParams:
|
|||
input_prompt: str
|
||||
input_negative_prompt: str
|
||||
loopback: int
|
||||
tiled_vae: bool
|
||||
tiles: int
|
||||
overlap: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -208,6 +211,9 @@ class ImageParams:
|
|||
input_prompt: Optional[str] = None,
|
||||
input_negative_prompt: Optional[str] = None,
|
||||
loopback: int = 0,
|
||||
tiled_vae: bool = False,
|
||||
tiles: int = 512,
|
||||
overlap: float = 0.25,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.pipeline = pipeline
|
||||
|
@ -223,6 +229,9 @@ class ImageParams:
|
|||
self.input_prompt = input_prompt or prompt
|
||||
self.input_negative_prompt = input_negative_prompt or negative_prompt
|
||||
self.loopback = loopback
|
||||
self.tiled_vae = tiled_vae
|
||||
self.tiles = tiles
|
||||
self.overlap = overlap
|
||||
|
||||
def do_cfg(self):
|
||||
return self.cfg > 1.0
|
||||
|
@ -251,6 +260,9 @@ class ImageParams:
|
|||
def lpw(self):
|
||||
return self.pipeline == "lpw"
|
||||
|
||||
def stride(self):
|
||||
return int(self.tiles * self.stride)
|
||||
|
||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||
return {
|
||||
"model": self.model,
|
||||
|
@ -267,6 +279,9 @@ class ImageParams:
|
|||
"input_prompt": self.input_prompt,
|
||||
"input_negative_prompt": self.input_negative_prompt,
|
||||
"loopback": self.loopback,
|
||||
"tiled_vae": self.tiled_vae,
|
||||
"tiles": self.tiles,
|
||||
"overlap": self.overlap,
|
||||
}
|
||||
|
||||
def with_args(self, **kwargs):
|
||||
|
@ -285,6 +300,9 @@ class ImageParams:
|
|||
kwargs.get("input_prompt", self.input_prompt),
|
||||
kwargs.get("input_negative_prompt", self.input_negative_prompt),
|
||||
kwargs.get("loopback", self.loopback),
|
||||
kwargs.get("tiled_vae", self.tiled_vae),
|
||||
kwargs.get("tiles", self.tiles),
|
||||
kwargs.get("overlap", self.overlap),
|
||||
)
|
||||
|
||||
|
||||
|
@ -405,14 +423,12 @@ class HighresParams:
|
|||
strength: float,
|
||||
method: Literal["bilinear", "lanczos", "upscale"] = "lanczos",
|
||||
iterations: int = 1,
|
||||
tiled_vae: bool = False,
|
||||
):
|
||||
self.scale = scale
|
||||
self.steps = steps
|
||||
self.strength = strength
|
||||
self.method = method
|
||||
self.iterations = iterations
|
||||
self.tiled_vae = tiled_vae
|
||||
|
||||
def resize(self, size: Size) -> Size:
|
||||
return Size(
|
||||
|
@ -427,5 +443,4 @@ class HighresParams:
|
|||
"scale": self.scale,
|
||||
"steps": self.steps,
|
||||
"strength": self.strength,
|
||||
"tiled_vae": self.tiled_vae,
|
||||
}
|
||||
|
|
|
@ -125,6 +125,21 @@ def pipeline_from_request(
|
|||
get_config_value("width", "max"),
|
||||
get_config_value("width", "min"),
|
||||
)
|
||||
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE"))
|
||||
tiles = get_and_clamp_int(
|
||||
request.args,
|
||||
"tiles",
|
||||
get_config_value("tiles"),
|
||||
get_config_value("tiles", "max"),
|
||||
get_config_value("tiles", "min"),
|
||||
)
|
||||
overlap = get_and_clamp_float(
|
||||
request.args,
|
||||
"overlap",
|
||||
get_config_value("overlap"),
|
||||
get_config_value("overlap", "max"),
|
||||
get_config_value("overlap", "min"),
|
||||
)
|
||||
|
||||
seed = int(request.args.get("seed", -1))
|
||||
if seed == -1:
|
||||
|
@ -159,6 +174,9 @@ def pipeline_from_request(
|
|||
batch=batch,
|
||||
control=control,
|
||||
loopback=loopback,
|
||||
tiled_vae=tiled_vae,
|
||||
tiles=tiles,
|
||||
overlap=overlap,
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
@ -279,17 +297,13 @@ def highres_from_request() -> HighresParams:
|
|||
request.args,
|
||||
"highresStrength",
|
||||
get_config_value("highresStrength"),
|
||||
get_config_value("highresStrength"),
|
||||
get_config_value("highresStrength", "max"),
|
||||
get_config_value("highresStrength", "min"),
|
||||
)
|
||||
|
||||
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE"))
|
||||
|
||||
return HighresParams(
|
||||
scale,
|
||||
steps,
|
||||
strength,
|
||||
method=method,
|
||||
iterations=iterations,
|
||||
tiled_vae=tiled_vae,
|
||||
)
|
||||
|
|
|
@ -130,6 +130,12 @@
|
|||
"max": 4,
|
||||
"step": 1
|
||||
},
|
||||
"overlap": {
|
||||
"default": 0.25,
|
||||
"min": 0.0,
|
||||
"max": 0.9,
|
||||
"step": 0.01
|
||||
},
|
||||
"pipeline": {
|
||||
"default": "",
|
||||
"keys": [
|
||||
|
@ -183,6 +189,12 @@
|
|||
"tiledVAE": {
|
||||
"default": false
|
||||
},
|
||||
"tiles": {
|
||||
"default": 512,
|
||||
"min": 128,
|
||||
"max": 1024,
|
||||
"step": 128
|
||||
},
|
||||
"tileOrder": {
|
||||
"default": "spiral",
|
||||
"keys": [
|
||||
|
|
|
@ -128,6 +128,12 @@
|
|||
"max": 4,
|
||||
"step": 1
|
||||
},
|
||||
"overlap": {
|
||||
"default": 0.25,
|
||||
"min": 0.0,
|
||||
"max": 0.9,
|
||||
"step": 0.01
|
||||
},
|
||||
"pipeline": {
|
||||
"default": "",
|
||||
"keys": [
|
||||
|
@ -181,6 +187,12 @@
|
|||
"tiledVAE": {
|
||||
"default": false
|
||||
},
|
||||
"tiles": {
|
||||
"default": 512,
|
||||
"min": 128,
|
||||
"max": 1024,
|
||||
"step": 128
|
||||
},
|
||||
"tileOrder": {
|
||||
"default": "spiral",
|
||||
"keys": [
|
||||
|
|
Loading…
Reference in New Issue