1
0
Fork 0

lint(api): name context params consistently (#278)

This commit is contained in:
Sean Sube 2023-04-09 20:33:03 -05:00
parent fea9185707
commit 9698e29268
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
19 changed files with 363 additions and 238 deletions

View File

@ -23,7 +23,7 @@ class StageCallback(Protocol):
def __call__( def __call__(
self, self,
job: WorkerContext, job: WorkerContext,
ctx: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, source: Image.Image,

View File

@ -140,7 +140,7 @@ base_models: Models = {
def fetch_model( def fetch_model(
ctx: ConversionContext, conversion: ConversionContext,
name: str, name: str,
source: str, source: str,
dest: Optional[str] = None, dest: Optional[str] = None,
@ -148,7 +148,7 @@ def fetch_model(
hf_hub_fetch: bool = False, hf_hub_fetch: bool = False,
hf_hub_filename: Optional[str] = None, hf_hub_filename: Optional[str] = None,
) -> str: ) -> str:
cache_path = dest or ctx.cache_path cache_path = dest or conversion.cache_path
cache_name = path.join(cache_path, name) cache_name = path.join(cache_path, name)
# add an extension if possible, some of the conversion code checks for it # add an extension if possible, some of the conversion code checks for it
@ -201,7 +201,7 @@ def fetch_model(
return source return source
def convert_models(ctx: ConversionContext, args, models: Models): def convert_models(conversion: ConversionContext, args, models: Models):
if args.sources and "sources" in models: if args.sources and "sources" in models:
for model in models.get("sources"): for model in models.get("sources"):
model = tuple_to_source(model) model = tuple_to_source(model)
@ -214,7 +214,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
source = model["source"] source = model["source"]
try: try:
dest = fetch_model(ctx, name, source, format=model_format) dest = fetch_model(conversion, name, source, format=model_format)
logger.info("finished downloading source: %s -> %s", source, dest) logger.info("finished downloading source: %s -> %s", source, dest)
except Exception: except Exception:
logger.exception("error fetching source %s", name) logger.exception("error fetching source %s", name)
@ -234,20 +234,20 @@ def convert_models(ctx: ConversionContext, args, models: Models):
try: try:
if network_type == "inversion" and network_model == "concept": if network_type == "inversion" and network_model == "concept":
dest = fetch_model( dest = fetch_model(
ctx, conversion,
name, name,
source, source,
dest=path.join(ctx.model_path, network_type), dest=path.join(conversion.model_path, network_type),
format=network_format, format=network_format,
hf_hub_fetch=True, hf_hub_fetch=True,
hf_hub_filename="learned_embeds.bin", hf_hub_filename="learned_embeds.bin",
) )
else: else:
dest = fetch_model( dest = fetch_model(
ctx, conversion,
name, name,
source, source,
dest=path.join(ctx.model_path, network_type), dest=path.join(conversion.model_path, network_type),
format=network_format, format=network_format,
) )
@ -267,19 +267,19 @@ def convert_models(ctx: ConversionContext, args, models: Models):
try: try:
source = fetch_model( source = fetch_model(
ctx, name, model["source"], format=model_format conversion, name, model["source"], format=model_format
) )
converted = False converted = False
if model_format in model_formats_original: if model_format in model_formats_original:
converted, dest = convert_diffusion_original( converted, dest = convert_diffusion_original(
ctx, conversion,
model, model,
source, source,
) )
else: else:
converted, dest = convert_diffusion_diffusers( converted, dest = convert_diffusion_diffusers(
ctx, conversion,
model, model,
source, source,
) )
@ -289,8 +289,8 @@ def convert_models(ctx: ConversionContext, args, models: Models):
# keep track of which models have been blended # keep track of which models have been blended
blend_models = {} blend_models = {}
inversion_dest = path.join(ctx.model_path, "inversion") inversion_dest = path.join(conversion.model_path, "inversion")
lora_dest = path.join(ctx.model_path, "lora") lora_dest = path.join(conversion.model_path, "lora")
for inversion in model.get("inversions", []): for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models: if "text_encoder" not in blend_models:
@ -314,7 +314,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
inversion_source = inversion["source"] inversion_source = inversion["source"]
inversion_format = inversion.get("format", None) inversion_format = inversion.get("format", None)
inversion_source = fetch_model( inversion_source = fetch_model(
ctx, conversion,
inversion_name, inversion_name,
inversion_source, inversion_source,
dest=inversion_dest, dest=inversion_dest,
@ -323,7 +323,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
inversion_weight = inversion.get("weight", 1.0) inversion_weight = inversion.get("weight", 1.0)
blend_textual_inversions( blend_textual_inversions(
ctx, conversion,
blend_models["text_encoder"], blend_models["text_encoder"],
blend_models["tokenizer"], blend_models["tokenizer"],
[ [
@ -355,7 +355,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
lora_name = lora["name"] lora_name = lora["name"]
lora_source = lora["source"] lora_source = lora["source"]
lora_source = fetch_model( lora_source = fetch_model(
ctx, conversion,
f"{name}-lora-{lora_name}", f"{name}-lora-{lora_name}",
lora_source, lora_source,
dest=lora_dest, dest=lora_dest,
@ -363,14 +363,14 @@ def convert_models(ctx: ConversionContext, args, models: Models):
lora_weight = lora.get("weight", 1.0) lora_weight = lora.get("weight", 1.0)
blend_loras( blend_loras(
ctx, conversion,
blend_models["text_encoder"], blend_models["text_encoder"],
[(lora_source, lora_weight)], [(lora_source, lora_weight)],
"text_encoder", "text_encoder",
) )
blend_loras( blend_loras(
ctx, conversion,
blend_models["unet"], blend_models["unet"],
[(lora_source, lora_weight)], [(lora_source, lora_weight)],
"unet", "unet",
@ -413,9 +413,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
try: try:
source = fetch_model( source = fetch_model(
ctx, name, model["source"], format=model_format conversion, name, model["source"], format=model_format
) )
convert_upscale_resrgan(ctx, model, source) convert_upscale_resrgan(conversion, model, source)
except Exception: except Exception:
logger.exception( logger.exception(
"error converting upscaling model %s", "error converting upscaling model %s",
@ -433,9 +433,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
model_format = source_format(model) model_format = source_format(model)
try: try:
source = fetch_model( source = fetch_model(
ctx, name, model["source"], format=model_format conversion, name, model["source"], format=model_format
) )
convert_correction_gfpgan(ctx, model, source) convert_correction_gfpgan(conversion, model, source)
except Exception: except Exception:
logger.exception( logger.exception(
"error converting correction model %s", "error converting correction model %s",
@ -482,21 +482,21 @@ def main() -> int:
args = parser.parse_args() args = parser.parse_args()
logger.info("CLI arguments: %s", args) logger.info("CLI arguments: %s", args)
ctx = ConversionContext.from_environ() server = ConversionContext.from_environ()
ctx.half = args.half or "onnx-fp16" in ctx.optimizations server.half = args.half or "onnx-fp16" in server.optimizations
ctx.opset = args.opset server.opset = args.opset
ctx.token = args.token server.token = args.token
logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device) logger.info("converting models in %s using %s", server.model_path, server.training_device)
if not path.exists(ctx.model_path): if not path.exists(server.model_path):
logger.info("model path does not existing, creating: %s", ctx.model_path) logger.info("model path does not existing, creating: %s", server.model_path)
makedirs(ctx.model_path) makedirs(server.model_path)
logger.info("converting base models") logger.info("converting base models")
convert_models(ctx, args, base_models) convert_models(server, args, base_models)
extras = [] extras = []
extras.extend(ctx.extra_models) extras.extend(server.extra_models)
extras.extend(args.extras) extras.extend(args.extras)
extras = list(set(extras)) extras = list(set(extras))
extras.sort() extras.sort()
@ -516,7 +516,7 @@ def main() -> int:
try: try:
validate(data, extra_schema) validate(data, extra_schema)
logger.info("converting extra models") logger.info("converting extra models")
convert_models(ctx, args, data) convert_models(server, args, data)
except ValidationError: except ValidationError:
logger.exception("invalid data in extras file") logger.exception("invalid data in extras file")
except Exception: except Exception:

View File

@ -12,7 +12,7 @@ logger = getLogger(__name__)
@torch.no_grad() @torch.no_grad()
def convert_correction_gfpgan( def convert_correction_gfpgan(
ctx: ConversionContext, conversion: ConversionContext,
model: ModelDict, model: ModelDict,
source: str, source: str,
): ):
@ -20,7 +20,7 @@ def convert_correction_gfpgan(
source = source or model.get("source") source = source or model.get("source")
scale = model.get("scale") scale = model.get("scale")
dest = path.join(ctx.model_path, name + ".onnx") dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest) logger.info("converting GFPGAN model: %s -> %s", name, dest)
if path.isfile(dest): if path.isfile(dest):
@ -37,17 +37,17 @@ def convert_correction_gfpgan(
scale=scale, scale=scale,
) )
torch_model = torch.load(source, map_location=ctx.map_location) torch_model = torch.load(source, map_location=conversion.map_location)
# TODO: make sure strict=False is safe here # TODO: make sure strict=False is safe here
if "params_ema" in torch_model: if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False) model.load_state_dict(torch_model["params_ema"], strict=False)
else: else:
model.load_state_dict(torch_model["params"], strict=False) model.load_state_dict(torch_model["params"], strict=False)
model.to(ctx.training_device).train(False) model.to(conversion.training_device).train(False)
model.eval() model.eval()
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location) rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["data"] input_names = ["data"]
output_names = ["output"] output_names = ["output"]
dynamic_axes = { dynamic_axes = {
@ -63,7 +63,7 @@ def convert_correction_gfpgan(
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
opset_version=ctx.opset, opset_version=conversion.opset,
export_params=True, export_params=True,
) )
logger.info("GFPGAN exported to ONNX successfully") logger.info("GFPGAN exported to ONNX successfully")

View File

@ -90,7 +90,7 @@ def onnx_export(
@torch.no_grad() @torch.no_grad()
def convert_diffusion_diffusers( def convert_diffusion_diffusers(
ctx: ConversionContext, conversion: ConversionContext,
model: Dict, model: Dict,
source: str, source: str,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
@ -102,10 +102,10 @@ def convert_diffusion_diffusers(
single_vae = model.get("single_vae") single_vae = model.get("single_vae")
replace_vae = model.get("vae") replace_vae = model.get("vae")
dtype = ctx.torch_dtype() dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype) logger.debug("using Torch dtype %s for pipeline", dtype)
dest_path = path.join(ctx.model_path, name) dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json") model_index = path.join(dest_path, "model_index.json")
# diffusers go into a directory rather than .onnx file # diffusers go into a directory rather than .onnx file
@ -123,11 +123,11 @@ def convert_diffusion_diffusers(
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
source, source,
torch_dtype=dtype, torch_dtype=dtype,
use_auth_token=ctx.token, use_auth_token=conversion.token,
).to(ctx.training_device) ).to(conversion.training_device)
output_path = Path(dest_path) output_path = Path(dest_path)
optimize_pipeline(ctx, pipeline) optimize_pipeline(conversion, pipeline)
# TEXT ENCODER # TEXT ENCODER
num_tokens = pipeline.text_encoder.config.max_position_embeddings num_tokens = pipeline.text_encoder.config.max_position_embeddings
@ -143,11 +143,11 @@ def convert_diffusion_diffusers(
pipeline.text_encoder, pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=( model_args=(
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32), text_input.input_ids.to(device=conversion.training_device, dtype=torch.int32),
None, # attention mask None, # attention mask
None, # position ids None, # position ids
None, # output attentions None, # output attentions
torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool), torch.tensor(True).to(device=conversion.training_device, dtype=torch.bool),
), ),
output_path=output_path / "text_encoder" / ONNX_MODEL, output_path=output_path / "text_encoder" / ONNX_MODEL,
ordered_input_names=["input_ids"], ordered_input_names=["input_ids"],
@ -155,8 +155,8 @@ def convert_diffusion_diffusers(
dynamic_axes={ dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"}, "input_ids": {0: "batch", 1: "sequence"},
}, },
opset=ctx.opset, opset=conversion.opset,
half=ctx.half, half=conversion.half,
) )
del pipeline.text_encoder del pipeline.text_encoder
@ -165,11 +165,11 @@ def convert_diffusion_diffusers(
# UNET # UNET
if single_vae: if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"] unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.long) unet_scale = torch.tensor(4).to(device=conversion.training_device, dtype=torch.long)
else: else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to( unet_scale = torch.tensor(False).to(
device=ctx.training_device, dtype=torch.bool device=conversion.training_device, dtype=torch.bool
) )
if is_torch_2_0: if is_torch_2_0:
@ -182,11 +182,11 @@ def convert_diffusion_diffusers(
pipeline.unet, pipeline.unet,
model_args=( model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to( torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=ctx.training_device, dtype=dtype device=conversion.training_device, dtype=dtype
), ),
torch.randn(2).to(device=ctx.training_device, dtype=dtype), torch.randn(2).to(device=conversion.training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to( torch.randn(2, num_tokens, text_hidden_size).to(
device=ctx.training_device, dtype=dtype device=conversion.training_device, dtype=dtype
), ),
unet_scale, unet_scale,
), ),
@ -199,8 +199,8 @@ def convert_diffusion_diffusers(
"timestep": {0: "batch"}, "timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"}, "encoder_hidden_states": {0: "batch", 1: "sequence"},
}, },
opset=ctx.opset, opset=conversion.opset,
half=ctx.half, half=conversion.half,
external_data=True, external_data=True,
) )
unet_model_path = str(unet_path.absolute().as_posix()) unet_model_path = str(unet_path.absolute().as_posix())
@ -238,7 +238,7 @@ def convert_diffusion_diffusers(
model_args=( model_args=(
torch.randn( torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size 1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=ctx.training_device, dtype=dtype), ).to(device=conversion.training_device, dtype=dtype),
False, False,
), ),
output_path=output_path / "vae" / ONNX_MODEL, output_path=output_path / "vae" / ONNX_MODEL,
@ -247,8 +247,8 @@ def convert_diffusion_diffusers(
dynamic_axes={ dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
}, },
opset=ctx.opset, opset=conversion.opset,
half=ctx.half, half=conversion.half,
) )
else: else:
# VAE ENCODER # VAE ENCODER
@ -263,7 +263,7 @@ def convert_diffusion_diffusers(
vae_encoder, vae_encoder,
model_args=( model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=ctx.training_device, dtype=dtype device=conversion.training_device, dtype=dtype
), ),
False, False,
), ),
@ -273,7 +273,7 @@ def convert_diffusion_diffusers(
dynamic_axes={ dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
}, },
opset=ctx.opset, opset=conversion.opset,
half=False, # https://github.com/ssube/onnx-web/issues/290 half=False, # https://github.com/ssube/onnx-web/issues/290
) )
@ -287,7 +287,7 @@ def convert_diffusion_diffusers(
model_args=( model_args=(
torch.randn( torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size 1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=ctx.training_device, dtype=dtype), ).to(device=conversion.training_device, dtype=dtype),
False, False,
), ),
output_path=output_path / "vae_decoder" / ONNX_MODEL, output_path=output_path / "vae_decoder" / ONNX_MODEL,
@ -296,8 +296,8 @@ def convert_diffusion_diffusers(
dynamic_axes={ dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
}, },
opset=ctx.opset, opset=conversion.opset,
half=ctx.half, half=conversion.half,
) )
del pipeline.vae del pipeline.vae

View File

@ -55,7 +55,7 @@ def fix_node_name(key: str):
def blend_loras( def blend_loras(
_context: ServerContext, _conversion: ServerContext,
base_name: Union[str, ModelProto], base_name: Union[str, ModelProto],
loras: List[Tuple[str, float]], loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"], model_type: Literal["text_encoder", "unet"],

View File

@ -146,7 +146,7 @@ class TrainingConfig:
def __init__( def __init__(
self, self,
ctx: ConversionContext, conversion: ConversionContext,
model_name: str = "", model_name: str = "",
scheduler: str = "ddim", scheduler: str = "ddim",
v2: bool = False, v2: bool = False,
@ -155,7 +155,7 @@ class TrainingConfig:
**kwargs, **kwargs,
): ):
model_name = sanitize_name(model_name) model_name = sanitize_name(model_name)
model_dir = os.path.join(ctx.cache_path, model_name) model_dir = os.path.join(conversion.cache_path, model_name)
working_dir = os.path.join(model_dir, "working") working_dir = os.path.join(model_dir, "working")
if not os.path.exists(working_dir): if not os.path.exists(working_dir):
@ -1298,7 +1298,7 @@ def download_model(db_config: TrainingConfig, token):
def get_config_path( def get_config_path(
context: ConversionContext, conversion: ConversionContext,
model_version: str = "v1", model_version: str = "v1",
train_type: str = "default", train_type: str = "default",
config_base_name: str = "training", config_base_name: str = "training",
@ -1309,7 +1309,7 @@ def get_config_path(
) )
parts = os.path.join( parts = os.path.join(
context.model_path, conversion.model_path,
"configs", "configs",
f"{model_version}-{config_base_name}-{train_type}.yaml", f"{model_version}-{config_base_name}-{train_type}.yaml",
) )
@ -1317,7 +1317,7 @@ def get_config_path(
def get_config_file( def get_config_file(
context: ConversionContext, conversion: ConversionContext,
train_unfrozen=False, train_unfrozen=False,
v2=False, v2=False,
prediction_type="epsilon", prediction_type="epsilon",
@ -1343,7 +1343,7 @@ def get_config_file(
model_train_type = train_types["default"] model_train_type = train_types["default"]
return get_config_path( return get_config_path(
context, conversion,
model_version_name, model_version_name,
model_train_type, model_train_type,
config_base_name, config_base_name,
@ -1352,7 +1352,7 @@ def get_config_file(
def extract_checkpoint( def extract_checkpoint(
context: ConversionContext, conversion: ConversionContext,
new_model_name: str, new_model_name: str,
checkpoint_file: str, checkpoint_file: str,
scheduler_type="ddim", scheduler_type="ddim",
@ -1396,7 +1396,7 @@ def extract_checkpoint(
# Create empty config # Create empty config
db_config = TrainingConfig( db_config = TrainingConfig(
context, conversion,
model_name=new_model_name, model_name=new_model_name,
scheduler=scheduler_type, scheduler=scheduler_type,
src=checkpoint_file, src=checkpoint_file,
@ -1442,7 +1442,7 @@ def extract_checkpoint(
prediction_type = "epsilon" prediction_type = "epsilon"
original_config_file = get_config_file( original_config_file = get_config_file(
context, train_unfrozen, v2, prediction_type, config_file=config_file conversion, train_unfrozen, v2, prediction_type, config_file=config_file
) )
logger.info( logger.info(
@ -1533,7 +1533,7 @@ def extract_checkpoint(
checkpoint, vae_config checkpoint, vae_config
) )
else: else:
vae_file = os.path.join(context.model_path, vae_file) vae_file = os.path.join(conversion.model_path, vae_file)
logger.debug("loading custom VAE: %s", vae_file) logger.debug("loading custom VAE: %s", vae_file)
vae_checkpoint = load_tensor(vae_file, map_location=map_location) vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint( converted_vae_checkpoint = convert_ldm_vae_checkpoint(
@ -1658,14 +1658,14 @@ def extract_checkpoint(
@torch.no_grad() @torch.no_grad()
def convert_diffusion_original( def convert_diffusion_original(
ctx: ConversionContext, conversion: ConversionContext,
model: ModelDict, model: ModelDict,
source: str, source: str,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
name = model["name"] name = model["name"]
source = source or model["source"] source = source or model["source"]
dest_path = os.path.join(ctx.model_path, name) dest_path = os.path.join(conversion.model_path, name)
dest_index = os.path.join(dest_path, "model_index.json") dest_index = os.path.join(dest_path, "model_index.json")
logger.info( logger.info(
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest_path "converting original Diffusers checkpoint %s: %s -> %s", name, source, dest_path
@ -1676,8 +1676,8 @@ def convert_diffusion_original(
return (False, dest_path) return (False, dest_path)
torch_name = name + "-torch" torch_name = name + "-torch"
torch_path = os.path.join(ctx.cache_path, torch_name) torch_path = os.path.join(conversion.cache_path, torch_name)
working_name = os.path.join(ctx.cache_path, torch_name, "working") working_name = os.path.join(conversion.cache_path, torch_name, "working")
model_index = os.path.join(working_name, "model_index.json") model_index = os.path.join(working_name, "model_index.json")
if os.path.exists(torch_path) and os.path.exists(model_index): if os.path.exists(torch_path) and os.path.exists(model_index):
@ -1689,7 +1689,7 @@ def convert_diffusion_original(
torch_path, torch_path,
) )
if extract_checkpoint( if extract_checkpoint(
ctx, conversion,
torch_name, torch_name,
source, source,
config_file=model.get("config"), config_file=model.get("config"),
@ -1704,9 +1704,9 @@ def convert_diffusion_original(
if "vae" in model: if "vae" in model:
del model["vae"] del model["vae"]
result = convert_diffusion_diffusers(ctx, model, working_name) result = convert_diffusion_diffusers(conversion, model, working_name)
if "torch" in ctx.prune: if "torch" in conversion.prune:
logger.info("removing intermediate Torch models: %s", torch_path) logger.info("removing intermediate Torch models: %s", torch_path)
shutil.rmtree(torch_path) shutil.rmtree(torch_path)

View File

@ -16,7 +16,7 @@ logger = getLogger(__name__)
@torch.no_grad() @torch.no_grad()
def blend_textual_inversions( def blend_textual_inversions(
context: ServerContext, server: ServerContext,
text_encoder: ModelProto, text_encoder: ModelProto,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
inversions: List[Tuple[str, float, Optional[str], Optional[str]]], inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
@ -161,7 +161,7 @@ def blend_textual_inversions(
@torch.no_grad() @torch.no_grad()
def convert_diffusion_textual_inversion( def convert_diffusion_textual_inversion(
context: ConversionContext, conversion: ConversionContext,
name: str, name: str,
base_model: str, base_model: str,
inversion: str, inversion: str,
@ -169,7 +169,7 @@ def convert_diffusion_textual_inversion(
base_token: Optional[str] = None, base_token: Optional[str] = None,
inversion_weight: Optional[float] = 1.0, inversion_weight: Optional[float] = 1.0,
): ):
dest_path = path.join(context.model_path, f"inversion-{name}") dest_path = path.join(conversion.model_path, f"inversion-{name}")
logger.info( logger.info(
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path "converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
) )
@ -194,7 +194,7 @@ def convert_diffusion_textual_inversion(
subfolder="tokenizer", subfolder="tokenizer",
) )
text_encoder, tokenizer = blend_textual_inversions( text_encoder, tokenizer = blend_textual_inversions(
context, conversion,
text_encoder, text_encoder,
tokenizer, tokenizer,
[(inversion, inversion_weight, base_token, inversion_format)], [(inversion, inversion_weight, base_token, inversion_format)],

View File

@ -13,7 +13,7 @@ TAG_X4_V3 = "real-esrgan-x4-v3"
@torch.no_grad() @torch.no_grad()
def convert_upscale_resrgan( def convert_upscale_resrgan(
ctx: ConversionContext, conversion: ConversionContext,
model: ModelDict, model: ModelDict,
source: str, source: str,
): ):
@ -24,7 +24,7 @@ def convert_upscale_resrgan(
source = source or model.get("source") source = source or model.get("source")
scale = model.get("scale") scale = model.get("scale")
dest = path.join(ctx.model_path, name + ".onnx") dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest) logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
if path.isfile(dest): if path.isfile(dest):
@ -53,16 +53,16 @@ def convert_upscale_resrgan(
scale=scale, scale=scale,
) )
torch_model = torch.load(source, map_location=ctx.map_location) torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model: if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"]) model.load_state_dict(torch_model["params_ema"])
else: else:
model.load_state_dict(torch_model["params"], strict=False) model.load_state_dict(torch_model["params"], strict=False)
model.to(ctx.training_device).train(False) model.to(conversion.training_device).train(False)
model.eval() model.eval()
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location) rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["data"] input_names = ["data"]
output_names = ["output"] output_names = ["output"]
dynamic_axes = { dynamic_axes = {
@ -78,7 +78,7 @@ def convert_upscale_resrgan(
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
opset_version=ctx.opset, opset_version=conversion.opset,
export_params=True, export_params=True,
) )
logger.info("real ESRGAN exported to ONNX successfully") logger.info("real ESRGAN exported to ONNX successfully")

View File

@ -64,7 +64,7 @@ def json_params(
def make_output_name( def make_output_name(
ctx: ServerContext, server: ServerContext,
mode: str, mode: str,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
@ -92,20 +92,20 @@ def make_output_name(
hash_value(sha, param) hash_value(sha, param)
return [ return [
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{ctx.image_format}" f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
for i in range(params.batch) for i in range(params.batch)
] ]
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str: def save_image(server: ServerContext, output: str, image: Image.Image) -> str:
path = base_join(ctx.output_path, output) path = base_join(server.output_path, output)
image.save(path, format=ctx.image_format) image.save(path, format=server.image_format)
logger.debug("saved output image to: %s", path) logger.debug("saved output image to: %s", path)
return path return path
def save_params( def save_params(
ctx: ServerContext, server: ServerContext,
output: str, output: str,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
@ -113,7 +113,7 @@ def save_params(
border: Optional[Border] = None, border: Optional[Border] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
) -> str: ) -> str:
path = base_join(ctx.output_path, f"{output}.json") path = base_join(server.output_path, f"{output}.json")
json = json_params( json = json_params(
output, params, size, upscale=upscale, border=border, highres=highres output, params, size, upscale=upscale, border=border, highres=highres
) )

View File

@ -7,6 +7,10 @@ from .torch_before_ort import GraphOptimizationLevel, SessionOptions
logger = getLogger(__name__) logger = getLogger(__name__)
Param = Union[str, int, float]
Point = Tuple[int, int]
class SizeChart(IntEnum): class SizeChart(IntEnum):
mini = 128 # small tile for very expensive models mini = 128 # small tile for very expensive models
half = 256 # half tile for outpainting half = 256 # half tile for outpainting
@ -25,10 +29,6 @@ class TileOrder:
spiral = "spiral" spiral = "spiral"
Param = Union[str, int, float]
Point = Tuple[int, int]
class Border: class Border:
def __init__(self, left: int, right: int, top: int, bottom: int) -> None: def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
self.left = left self.left = left
@ -37,7 +37,7 @@ class Border:
self.bottom = bottom self.bottom = bottom
def __str__(self) -> str: def __str__(self) -> str:
return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom) return "(%s, %s, %s, %s)" % (self.left, self.top, self.right, self.bottom)
def tojson(self): def tojson(self):
return { return {
@ -145,6 +145,7 @@ class DeviceParams:
return sess return sess
def torch_str(self) -> str: def torch_str(self) -> str:
# TODO: return cuda devices for ROCm as well
if self.device.startswith("cuda"): if self.device.startswith("cuda"):
return self.device return self.device
else: else:

View File

@ -93,7 +93,7 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options) return url_for(rule.endpoint, **options)
def introspect(context: ServerContext, app: Flask): def introspect(server: ServerContext, app: Flask):
return { return {
"name": "onnx-web", "name": "onnx-web",
"routes": [ "routes": [
@ -103,15 +103,15 @@ def introspect(context: ServerContext, app: Flask):
} }
def list_extra_strings(context: ServerContext): def list_extra_strings(server: ServerContext):
return jsonify(get_extra_strings()) return jsonify(get_extra_strings())
def list_mask_filters(context: ServerContext): def list_mask_filters(server: ServerContext):
return jsonify(list(get_mask_filters().keys())) return jsonify(list(get_mask_filters().keys()))
def list_models(context: ServerContext): def list_models(server: ServerContext):
return jsonify( return jsonify(
{ {
"correction": get_correction_models(), "correction": get_correction_models(),
@ -122,30 +122,30 @@ def list_models(context: ServerContext):
) )
def list_noise_sources(context: ServerContext): def list_noise_sources(server: ServerContext):
return jsonify(list(get_noise_sources().keys())) return jsonify(list(get_noise_sources().keys()))
def list_params(context: ServerContext): def list_params(server: ServerContext):
return jsonify(get_config_params()) return jsonify(get_config_params())
def list_platforms(context: ServerContext): def list_platforms(server: ServerContext):
return jsonify([p.device for p in get_available_platforms()]) return jsonify([p.device for p in get_available_platforms()])
def list_schedulers(context: ServerContext): def list_schedulers(server: ServerContext):
return jsonify(list(get_pipeline_schedulers().keys())) return jsonify(list(get_pipeline_schedulers().keys()))
def img2img(context: ServerContext, pool: DevicePoolExecutor): def img2img(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source") source_file = request.files.get("source")
if source_file is None: if source_file is None:
return error_reply("source image is required") return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(server)
upscale = upscale_from_request() upscale = upscale_from_request()
strength = get_and_clamp_float( strength = get_and_clamp_float(
@ -156,7 +156,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
get_config_value("strength", "min"), get_config_value("strength", "min"),
) )
output = make_output_name(context, "img2img", params, size, extras=[strength]) output = make_output_name(server, "img2img", params, size, extras=[strength])
job_name = output[0] job_name = output[0]
logger.info("img2img job queued for: %s", job_name) logger.info("img2img job queued for: %s", job_name)
@ -164,7 +164,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
pool.submit( pool.submit(
job_name, job_name,
run_img2img_pipeline, run_img2img_pipeline,
context, server,
params, params,
output, output,
upscale, upscale,
@ -176,19 +176,19 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
def txt2img(context: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(server)
upscale = upscale_from_request() upscale = upscale_from_request()
highres = highres_from_request() highres = highres_from_request()
output = make_output_name(context, "txt2img", params, size) output = make_output_name(server, "txt2img", params, size)
job_name = output[0] job_name = output[0]
logger.info("txt2img job queued for: %s", job_name) logger.info("txt2img job queued for: %s", job_name)
pool.submit( pool.submit(
job_name, job_name,
run_txt2img_pipeline, run_txt2img_pipeline,
context, server,
params, params,
size, size,
output, output,
@ -200,7 +200,7 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
def inpaint(context: ServerContext, pool: DevicePoolExecutor): def inpaint(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source") source_file = request.files.get("source")
if source_file is None: if source_file is None:
return error_reply("source image is required") return error_reply("source image is required")
@ -212,7 +212,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB") mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(server)
expand = border_from_request() expand = border_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()
@ -224,7 +224,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
) )
output = make_output_name( output = make_output_name(
context, server,
"inpaint", "inpaint",
params, params,
size, size,
@ -247,7 +247,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
pool.submit( pool.submit(
job_name, job_name,
run_inpaint_pipeline, run_inpaint_pipeline,
context, server,
params, params,
size, size,
output, output,
@ -265,17 +265,17 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
def upscale(context: ServerContext, pool: DevicePoolExecutor): def upscale(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source") source_file = request.files.get("source")
if source_file is None: if source_file is None:
return error_reply("source image is required") return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(server)
upscale = upscale_from_request() upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size) output = make_output_name(server, "upscale", params, size)
job_name = output[0] job_name = output[0]
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
@ -283,7 +283,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor):
pool.submit( pool.submit(
job_name, job_name,
run_upscale_pipeline, run_upscale_pipeline,
context, server,
params, params,
size, size,
output, output,
@ -295,7 +295,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
def chain(context: ServerContext, pool: DevicePoolExecutor): def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.debug( logger.debug(
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys() "chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
) )
@ -311,8 +311,8 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
validate(data, schema) validate(data, schema)
# get defaults from the regular parameters # get defaults from the regular parameters
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(server)
output = make_output_name(context, "chain", params, size) output = make_output_name(server, "chain", params, size)
job_name = output[0] job_name = output[0]
pipeline = ChainPipeline() pipeline = ChainPipeline()
@ -371,7 +371,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
pool.submit( pool.submit(
job_name, job_name,
pipeline, pipeline,
context, server,
params, params,
empty_source, empty_source,
output=output[0], output=output[0],
@ -382,7 +382,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size)) return jsonify(json_params(output, params, size))
def blend(context: ServerContext, pool: DevicePoolExecutor): def blend(server: ServerContext, pool: DevicePoolExecutor):
mask_file = request.files.get("mask") mask_file = request.files.get("mask")
if mask_file is None: if mask_file is None:
return error_reply("mask image is required") return error_reply("mask image is required")
@ -402,17 +402,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
source = valid_image(source, mask.size, mask.size) source = valid_image(source, mask.size, mask.size)
sources.append(source) sources.append(source)
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(server)
upscale = upscale_from_request() upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size) output = make_output_name(server, "upscale", params, size)
job_name = output[0] job_name = output[0]
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
pool.submit( pool.submit(
job_name, job_name,
run_blend_pipeline, run_blend_pipeline,
context, server,
params, params,
size, size,
output, output,
@ -425,17 +425,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
def txt2txt(context: ServerContext, pool: DevicePoolExecutor): def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(server)
output = make_output_name(context, "txt2txt", params, size) output = make_output_name(server, "txt2txt", params, size)
job_name = output[0] job_name = output[0]
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
pool.submit( pool.submit(
job_name, job_name,
run_txt2txt_pipeline, run_txt2txt_pipeline,
context, server,
params, params,
size, size,
output, output,
@ -445,7 +445,7 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size)) return jsonify(json_params(output, params, size))
def cancel(context: ServerContext, pool: DevicePoolExecutor): def cancel(server: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None) output_file = request.args.get("output", None)
if output_file is None: if output_file is None:
return error_reply("output name is required") return error_reply("output name is required")
@ -456,7 +456,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
return ready_reply(cancelled=cancelled) return ready_reply(cancelled=cancelled)
def ready(context: ServerContext, pool: DevicePoolExecutor): def ready(server: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None) output_file = request.args.get("output", None)
if output_file is None: if output_file is None:
return error_reply("output name is required") return error_reply("output name is required")
@ -468,7 +468,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
return ready_reply(pending=True) return ready_reply(pending=True)
if progress is None: if progress is None:
output = base_join(context.output_path, output_file) output = base_join(server.output_path, output_file)
if path.exists(output): if path.exists(output):
return ready_reply(ready=True) return ready_reply(ready=True)
else: else:
@ -485,44 +485,44 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
) )
def status(context: ServerContext, pool: DevicePoolExecutor): def status(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(pool.status()) return jsonify(pool.status())
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [ return [
app.route("/api")(wrap_route(introspect, context, app=app)), app.route("/api")(wrap_route(introspect, server, app=app)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
app.route("/api/settings/models")(wrap_route(list_models, context)), app.route("/api/settings/models")(wrap_route(list_models, server)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)),
app.route("/api/settings/params")(wrap_route(list_params, context)), app.route("/api/settings/params")(wrap_route(list_params, server)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), app.route("/api/settings/platforms")(wrap_route(list_platforms, server)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
app.route("/api/settings/strings")(wrap_route(list_extra_strings, context)), app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
app.route("/api/img2img", methods=["POST"])( app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, context, pool=pool) wrap_route(img2img, server, pool=pool)
), ),
app.route("/api/txt2img", methods=["POST"])( app.route("/api/txt2img", methods=["POST"])(
wrap_route(txt2img, context, pool=pool) wrap_route(txt2img, server, pool=pool)
), ),
app.route("/api/txt2txt", methods=["POST"])( app.route("/api/txt2txt", methods=["POST"])(
wrap_route(txt2txt, context, pool=pool) wrap_route(txt2txt, server, pool=pool)
), ),
app.route("/api/inpaint", methods=["POST"])( app.route("/api/inpaint", methods=["POST"])(
wrap_route(inpaint, context, pool=pool) wrap_route(inpaint, server, pool=pool)
), ),
app.route("/api/upscale", methods=["POST"])( app.route("/api/upscale", methods=["POST"])(
wrap_route(upscale, context, pool=pool) wrap_route(upscale, server, pool=pool)
), ),
app.route("/api/chain", methods=["POST"])( app.route("/api/chain", methods=["POST"])(
wrap_route(chain, context, pool=pool) wrap_route(chain, server, pool=pool)
), ),
app.route("/api/blend", methods=["POST"])( app.route("/api/blend", methods=["POST"])(
wrap_route(blend, context, pool=pool) wrap_route(blend, server, pool=pool)
), ),
app.route("/api/cancel", methods=["PUT"])( app.route("/api/cancel", methods=["PUT"])(
wrap_route(cancel, context, pool=pool) wrap_route(cancel, server, pool=pool)
), ),
app.route("/api/ready")(wrap_route(ready, context, pool=pool)), app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
app.route("/api/status")(wrap_route(status, context, pool=pool)), app.route("/api/status")(wrap_route(status, server, pool=pool)),
] ]

View File

@ -118,13 +118,13 @@ def patch_not_impl():
raise NotImplementedError() raise NotImplementedError()
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
cache_path = cache_path_map.get(url, None) cache_path = cache_path_map.get(url, None)
if cache_path is None: if cache_path is None:
parsed = urlparse(url) parsed = urlparse(url)
cache_path = path.basename(parsed.path) cache_path = path.basename(parsed.path)
cache_path = path.join(ctx.cache_path, cache_path) cache_path = path.join(server.cache_path, cache_path)
logger.debug("patching download path: %s -> %s", url, cache_path) logger.debug("patching download path: %s -> %s", url, cache_path)
if path.exists(cache_path): if path.exists(cache_path):
@ -133,27 +133,27 @@ def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
raise FileNotFoundError("missing cache file: %s" % (cache_path)) raise FileNotFoundError("missing cache file: %s" % (cache_path))
def apply_patch_basicsr(ctx: ServerContext): def apply_patch_basicsr(server: ServerContext):
logger.debug("patching BasicSR module") logger.debug("patching BasicSR module")
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx) basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
def apply_patch_codeformer(ctx: ServerContext): def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module") logger.debug("patching CodeFormer module")
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx) codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
def apply_patch_facexlib(ctx: ServerContext): def apply_patch_facexlib(server: ServerContext):
logger.debug("patching Facexlib module") logger.debug("patching Facexlib module")
facexlib.utils.load_file_from_url = partial(patch_cache_path, ctx) facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
def apply_patches(ctx: ServerContext): def apply_patches(server: ServerContext):
apply_patch_basicsr(ctx) apply_patch_basicsr(server)
apply_patch_codeformer(ctx) apply_patch_codeformer(server)
apply_patch_facexlib(ctx) apply_patch_facexlib(server)
unload( unload(
[ [
"basicsr.utils.download_util", "basicsr.utils.download_util",

View File

@ -115,7 +115,7 @@ def get_config_value(key: str, subkey: str = "default", default=None):
return config_params.get(key, {}).get(subkey, default) return config_params.get(key, {}).get(subkey, default)
def load_extras(context: ServerContext): def load_extras(server: ServerContext):
""" """
Load the extras file(s) and collect the relevant parts for the server: labels and strings Load the extras file(s) and collect the relevant parts for the server: labels and strings
""" """
@ -127,7 +127,7 @@ def load_extras(context: ServerContext):
with open("./schemas/extras.yaml", "r") as f: with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read()) extra_schema = safe_load(f.read())
for file in context.extra_models: for file in server.extra_models:
if file is not None and file != "": if file is not None and file != "":
logger.info("loading extra models from %s", file) logger.info("loading extra models from %s", file)
try: try:
@ -209,11 +209,11 @@ IGNORE_EXTENSIONS = [".crdownload", ".lock", ".tmp"]
def list_model_globs( def list_model_globs(
context: ServerContext, globs: List[str], base_path: Optional[str] = None server: ServerContext, globs: List[str], base_path: Optional[str] = None
) -> List[str]: ) -> List[str]:
models = [] models = []
for pattern in globs: for pattern in globs:
pattern_path = path.join(base_path or context.model_path, pattern) pattern_path = path.join(base_path or server.model_path, pattern)
logger.debug("loading models from %s", pattern_path) logger.debug("loading models from %s", pattern_path)
for name in glob(pattern_path): for name in glob(pattern_path):
base = path.basename(name) base = path.basename(name)
@ -226,7 +226,7 @@ def list_model_globs(
return unique_models return unique_models
def load_models(context: ServerContext) -> None: def load_models(server: ServerContext) -> None:
global correction_models global correction_models
global diffusion_models global diffusion_models
global network_models global network_models
@ -234,7 +234,7 @@ def load_models(context: ServerContext) -> None:
# main categories # main categories
diffusion_models = list_model_globs( diffusion_models = list_model_globs(
context, server,
[ [
"diffusion-*", "diffusion-*",
"stable-diffusion-*", "stable-diffusion-*",
@ -243,7 +243,7 @@ def load_models(context: ServerContext) -> None:
logger.debug("loaded diffusion models from disk: %s", diffusion_models) logger.debug("loaded diffusion models from disk: %s", diffusion_models)
correction_models = list_model_globs( correction_models = list_model_globs(
context, server,
[ [
"correction-*", "correction-*",
], ],
@ -251,7 +251,7 @@ def load_models(context: ServerContext) -> None:
logger.debug("loaded correction models from disk: %s", correction_models) logger.debug("loaded correction models from disk: %s", correction_models)
upscaling_models = list_model_globs( upscaling_models = list_model_globs(
context, server,
[ [
"upscaling-*", "upscaling-*",
], ],
@ -260,11 +260,11 @@ def load_models(context: ServerContext) -> None:
# additional networks # additional networks
inversion_models = list_model_globs( inversion_models = list_model_globs(
context, server,
[ [
"*", "*",
], ],
base_path=path.join(context.model_path, "inversion"), base_path=path.join(server.model_path, "inversion"),
) )
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models) logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
network_models.extend( network_models.extend(
@ -272,35 +272,35 @@ def load_models(context: ServerContext) -> None:
) )
lora_models = list_model_globs( lora_models = list_model_globs(
context, server,
[ [
"*", "*",
], ],
base_path=path.join(context.model_path, "lora"), base_path=path.join(server.model_path, "lora"),
) )
logger.debug("loaded LoRA models from disk: %s", lora_models) logger.debug("loaded LoRA models from disk: %s", lora_models)
network_models.extend([NetworkModel(model, "lora") for model in lora_models]) network_models.extend([NetworkModel(model, "lora") for model in lora_models])
def load_params(context: ServerContext) -> None: def load_params(server: ServerContext) -> None:
global config_params global config_params
params_file = path.join(context.params_path, "params.json") params_file = path.join(server.params_path, "params.json")
logger.debug("loading server parameters from file: %s", params_file) logger.debug("loading server parameters from file: %s", params_file)
with open(params_file, "r") as f: with open(params_file, "r") as f:
config_params = yaml.safe_load(f) config_params = yaml.safe_load(f)
if "platform" in config_params and context.default_platform is not None: if "platform" in config_params and server.default_platform is not None:
logger.info( logger.info(
"overriding default platform from environment: %s", "overriding default platform from environment: %s",
context.default_platform, server.default_platform,
) )
config_platform = config_params.get("platform", {}) config_platform = config_params.get("platform", {})
config_platform["default"] = context.default_platform config_platform["default"] = server.default_platform
def load_platforms(context: ServerContext) -> None: def load_platforms(server: ServerContext) -> None:
global available_platforms global available_platforms
providers = list(get_available_providers()) providers = list(get_available_providers())
@ -309,7 +309,7 @@ def load_platforms(context: ServerContext) -> None:
for potential in platform_providers: for potential in platform_providers:
if ( if (
platform_providers[potential] in providers platform_providers[potential] in providers
and potential not in context.block_platforms and potential not in server.block_platforms
): ):
if potential == "cuda": if potential == "cuda":
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
@ -317,16 +317,16 @@ def load_platforms(context: ServerContext) -> None:
"device_id": i, "device_id": i,
} }
if context.memory_limit is not None: if server.memory_limit is not None:
options["arena_extend_strategy"] = "kSameAsRequested" options["arena_extend_strategy"] = "kSameAsRequested"
options["gpu_mem_limit"] = context.memory_limit options["gpu_mem_limit"] = server.memory_limit
available_platforms.append( available_platforms.append(
DeviceParams( DeviceParams(
potential, potential,
platform_providers[potential], platform_providers[potential],
options, options,
context.optimizations, server.optimizations,
) )
) )
else: else:
@ -335,18 +335,18 @@ def load_platforms(context: ServerContext) -> None:
potential, potential,
platform_providers[potential], platform_providers[potential],
None, None,
context.optimizations, server.optimizations,
) )
) )
if context.any_platform: if server.any_platform:
# the platform should be ignored when the job is scheduled, but set to CPU just in case # the platform should be ignored when the job is scheduled, but set to CPU just in case
available_platforms.append( available_platforms.append(
DeviceParams( DeviceParams(
"any", "any",
platform_providers["cpu"], platform_providers["cpu"],
None, None,
context.optimizations, server.optimizations,
) )
) )

View File

@ -28,7 +28,7 @@ logger = getLogger(__name__)
def pipeline_from_request( def pipeline_from_request(
context: ServerContext, server: ServerContext,
) -> Tuple[DeviceParams, ImageParams, Size]: ) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr user = request.remote_addr
@ -44,7 +44,7 @@ def pipeline_from_request(
# pipeline stuff # pipeline stuff
lpw = get_not_empty(request.args, "lpw", "false") == "true" lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model")) model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(context, model) model_path = get_model_path(server, model)
scheduler = get_from_list( scheduler = get_from_list(
request.args, "scheduler", list(pipeline_schedulers.keys()) request.args, "scheduler", list(pipeline_schedulers.keys())
) )

View File

@ -7,30 +7,30 @@ from .context import ServerContext
from .utils import wrap_route from .utils import wrap_route
def serve_bundle_file(context: ServerContext, filename="index.html"): def serve_bundle_file(server: ServerContext, filename="index.html"):
return send_from_directory(path.join("..", context.bundle_path), filename) return send_from_directory(path.join("..", server.bundle_path), filename)
# non-API routes # non-API routes
def index(context: ServerContext): def index(server: ServerContext):
return serve_bundle_file(context) return serve_bundle_file(server)
def index_path(context: ServerContext, filename: str): def index_path(server: ServerContext, filename: str):
return serve_bundle_file(context, filename) return serve_bundle_file(server, filename)
def output(context: ServerContext, filename: str): def output(server: ServerContext, filename: str):
return send_from_directory( return send_from_directory(
path.join("..", context.output_path), filename, as_attachment=False path.join("..", server.output_path), filename, as_attachment=False
) )
def register_static_routes( def register_static_routes(
app: Flask, context: ServerContext, _pool: DevicePoolExecutor app: Flask, server: ServerContext, _pool: DevicePoolExecutor
): ):
return [ return [
app.route("/")(wrap_route(index, context)), app.route("/")(wrap_route(index, server)),
app.route("/<path:filename>")(wrap_route(index_path, context)), app.route("/<path:filename>")(wrap_route(index_path, server)),
app.route("/output/<path:filename>")(wrap_route(output, context)), app.route("/output/<path:filename>")(wrap_route(output, server)),
] ]

View File

@ -9,26 +9,26 @@ from ..worker.pool import DevicePoolExecutor
from .context import ServerContext from .context import ServerContext
def check_paths(context: ServerContext) -> None: def check_paths(server: ServerContext) -> None:
if not path.exists(context.model_path): if not path.exists(server.model_path):
raise RuntimeError("model path must exist") raise RuntimeError("model path must exist")
if not path.exists(context.output_path): if not path.exists(server.output_path):
makedirs(context.output_path) makedirs(server.output_path)
def get_model_path(context: ServerContext, model: str): def get_model_path(server: ServerContext, model: str):
return base_join(context.model_path, model) return base_join(server.model_path, model)
def register_routes( def register_routes(
app: Flask, app: Flask,
context: ServerContext, server: ServerContext,
pool: DevicePoolExecutor, pool: DevicePoolExecutor,
routes: List[Tuple[str, Dict, Callable]], routes: List[Tuple[str, Dict, Callable]],
): ):
for route, kwargs, method in routes: for route, kwargs, method in routes:
app.route(route, **kwargs)(wrap_route(method, context, pool=pool)) app.route(route, **kwargs)(wrap_route(method, server, pool=pool))
def wrap_route(func, *args, **kwargs): def wrap_route(func, *args, **kwargs):

View File

@ -26,51 +26,51 @@ MEMORY_ERRORS = [
] ]
def worker_main(context: WorkerContext, server: ServerContext): def worker_main(worker: WorkerContext, server: ServerContext):
apply_patches(server) apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device)) setproctitle("onnx-web worker: %s" % (worker.device.device))
logger.trace( logger.trace(
"checking in from worker with providers: %s", get_available_providers() "checking in from worker with providers: %s", get_available_providers()
) )
# make leaking workers easier to recycle # make leaking workers easier to recycle
context.progress.cancel_join_thread() worker.progress.cancel_join_thread()
while True: while True:
try: try:
if not context.is_active(): if not worker.is_active():
logger.warning( logger.warning(
"worker %s has been replaced by %s, exiting", "worker %s has been replaced by %s, exiting",
getpid(), getpid(),
context.get_active(), worker.get_active(),
) )
exit(EXIT_REPLACED) exit(EXIT_REPLACED)
# wait briefly for the next job # wait briefly for the next job
job = context.pending.get(timeout=1.0) job = worker.pending.get(timeout=1.0)
logger.info("worker %s got job: %s", context.device.device, job.name) logger.info("worker %s got job: %s", worker.device.device, job.name)
# clear flags and save the job name # clear flags and save the job name
context.start(job.name) worker.start(job.name)
logger.info("starting job: %s", job.name) logger.info("starting job: %s", job.name)
# reset progress, which does a final check for cancellation # reset progress, which does a final check for cancellation
context.set_progress(0) worker.set_progress(0)
job.fn(context, *job.args, **job.kwargs) job.fn(worker, *job.args, **job.kwargs)
# confirm completion of the job # confirm completion of the job
logger.info("job succeeded: %s", job.name) logger.info("job succeeded: %s", job.name)
context.finish() worker.finish()
except Empty: except Empty:
pass pass
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("worker got keyboard interrupt") logger.info("worker got keyboard interrupt")
context.fail() worker.fail()
exit(EXIT_INTERRUPT) exit(EXIT_INTERRUPT)
except ValueError: except ValueError:
logger.exception("value error in worker, exiting: %s") logger.exception("value error in worker, exiting: %s")
context.fail() worker.fail()
exit(EXIT_ERROR) exit(EXIT_ERROR)
except Exception as e: except Exception as e:
e_str = str(e) e_str = str(e)
@ -78,11 +78,11 @@ def worker_main(context: WorkerContext, server: ServerContext):
for e_mem in MEMORY_ERRORS: for e_mem in MEMORY_ERRORS:
if e_mem in e_str: if e_mem in e_str:
logger.error("detected out-of-memory error, exiting: %s", e) logger.error("detected out-of-memory error, exiting: %s", e)
context.fail() worker.fail()
exit(EXIT_MEMORY) exit(EXIT_MEMORY)
# carry on for other errors # carry on for other errors
logger.exception( logger.exception(
"unrecognized error while running job", "unrecognized error while running job",
) )
context.fail() worker.fail()

23
api/tests/test_output.py Normal file
View File

@ -0,0 +1,23 @@
import unittest
class TestHashValue(unittest.TestCase):
def test_hash_value(self):
pass
class TestJSONParams(unittest.TestCase):
def test_json_params(self):
pass
class TestMakeOutputName(unittest.TestCase):
pass
class TestSaveImage(unittest.TestCase):
pass
class TestSaveParams(unittest.TestCase):
pass

101
api/tests/test_params.py Normal file
View File

@ -0,0 +1,101 @@
import unittest
from onnx_web.params import Border, Size
class BorderTests(unittest.TestCase):
def test_json(self):
border = Border.even(0)
json = border.tojson()
self.assertIn("left", json)
self.assertIn("right", json)
self.assertIn("top", json)
self.assertIn("bottom", json)
def test_str(self):
border = Border.even(10)
bstr = str(border)
self.assertEqual("(10, 10, 10, 10)", bstr)
def test_uneven(self):
border = Border(1, 2, 3, 4)
self.assertEqual("(1, 2, 3, 4)", str(border))
def test_args(self):
pass
class SizeTests(unittest.TestCase):
def test_iter(self):
size = Size(1, 2)
self.assertEqual(list(size), [1, 2])
def test_str(self):
pass
def test_border(self):
pass
def test_json(self):
pass
def test_args(self):
pass
class DeviceParamsTests(unittest.TestCase):
def test_str(self):
pass
def test_provider(self):
pass
def test_options_optimizations(self):
pass
def test_options_cache(self):
pass
def test_torch_cuda(self):
pass
def test_torch_rocm(self):
pass
class ImageParamsTests(unittest.TestCase):
def test_json(self):
pass
def test_args(self):
pass
class StageParamsTests(unittest.TestCase):
def test_init(self):
pass
class UpscaleParamsTests(unittest.TestCase):
def test_rescale(self):
pass
def test_resize(self):
pass
def test_json(self):
pass
def test_args(self):
pass
class HighresParamsTests(unittest.TestCase):
def test_resize(self):
pass
def test_json(self):
pass