1
0
Fork 0

fix(api): replace some numpy RNGs with torch equivalent

This commit is contained in:
Sean Sube 2023-02-05 13:43:33 -06:00
parent 7158922039
commit 401df84069
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 14 additions and 10 deletions

View File

@ -1,6 +1,6 @@
from logging import getLogger
import numpy as np
import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image
@ -33,7 +33,7 @@ def blend_img2img(
job.get_device(),
)
rng = np.random.RandomState(params.seed)
rng = torch.manual_seed(params.seed)
result = pipe(
prompt,

View File

@ -1,7 +1,7 @@
from logging import getLogger
from typing import Callable, Tuple
import numpy as np
import torch
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image
@ -67,7 +67,7 @@ def blend_inpaint(
)
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
rng = torch.manual_seed(params.seed)
result = pipe(
params.prompt,

View File

@ -1,6 +1,7 @@
from logging import getLogger
import numpy as np
import torch
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image
@ -36,7 +37,7 @@ def source_txt2img(
)
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
rng = torch.manual_seed(params.seed)
result = pipe(
prompt,

View File

@ -1,7 +1,7 @@
from logging import getLogger
from typing import Callable, Tuple
import numpy as np
import torch
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
@ -73,7 +73,7 @@ def upscale_outpaint(
)
latents = get_tile_latents(full_latents, dims)
rng = np.random.RandomState(params.seed)
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
image,

View File

@ -19,7 +19,8 @@ latent_factor = 8
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
"""
From https://www.travelneil.com/stable-diffusion-updates.html
From https://www.travelneil.com/stable-diffusion-updates.html.
This one needs to use np.random because of the return type.
"""
latents_shape = (
batch,

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import Any
import numpy as np
import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops
@ -29,7 +30,7 @@ def run_txt2img_pipeline(
)
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback()
result = pipe.text2img(
@ -74,7 +75,7 @@ def run_img2img_pipeline(
job.get_device(),
)
rng = np.random.RandomState(params.seed)
rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback()
result = pipe.img2img(

View File

@ -215,6 +215,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
seed = int(request.args.get("seed", -1))
if seed == -1:
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info(