fix(api): replace some numpy RNGs with torch equivalent
This commit is contained in:
parent
7158922039
commit
401df84069
|
@ -1,6 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
import numpy as np
|
import torch
|
||||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ def blend_img2img(
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
|
|
||||||
result = pipe(
|
result = pipe(
|
||||||
prompt,
|
prompt,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import torch
|
||||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ def blend_inpaint(
|
||||||
)
|
)
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
|
|
||||||
result = pipe(
|
result = pipe(
|
||||||
params.prompt,
|
params.prompt,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -36,7 +37,7 @@ def source_txt2img(
|
||||||
)
|
)
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
|
|
||||||
result = pipe(
|
result = pipe(
|
||||||
prompt,
|
prompt,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import torch
|
||||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ def upscale_outpaint(
|
||||||
)
|
)
|
||||||
|
|
||||||
latents = get_tile_latents(full_latents, dims)
|
latents = get_tile_latents(full_latents, dims)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
|
|
||||||
result = pipe.inpaint(
|
result = pipe.inpaint(
|
||||||
image,
|
image,
|
||||||
|
|
|
@ -19,7 +19,8 @@ latent_factor = 8
|
||||||
|
|
||||||
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
From https://www.travelneil.com/stable-diffusion-updates.html
|
From https://www.travelneil.com/stable-diffusion-updates.html.
|
||||||
|
This one needs to use np.random because of the return type.
|
||||||
"""
|
"""
|
||||||
latents_shape = (
|
latents_shape = (
|
||||||
batch,
|
batch,
|
||||||
|
|
|
@ -2,6 +2,7 @@ from logging import getLogger
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
|
|
||||||
|
@ -29,7 +30,7 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
|
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
result = pipe.text2img(
|
result = pipe.text2img(
|
||||||
|
@ -74,7 +75,7 @@ def run_img2img_pipeline(
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
|
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
result = pipe.img2img(
|
result = pipe.img2img(
|
||||||
|
|
|
@ -215,6 +215,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
|
|
||||||
seed = int(request.args.get("seed", -1))
|
seed = int(request.args.get("seed", -1))
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
|
# this one can safely use np.random because it produces a single value
|
||||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
Loading…
Reference in New Issue