lint(api): name context params consistently (#278)
This commit is contained in:
parent
fea9185707
commit
9698e29268
|
@ -23,7 +23,7 @@ class StageCallback(Protocol):
|
|||
def __call__(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
ctx: ServerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
|
|
|
@ -140,7 +140,7 @@ base_models: Models = {
|
|||
|
||||
|
||||
def fetch_model(
|
||||
ctx: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
dest: Optional[str] = None,
|
||||
|
@ -148,7 +148,7 @@ def fetch_model(
|
|||
hf_hub_fetch: bool = False,
|
||||
hf_hub_filename: Optional[str] = None,
|
||||
) -> str:
|
||||
cache_path = dest or ctx.cache_path
|
||||
cache_path = dest or conversion.cache_path
|
||||
cache_name = path.join(cache_path, name)
|
||||
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
|
@ -201,7 +201,7 @@ def fetch_model(
|
|||
return source
|
||||
|
||||
|
||||
def convert_models(ctx: ConversionContext, args, models: Models):
|
||||
def convert_models(conversion: ConversionContext, args, models: Models):
|
||||
if args.sources and "sources" in models:
|
||||
for model in models.get("sources"):
|
||||
model = tuple_to_source(model)
|
||||
|
@ -214,7 +214,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
source = model["source"]
|
||||
|
||||
try:
|
||||
dest = fetch_model(ctx, name, source, format=model_format)
|
||||
dest = fetch_model(conversion, name, source, format=model_format)
|
||||
logger.info("finished downloading source: %s -> %s", source, dest)
|
||||
except Exception:
|
||||
logger.exception("error fetching source %s", name)
|
||||
|
@ -234,20 +234,20 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
try:
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
dest = fetch_model(
|
||||
ctx,
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(ctx.model_path, network_type),
|
||||
dest=path.join(conversion.model_path, network_type),
|
||||
format=network_format,
|
||||
hf_hub_fetch=True,
|
||||
hf_hub_filename="learned_embeds.bin",
|
||||
)
|
||||
else:
|
||||
dest = fetch_model(
|
||||
ctx,
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(ctx.model_path, network_type),
|
||||
dest=path.join(conversion.model_path, network_type),
|
||||
format=network_format,
|
||||
)
|
||||
|
||||
|
@ -267,19 +267,19 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
|
||||
try:
|
||||
source = fetch_model(
|
||||
ctx, name, model["source"], format=model_format
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
|
||||
converted = False
|
||||
if model_format in model_formats_original:
|
||||
converted, dest = convert_diffusion_original(
|
||||
ctx,
|
||||
conversion,
|
||||
model,
|
||||
source,
|
||||
)
|
||||
else:
|
||||
converted, dest = convert_diffusion_diffusers(
|
||||
ctx,
|
||||
conversion,
|
||||
model,
|
||||
source,
|
||||
)
|
||||
|
@ -289,8 +289,8 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
# keep track of which models have been blended
|
||||
blend_models = {}
|
||||
|
||||
inversion_dest = path.join(ctx.model_path, "inversion")
|
||||
lora_dest = path.join(ctx.model_path, "lora")
|
||||
inversion_dest = path.join(conversion.model_path, "inversion")
|
||||
lora_dest = path.join(conversion.model_path, "lora")
|
||||
|
||||
for inversion in model.get("inversions", []):
|
||||
if "text_encoder" not in blend_models:
|
||||
|
@ -314,7 +314,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
inversion_source = inversion["source"]
|
||||
inversion_format = inversion.get("format", None)
|
||||
inversion_source = fetch_model(
|
||||
ctx,
|
||||
conversion,
|
||||
inversion_name,
|
||||
inversion_source,
|
||||
dest=inversion_dest,
|
||||
|
@ -323,7 +323,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
inversion_weight = inversion.get("weight", 1.0)
|
||||
|
||||
blend_textual_inversions(
|
||||
ctx,
|
||||
conversion,
|
||||
blend_models["text_encoder"],
|
||||
blend_models["tokenizer"],
|
||||
[
|
||||
|
@ -355,7 +355,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
lora_name = lora["name"]
|
||||
lora_source = lora["source"]
|
||||
lora_source = fetch_model(
|
||||
ctx,
|
||||
conversion,
|
||||
f"{name}-lora-{lora_name}",
|
||||
lora_source,
|
||||
dest=lora_dest,
|
||||
|
@ -363,14 +363,14 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
lora_weight = lora.get("weight", 1.0)
|
||||
|
||||
blend_loras(
|
||||
ctx,
|
||||
conversion,
|
||||
blend_models["text_encoder"],
|
||||
[(lora_source, lora_weight)],
|
||||
"text_encoder",
|
||||
)
|
||||
|
||||
blend_loras(
|
||||
ctx,
|
||||
conversion,
|
||||
blend_models["unet"],
|
||||
[(lora_source, lora_weight)],
|
||||
"unet",
|
||||
|
@ -413,9 +413,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
|
||||
try:
|
||||
source = fetch_model(
|
||||
ctx, name, model["source"], format=model_format
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
convert_upscale_resrgan(ctx, model, source)
|
||||
convert_upscale_resrgan(conversion, model, source)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting upscaling model %s",
|
||||
|
@ -433,9 +433,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
model_format = source_format(model)
|
||||
try:
|
||||
source = fetch_model(
|
||||
ctx, name, model["source"], format=model_format
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
convert_correction_gfpgan(ctx, model, source)
|
||||
convert_correction_gfpgan(conversion, model, source)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting correction model %s",
|
||||
|
@ -482,21 +482,21 @@ def main() -> int:
|
|||
args = parser.parse_args()
|
||||
logger.info("CLI arguments: %s", args)
|
||||
|
||||
ctx = ConversionContext.from_environ()
|
||||
ctx.half = args.half or "onnx-fp16" in ctx.optimizations
|
||||
ctx.opset = args.opset
|
||||
ctx.token = args.token
|
||||
logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||
server = ConversionContext.from_environ()
|
||||
server.half = args.half or "onnx-fp16" in server.optimizations
|
||||
server.opset = args.opset
|
||||
server.token = args.token
|
||||
logger.info("converting models in %s using %s", server.model_path, server.training_device)
|
||||
|
||||
if not path.exists(ctx.model_path):
|
||||
logger.info("model path does not existing, creating: %s", ctx.model_path)
|
||||
makedirs(ctx.model_path)
|
||||
if not path.exists(server.model_path):
|
||||
logger.info("model path does not existing, creating: %s", server.model_path)
|
||||
makedirs(server.model_path)
|
||||
|
||||
logger.info("converting base models")
|
||||
convert_models(ctx, args, base_models)
|
||||
convert_models(server, args, base_models)
|
||||
|
||||
extras = []
|
||||
extras.extend(ctx.extra_models)
|
||||
extras.extend(server.extra_models)
|
||||
extras.extend(args.extras)
|
||||
extras = list(set(extras))
|
||||
extras.sort()
|
||||
|
@ -516,7 +516,7 @@ def main() -> int:
|
|||
try:
|
||||
validate(data, extra_schema)
|
||||
logger.info("converting extra models")
|
||||
convert_models(ctx, args, data)
|
||||
convert_models(server, args, data)
|
||||
except ValidationError:
|
||||
logger.exception("invalid data in extras file")
|
||||
except Exception:
|
||||
|
|
|
@ -12,7 +12,7 @@ logger = getLogger(__name__)
|
|||
|
||||
@torch.no_grad()
|
||||
def convert_correction_gfpgan(
|
||||
ctx: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
|
@ -20,7 +20,7 @@ def convert_correction_gfpgan(
|
|||
source = source or model.get("source")
|
||||
scale = model.get("scale")
|
||||
|
||||
dest = path.join(ctx.model_path, name + ".onnx")
|
||||
dest = path.join(conversion.model_path, name + ".onnx")
|
||||
logger.info("converting GFPGAN model: %s -> %s", name, dest)
|
||||
|
||||
if path.isfile(dest):
|
||||
|
@ -37,17 +37,17 @@ def convert_correction_gfpgan(
|
|||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(source, map_location=ctx.map_location)
|
||||
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||
# TODO: make sure strict=False is safe here
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
|
||||
model.to(ctx.training_device).train(False)
|
||||
model.to(conversion.training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
|
||||
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
||||
input_names = ["data"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
|
@ -63,7 +63,7 @@ def convert_correction_gfpgan(
|
|||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=ctx.opset,
|
||||
opset_version=conversion.opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("GFPGAN exported to ONNX successfully")
|
||||
|
|
|
@ -90,7 +90,7 @@ def onnx_export(
|
|||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_diffusers(
|
||||
ctx: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
) -> Tuple[bool, str]:
|
||||
|
@ -102,10 +102,10 @@ def convert_diffusion_diffusers(
|
|||
single_vae = model.get("single_vae")
|
||||
replace_vae = model.get("vae")
|
||||
|
||||
dtype = ctx.torch_dtype()
|
||||
dtype = conversion.torch_dtype()
|
||||
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||
|
||||
dest_path = path.join(ctx.model_path, name)
|
||||
dest_path = path.join(conversion.model_path, name)
|
||||
model_index = path.join(dest_path, "model_index.json")
|
||||
|
||||
# diffusers go into a directory rather than .onnx file
|
||||
|
@ -123,11 +123,11 @@ def convert_diffusion_diffusers(
|
|||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
source,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=ctx.token,
|
||||
).to(ctx.training_device)
|
||||
use_auth_token=conversion.token,
|
||||
).to(conversion.training_device)
|
||||
output_path = Path(dest_path)
|
||||
|
||||
optimize_pipeline(ctx, pipeline)
|
||||
optimize_pipeline(conversion, pipeline)
|
||||
|
||||
# TEXT ENCODER
|
||||
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
||||
|
@ -143,11 +143,11 @@ def convert_diffusion_diffusers(
|
|||
pipeline.text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
model_args=(
|
||||
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32),
|
||||
text_input.input_ids.to(device=conversion.training_device, dtype=torch.int32),
|
||||
None, # attention mask
|
||||
None, # position ids
|
||||
None, # output attentions
|
||||
torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool),
|
||||
torch.tensor(True).to(device=conversion.training_device, dtype=torch.bool),
|
||||
),
|
||||
output_path=output_path / "text_encoder" / ONNX_MODEL,
|
||||
ordered_input_names=["input_ids"],
|
||||
|
@ -155,8 +155,8 @@ def convert_diffusion_diffusers(
|
|||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
opset=ctx.opset,
|
||||
half=ctx.half,
|
||||
opset=conversion.opset,
|
||||
half=conversion.half,
|
||||
)
|
||||
del pipeline.text_encoder
|
||||
|
||||
|
@ -165,11 +165,11 @@ def convert_diffusion_diffusers(
|
|||
# UNET
|
||||
if single_vae:
|
||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
||||
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.long)
|
||||
unet_scale = torch.tensor(4).to(device=conversion.training_device, dtype=torch.long)
|
||||
else:
|
||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||
unet_scale = torch.tensor(False).to(
|
||||
device=ctx.training_device, dtype=torch.bool
|
||||
device=conversion.training_device, dtype=torch.bool
|
||||
)
|
||||
|
||||
if is_torch_2_0:
|
||||
|
@ -182,11 +182,11 @@ def convert_diffusion_diffusers(
|
|||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=ctx.training_device, dtype=dtype
|
||||
device=conversion.training_device, dtype=dtype
|
||||
),
|
||||
torch.randn(2).to(device=ctx.training_device, dtype=dtype),
|
||||
torch.randn(2).to(device=conversion.training_device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||
device=ctx.training_device, dtype=dtype
|
||||
device=conversion.training_device, dtype=dtype
|
||||
),
|
||||
unet_scale,
|
||||
),
|
||||
|
@ -199,8 +199,8 @@ def convert_diffusion_diffusers(
|
|||
"timestep": {0: "batch"},
|
||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
opset=ctx.opset,
|
||||
half=ctx.half,
|
||||
opset=conversion.opset,
|
||||
half=conversion.half,
|
||||
external_data=True,
|
||||
)
|
||||
unet_model_path = str(unet_path.absolute().as_posix())
|
||||
|
@ -238,7 +238,7 @@ def convert_diffusion_diffusers(
|
|||
model_args=(
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=ctx.training_device, dtype=dtype),
|
||||
).to(device=conversion.training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae" / ONNX_MODEL,
|
||||
|
@ -247,8 +247,8 @@ def convert_diffusion_diffusers(
|
|||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=ctx.opset,
|
||||
half=ctx.half,
|
||||
opset=conversion.opset,
|
||||
half=conversion.half,
|
||||
)
|
||||
else:
|
||||
# VAE ENCODER
|
||||
|
@ -263,7 +263,7 @@ def convert_diffusion_diffusers(
|
|||
vae_encoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||
device=ctx.training_device, dtype=dtype
|
||||
device=conversion.training_device, dtype=dtype
|
||||
),
|
||||
False,
|
||||
),
|
||||
|
@ -273,7 +273,7 @@ def convert_diffusion_diffusers(
|
|||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=ctx.opset,
|
||||
opset=conversion.opset,
|
||||
half=False, # https://github.com/ssube/onnx-web/issues/290
|
||||
)
|
||||
|
||||
|
@ -287,7 +287,7 @@ def convert_diffusion_diffusers(
|
|||
model_args=(
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=ctx.training_device, dtype=dtype),
|
||||
).to(device=conversion.training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / ONNX_MODEL,
|
||||
|
@ -296,8 +296,8 @@ def convert_diffusion_diffusers(
|
|||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=ctx.opset,
|
||||
half=ctx.half,
|
||||
opset=conversion.opset,
|
||||
half=conversion.half,
|
||||
)
|
||||
|
||||
del pipeline.vae
|
||||
|
|
|
@ -55,7 +55,7 @@ def fix_node_name(key: str):
|
|||
|
||||
|
||||
def blend_loras(
|
||||
_context: ServerContext,
|
||||
_conversion: ServerContext,
|
||||
base_name: Union[str, ModelProto],
|
||||
loras: List[Tuple[str, float]],
|
||||
model_type: Literal["text_encoder", "unet"],
|
||||
|
|
|
@ -146,7 +146,7 @@ class TrainingConfig:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
model_name: str = "",
|
||||
scheduler: str = "ddim",
|
||||
v2: bool = False,
|
||||
|
@ -155,7 +155,7 @@ class TrainingConfig:
|
|||
**kwargs,
|
||||
):
|
||||
model_name = sanitize_name(model_name)
|
||||
model_dir = os.path.join(ctx.cache_path, model_name)
|
||||
model_dir = os.path.join(conversion.cache_path, model_name)
|
||||
working_dir = os.path.join(model_dir, "working")
|
||||
|
||||
if not os.path.exists(working_dir):
|
||||
|
@ -1298,7 +1298,7 @@ def download_model(db_config: TrainingConfig, token):
|
|||
|
||||
|
||||
def get_config_path(
|
||||
context: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
model_version: str = "v1",
|
||||
train_type: str = "default",
|
||||
config_base_name: str = "training",
|
||||
|
@ -1309,7 +1309,7 @@ def get_config_path(
|
|||
)
|
||||
|
||||
parts = os.path.join(
|
||||
context.model_path,
|
||||
conversion.model_path,
|
||||
"configs",
|
||||
f"{model_version}-{config_base_name}-{train_type}.yaml",
|
||||
)
|
||||
|
@ -1317,7 +1317,7 @@ def get_config_path(
|
|||
|
||||
|
||||
def get_config_file(
|
||||
context: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
train_unfrozen=False,
|
||||
v2=False,
|
||||
prediction_type="epsilon",
|
||||
|
@ -1343,7 +1343,7 @@ def get_config_file(
|
|||
model_train_type = train_types["default"]
|
||||
|
||||
return get_config_path(
|
||||
context,
|
||||
conversion,
|
||||
model_version_name,
|
||||
model_train_type,
|
||||
config_base_name,
|
||||
|
@ -1352,7 +1352,7 @@ def get_config_file(
|
|||
|
||||
|
||||
def extract_checkpoint(
|
||||
context: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
new_model_name: str,
|
||||
checkpoint_file: str,
|
||||
scheduler_type="ddim",
|
||||
|
@ -1396,7 +1396,7 @@ def extract_checkpoint(
|
|||
|
||||
# Create empty config
|
||||
db_config = TrainingConfig(
|
||||
context,
|
||||
conversion,
|
||||
model_name=new_model_name,
|
||||
scheduler=scheduler_type,
|
||||
src=checkpoint_file,
|
||||
|
@ -1442,7 +1442,7 @@ def extract_checkpoint(
|
|||
prediction_type = "epsilon"
|
||||
|
||||
original_config_file = get_config_file(
|
||||
context, train_unfrozen, v2, prediction_type, config_file=config_file
|
||||
conversion, train_unfrozen, v2, prediction_type, config_file=config_file
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
@ -1533,7 +1533,7 @@ def extract_checkpoint(
|
|||
checkpoint, vae_config
|
||||
)
|
||||
else:
|
||||
vae_file = os.path.join(context.model_path, vae_file)
|
||||
vae_file = os.path.join(conversion.model_path, vae_file)
|
||||
logger.debug("loading custom VAE: %s", vae_file)
|
||||
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
|
@ -1658,14 +1658,14 @@ def extract_checkpoint(
|
|||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_original(
|
||||
ctx: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
) -> Tuple[bool, str]:
|
||||
name = model["name"]
|
||||
source = source or model["source"]
|
||||
|
||||
dest_path = os.path.join(ctx.model_path, name)
|
||||
dest_path = os.path.join(conversion.model_path, name)
|
||||
dest_index = os.path.join(dest_path, "model_index.json")
|
||||
logger.info(
|
||||
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest_path
|
||||
|
@ -1676,8 +1676,8 @@ def convert_diffusion_original(
|
|||
return (False, dest_path)
|
||||
|
||||
torch_name = name + "-torch"
|
||||
torch_path = os.path.join(ctx.cache_path, torch_name)
|
||||
working_name = os.path.join(ctx.cache_path, torch_name, "working")
|
||||
torch_path = os.path.join(conversion.cache_path, torch_name)
|
||||
working_name = os.path.join(conversion.cache_path, torch_name, "working")
|
||||
model_index = os.path.join(working_name, "model_index.json")
|
||||
|
||||
if os.path.exists(torch_path) and os.path.exists(model_index):
|
||||
|
@ -1689,7 +1689,7 @@ def convert_diffusion_original(
|
|||
torch_path,
|
||||
)
|
||||
if extract_checkpoint(
|
||||
ctx,
|
||||
conversion,
|
||||
torch_name,
|
||||
source,
|
||||
config_file=model.get("config"),
|
||||
|
@ -1704,9 +1704,9 @@ def convert_diffusion_original(
|
|||
if "vae" in model:
|
||||
del model["vae"]
|
||||
|
||||
result = convert_diffusion_diffusers(ctx, model, working_name)
|
||||
result = convert_diffusion_diffusers(conversion, model, working_name)
|
||||
|
||||
if "torch" in ctx.prune:
|
||||
if "torch" in conversion.prune:
|
||||
logger.info("removing intermediate Torch models: %s", torch_path)
|
||||
shutil.rmtree(torch_path)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
|||
|
||||
@torch.no_grad()
|
||||
def blend_textual_inversions(
|
||||
context: ServerContext,
|
||||
server: ServerContext,
|
||||
text_encoder: ModelProto,
|
||||
tokenizer: CLIPTokenizer,
|
||||
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||
|
@ -161,7 +161,7 @@ def blend_textual_inversions(
|
|||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_textual_inversion(
|
||||
context: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
base_model: str,
|
||||
inversion: str,
|
||||
|
@ -169,7 +169,7 @@ def convert_diffusion_textual_inversion(
|
|||
base_token: Optional[str] = None,
|
||||
inversion_weight: Optional[float] = 1.0,
|
||||
):
|
||||
dest_path = path.join(context.model_path, f"inversion-{name}")
|
||||
dest_path = path.join(conversion.model_path, f"inversion-{name}")
|
||||
logger.info(
|
||||
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
|
||||
)
|
||||
|
@ -194,7 +194,7 @@ def convert_diffusion_textual_inversion(
|
|||
subfolder="tokenizer",
|
||||
)
|
||||
text_encoder, tokenizer = blend_textual_inversions(
|
||||
context,
|
||||
conversion,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
[(inversion, inversion_weight, base_token, inversion_format)],
|
||||
|
|
|
@ -13,7 +13,7 @@ TAG_X4_V3 = "real-esrgan-x4-v3"
|
|||
|
||||
@torch.no_grad()
|
||||
def convert_upscale_resrgan(
|
||||
ctx: ConversionContext,
|
||||
conversion: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
|
@ -24,7 +24,7 @@ def convert_upscale_resrgan(
|
|||
source = source or model.get("source")
|
||||
scale = model.get("scale")
|
||||
|
||||
dest = path.join(ctx.model_path, name + ".onnx")
|
||||
dest = path.join(conversion.model_path, name + ".onnx")
|
||||
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
|
||||
|
||||
if path.isfile(dest):
|
||||
|
@ -53,16 +53,16 @@ def convert_upscale_resrgan(
|
|||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(source, map_location=ctx.map_location)
|
||||
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"])
|
||||
else:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
|
||||
model.to(ctx.training_device).train(False)
|
||||
model.to(conversion.training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
|
||||
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
||||
input_names = ["data"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
|
@ -78,7 +78,7 @@ def convert_upscale_resrgan(
|
|||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=ctx.opset,
|
||||
opset_version=conversion.opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("real ESRGAN exported to ONNX successfully")
|
||||
|
|
|
@ -64,7 +64,7 @@ def json_params(
|
|||
|
||||
|
||||
def make_output_name(
|
||||
ctx: ServerContext,
|
||||
server: ServerContext,
|
||||
mode: str,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
@ -92,20 +92,20 @@ def make_output_name(
|
|||
hash_value(sha, param)
|
||||
|
||||
return [
|
||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{ctx.image_format}"
|
||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
||||
for i in range(params.batch)
|
||||
]
|
||||
|
||||
|
||||
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
||||
path = base_join(ctx.output_path, output)
|
||||
image.save(path, format=ctx.image_format)
|
||||
def save_image(server: ServerContext, output: str, image: Image.Image) -> str:
|
||||
path = base_join(server.output_path, output)
|
||||
image.save(path, format=server.image_format)
|
||||
logger.debug("saved output image to: %s", path)
|
||||
return path
|
||||
|
||||
|
||||
def save_params(
|
||||
ctx: ServerContext,
|
||||
server: ServerContext,
|
||||
output: str,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
@ -113,7 +113,7 @@ def save_params(
|
|||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
) -> str:
|
||||
path = base_join(ctx.output_path, f"{output}.json")
|
||||
path = base_join(server.output_path, f"{output}.json")
|
||||
json = json_params(
|
||||
output, params, size, upscale=upscale, border=border, highres=highres
|
||||
)
|
||||
|
|
|
@ -7,6 +7,10 @@ from .torch_before_ort import GraphOptimizationLevel, SessionOptions
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
Param = Union[str, int, float]
|
||||
Point = Tuple[int, int]
|
||||
|
||||
|
||||
class SizeChart(IntEnum):
|
||||
mini = 128 # small tile for very expensive models
|
||||
half = 256 # half tile for outpainting
|
||||
|
@ -25,10 +29,6 @@ class TileOrder:
|
|||
spiral = "spiral"
|
||||
|
||||
|
||||
Param = Union[str, int, float]
|
||||
Point = Tuple[int, int]
|
||||
|
||||
|
||||
class Border:
|
||||
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
|
||||
self.left = left
|
||||
|
@ -37,7 +37,7 @@ class Border:
|
|||
self.bottom = bottom
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom)
|
||||
return "(%s, %s, %s, %s)" % (self.left, self.top, self.right, self.bottom)
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
|
@ -145,6 +145,7 @@ class DeviceParams:
|
|||
return sess
|
||||
|
||||
def torch_str(self) -> str:
|
||||
# TODO: return cuda devices for ROCm as well
|
||||
if self.device.startswith("cuda"):
|
||||
return self.device
|
||||
else:
|
||||
|
|
|
@ -93,7 +93,7 @@ def url_from_rule(rule) -> str:
|
|||
return url_for(rule.endpoint, **options)
|
||||
|
||||
|
||||
def introspect(context: ServerContext, app: Flask):
|
||||
def introspect(server: ServerContext, app: Flask):
|
||||
return {
|
||||
"name": "onnx-web",
|
||||
"routes": [
|
||||
|
@ -103,15 +103,15 @@ def introspect(context: ServerContext, app: Flask):
|
|||
}
|
||||
|
||||
|
||||
def list_extra_strings(context: ServerContext):
|
||||
def list_extra_strings(server: ServerContext):
|
||||
return jsonify(get_extra_strings())
|
||||
|
||||
|
||||
def list_mask_filters(context: ServerContext):
|
||||
def list_mask_filters(server: ServerContext):
|
||||
return jsonify(list(get_mask_filters().keys()))
|
||||
|
||||
|
||||
def list_models(context: ServerContext):
|
||||
def list_models(server: ServerContext):
|
||||
return jsonify(
|
||||
{
|
||||
"correction": get_correction_models(),
|
||||
|
@ -122,30 +122,30 @@ def list_models(context: ServerContext):
|
|||
)
|
||||
|
||||
|
||||
def list_noise_sources(context: ServerContext):
|
||||
def list_noise_sources(server: ServerContext):
|
||||
return jsonify(list(get_noise_sources().keys()))
|
||||
|
||||
|
||||
def list_params(context: ServerContext):
|
||||
def list_params(server: ServerContext):
|
||||
return jsonify(get_config_params())
|
||||
|
||||
|
||||
def list_platforms(context: ServerContext):
|
||||
def list_platforms(server: ServerContext):
|
||||
return jsonify([p.device for p in get_available_platforms()])
|
||||
|
||||
|
||||
def list_schedulers(context: ServerContext):
|
||||
def list_schedulers(server: ServerContext):
|
||||
return jsonify(list(get_pipeline_schedulers().keys()))
|
||||
|
||||
|
||||
def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||
def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||
source_file = request.files.get("source")
|
||||
if source_file is None:
|
||||
return error_reply("source image is required")
|
||||
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(context)
|
||||
device, params, size = pipeline_from_request(server)
|
||||
upscale = upscale_from_request()
|
||||
|
||||
strength = get_and_clamp_float(
|
||||
|
@ -156,7 +156,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
|||
get_config_value("strength", "min"),
|
||||
)
|
||||
|
||||
output = make_output_name(context, "img2img", params, size, extras=[strength])
|
||||
output = make_output_name(server, "img2img", params, size, extras=[strength])
|
||||
job_name = output[0]
|
||||
logger.info("img2img job queued for: %s", job_name)
|
||||
|
||||
|
@ -164,7 +164,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
|||
pool.submit(
|
||||
job_name,
|
||||
run_img2img_pipeline,
|
||||
context,
|
||||
server,
|
||||
params,
|
||||
output,
|
||||
upscale,
|
||||
|
@ -176,19 +176,19 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
||||
|
||||
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||
device, params, size = pipeline_from_request(context)
|
||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||
device, params, size = pipeline_from_request(server)
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
|
||||
output = make_output_name(context, "txt2img", params, size)
|
||||
output = make_output_name(server, "txt2img", params, size)
|
||||
job_name = output[0]
|
||||
logger.info("txt2img job queued for: %s", job_name)
|
||||
|
||||
pool.submit(
|
||||
job_name,
|
||||
run_txt2img_pipeline,
|
||||
context,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
|
@ -200,7 +200,7 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||
|
||||
|
||||
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
||||
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||
source_file = request.files.get("source")
|
||||
if source_file is None:
|
||||
return error_reply("source image is required")
|
||||
|
@ -212,7 +212,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
|||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(context)
|
||||
device, params, size = pipeline_from_request(server)
|
||||
expand = border_from_request()
|
||||
upscale = upscale_from_request()
|
||||
|
||||
|
@ -224,7 +224,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
|||
)
|
||||
|
||||
output = make_output_name(
|
||||
context,
|
||||
server,
|
||||
"inpaint",
|
||||
params,
|
||||
size,
|
||||
|
@ -247,7 +247,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
|||
pool.submit(
|
||||
job_name,
|
||||
run_inpaint_pipeline,
|
||||
context,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
|
@ -265,17 +265,17 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
||||
|
||||
|
||||
def upscale(context: ServerContext, pool: DevicePoolExecutor):
|
||||
def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||
source_file = request.files.get("source")
|
||||
if source_file is None:
|
||||
return error_reply("source image is required")
|
||||
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(context)
|
||||
device, params, size = pipeline_from_request(server)
|
||||
upscale = upscale_from_request()
|
||||
|
||||
output = make_output_name(context, "upscale", params, size)
|
||||
output = make_output_name(server, "upscale", params, size)
|
||||
job_name = output[0]
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
|
@ -283,7 +283,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor):
|
|||
pool.submit(
|
||||
job_name,
|
||||
run_upscale_pipeline,
|
||||
context,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
|
@ -295,7 +295,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
||||
|
||||
def chain(context: ServerContext, pool: DevicePoolExecutor):
|
||||
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||
logger.debug(
|
||||
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||
)
|
||||
|
@ -311,8 +311,8 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
|
|||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
device, params, size = pipeline_from_request(context)
|
||||
output = make_output_name(context, "chain", params, size)
|
||||
device, params, size = pipeline_from_request(server)
|
||||
output = make_output_name(server, "chain", params, size)
|
||||
job_name = output[0]
|
||||
|
||||
pipeline = ChainPipeline()
|
||||
|
@ -371,7 +371,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
|
|||
pool.submit(
|
||||
job_name,
|
||||
pipeline,
|
||||
context,
|
||||
server,
|
||||
params,
|
||||
empty_source,
|
||||
output=output[0],
|
||||
|
@ -382,7 +382,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size))
|
||||
|
||||
|
||||
def blend(context: ServerContext, pool: DevicePoolExecutor):
|
||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||
mask_file = request.files.get("mask")
|
||||
if mask_file is None:
|
||||
return error_reply("mask image is required")
|
||||
|
@ -402,17 +402,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
|
|||
source = valid_image(source, mask.size, mask.size)
|
||||
sources.append(source)
|
||||
|
||||
device, params, size = pipeline_from_request(context)
|
||||
device, params, size = pipeline_from_request(server)
|
||||
upscale = upscale_from_request()
|
||||
|
||||
output = make_output_name(context, "upscale", params, size)
|
||||
output = make_output_name(server, "upscale", params, size)
|
||||
job_name = output[0]
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
pool.submit(
|
||||
job_name,
|
||||
run_blend_pipeline,
|
||||
context,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
|
@ -425,17 +425,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
||||
|
||||
def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
|
||||
device, params, size = pipeline_from_request(context)
|
||||
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||
device, params, size = pipeline_from_request(server)
|
||||
|
||||
output = make_output_name(context, "txt2txt", params, size)
|
||||
output = make_output_name(server, "txt2txt", params, size)
|
||||
job_name = output[0]
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
pool.submit(
|
||||
job_name,
|
||||
run_txt2txt_pipeline,
|
||||
context,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
|
@ -445,7 +445,7 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size))
|
||||
|
||||
|
||||
def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||
output_file = request.args.get("output", None)
|
||||
if output_file is None:
|
||||
return error_reply("output name is required")
|
||||
|
@ -456,7 +456,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return ready_reply(cancelled=cancelled)
|
||||
|
||||
|
||||
def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||
def ready(server: ServerContext, pool: DevicePoolExecutor):
|
||||
output_file = request.args.get("output", None)
|
||||
if output_file is None:
|
||||
return error_reply("output name is required")
|
||||
|
@ -468,7 +468,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return ready_reply(pending=True)
|
||||
|
||||
if progress is None:
|
||||
output = base_join(context.output_path, output_file)
|
||||
output = base_join(server.output_path, output_file)
|
||||
if path.exists(output):
|
||||
return ready_reply(ready=True)
|
||||
else:
|
||||
|
@ -485,44 +485,44 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
|||
)
|
||||
|
||||
|
||||
def status(context: ServerContext, pool: DevicePoolExecutor):
|
||||
def status(server: ServerContext, pool: DevicePoolExecutor):
|
||||
return jsonify(pool.status())
|
||||
|
||||
|
||||
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
|
||||
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
||||
return [
|
||||
app.route("/api")(wrap_route(introspect, context, app=app)),
|
||||
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
|
||||
app.route("/api/settings/models")(wrap_route(list_models, context)),
|
||||
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
|
||||
app.route("/api/settings/params")(wrap_route(list_params, context)),
|
||||
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
|
||||
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
|
||||
app.route("/api/settings/strings")(wrap_route(list_extra_strings, context)),
|
||||
app.route("/api")(wrap_route(introspect, server, app=app)),
|
||||
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
|
||||
app.route("/api/settings/models")(wrap_route(list_models, server)),
|
||||
app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)),
|
||||
app.route("/api/settings/params")(wrap_route(list_params, server)),
|
||||
app.route("/api/settings/platforms")(wrap_route(list_platforms, server)),
|
||||
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
|
||||
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
|
||||
app.route("/api/img2img", methods=["POST"])(
|
||||
wrap_route(img2img, context, pool=pool)
|
||||
wrap_route(img2img, server, pool=pool)
|
||||
),
|
||||
app.route("/api/txt2img", methods=["POST"])(
|
||||
wrap_route(txt2img, context, pool=pool)
|
||||
wrap_route(txt2img, server, pool=pool)
|
||||
),
|
||||
app.route("/api/txt2txt", methods=["POST"])(
|
||||
wrap_route(txt2txt, context, pool=pool)
|
||||
wrap_route(txt2txt, server, pool=pool)
|
||||
),
|
||||
app.route("/api/inpaint", methods=["POST"])(
|
||||
wrap_route(inpaint, context, pool=pool)
|
||||
wrap_route(inpaint, server, pool=pool)
|
||||
),
|
||||
app.route("/api/upscale", methods=["POST"])(
|
||||
wrap_route(upscale, context, pool=pool)
|
||||
wrap_route(upscale, server, pool=pool)
|
||||
),
|
||||
app.route("/api/chain", methods=["POST"])(
|
||||
wrap_route(chain, context, pool=pool)
|
||||
wrap_route(chain, server, pool=pool)
|
||||
),
|
||||
app.route("/api/blend", methods=["POST"])(
|
||||
wrap_route(blend, context, pool=pool)
|
||||
wrap_route(blend, server, pool=pool)
|
||||
),
|
||||
app.route("/api/cancel", methods=["PUT"])(
|
||||
wrap_route(cancel, context, pool=pool)
|
||||
wrap_route(cancel, server, pool=pool)
|
||||
),
|
||||
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
|
||||
app.route("/api/status")(wrap_route(status, context, pool=pool)),
|
||||
app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
|
||||
app.route("/api/status")(wrap_route(status, server, pool=pool)),
|
||||
]
|
||||
|
|
|
@ -118,13 +118,13 @@ def patch_not_impl():
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
|
||||
def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
|
||||
cache_path = cache_path_map.get(url, None)
|
||||
if cache_path is None:
|
||||
parsed = urlparse(url)
|
||||
cache_path = path.basename(parsed.path)
|
||||
|
||||
cache_path = path.join(ctx.cache_path, cache_path)
|
||||
cache_path = path.join(server.cache_path, cache_path)
|
||||
logger.debug("patching download path: %s -> %s", url, cache_path)
|
||||
|
||||
if path.exists(cache_path):
|
||||
|
@ -133,27 +133,27 @@ def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
|
|||
raise FileNotFoundError("missing cache file: %s" % (cache_path))
|
||||
|
||||
|
||||
def apply_patch_basicsr(ctx: ServerContext):
|
||||
def apply_patch_basicsr(server: ServerContext):
|
||||
logger.debug("patching BasicSR module")
|
||||
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)
|
||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
|
||||
|
||||
|
||||
def apply_patch_codeformer(ctx: ServerContext):
|
||||
def apply_patch_codeformer(server: ServerContext):
|
||||
logger.debug("patching CodeFormer module")
|
||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)
|
||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
|
||||
|
||||
|
||||
def apply_patch_facexlib(ctx: ServerContext):
|
||||
def apply_patch_facexlib(server: ServerContext):
|
||||
logger.debug("patching Facexlib module")
|
||||
facexlib.utils.load_file_from_url = partial(patch_cache_path, ctx)
|
||||
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
|
||||
|
||||
|
||||
def apply_patches(ctx: ServerContext):
|
||||
apply_patch_basicsr(ctx)
|
||||
apply_patch_codeformer(ctx)
|
||||
apply_patch_facexlib(ctx)
|
||||
def apply_patches(server: ServerContext):
|
||||
apply_patch_basicsr(server)
|
||||
apply_patch_codeformer(server)
|
||||
apply_patch_facexlib(server)
|
||||
unload(
|
||||
[
|
||||
"basicsr.utils.download_util",
|
||||
|
|
|
@ -115,7 +115,7 @@ def get_config_value(key: str, subkey: str = "default", default=None):
|
|||
return config_params.get(key, {}).get(subkey, default)
|
||||
|
||||
|
||||
def load_extras(context: ServerContext):
|
||||
def load_extras(server: ServerContext):
|
||||
"""
|
||||
Load the extras file(s) and collect the relevant parts for the server: labels and strings
|
||||
"""
|
||||
|
@ -127,7 +127,7 @@ def load_extras(context: ServerContext):
|
|||
with open("./schemas/extras.yaml", "r") as f:
|
||||
extra_schema = safe_load(f.read())
|
||||
|
||||
for file in context.extra_models:
|
||||
for file in server.extra_models:
|
||||
if file is not None and file != "":
|
||||
logger.info("loading extra models from %s", file)
|
||||
try:
|
||||
|
@ -209,11 +209,11 @@ IGNORE_EXTENSIONS = [".crdownload", ".lock", ".tmp"]
|
|||
|
||||
|
||||
def list_model_globs(
|
||||
context: ServerContext, globs: List[str], base_path: Optional[str] = None
|
||||
server: ServerContext, globs: List[str], base_path: Optional[str] = None
|
||||
) -> List[str]:
|
||||
models = []
|
||||
for pattern in globs:
|
||||
pattern_path = path.join(base_path or context.model_path, pattern)
|
||||
pattern_path = path.join(base_path or server.model_path, pattern)
|
||||
logger.debug("loading models from %s", pattern_path)
|
||||
for name in glob(pattern_path):
|
||||
base = path.basename(name)
|
||||
|
@ -226,7 +226,7 @@ def list_model_globs(
|
|||
return unique_models
|
||||
|
||||
|
||||
def load_models(context: ServerContext) -> None:
|
||||
def load_models(server: ServerContext) -> None:
|
||||
global correction_models
|
||||
global diffusion_models
|
||||
global network_models
|
||||
|
@ -234,7 +234,7 @@ def load_models(context: ServerContext) -> None:
|
|||
|
||||
# main categories
|
||||
diffusion_models = list_model_globs(
|
||||
context,
|
||||
server,
|
||||
[
|
||||
"diffusion-*",
|
||||
"stable-diffusion-*",
|
||||
|
@ -243,7 +243,7 @@ def load_models(context: ServerContext) -> None:
|
|||
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
|
||||
|
||||
correction_models = list_model_globs(
|
||||
context,
|
||||
server,
|
||||
[
|
||||
"correction-*",
|
||||
],
|
||||
|
@ -251,7 +251,7 @@ def load_models(context: ServerContext) -> None:
|
|||
logger.debug("loaded correction models from disk: %s", correction_models)
|
||||
|
||||
upscaling_models = list_model_globs(
|
||||
context,
|
||||
server,
|
||||
[
|
||||
"upscaling-*",
|
||||
],
|
||||
|
@ -260,11 +260,11 @@ def load_models(context: ServerContext) -> None:
|
|||
|
||||
# additional networks
|
||||
inversion_models = list_model_globs(
|
||||
context,
|
||||
server,
|
||||
[
|
||||
"*",
|
||||
],
|
||||
base_path=path.join(context.model_path, "inversion"),
|
||||
base_path=path.join(server.model_path, "inversion"),
|
||||
)
|
||||
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
||||
network_models.extend(
|
||||
|
@ -272,35 +272,35 @@ def load_models(context: ServerContext) -> None:
|
|||
)
|
||||
|
||||
lora_models = list_model_globs(
|
||||
context,
|
||||
server,
|
||||
[
|
||||
"*",
|
||||
],
|
||||
base_path=path.join(context.model_path, "lora"),
|
||||
base_path=path.join(server.model_path, "lora"),
|
||||
)
|
||||
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
||||
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
||||
|
||||
|
||||
def load_params(context: ServerContext) -> None:
|
||||
def load_params(server: ServerContext) -> None:
|
||||
global config_params
|
||||
|
||||
params_file = path.join(context.params_path, "params.json")
|
||||
params_file = path.join(server.params_path, "params.json")
|
||||
logger.debug("loading server parameters from file: %s", params_file)
|
||||
|
||||
with open(params_file, "r") as f:
|
||||
config_params = yaml.safe_load(f)
|
||||
|
||||
if "platform" in config_params and context.default_platform is not None:
|
||||
if "platform" in config_params and server.default_platform is not None:
|
||||
logger.info(
|
||||
"overriding default platform from environment: %s",
|
||||
context.default_platform,
|
||||
server.default_platform,
|
||||
)
|
||||
config_platform = config_params.get("platform", {})
|
||||
config_platform["default"] = context.default_platform
|
||||
config_platform["default"] = server.default_platform
|
||||
|
||||
|
||||
def load_platforms(context: ServerContext) -> None:
|
||||
def load_platforms(server: ServerContext) -> None:
|
||||
global available_platforms
|
||||
|
||||
providers = list(get_available_providers())
|
||||
|
@ -309,7 +309,7 @@ def load_platforms(context: ServerContext) -> None:
|
|||
for potential in platform_providers:
|
||||
if (
|
||||
platform_providers[potential] in providers
|
||||
and potential not in context.block_platforms
|
||||
and potential not in server.block_platforms
|
||||
):
|
||||
if potential == "cuda":
|
||||
for i in range(torch.cuda.device_count()):
|
||||
|
@ -317,16 +317,16 @@ def load_platforms(context: ServerContext) -> None:
|
|||
"device_id": i,
|
||||
}
|
||||
|
||||
if context.memory_limit is not None:
|
||||
if server.memory_limit is not None:
|
||||
options["arena_extend_strategy"] = "kSameAsRequested"
|
||||
options["gpu_mem_limit"] = context.memory_limit
|
||||
options["gpu_mem_limit"] = server.memory_limit
|
||||
|
||||
available_platforms.append(
|
||||
DeviceParams(
|
||||
potential,
|
||||
platform_providers[potential],
|
||||
options,
|
||||
context.optimizations,
|
||||
server.optimizations,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -335,18 +335,18 @@ def load_platforms(context: ServerContext) -> None:
|
|||
potential,
|
||||
platform_providers[potential],
|
||||
None,
|
||||
context.optimizations,
|
||||
server.optimizations,
|
||||
)
|
||||
)
|
||||
|
||||
if context.any_platform:
|
||||
if server.any_platform:
|
||||
# the platform should be ignored when the job is scheduled, but set to CPU just in case
|
||||
available_platforms.append(
|
||||
DeviceParams(
|
||||
"any",
|
||||
platform_providers["cpu"],
|
||||
None,
|
||||
context.optimizations,
|
||||
server.optimizations,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def pipeline_from_request(
|
||||
context: ServerContext,
|
||||
server: ServerContext,
|
||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
|
@ -44,7 +44,7 @@ def pipeline_from_request(
|
|||
# pipeline stuff
|
||||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||
model_path = get_model_path(context, model)
|
||||
model_path = get_model_path(server, model)
|
||||
scheduler = get_from_list(
|
||||
request.args, "scheduler", list(pipeline_schedulers.keys())
|
||||
)
|
||||
|
|
|
@ -7,30 +7,30 @@ from .context import ServerContext
|
|||
from .utils import wrap_route
|
||||
|
||||
|
||||
def serve_bundle_file(context: ServerContext, filename="index.html"):
|
||||
return send_from_directory(path.join("..", context.bundle_path), filename)
|
||||
def serve_bundle_file(server: ServerContext, filename="index.html"):
|
||||
return send_from_directory(path.join("..", server.bundle_path), filename)
|
||||
|
||||
|
||||
# non-API routes
|
||||
def index(context: ServerContext):
|
||||
return serve_bundle_file(context)
|
||||
def index(server: ServerContext):
|
||||
return serve_bundle_file(server)
|
||||
|
||||
|
||||
def index_path(context: ServerContext, filename: str):
|
||||
return serve_bundle_file(context, filename)
|
||||
def index_path(server: ServerContext, filename: str):
|
||||
return serve_bundle_file(server, filename)
|
||||
|
||||
|
||||
def output(context: ServerContext, filename: str):
|
||||
def output(server: ServerContext, filename: str):
|
||||
return send_from_directory(
|
||||
path.join("..", context.output_path), filename, as_attachment=False
|
||||
path.join("..", server.output_path), filename, as_attachment=False
|
||||
)
|
||||
|
||||
|
||||
def register_static_routes(
|
||||
app: Flask, context: ServerContext, _pool: DevicePoolExecutor
|
||||
app: Flask, server: ServerContext, _pool: DevicePoolExecutor
|
||||
):
|
||||
return [
|
||||
app.route("/")(wrap_route(index, context)),
|
||||
app.route("/<path:filename>")(wrap_route(index_path, context)),
|
||||
app.route("/output/<path:filename>")(wrap_route(output, context)),
|
||||
app.route("/")(wrap_route(index, server)),
|
||||
app.route("/<path:filename>")(wrap_route(index_path, server)),
|
||||
app.route("/output/<path:filename>")(wrap_route(output, server)),
|
||||
]
|
||||
|
|
|
@ -9,26 +9,26 @@ from ..worker.pool import DevicePoolExecutor
|
|||
from .context import ServerContext
|
||||
|
||||
|
||||
def check_paths(context: ServerContext) -> None:
|
||||
if not path.exists(context.model_path):
|
||||
def check_paths(server: ServerContext) -> None:
|
||||
if not path.exists(server.model_path):
|
||||
raise RuntimeError("model path must exist")
|
||||
|
||||
if not path.exists(context.output_path):
|
||||
makedirs(context.output_path)
|
||||
if not path.exists(server.output_path):
|
||||
makedirs(server.output_path)
|
||||
|
||||
|
||||
def get_model_path(context: ServerContext, model: str):
|
||||
return base_join(context.model_path, model)
|
||||
def get_model_path(server: ServerContext, model: str):
|
||||
return base_join(server.model_path, model)
|
||||
|
||||
|
||||
def register_routes(
|
||||
app: Flask,
|
||||
context: ServerContext,
|
||||
server: ServerContext,
|
||||
pool: DevicePoolExecutor,
|
||||
routes: List[Tuple[str, Dict, Callable]],
|
||||
):
|
||||
for route, kwargs, method in routes:
|
||||
app.route(route, **kwargs)(wrap_route(method, context, pool=pool))
|
||||
app.route(route, **kwargs)(wrap_route(method, server, pool=pool))
|
||||
|
||||
|
||||
def wrap_route(func, *args, **kwargs):
|
||||
|
|
|
@ -26,51 +26,51 @@ MEMORY_ERRORS = [
|
|||
]
|
||||
|
||||
|
||||
def worker_main(context: WorkerContext, server: ServerContext):
|
||||
def worker_main(worker: WorkerContext, server: ServerContext):
|
||||
apply_patches(server)
|
||||
setproctitle("onnx-web worker: %s" % (context.device.device))
|
||||
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
||||
|
||||
logger.trace(
|
||||
"checking in from worker with providers: %s", get_available_providers()
|
||||
)
|
||||
|
||||
# make leaking workers easier to recycle
|
||||
context.progress.cancel_join_thread()
|
||||
worker.progress.cancel_join_thread()
|
||||
|
||||
while True:
|
||||
try:
|
||||
if not context.is_active():
|
||||
if not worker.is_active():
|
||||
logger.warning(
|
||||
"worker %s has been replaced by %s, exiting",
|
||||
getpid(),
|
||||
context.get_active(),
|
||||
worker.get_active(),
|
||||
)
|
||||
exit(EXIT_REPLACED)
|
||||
|
||||
# wait briefly for the next job
|
||||
job = context.pending.get(timeout=1.0)
|
||||
logger.info("worker %s got job: %s", context.device.device, job.name)
|
||||
job = worker.pending.get(timeout=1.0)
|
||||
logger.info("worker %s got job: %s", worker.device.device, job.name)
|
||||
|
||||
# clear flags and save the job name
|
||||
context.start(job.name)
|
||||
worker.start(job.name)
|
||||
logger.info("starting job: %s", job.name)
|
||||
|
||||
# reset progress, which does a final check for cancellation
|
||||
context.set_progress(0)
|
||||
job.fn(context, *job.args, **job.kwargs)
|
||||
worker.set_progress(0)
|
||||
job.fn(worker, *job.args, **job.kwargs)
|
||||
|
||||
# confirm completion of the job
|
||||
logger.info("job succeeded: %s", job.name)
|
||||
context.finish()
|
||||
worker.finish()
|
||||
except Empty:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
logger.info("worker got keyboard interrupt")
|
||||
context.fail()
|
||||
worker.fail()
|
||||
exit(EXIT_INTERRUPT)
|
||||
except ValueError:
|
||||
logger.exception("value error in worker, exiting: %s")
|
||||
context.fail()
|
||||
worker.fail()
|
||||
exit(EXIT_ERROR)
|
||||
except Exception as e:
|
||||
e_str = str(e)
|
||||
|
@ -78,11 +78,11 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
|||
for e_mem in MEMORY_ERRORS:
|
||||
if e_mem in e_str:
|
||||
logger.error("detected out-of-memory error, exiting: %s", e)
|
||||
context.fail()
|
||||
worker.fail()
|
||||
exit(EXIT_MEMORY)
|
||||
|
||||
# carry on for other errors
|
||||
logger.exception(
|
||||
"unrecognized error while running job",
|
||||
)
|
||||
context.fail()
|
||||
worker.fail()
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import unittest
|
||||
|
||||
|
||||
class TestHashValue(unittest.TestCase):
|
||||
def test_hash_value(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestJSONParams(unittest.TestCase):
|
||||
def test_json_params(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestMakeOutputName(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TestSaveImage(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TestSaveParams(unittest.TestCase):
|
||||
pass
|
|
@ -0,0 +1,101 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.params import Border, Size
|
||||
|
||||
class BorderTests(unittest.TestCase):
|
||||
def test_json(self):
|
||||
border = Border.even(0)
|
||||
json = border.tojson()
|
||||
|
||||
self.assertIn("left", json)
|
||||
self.assertIn("right", json)
|
||||
self.assertIn("top", json)
|
||||
self.assertIn("bottom", json)
|
||||
|
||||
def test_str(self):
|
||||
border = Border.even(10)
|
||||
bstr = str(border)
|
||||
|
||||
self.assertEqual("(10, 10, 10, 10)", bstr)
|
||||
|
||||
def test_uneven(self):
|
||||
border = Border(1, 2, 3, 4)
|
||||
|
||||
self.assertEqual("(1, 2, 3, 4)", str(border))
|
||||
|
||||
def test_args(self):
|
||||
pass
|
||||
|
||||
|
||||
class SizeTests(unittest.TestCase):
|
||||
def test_iter(self):
|
||||
size = Size(1, 2)
|
||||
|
||||
self.assertEqual(list(size), [1, 2])
|
||||
|
||||
def test_str(self):
|
||||
pass
|
||||
|
||||
def test_border(self):
|
||||
pass
|
||||
|
||||
def test_json(self):
|
||||
pass
|
||||
|
||||
def test_args(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceParamsTests(unittest.TestCase):
|
||||
def test_str(self):
|
||||
pass
|
||||
|
||||
def test_provider(self):
|
||||
pass
|
||||
|
||||
def test_options_optimizations(self):
|
||||
pass
|
||||
|
||||
def test_options_cache(self):
|
||||
pass
|
||||
|
||||
def test_torch_cuda(self):
|
||||
pass
|
||||
|
||||
def test_torch_rocm(self):
|
||||
pass
|
||||
|
||||
|
||||
class ImageParamsTests(unittest.TestCase):
|
||||
def test_json(self):
|
||||
pass
|
||||
|
||||
def test_args(self):
|
||||
pass
|
||||
|
||||
|
||||
class StageParamsTests(unittest.TestCase):
|
||||
def test_init(self):
|
||||
pass
|
||||
|
||||
|
||||
class UpscaleParamsTests(unittest.TestCase):
|
||||
def test_rescale(self):
|
||||
pass
|
||||
|
||||
def test_resize(self):
|
||||
pass
|
||||
|
||||
def test_json(self):
|
||||
pass
|
||||
|
||||
def test_args(self):
|
||||
pass
|
||||
|
||||
|
||||
class HighresParamsTests(unittest.TestCase):
|
||||
def test_resize(self):
|
||||
pass
|
||||
|
||||
def test_json(self):
|
||||
pass
|
Loading…
Reference in New Issue