1
0
Fork 0

remove references to /tmp

This commit is contained in:
Sean Sube 2024-01-28 20:50:22 -06:00
parent 7cfe619b4a
commit 8f7ef6dfce
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 5 additions and 3 deletions

View File

@ -362,7 +362,7 @@ def export_unet(pipeline, output_path, unet_sample_size=1024):
) )
def load_and_export(source="stabilityai/sd-x2-latent-upscaler"): def load_and_export(output, source="stabilityai/sd-x2-latent-upscaler"):
from pathlib import Path from pathlib import Path
from diffusers import StableDiffusionLatentUpscalePipeline from diffusers import StableDiffusionLatentUpscalePipeline
@ -370,11 +370,12 @@ def load_and_export(source="stabilityai/sd-x2-latent-upscaler"):
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained( upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
source, torch_dtype=torch.float32 source, torch_dtype=torch.float32
) )
export_unet(upscaler, Path("/tmp/latent-upscaler")) export_unet(upscaler, Path(output))
def load_and_run( def load_and_run(
prompt, prompt,
output,
source="stabilityai/sd-x2-latent-upscaler", source="stabilityai/sd-x2-latent-upscaler",
checkpoint="../models/stable-diffusion-onnx-v1-5", checkpoint="../models/stable-diffusion-onnx-v1-5",
): ):
@ -396,7 +397,7 @@ def load_and_run(
# run # run
result = highres.text2img(prompt, num_inference_steps=25, num_upscale_steps=25) result = highres.text2img(prompt, num_inference_steps=25, num_upscale_steps=25)
image = result.images[0] image = result.images[0]
image.save("/tmp/highres.png") image.save(output)
class RetorchModel: class RetorchModel:

View File

@ -95,6 +95,7 @@ def save_result(
thumbnails = [] thumbnails = []
for image, filename in zip(images, result.thumbnails): for image, filename in zip(images, result.thumbnails):
# TODO: only make a thumbnail if the image is larger than the thumbnail size
thumbnail = image.copy() thumbnail = image.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))