1
0
Fork 0

feat(api): enable LPW custom pipeline (#27)

This commit is contained in:
Sean Sube 2023-02-04 17:44:54 -06:00
parent d636ce3eef
commit 70dedf811a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 8 additions and 6 deletions

View File

@ -92,14 +92,14 @@ def upscale_outpaint(
latents = get_tile_latents(full_latents, dims) latents = get_tile_latents(full_latents, dims)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe.inpaint(
image,
mask,
prompt, prompt,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
height=size.height, height=size.height,
image=image,
latents=latents, latents=latents,
mask_image=mask,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,

View File

@ -63,6 +63,8 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
logger.debug('loading new diffusion pipeline from %s', model) logger.debug('loading new diffusion pipeline from %s', model)
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,
custom_pipeline='lpw_stable_diffusion_onnx',
revision='onnx',
provider=provider, provider=provider,
safety_checker=None, safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler') scheduler=scheduler.from_pretrained(model, subfolder='scheduler')

View File

@ -56,7 +56,7 @@ def run_txt2img_pipeline(
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
progress = job.get_progress_callback() progress = job.get_progress_callback()
result = pipe( result = pipe.txt2img(
params.prompt, params.prompt,
height=size.height, height=size.height,
width=size.width, width=size.width,
@ -97,11 +97,11 @@ def run_img2img_pipeline(
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
progress = job.get_progress_callback() progress = job.get_progress_callback()
result = pipe( result = pipe.img2img(
source_image,
params.prompt, params.prompt,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
image=source_image,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
strength=strength, strength=strength,