1
0
Fork 0

fix(api): switch RNG based on LPW parameter

This commit is contained in:
Sean Sube 2023-02-05 17:24:08 -06:00
parent c47209cfbf
commit f3983a7917
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 23 additions and 7 deletions

View File

@ -1,6 +1,7 @@
from logging import getLogger
import torch
import numpy as np
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image
@ -35,8 +36,9 @@ def blend_img2img(
)
if params.lpw:
pipe = pipe.img2img
rng = torch.manual_seed(params.seed)
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
prompt,

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import Callable, Tuple
import torch
import numpy as np
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image
@ -68,9 +69,11 @@ def blend_inpaint(
)
if params.lpw:
pipe = pipe.inpaint
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
latents = get_latents_from_seed(params.seed, size)
rng = torch.manual_seed(params.seed)
result = pipe(
params.prompt,

View File

@ -1,6 +1,7 @@
from logging import getLogger
import torch
import numpy as np
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image
@ -36,9 +37,11 @@ def source_txt2img(
)
if params.lpw:
pipe = pipe.text2img
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
latents = get_latents_from_seed(params.seed, size)
rng = torch.manual_seed(params.seed)
result = pipe(
prompt,

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import Callable, Tuple
import torch
import numpy as np
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
@ -73,9 +74,11 @@ def upscale_outpaint(
)
if params.lpw:
pipe = pipe.inpaint
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
latents = get_tile_latents(full_latents, dims)
rng = torch.manual_seed(params.seed)
result = pipe(
image,

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import Any
import torch
import numpy as np
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops
@ -29,9 +30,11 @@ def run_txt2img_pipeline(
)
if params.lpw:
pipe = pipe.text2img
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
latents = get_latents_from_seed(params.seed, size)
rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback()
result = pipe(
@ -78,8 +81,10 @@ def run_img2img_pipeline(
)
if params.lpw:
pipe = pipe.img2img
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback()
result = pipe(