diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 5d0da615..d51aa590 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -1,6 +1,6 @@ from logging import getLogger -import numpy as np +import torch from diffusers import OnnxStableDiffusionImg2ImgPipeline from PIL import Image @@ -33,7 +33,7 @@ def blend_img2img( job.get_device(), ) - rng = np.random.RandomState(params.seed) + rng = torch.manual_seed(params.seed) result = pipe( prompt, diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index d59459c3..7b23405a 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import Callable, Tuple -import numpy as np +import torch from diffusers import OnnxStableDiffusionInpaintPipeline from PIL import Image @@ -67,7 +67,7 @@ def blend_inpaint( ) latents = get_latents_from_seed(params.seed, size) - rng = np.random.RandomState(params.seed) + rng = torch.manual_seed(params.seed) result = pipe( params.prompt, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 6122999e..76072cce 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -1,6 +1,7 @@ from logging import getLogger import numpy as np +import torch from diffusers import OnnxStableDiffusionPipeline from PIL import Image @@ -36,7 +37,7 @@ def source_txt2img( ) latents = get_latents_from_seed(params.seed, size) - rng = np.random.RandomState(params.seed) + rng = torch.manual_seed(params.seed) result = pipe( prompt, diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 4f6ecbb5..253a881c 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import Callable, Tuple -import numpy as np +import torch from diffusers import OnnxStableDiffusionInpaintPipeline from PIL import Image, ImageDraw @@ -73,7 +73,7 @@ def upscale_outpaint( ) latents = get_tile_latents(full_latents, dims) - rng = np.random.RandomState(params.seed) + rng = torch.manual_seed(params.seed) result = pipe.inpaint( image, diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 3a530bd0..808ab7c0 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -19,7 +19,8 @@ latent_factor = 8 def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: """ - From https://www.travelneil.com/stable-diffusion-updates.html + From https://www.travelneil.com/stable-diffusion-updates.html. + This one needs to use np.random because of the return type. """ latents_shape = ( batch, diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 78c6e6b0..a9794091 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -2,6 +2,7 @@ from logging import getLogger from typing import Any import numpy as np +import torch from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from PIL import Image, ImageChops @@ -29,7 +30,7 @@ def run_txt2img_pipeline( ) latents = get_latents_from_seed(params.seed, size) - rng = np.random.RandomState(params.seed) + rng = torch.manual_seed(params.seed) progress = job.get_progress_callback() result = pipe.text2img( @@ -74,7 +75,7 @@ def run_img2img_pipeline( job.get_device(), ) - rng = np.random.RandomState(params.seed) + rng = torch.manual_seed(params.seed) progress = job.get_progress_callback() result = pipe.img2img( diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index ea62b330..28680c88 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -215,6 +215,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: seed = int(request.args.get("seed", -1)) if seed == -1: + # this one can safely use np.random because it produces a single value seed = np.random.randint(np.iinfo(np.int32).max) logger.info(