1
0
Fork 0

fix(api): replace some numpy RNGs with torch equivalent

This commit is contained in:
Sean Sube 2023-02-05 13:43:33 -06:00
parent 7158922039
commit 401df84069
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 14 additions and 10 deletions

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
import numpy as np import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image from PIL import Image
@ -33,7 +33,7 @@ def blend_img2img(
job.get_device(), job.get_device(),
) )
rng = np.random.RandomState(params.seed) rng = torch.manual_seed(params.seed)
result = pipe( result = pipe(
prompt, prompt,

View File

@ -1,7 +1,7 @@
from logging import getLogger from logging import getLogger
from typing import Callable, Tuple from typing import Callable, Tuple
import numpy as np import torch
from diffusers import OnnxStableDiffusionInpaintPipeline from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image from PIL import Image
@ -67,7 +67,7 @@ def blend_inpaint(
) )
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = torch.manual_seed(params.seed)
result = pipe( result = pipe(
params.prompt, params.prompt,

View File

@ -1,6 +1,7 @@
from logging import getLogger from logging import getLogger
import numpy as np import numpy as np
import torch
from diffusers import OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionPipeline
from PIL import Image from PIL import Image
@ -36,7 +37,7 @@ def source_txt2img(
) )
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = torch.manual_seed(params.seed)
result = pipe( result = pipe(
prompt, prompt,

View File

@ -1,7 +1,7 @@
from logging import getLogger from logging import getLogger
from typing import Callable, Tuple from typing import Callable, Tuple
import numpy as np import torch
from diffusers import OnnxStableDiffusionInpaintPipeline from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
@ -73,7 +73,7 @@ def upscale_outpaint(
) )
latents = get_tile_latents(full_latents, dims) latents = get_tile_latents(full_latents, dims)
rng = np.random.RandomState(params.seed) rng = torch.manual_seed(params.seed)
result = pipe.inpaint( result = pipe.inpaint(
image, image,

View File

@ -19,7 +19,8 @@ latent_factor = 8
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
""" """
From https://www.travelneil.com/stable-diffusion-updates.html From https://www.travelneil.com/stable-diffusion-updates.html.
This one needs to use np.random because of the return type.
""" """
latents_shape = ( latents_shape = (
batch, batch,

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import Any from typing import Any
import numpy as np import numpy as np
import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops from PIL import Image, ImageChops
@ -29,7 +30,7 @@ def run_txt2img_pipeline(
) )
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback() progress = job.get_progress_callback()
result = pipe.text2img( result = pipe.text2img(
@ -74,7 +75,7 @@ def run_img2img_pipeline(
job.get_device(), job.get_device(),
) )
rng = np.random.RandomState(params.seed) rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback() progress = job.get_progress_callback()
result = pipe.img2img( result = pipe.img2img(

View File

@ -215,6 +215,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
seed = int(request.args.get("seed", -1)) seed = int(request.args.get("seed", -1))
if seed == -1: if seed == -1:
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max) seed = np.random.randint(np.iinfo(np.int32).max)
logger.info( logger.info(