remove references to /tmp
This commit is contained in:
parent
7cfe619b4a
commit
8f7ef6dfce
|
@ -362,7 +362,7 @@ def export_unet(pipeline, output_path, unet_sample_size=1024):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_and_export(source="stabilityai/sd-x2-latent-upscaler"):
|
def load_and_export(output, source="stabilityai/sd-x2-latent-upscaler"):
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from diffusers import StableDiffusionLatentUpscalePipeline
|
from diffusers import StableDiffusionLatentUpscalePipeline
|
||||||
|
@ -370,11 +370,12 @@ def load_and_export(source="stabilityai/sd-x2-latent-upscaler"):
|
||||||
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
|
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
|
||||||
source, torch_dtype=torch.float32
|
source, torch_dtype=torch.float32
|
||||||
)
|
)
|
||||||
export_unet(upscaler, Path("/tmp/latent-upscaler"))
|
export_unet(upscaler, Path(output))
|
||||||
|
|
||||||
|
|
||||||
def load_and_run(
|
def load_and_run(
|
||||||
prompt,
|
prompt,
|
||||||
|
output,
|
||||||
source="stabilityai/sd-x2-latent-upscaler",
|
source="stabilityai/sd-x2-latent-upscaler",
|
||||||
checkpoint="../models/stable-diffusion-onnx-v1-5",
|
checkpoint="../models/stable-diffusion-onnx-v1-5",
|
||||||
):
|
):
|
||||||
|
@ -396,7 +397,7 @@ def load_and_run(
|
||||||
# run
|
# run
|
||||||
result = highres.text2img(prompt, num_inference_steps=25, num_upscale_steps=25)
|
result = highres.text2img(prompt, num_inference_steps=25, num_upscale_steps=25)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image.save("/tmp/highres.png")
|
image.save(output)
|
||||||
|
|
||||||
|
|
||||||
class RetorchModel:
|
class RetorchModel:
|
||||||
|
|
|
@ -95,6 +95,7 @@ def save_result(
|
||||||
|
|
||||||
thumbnails = []
|
thumbnails = []
|
||||||
for image, filename in zip(images, result.thumbnails):
|
for image, filename in zip(images, result.thumbnails):
|
||||||
|
# TODO: only make a thumbnail if the image is larger than the thumbnail size
|
||||||
thumbnail = image.copy()
|
thumbnail = image.copy()
|
||||||
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue