From 3f4b3fa32285f91ce56a769c4aaf3900c97e57a5 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 21 Feb 2023 22:12:12 -0600 Subject: [PATCH] load CLIP on training device --- api/onnx_web/convert/__main__.py | 2 +- api/onnx_web/convert/diffusion/textual_inversion.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index b1869fd2..09ce6e3d 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -234,7 +234,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): inversion_name = inversion["name"] inversion_source = inversion["source"] inversion_source = fetch_model(ctx, f"{name}-inversion-{inversion_name}", inversion_source) - convert_diffusion_textual_inversion(ctx, inversion_name, source, inversion_source) + convert_diffusion_textual_inversion(ctx, inversion_name, model["source"], inversion_source) except Exception as e: logger.error("error converting diffusion model %s: %s", name, e) diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index d125f4b2..14ffa703 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -29,11 +29,11 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b tokenizer = CLIPTokenizer.from_pretrained( base_model, subfolder="tokenizer", - ) + ).to(context.training_device) text_encoder = CLIPTextModel.from_pretrained( base_model, subfolder="text_encoder", - ) + ).to(context.training_device) loaded_embeds = torch.load(embeds_file, map_location=context.map_location)