diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index adace06f..cc9fd56b 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -23,7 +23,7 @@ class StageCallback(Protocol): def __call__( self, job: WorkerContext, - ctx: ServerContext, + server: ServerContext, stage: StageParams, params: ImageParams, source: Image.Image, diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 6107476e..6c275789 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -140,7 +140,7 @@ base_models: Models = { def fetch_model( - ctx: ConversionContext, + conversion: ConversionContext, name: str, source: str, dest: Optional[str] = None, @@ -148,7 +148,7 @@ def fetch_model( hf_hub_fetch: bool = False, hf_hub_filename: Optional[str] = None, ) -> str: - cache_path = dest or ctx.cache_path + cache_path = dest or conversion.cache_path cache_name = path.join(cache_path, name) # add an extension if possible, some of the conversion code checks for it @@ -201,7 +201,7 @@ def fetch_model( return source -def convert_models(ctx: ConversionContext, args, models: Models): +def convert_models(conversion: ConversionContext, args, models: Models): if args.sources and "sources" in models: for model in models.get("sources"): model = tuple_to_source(model) @@ -214,7 +214,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): source = model["source"] try: - dest = fetch_model(ctx, name, source, format=model_format) + dest = fetch_model(conversion, name, source, format=model_format) logger.info("finished downloading source: %s -> %s", source, dest) except Exception: logger.exception("error fetching source %s", name) @@ -234,20 +234,20 @@ def convert_models(ctx: ConversionContext, args, models: Models): try: if network_type == "inversion" and network_model == "concept": dest = fetch_model( - ctx, + conversion, name, source, - dest=path.join(ctx.model_path, network_type), + dest=path.join(conversion.model_path, network_type), format=network_format, hf_hub_fetch=True, hf_hub_filename="learned_embeds.bin", ) else: dest = fetch_model( - ctx, + conversion, name, source, - dest=path.join(ctx.model_path, network_type), + dest=path.join(conversion.model_path, network_type), format=network_format, ) @@ -267,19 +267,19 @@ def convert_models(ctx: ConversionContext, args, models: Models): try: source = fetch_model( - ctx, name, model["source"], format=model_format + conversion, name, model["source"], format=model_format ) converted = False if model_format in model_formats_original: converted, dest = convert_diffusion_original( - ctx, + conversion, model, source, ) else: converted, dest = convert_diffusion_diffusers( - ctx, + conversion, model, source, ) @@ -289,8 +289,8 @@ def convert_models(ctx: ConversionContext, args, models: Models): # keep track of which models have been blended blend_models = {} - inversion_dest = path.join(ctx.model_path, "inversion") - lora_dest = path.join(ctx.model_path, "lora") + inversion_dest = path.join(conversion.model_path, "inversion") + lora_dest = path.join(conversion.model_path, "lora") for inversion in model.get("inversions", []): if "text_encoder" not in blend_models: @@ -314,7 +314,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): inversion_source = inversion["source"] inversion_format = inversion.get("format", None) inversion_source = fetch_model( - ctx, + conversion, inversion_name, inversion_source, dest=inversion_dest, @@ -323,7 +323,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): inversion_weight = inversion.get("weight", 1.0) blend_textual_inversions( - ctx, + conversion, blend_models["text_encoder"], blend_models["tokenizer"], [ @@ -355,7 +355,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): lora_name = lora["name"] lora_source = lora["source"] lora_source = fetch_model( - ctx, + conversion, f"{name}-lora-{lora_name}", lora_source, dest=lora_dest, @@ -363,14 +363,14 @@ def convert_models(ctx: ConversionContext, args, models: Models): lora_weight = lora.get("weight", 1.0) blend_loras( - ctx, + conversion, blend_models["text_encoder"], [(lora_source, lora_weight)], "text_encoder", ) blend_loras( - ctx, + conversion, blend_models["unet"], [(lora_source, lora_weight)], "unet", @@ -413,9 +413,9 @@ def convert_models(ctx: ConversionContext, args, models: Models): try: source = fetch_model( - ctx, name, model["source"], format=model_format + conversion, name, model["source"], format=model_format ) - convert_upscale_resrgan(ctx, model, source) + convert_upscale_resrgan(conversion, model, source) except Exception: logger.exception( "error converting upscaling model %s", @@ -433,9 +433,9 @@ def convert_models(ctx: ConversionContext, args, models: Models): model_format = source_format(model) try: source = fetch_model( - ctx, name, model["source"], format=model_format + conversion, name, model["source"], format=model_format ) - convert_correction_gfpgan(ctx, model, source) + convert_correction_gfpgan(conversion, model, source) except Exception: logger.exception( "error converting correction model %s", @@ -482,21 +482,21 @@ def main() -> int: args = parser.parse_args() logger.info("CLI arguments: %s", args) - ctx = ConversionContext.from_environ() - ctx.half = args.half or "onnx-fp16" in ctx.optimizations - ctx.opset = args.opset - ctx.token = args.token - logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device) + server = ConversionContext.from_environ() + server.half = args.half or "onnx-fp16" in server.optimizations + server.opset = args.opset + server.token = args.token + logger.info("converting models in %s using %s", server.model_path, server.training_device) - if not path.exists(ctx.model_path): - logger.info("model path does not existing, creating: %s", ctx.model_path) - makedirs(ctx.model_path) + if not path.exists(server.model_path): + logger.info("model path does not existing, creating: %s", server.model_path) + makedirs(server.model_path) logger.info("converting base models") - convert_models(ctx, args, base_models) + convert_models(server, args, base_models) extras = [] - extras.extend(ctx.extra_models) + extras.extend(server.extra_models) extras.extend(args.extras) extras = list(set(extras)) extras.sort() @@ -516,7 +516,7 @@ def main() -> int: try: validate(data, extra_schema) logger.info("converting extra models") - convert_models(ctx, args, data) + convert_models(server, args, data) except ValidationError: logger.exception("invalid data in extras file") except Exception: diff --git a/api/onnx_web/convert/correction_gfpgan.py b/api/onnx_web/convert/correction_gfpgan.py index 957c6f17..a56243a7 100644 --- a/api/onnx_web/convert/correction_gfpgan.py +++ b/api/onnx_web/convert/correction_gfpgan.py @@ -12,7 +12,7 @@ logger = getLogger(__name__) @torch.no_grad() def convert_correction_gfpgan( - ctx: ConversionContext, + conversion: ConversionContext, model: ModelDict, source: str, ): @@ -20,7 +20,7 @@ def convert_correction_gfpgan( source = source or model.get("source") scale = model.get("scale") - dest = path.join(ctx.model_path, name + ".onnx") + dest = path.join(conversion.model_path, name + ".onnx") logger.info("converting GFPGAN model: %s -> %s", name, dest) if path.isfile(dest): @@ -37,17 +37,17 @@ def convert_correction_gfpgan( scale=scale, ) - torch_model = torch.load(source, map_location=ctx.map_location) + torch_model = torch.load(source, map_location=conversion.map_location) # TODO: make sure strict=False is safe here if "params_ema" in torch_model: model.load_state_dict(torch_model["params_ema"], strict=False) else: model.load_state_dict(torch_model["params"], strict=False) - model.to(ctx.training_device).train(False) + model.to(conversion.training_device).train(False) model.eval() - rng = torch.rand(1, 3, 64, 64, device=ctx.map_location) + rng = torch.rand(1, 3, 64, 64, device=conversion.map_location) input_names = ["data"] output_names = ["output"] dynamic_axes = { @@ -63,7 +63,7 @@ def convert_correction_gfpgan( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=ctx.opset, + opset_version=conversion.opset, export_params=True, ) logger.info("GFPGAN exported to ONNX successfully") diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 1d88fc3b..473d5215 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -90,7 +90,7 @@ def onnx_export( @torch.no_grad() def convert_diffusion_diffusers( - ctx: ConversionContext, + conversion: ConversionContext, model: Dict, source: str, ) -> Tuple[bool, str]: @@ -102,10 +102,10 @@ def convert_diffusion_diffusers( single_vae = model.get("single_vae") replace_vae = model.get("vae") - dtype = ctx.torch_dtype() + dtype = conversion.torch_dtype() logger.debug("using Torch dtype %s for pipeline", dtype) - dest_path = path.join(ctx.model_path, name) + dest_path = path.join(conversion.model_path, name) model_index = path.join(dest_path, "model_index.json") # diffusers go into a directory rather than .onnx file @@ -123,11 +123,11 @@ def convert_diffusion_diffusers( pipeline = StableDiffusionPipeline.from_pretrained( source, torch_dtype=dtype, - use_auth_token=ctx.token, - ).to(ctx.training_device) + use_auth_token=conversion.token, + ).to(conversion.training_device) output_path = Path(dest_path) - optimize_pipeline(ctx, pipeline) + optimize_pipeline(conversion, pipeline) # TEXT ENCODER num_tokens = pipeline.text_encoder.config.max_position_embeddings @@ -143,11 +143,11 @@ def convert_diffusion_diffusers( pipeline.text_encoder, # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files model_args=( - text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32), + text_input.input_ids.to(device=conversion.training_device, dtype=torch.int32), None, # attention mask None, # position ids None, # output attentions - torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool), + torch.tensor(True).to(device=conversion.training_device, dtype=torch.bool), ), output_path=output_path / "text_encoder" / ONNX_MODEL, ordered_input_names=["input_ids"], @@ -155,8 +155,8 @@ def convert_diffusion_diffusers( dynamic_axes={ "input_ids": {0: "batch", 1: "sequence"}, }, - opset=ctx.opset, - half=ctx.half, + opset=conversion.opset, + half=conversion.half, ) del pipeline.text_encoder @@ -165,11 +165,11 @@ def convert_diffusion_diffusers( # UNET if single_vae: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"] - unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.long) + unet_scale = torch.tensor(4).to(device=conversion.training_device, dtype=torch.long) else: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_scale = torch.tensor(False).to( - device=ctx.training_device, dtype=torch.bool + device=conversion.training_device, dtype=torch.bool ) if is_torch_2_0: @@ -182,11 +182,11 @@ def convert_diffusion_diffusers( pipeline.unet, model_args=( torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to( - device=ctx.training_device, dtype=dtype + device=conversion.training_device, dtype=dtype ), - torch.randn(2).to(device=ctx.training_device, dtype=dtype), + torch.randn(2).to(device=conversion.training_device, dtype=dtype), torch.randn(2, num_tokens, text_hidden_size).to( - device=ctx.training_device, dtype=dtype + device=conversion.training_device, dtype=dtype ), unet_scale, ), @@ -199,8 +199,8 @@ def convert_diffusion_diffusers( "timestep": {0: "batch"}, "encoder_hidden_states": {0: "batch", 1: "sequence"}, }, - opset=ctx.opset, - half=ctx.half, + opset=conversion.opset, + half=conversion.half, external_data=True, ) unet_model_path = str(unet_path.absolute().as_posix()) @@ -238,7 +238,7 @@ def convert_diffusion_diffusers( model_args=( torch.randn( 1, vae_latent_channels, unet_sample_size, unet_sample_size - ).to(device=ctx.training_device, dtype=dtype), + ).to(device=conversion.training_device, dtype=dtype), False, ), output_path=output_path / "vae" / ONNX_MODEL, @@ -247,8 +247,8 @@ def convert_diffusion_diffusers( dynamic_axes={ "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, - opset=ctx.opset, - half=ctx.half, + opset=conversion.opset, + half=conversion.half, ) else: # VAE ENCODER @@ -263,7 +263,7 @@ def convert_diffusion_diffusers( vae_encoder, model_args=( torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( - device=ctx.training_device, dtype=dtype + device=conversion.training_device, dtype=dtype ), False, ), @@ -273,7 +273,7 @@ def convert_diffusion_diffusers( dynamic_axes={ "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, - opset=ctx.opset, + opset=conversion.opset, half=False, # https://github.com/ssube/onnx-web/issues/290 ) @@ -287,7 +287,7 @@ def convert_diffusion_diffusers( model_args=( torch.randn( 1, vae_latent_channels, unet_sample_size, unet_sample_size - ).to(device=ctx.training_device, dtype=dtype), + ).to(device=conversion.training_device, dtype=dtype), False, ), output_path=output_path / "vae_decoder" / ONNX_MODEL, @@ -296,8 +296,8 @@ def convert_diffusion_diffusers( dynamic_axes={ "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, - opset=ctx.opset, - half=ctx.half, + opset=conversion.opset, + half=conversion.half, ) del pipeline.vae diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index d2b84255..3f6395f4 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -55,7 +55,7 @@ def fix_node_name(key: str): def blend_loras( - _context: ServerContext, + _conversion: ServerContext, base_name: Union[str, ModelProto], loras: List[Tuple[str, float]], model_type: Literal["text_encoder", "unet"], diff --git a/api/onnx_web/convert/diffusion/original.py b/api/onnx_web/convert/diffusion/original.py index 7d4f6300..f4e40780 100644 --- a/api/onnx_web/convert/diffusion/original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -146,7 +146,7 @@ class TrainingConfig: def __init__( self, - ctx: ConversionContext, + conversion: ConversionContext, model_name: str = "", scheduler: str = "ddim", v2: bool = False, @@ -155,7 +155,7 @@ class TrainingConfig: **kwargs, ): model_name = sanitize_name(model_name) - model_dir = os.path.join(ctx.cache_path, model_name) + model_dir = os.path.join(conversion.cache_path, model_name) working_dir = os.path.join(model_dir, "working") if not os.path.exists(working_dir): @@ -1298,7 +1298,7 @@ def download_model(db_config: TrainingConfig, token): def get_config_path( - context: ConversionContext, + conversion: ConversionContext, model_version: str = "v1", train_type: str = "default", config_base_name: str = "training", @@ -1309,7 +1309,7 @@ def get_config_path( ) parts = os.path.join( - context.model_path, + conversion.model_path, "configs", f"{model_version}-{config_base_name}-{train_type}.yaml", ) @@ -1317,7 +1317,7 @@ def get_config_path( def get_config_file( - context: ConversionContext, + conversion: ConversionContext, train_unfrozen=False, v2=False, prediction_type="epsilon", @@ -1343,7 +1343,7 @@ def get_config_file( model_train_type = train_types["default"] return get_config_path( - context, + conversion, model_version_name, model_train_type, config_base_name, @@ -1352,7 +1352,7 @@ def get_config_file( def extract_checkpoint( - context: ConversionContext, + conversion: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", @@ -1396,7 +1396,7 @@ def extract_checkpoint( # Create empty config db_config = TrainingConfig( - context, + conversion, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file, @@ -1442,7 +1442,7 @@ def extract_checkpoint( prediction_type = "epsilon" original_config_file = get_config_file( - context, train_unfrozen, v2, prediction_type, config_file=config_file + conversion, train_unfrozen, v2, prediction_type, config_file=config_file ) logger.info( @@ -1533,7 +1533,7 @@ def extract_checkpoint( checkpoint, vae_config ) else: - vae_file = os.path.join(context.model_path, vae_file) + vae_file = os.path.join(conversion.model_path, vae_file) logger.debug("loading custom VAE: %s", vae_file) vae_checkpoint = load_tensor(vae_file, map_location=map_location) converted_vae_checkpoint = convert_ldm_vae_checkpoint( @@ -1658,14 +1658,14 @@ def extract_checkpoint( @torch.no_grad() def convert_diffusion_original( - ctx: ConversionContext, + conversion: ConversionContext, model: ModelDict, source: str, ) -> Tuple[bool, str]: name = model["name"] source = source or model["source"] - dest_path = os.path.join(ctx.model_path, name) + dest_path = os.path.join(conversion.model_path, name) dest_index = os.path.join(dest_path, "model_index.json") logger.info( "converting original Diffusers checkpoint %s: %s -> %s", name, source, dest_path @@ -1676,8 +1676,8 @@ def convert_diffusion_original( return (False, dest_path) torch_name = name + "-torch" - torch_path = os.path.join(ctx.cache_path, torch_name) - working_name = os.path.join(ctx.cache_path, torch_name, "working") + torch_path = os.path.join(conversion.cache_path, torch_name) + working_name = os.path.join(conversion.cache_path, torch_name, "working") model_index = os.path.join(working_name, "model_index.json") if os.path.exists(torch_path) and os.path.exists(model_index): @@ -1689,7 +1689,7 @@ def convert_diffusion_original( torch_path, ) if extract_checkpoint( - ctx, + conversion, torch_name, source, config_file=model.get("config"), @@ -1704,9 +1704,9 @@ def convert_diffusion_original( if "vae" in model: del model["vae"] - result = convert_diffusion_diffusers(ctx, model, working_name) + result = convert_diffusion_diffusers(conversion, model, working_name) - if "torch" in ctx.prune: + if "torch" in conversion.prune: logger.info("removing intermediate Torch models: %s", torch_path) shutil.rmtree(torch_path) diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 336074e4..6a30173e 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -16,7 +16,7 @@ logger = getLogger(__name__) @torch.no_grad() def blend_textual_inversions( - context: ServerContext, + server: ServerContext, text_encoder: ModelProto, tokenizer: CLIPTokenizer, inversions: List[Tuple[str, float, Optional[str], Optional[str]]], @@ -161,7 +161,7 @@ def blend_textual_inversions( @torch.no_grad() def convert_diffusion_textual_inversion( - context: ConversionContext, + conversion: ConversionContext, name: str, base_model: str, inversion: str, @@ -169,7 +169,7 @@ def convert_diffusion_textual_inversion( base_token: Optional[str] = None, inversion_weight: Optional[float] = 1.0, ): - dest_path = path.join(context.model_path, f"inversion-{name}") + dest_path = path.join(conversion.model_path, f"inversion-{name}") logger.info( "converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path ) @@ -194,7 +194,7 @@ def convert_diffusion_textual_inversion( subfolder="tokenizer", ) text_encoder, tokenizer = blend_textual_inversions( - context, + conversion, text_encoder, tokenizer, [(inversion, inversion_weight, base_token, inversion_format)], diff --git a/api/onnx_web/convert/upscale_resrgan.py b/api/onnx_web/convert/upscale_resrgan.py index ef44c790..666df601 100644 --- a/api/onnx_web/convert/upscale_resrgan.py +++ b/api/onnx_web/convert/upscale_resrgan.py @@ -13,7 +13,7 @@ TAG_X4_V3 = "real-esrgan-x4-v3" @torch.no_grad() def convert_upscale_resrgan( - ctx: ConversionContext, + conversion: ConversionContext, model: ModelDict, source: str, ): @@ -24,7 +24,7 @@ def convert_upscale_resrgan( source = source or model.get("source") scale = model.get("scale") - dest = path.join(ctx.model_path, name + ".onnx") + dest = path.join(conversion.model_path, name + ".onnx") logger.info("converting Real ESRGAN model: %s -> %s", name, dest) if path.isfile(dest): @@ -53,16 +53,16 @@ def convert_upscale_resrgan( scale=scale, ) - torch_model = torch.load(source, map_location=ctx.map_location) + torch_model = torch.load(source, map_location=conversion.map_location) if "params_ema" in torch_model: model.load_state_dict(torch_model["params_ema"]) else: model.load_state_dict(torch_model["params"], strict=False) - model.to(ctx.training_device).train(False) + model.to(conversion.training_device).train(False) model.eval() - rng = torch.rand(1, 3, 64, 64, device=ctx.map_location) + rng = torch.rand(1, 3, 64, 64, device=conversion.map_location) input_names = ["data"] output_names = ["output"] dynamic_axes = { @@ -78,7 +78,7 @@ def convert_upscale_resrgan( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=ctx.opset, + opset_version=conversion.opset, export_params=True, ) logger.info("real ESRGAN exported to ONNX successfully") diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 453b8101..013882fc 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -64,7 +64,7 @@ def json_params( def make_output_name( - ctx: ServerContext, + server: ServerContext, mode: str, params: ImageParams, size: Size, @@ -92,20 +92,20 @@ def make_output_name( hash_value(sha, param) return [ - f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{ctx.image_format}" + f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}" for i in range(params.batch) ] -def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str: - path = base_join(ctx.output_path, output) - image.save(path, format=ctx.image_format) +def save_image(server: ServerContext, output: str, image: Image.Image) -> str: + path = base_join(server.output_path, output) + image.save(path, format=server.image_format) logger.debug("saved output image to: %s", path) return path def save_params( - ctx: ServerContext, + server: ServerContext, output: str, params: ImageParams, size: Size, @@ -113,7 +113,7 @@ def save_params( border: Optional[Border] = None, highres: Optional[HighresParams] = None, ) -> str: - path = base_join(ctx.output_path, f"{output}.json") + path = base_join(server.output_path, f"{output}.json") json = json_params( output, params, size, upscale=upscale, border=border, highres=highres ) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 3fa92467..c5d83264 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -7,6 +7,10 @@ from .torch_before_ort import GraphOptimizationLevel, SessionOptions logger = getLogger(__name__) +Param = Union[str, int, float] +Point = Tuple[int, int] + + class SizeChart(IntEnum): mini = 128 # small tile for very expensive models half = 256 # half tile for outpainting @@ -25,10 +29,6 @@ class TileOrder: spiral = "spiral" -Param = Union[str, int, float] -Point = Tuple[int, int] - - class Border: def __init__(self, left: int, right: int, top: int, bottom: int) -> None: self.left = left @@ -37,7 +37,7 @@ class Border: self.bottom = bottom def __str__(self) -> str: - return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom) + return "(%s, %s, %s, %s)" % (self.left, self.top, self.right, self.bottom) def tojson(self): return { @@ -145,6 +145,7 @@ class DeviceParams: return sess def torch_str(self) -> str: + # TODO: return cuda devices for ROCm as well if self.device.startswith("cuda"): return self.device else: diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 41233c11..b587fff8 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -93,7 +93,7 @@ def url_from_rule(rule) -> str: return url_for(rule.endpoint, **options) -def introspect(context: ServerContext, app: Flask): +def introspect(server: ServerContext, app: Flask): return { "name": "onnx-web", "routes": [ @@ -103,15 +103,15 @@ def introspect(context: ServerContext, app: Flask): } -def list_extra_strings(context: ServerContext): +def list_extra_strings(server: ServerContext): return jsonify(get_extra_strings()) -def list_mask_filters(context: ServerContext): +def list_mask_filters(server: ServerContext): return jsonify(list(get_mask_filters().keys())) -def list_models(context: ServerContext): +def list_models(server: ServerContext): return jsonify( { "correction": get_correction_models(), @@ -122,30 +122,30 @@ def list_models(context: ServerContext): ) -def list_noise_sources(context: ServerContext): +def list_noise_sources(server: ServerContext): return jsonify(list(get_noise_sources().keys())) -def list_params(context: ServerContext): +def list_params(server: ServerContext): return jsonify(get_config_params()) -def list_platforms(context: ServerContext): +def list_platforms(server: ServerContext): return jsonify([p.device for p in get_available_platforms()]) -def list_schedulers(context: ServerContext): +def list_schedulers(server: ServerContext): return jsonify(list(get_pipeline_schedulers().keys())) -def img2img(context: ServerContext, pool: DevicePoolExecutor): +def img2img(server: ServerContext, pool: DevicePoolExecutor): source_file = request.files.get("source") if source_file is None: return error_reply("source image is required") source = Image.open(BytesIO(source_file.read())).convert("RGB") - device, params, size = pipeline_from_request(context) + device, params, size = pipeline_from_request(server) upscale = upscale_from_request() strength = get_and_clamp_float( @@ -156,7 +156,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor): get_config_value("strength", "min"), ) - output = make_output_name(context, "img2img", params, size, extras=[strength]) + output = make_output_name(server, "img2img", params, size, extras=[strength]) job_name = output[0] logger.info("img2img job queued for: %s", job_name) @@ -164,7 +164,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor): pool.submit( job_name, run_img2img_pipeline, - context, + server, params, output, upscale, @@ -176,19 +176,19 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor): return jsonify(json_params(output, params, size, upscale=upscale)) -def txt2img(context: ServerContext, pool: DevicePoolExecutor): - device, params, size = pipeline_from_request(context) +def txt2img(server: ServerContext, pool: DevicePoolExecutor): + device, params, size = pipeline_from_request(server) upscale = upscale_from_request() highres = highres_from_request() - output = make_output_name(context, "txt2img", params, size) + output = make_output_name(server, "txt2img", params, size) job_name = output[0] logger.info("txt2img job queued for: %s", job_name) pool.submit( job_name, run_txt2img_pipeline, - context, + server, params, size, output, @@ -200,7 +200,7 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor): return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) -def inpaint(context: ServerContext, pool: DevicePoolExecutor): +def inpaint(server: ServerContext, pool: DevicePoolExecutor): source_file = request.files.get("source") if source_file is None: return error_reply("source image is required") @@ -212,7 +212,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor): source = Image.open(BytesIO(source_file.read())).convert("RGB") mask = Image.open(BytesIO(mask_file.read())).convert("RGB") - device, params, size = pipeline_from_request(context) + device, params, size = pipeline_from_request(server) expand = border_from_request() upscale = upscale_from_request() @@ -224,7 +224,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor): ) output = make_output_name( - context, + server, "inpaint", params, size, @@ -247,7 +247,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor): pool.submit( job_name, run_inpaint_pipeline, - context, + server, params, size, output, @@ -265,17 +265,17 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor): return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) -def upscale(context: ServerContext, pool: DevicePoolExecutor): +def upscale(server: ServerContext, pool: DevicePoolExecutor): source_file = request.files.get("source") if source_file is None: return error_reply("source image is required") source = Image.open(BytesIO(source_file.read())).convert("RGB") - device, params, size = pipeline_from_request(context) + device, params, size = pipeline_from_request(server) upscale = upscale_from_request() - output = make_output_name(context, "upscale", params, size) + output = make_output_name(server, "upscale", params, size) job_name = output[0] logger.info("upscale job queued for: %s", job_name) @@ -283,7 +283,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor): pool.submit( job_name, run_upscale_pipeline, - context, + server, params, size, output, @@ -295,7 +295,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor): return jsonify(json_params(output, params, size, upscale=upscale)) -def chain(context: ServerContext, pool: DevicePoolExecutor): +def chain(server: ServerContext, pool: DevicePoolExecutor): logger.debug( "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() ) @@ -311,8 +311,8 @@ def chain(context: ServerContext, pool: DevicePoolExecutor): validate(data, schema) # get defaults from the regular parameters - device, params, size = pipeline_from_request(context) - output = make_output_name(context, "chain", params, size) + device, params, size = pipeline_from_request(server) + output = make_output_name(server, "chain", params, size) job_name = output[0] pipeline = ChainPipeline() @@ -371,7 +371,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor): pool.submit( job_name, pipeline, - context, + server, params, empty_source, output=output[0], @@ -382,7 +382,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor): return jsonify(json_params(output, params, size)) -def blend(context: ServerContext, pool: DevicePoolExecutor): +def blend(server: ServerContext, pool: DevicePoolExecutor): mask_file = request.files.get("mask") if mask_file is None: return error_reply("mask image is required") @@ -402,17 +402,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor): source = valid_image(source, mask.size, mask.size) sources.append(source) - device, params, size = pipeline_from_request(context) + device, params, size = pipeline_from_request(server) upscale = upscale_from_request() - output = make_output_name(context, "upscale", params, size) + output = make_output_name(server, "upscale", params, size) job_name = output[0] logger.info("upscale job queued for: %s", job_name) pool.submit( job_name, run_blend_pipeline, - context, + server, params, size, output, @@ -425,17 +425,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor): return jsonify(json_params(output, params, size, upscale=upscale)) -def txt2txt(context: ServerContext, pool: DevicePoolExecutor): - device, params, size = pipeline_from_request(context) +def txt2txt(server: ServerContext, pool: DevicePoolExecutor): + device, params, size = pipeline_from_request(server) - output = make_output_name(context, "txt2txt", params, size) + output = make_output_name(server, "txt2txt", params, size) job_name = output[0] logger.info("upscale job queued for: %s", job_name) pool.submit( job_name, run_txt2txt_pipeline, - context, + server, params, size, output, @@ -445,7 +445,7 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor): return jsonify(json_params(output, params, size)) -def cancel(context: ServerContext, pool: DevicePoolExecutor): +def cancel(server: ServerContext, pool: DevicePoolExecutor): output_file = request.args.get("output", None) if output_file is None: return error_reply("output name is required") @@ -456,7 +456,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor): return ready_reply(cancelled=cancelled) -def ready(context: ServerContext, pool: DevicePoolExecutor): +def ready(server: ServerContext, pool: DevicePoolExecutor): output_file = request.args.get("output", None) if output_file is None: return error_reply("output name is required") @@ -468,7 +468,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor): return ready_reply(pending=True) if progress is None: - output = base_join(context.output_path, output_file) + output = base_join(server.output_path, output_file) if path.exists(output): return ready_reply(ready=True) else: @@ -485,44 +485,44 @@ def ready(context: ServerContext, pool: DevicePoolExecutor): ) -def status(context: ServerContext, pool: DevicePoolExecutor): +def status(server: ServerContext, pool: DevicePoolExecutor): return jsonify(pool.status()) -def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): +def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor): return [ - app.route("/api")(wrap_route(introspect, context, app=app)), - app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), - app.route("/api/settings/models")(wrap_route(list_models, context)), - app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), - app.route("/api/settings/params")(wrap_route(list_params, context)), - app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), - app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), - app.route("/api/settings/strings")(wrap_route(list_extra_strings, context)), + app.route("/api")(wrap_route(introspect, server, app=app)), + app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)), + app.route("/api/settings/models")(wrap_route(list_models, server)), + app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)), + app.route("/api/settings/params")(wrap_route(list_params, server)), + app.route("/api/settings/platforms")(wrap_route(list_platforms, server)), + app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)), + app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)), app.route("/api/img2img", methods=["POST"])( - wrap_route(img2img, context, pool=pool) + wrap_route(img2img, server, pool=pool) ), app.route("/api/txt2img", methods=["POST"])( - wrap_route(txt2img, context, pool=pool) + wrap_route(txt2img, server, pool=pool) ), app.route("/api/txt2txt", methods=["POST"])( - wrap_route(txt2txt, context, pool=pool) + wrap_route(txt2txt, server, pool=pool) ), app.route("/api/inpaint", methods=["POST"])( - wrap_route(inpaint, context, pool=pool) + wrap_route(inpaint, server, pool=pool) ), app.route("/api/upscale", methods=["POST"])( - wrap_route(upscale, context, pool=pool) + wrap_route(upscale, server, pool=pool) ), app.route("/api/chain", methods=["POST"])( - wrap_route(chain, context, pool=pool) + wrap_route(chain, server, pool=pool) ), app.route("/api/blend", methods=["POST"])( - wrap_route(blend, context, pool=pool) + wrap_route(blend, server, pool=pool) ), app.route("/api/cancel", methods=["PUT"])( - wrap_route(cancel, context, pool=pool) + wrap_route(cancel, server, pool=pool) ), - app.route("/api/ready")(wrap_route(ready, context, pool=pool)), - app.route("/api/status")(wrap_route(status, context, pool=pool)), + app.route("/api/ready")(wrap_route(ready, server, pool=pool)), + app.route("/api/status")(wrap_route(status, server, pool=pool)), ] diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index 7847ccf4..f0727815 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -118,13 +118,13 @@ def patch_not_impl(): raise NotImplementedError() -def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: +def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str: cache_path = cache_path_map.get(url, None) if cache_path is None: parsed = urlparse(url) cache_path = path.basename(parsed.path) - cache_path = path.join(ctx.cache_path, cache_path) + cache_path = path.join(server.cache_path, cache_path) logger.debug("patching download path: %s -> %s", url, cache_path) if path.exists(cache_path): @@ -133,27 +133,27 @@ def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: raise FileNotFoundError("missing cache file: %s" % (cache_path)) -def apply_patch_basicsr(ctx: ServerContext): +def apply_patch_basicsr(server: ServerContext): logger.debug("patching BasicSR module") basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl - basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx) + basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server) -def apply_patch_codeformer(ctx: ServerContext): +def apply_patch_codeformer(server: ServerContext): logger.debug("patching CodeFormer module") codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl - codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx) + codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server) -def apply_patch_facexlib(ctx: ServerContext): +def apply_patch_facexlib(server: ServerContext): logger.debug("patching Facexlib module") - facexlib.utils.load_file_from_url = partial(patch_cache_path, ctx) + facexlib.utils.load_file_from_url = partial(patch_cache_path, server) -def apply_patches(ctx: ServerContext): - apply_patch_basicsr(ctx) - apply_patch_codeformer(ctx) - apply_patch_facexlib(ctx) +def apply_patches(server: ServerContext): + apply_patch_basicsr(server) + apply_patch_codeformer(server) + apply_patch_facexlib(server) unload( [ "basicsr.utils.download_util", diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 6224a864..442bccdf 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -115,7 +115,7 @@ def get_config_value(key: str, subkey: str = "default", default=None): return config_params.get(key, {}).get(subkey, default) -def load_extras(context: ServerContext): +def load_extras(server: ServerContext): """ Load the extras file(s) and collect the relevant parts for the server: labels and strings """ @@ -127,7 +127,7 @@ def load_extras(context: ServerContext): with open("./schemas/extras.yaml", "r") as f: extra_schema = safe_load(f.read()) - for file in context.extra_models: + for file in server.extra_models: if file is not None and file != "": logger.info("loading extra models from %s", file) try: @@ -209,11 +209,11 @@ IGNORE_EXTENSIONS = [".crdownload", ".lock", ".tmp"] def list_model_globs( - context: ServerContext, globs: List[str], base_path: Optional[str] = None + server: ServerContext, globs: List[str], base_path: Optional[str] = None ) -> List[str]: models = [] for pattern in globs: - pattern_path = path.join(base_path or context.model_path, pattern) + pattern_path = path.join(base_path or server.model_path, pattern) logger.debug("loading models from %s", pattern_path) for name in glob(pattern_path): base = path.basename(name) @@ -226,7 +226,7 @@ def list_model_globs( return unique_models -def load_models(context: ServerContext) -> None: +def load_models(server: ServerContext) -> None: global correction_models global diffusion_models global network_models @@ -234,7 +234,7 @@ def load_models(context: ServerContext) -> None: # main categories diffusion_models = list_model_globs( - context, + server, [ "diffusion-*", "stable-diffusion-*", @@ -243,7 +243,7 @@ def load_models(context: ServerContext) -> None: logger.debug("loaded diffusion models from disk: %s", diffusion_models) correction_models = list_model_globs( - context, + server, [ "correction-*", ], @@ -251,7 +251,7 @@ def load_models(context: ServerContext) -> None: logger.debug("loaded correction models from disk: %s", correction_models) upscaling_models = list_model_globs( - context, + server, [ "upscaling-*", ], @@ -260,11 +260,11 @@ def load_models(context: ServerContext) -> None: # additional networks inversion_models = list_model_globs( - context, + server, [ "*", ], - base_path=path.join(context.model_path, "inversion"), + base_path=path.join(server.model_path, "inversion"), ) logger.debug("loaded Textual Inversion models from disk: %s", inversion_models) network_models.extend( @@ -272,35 +272,35 @@ def load_models(context: ServerContext) -> None: ) lora_models = list_model_globs( - context, + server, [ "*", ], - base_path=path.join(context.model_path, "lora"), + base_path=path.join(server.model_path, "lora"), ) logger.debug("loaded LoRA models from disk: %s", lora_models) network_models.extend([NetworkModel(model, "lora") for model in lora_models]) -def load_params(context: ServerContext) -> None: +def load_params(server: ServerContext) -> None: global config_params - params_file = path.join(context.params_path, "params.json") + params_file = path.join(server.params_path, "params.json") logger.debug("loading server parameters from file: %s", params_file) with open(params_file, "r") as f: config_params = yaml.safe_load(f) - if "platform" in config_params and context.default_platform is not None: + if "platform" in config_params and server.default_platform is not None: logger.info( "overriding default platform from environment: %s", - context.default_platform, + server.default_platform, ) config_platform = config_params.get("platform", {}) - config_platform["default"] = context.default_platform + config_platform["default"] = server.default_platform -def load_platforms(context: ServerContext) -> None: +def load_platforms(server: ServerContext) -> None: global available_platforms providers = list(get_available_providers()) @@ -309,7 +309,7 @@ def load_platforms(context: ServerContext) -> None: for potential in platform_providers: if ( platform_providers[potential] in providers - and potential not in context.block_platforms + and potential not in server.block_platforms ): if potential == "cuda": for i in range(torch.cuda.device_count()): @@ -317,16 +317,16 @@ def load_platforms(context: ServerContext) -> None: "device_id": i, } - if context.memory_limit is not None: + if server.memory_limit is not None: options["arena_extend_strategy"] = "kSameAsRequested" - options["gpu_mem_limit"] = context.memory_limit + options["gpu_mem_limit"] = server.memory_limit available_platforms.append( DeviceParams( potential, platform_providers[potential], options, - context.optimizations, + server.optimizations, ) ) else: @@ -335,18 +335,18 @@ def load_platforms(context: ServerContext) -> None: potential, platform_providers[potential], None, - context.optimizations, + server.optimizations, ) ) - if context.any_platform: + if server.any_platform: # the platform should be ignored when the job is scheduled, but set to CPU just in case available_platforms.append( DeviceParams( "any", platform_providers["cpu"], None, - context.optimizations, + server.optimizations, ) ) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index d87d611c..d21de826 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -28,7 +28,7 @@ logger = getLogger(__name__) def pipeline_from_request( - context: ServerContext, + server: ServerContext, ) -> Tuple[DeviceParams, ImageParams, Size]: user = request.remote_addr @@ -44,7 +44,7 @@ def pipeline_from_request( # pipeline stuff lpw = get_not_empty(request.args, "lpw", "false") == "true" model = get_not_empty(request.args, "model", get_config_value("model")) - model_path = get_model_path(context, model) + model_path = get_model_path(server, model) scheduler = get_from_list( request.args, "scheduler", list(pipeline_schedulers.keys()) ) diff --git a/api/onnx_web/server/static.py b/api/onnx_web/server/static.py index 72d46dc9..bda9bcab 100644 --- a/api/onnx_web/server/static.py +++ b/api/onnx_web/server/static.py @@ -7,30 +7,30 @@ from .context import ServerContext from .utils import wrap_route -def serve_bundle_file(context: ServerContext, filename="index.html"): - return send_from_directory(path.join("..", context.bundle_path), filename) +def serve_bundle_file(server: ServerContext, filename="index.html"): + return send_from_directory(path.join("..", server.bundle_path), filename) # non-API routes -def index(context: ServerContext): - return serve_bundle_file(context) +def index(server: ServerContext): + return serve_bundle_file(server) -def index_path(context: ServerContext, filename: str): - return serve_bundle_file(context, filename) +def index_path(server: ServerContext, filename: str): + return serve_bundle_file(server, filename) -def output(context: ServerContext, filename: str): +def output(server: ServerContext, filename: str): return send_from_directory( - path.join("..", context.output_path), filename, as_attachment=False + path.join("..", server.output_path), filename, as_attachment=False ) def register_static_routes( - app: Flask, context: ServerContext, _pool: DevicePoolExecutor + app: Flask, server: ServerContext, _pool: DevicePoolExecutor ): return [ - app.route("/")(wrap_route(index, context)), - app.route("/")(wrap_route(index_path, context)), - app.route("/output/")(wrap_route(output, context)), + app.route("/")(wrap_route(index, server)), + app.route("/")(wrap_route(index_path, server)), + app.route("/output/")(wrap_route(output, server)), ] diff --git a/api/onnx_web/server/utils.py b/api/onnx_web/server/utils.py index 4280b43f..be6dbe78 100644 --- a/api/onnx_web/server/utils.py +++ b/api/onnx_web/server/utils.py @@ -9,26 +9,26 @@ from ..worker.pool import DevicePoolExecutor from .context import ServerContext -def check_paths(context: ServerContext) -> None: - if not path.exists(context.model_path): +def check_paths(server: ServerContext) -> None: + if not path.exists(server.model_path): raise RuntimeError("model path must exist") - if not path.exists(context.output_path): - makedirs(context.output_path) + if not path.exists(server.output_path): + makedirs(server.output_path) -def get_model_path(context: ServerContext, model: str): - return base_join(context.model_path, model) +def get_model_path(server: ServerContext, model: str): + return base_join(server.model_path, model) def register_routes( app: Flask, - context: ServerContext, + server: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]], ): for route, kwargs, method in routes: - app.route(route, **kwargs)(wrap_route(method, context, pool=pool)) + app.route(route, **kwargs)(wrap_route(method, server, pool=pool)) def wrap_route(func, *args, **kwargs): diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index fe329784..4b174891 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -26,51 +26,51 @@ MEMORY_ERRORS = [ ] -def worker_main(context: WorkerContext, server: ServerContext): +def worker_main(worker: WorkerContext, server: ServerContext): apply_patches(server) - setproctitle("onnx-web worker: %s" % (context.device.device)) + setproctitle("onnx-web worker: %s" % (worker.device.device)) logger.trace( "checking in from worker with providers: %s", get_available_providers() ) # make leaking workers easier to recycle - context.progress.cancel_join_thread() + worker.progress.cancel_join_thread() while True: try: - if not context.is_active(): + if not worker.is_active(): logger.warning( "worker %s has been replaced by %s, exiting", getpid(), - context.get_active(), + worker.get_active(), ) exit(EXIT_REPLACED) # wait briefly for the next job - job = context.pending.get(timeout=1.0) - logger.info("worker %s got job: %s", context.device.device, job.name) + job = worker.pending.get(timeout=1.0) + logger.info("worker %s got job: %s", worker.device.device, job.name) # clear flags and save the job name - context.start(job.name) + worker.start(job.name) logger.info("starting job: %s", job.name) # reset progress, which does a final check for cancellation - context.set_progress(0) - job.fn(context, *job.args, **job.kwargs) + worker.set_progress(0) + job.fn(worker, *job.args, **job.kwargs) # confirm completion of the job logger.info("job succeeded: %s", job.name) - context.finish() + worker.finish() except Empty: pass except KeyboardInterrupt: logger.info("worker got keyboard interrupt") - context.fail() + worker.fail() exit(EXIT_INTERRUPT) except ValueError: logger.exception("value error in worker, exiting: %s") - context.fail() + worker.fail() exit(EXIT_ERROR) except Exception as e: e_str = str(e) @@ -78,11 +78,11 @@ def worker_main(context: WorkerContext, server: ServerContext): for e_mem in MEMORY_ERRORS: if e_mem in e_str: logger.error("detected out-of-memory error, exiting: %s", e) - context.fail() + worker.fail() exit(EXIT_MEMORY) # carry on for other errors logger.exception( "unrecognized error while running job", ) - context.fail() + worker.fail() diff --git a/api/tests/test_output.py b/api/tests/test_output.py new file mode 100644 index 00000000..fb3262a9 --- /dev/null +++ b/api/tests/test_output.py @@ -0,0 +1,23 @@ +import unittest + + +class TestHashValue(unittest.TestCase): + def test_hash_value(self): + pass + + +class TestJSONParams(unittest.TestCase): + def test_json_params(self): + pass + + +class TestMakeOutputName(unittest.TestCase): + pass + + +class TestSaveImage(unittest.TestCase): + pass + + +class TestSaveParams(unittest.TestCase): + pass diff --git a/api/tests/test_params.py b/api/tests/test_params.py new file mode 100644 index 00000000..0f84cfab --- /dev/null +++ b/api/tests/test_params.py @@ -0,0 +1,101 @@ +import unittest + +from onnx_web.params import Border, Size + +class BorderTests(unittest.TestCase): + def test_json(self): + border = Border.even(0) + json = border.tojson() + + self.assertIn("left", json) + self.assertIn("right", json) + self.assertIn("top", json) + self.assertIn("bottom", json) + + def test_str(self): + border = Border.even(10) + bstr = str(border) + + self.assertEqual("(10, 10, 10, 10)", bstr) + + def test_uneven(self): + border = Border(1, 2, 3, 4) + + self.assertEqual("(1, 2, 3, 4)", str(border)) + + def test_args(self): + pass + + +class SizeTests(unittest.TestCase): + def test_iter(self): + size = Size(1, 2) + + self.assertEqual(list(size), [1, 2]) + + def test_str(self): + pass + + def test_border(self): + pass + + def test_json(self): + pass + + def test_args(self): + pass + + +class DeviceParamsTests(unittest.TestCase): + def test_str(self): + pass + + def test_provider(self): + pass + + def test_options_optimizations(self): + pass + + def test_options_cache(self): + pass + + def test_torch_cuda(self): + pass + + def test_torch_rocm(self): + pass + + +class ImageParamsTests(unittest.TestCase): + def test_json(self): + pass + + def test_args(self): + pass + + +class StageParamsTests(unittest.TestCase): + def test_init(self): + pass + + +class UpscaleParamsTests(unittest.TestCase): + def test_rescale(self): + pass + + def test_resize(self): + pass + + def test_json(self): + pass + + def test_args(self): + pass + + +class HighresParamsTests(unittest.TestCase): + def test_resize(self): + pass + + def test_json(self): + pass