1
0
Fork 0

fix(api): switch RNG based on LPW parameter

This commit is contained in:
Sean Sube 2023-02-05 17:24:08 -06:00
parent c47209cfbf
commit f3983a7917
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 23 additions and 7 deletions

View File

@ -1,6 +1,7 @@
from logging import getLogger from logging import getLogger
import torch import torch
import numpy as np
from diffusers import OnnxStableDiffusionImg2ImgPipeline from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image from PIL import Image
@ -35,8 +36,9 @@ def blend_img2img(
) )
if params.lpw: if params.lpw:
pipe = pipe.img2img pipe = pipe.img2img
rng = torch.manual_seed(params.seed)
rng = torch.manual_seed(params.seed) else:
rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
prompt, prompt,

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import Callable, Tuple from typing import Callable, Tuple
import torch import torch
import numpy as np
from diffusers import OnnxStableDiffusionInpaintPipeline from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image from PIL import Image
@ -68,9 +69,11 @@ def blend_inpaint(
) )
if params.lpw: if params.lpw:
pipe = pipe.inpaint pipe = pipe.inpaint
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = torch.manual_seed(params.seed)
result = pipe( result = pipe(
params.prompt, params.prompt,

View File

@ -1,6 +1,7 @@
from logging import getLogger from logging import getLogger
import torch import torch
import numpy as np
from diffusers import OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionPipeline
from PIL import Image from PIL import Image
@ -36,9 +37,11 @@ def source_txt2img(
) )
if params.lpw: if params.lpw:
pipe = pipe.text2img pipe = pipe.text2img
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = torch.manual_seed(params.seed)
result = pipe( result = pipe(
prompt, prompt,

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import Callable, Tuple from typing import Callable, Tuple
import torch import torch
import numpy as np
from diffusers import OnnxStableDiffusionInpaintPipeline from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
@ -73,9 +74,11 @@ def upscale_outpaint(
) )
if params.lpw: if params.lpw:
pipe = pipe.inpaint pipe = pipe.inpaint
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
latents = get_tile_latents(full_latents, dims) latents = get_tile_latents(full_latents, dims)
rng = torch.manual_seed(params.seed)
result = pipe( result = pipe(
image, image,

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import Any from typing import Any
import torch import torch
import numpy as np
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops from PIL import Image, ImageChops
@ -29,9 +30,11 @@ def run_txt2img_pipeline(
) )
if params.lpw: if params.lpw:
pipe = pipe.text2img pipe = pipe.text2img
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback() progress = job.get_progress_callback()
result = pipe( result = pipe(
@ -78,8 +81,10 @@ def run_img2img_pipeline(
) )
if params.lpw: if params.lpw:
pipe = pipe.img2img pipe = pipe.img2img
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback() progress = job.get_progress_callback()
result = pipe( result = pipe(