lint(api): name context params consistently (#278)
This commit is contained in:
parent
fea9185707
commit
9698e29268
|
@ -23,7 +23,7 @@ class StageCallback(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
job: WorkerContext,
|
||||||
ctx: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
|
|
|
@ -140,7 +140,7 @@ base_models: Models = {
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(
|
def fetch_model(
|
||||||
ctx: ConversionContext,
|
conversion: ConversionContext,
|
||||||
name: str,
|
name: str,
|
||||||
source: str,
|
source: str,
|
||||||
dest: Optional[str] = None,
|
dest: Optional[str] = None,
|
||||||
|
@ -148,7 +148,7 @@ def fetch_model(
|
||||||
hf_hub_fetch: bool = False,
|
hf_hub_fetch: bool = False,
|
||||||
hf_hub_filename: Optional[str] = None,
|
hf_hub_filename: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
cache_path = dest or ctx.cache_path
|
cache_path = dest or conversion.cache_path
|
||||||
cache_name = path.join(cache_path, name)
|
cache_name = path.join(cache_path, name)
|
||||||
|
|
||||||
# add an extension if possible, some of the conversion code checks for it
|
# add an extension if possible, some of the conversion code checks for it
|
||||||
|
@ -201,7 +201,7 @@ def fetch_model(
|
||||||
return source
|
return source
|
||||||
|
|
||||||
|
|
||||||
def convert_models(ctx: ConversionContext, args, models: Models):
|
def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
if args.sources and "sources" in models:
|
if args.sources and "sources" in models:
|
||||||
for model in models.get("sources"):
|
for model in models.get("sources"):
|
||||||
model = tuple_to_source(model)
|
model = tuple_to_source(model)
|
||||||
|
@ -214,7 +214,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
source = model["source"]
|
source = model["source"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dest = fetch_model(ctx, name, source, format=model_format)
|
dest = fetch_model(conversion, name, source, format=model_format)
|
||||||
logger.info("finished downloading source: %s -> %s", source, dest)
|
logger.info("finished downloading source: %s -> %s", source, dest)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error fetching source %s", name)
|
logger.exception("error fetching source %s", name)
|
||||||
|
@ -234,20 +234,20 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
try:
|
try:
|
||||||
if network_type == "inversion" and network_model == "concept":
|
if network_type == "inversion" and network_model == "concept":
|
||||||
dest = fetch_model(
|
dest = fetch_model(
|
||||||
ctx,
|
conversion,
|
||||||
name,
|
name,
|
||||||
source,
|
source,
|
||||||
dest=path.join(ctx.model_path, network_type),
|
dest=path.join(conversion.model_path, network_type),
|
||||||
format=network_format,
|
format=network_format,
|
||||||
hf_hub_fetch=True,
|
hf_hub_fetch=True,
|
||||||
hf_hub_filename="learned_embeds.bin",
|
hf_hub_filename="learned_embeds.bin",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dest = fetch_model(
|
dest = fetch_model(
|
||||||
ctx,
|
conversion,
|
||||||
name,
|
name,
|
||||||
source,
|
source,
|
||||||
dest=path.join(ctx.model_path, network_type),
|
dest=path.join(conversion.model_path, network_type),
|
||||||
format=network_format,
|
format=network_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -267,19 +267,19 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
source = fetch_model(
|
source = fetch_model(
|
||||||
ctx, name, model["source"], format=model_format
|
conversion, name, model["source"], format=model_format
|
||||||
)
|
)
|
||||||
|
|
||||||
converted = False
|
converted = False
|
||||||
if model_format in model_formats_original:
|
if model_format in model_formats_original:
|
||||||
converted, dest = convert_diffusion_original(
|
converted, dest = convert_diffusion_original(
|
||||||
ctx,
|
conversion,
|
||||||
model,
|
model,
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
converted, dest = convert_diffusion_diffusers(
|
converted, dest = convert_diffusion_diffusers(
|
||||||
ctx,
|
conversion,
|
||||||
model,
|
model,
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
|
@ -289,8 +289,8 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
# keep track of which models have been blended
|
# keep track of which models have been blended
|
||||||
blend_models = {}
|
blend_models = {}
|
||||||
|
|
||||||
inversion_dest = path.join(ctx.model_path, "inversion")
|
inversion_dest = path.join(conversion.model_path, "inversion")
|
||||||
lora_dest = path.join(ctx.model_path, "lora")
|
lora_dest = path.join(conversion.model_path, "lora")
|
||||||
|
|
||||||
for inversion in model.get("inversions", []):
|
for inversion in model.get("inversions", []):
|
||||||
if "text_encoder" not in blend_models:
|
if "text_encoder" not in blend_models:
|
||||||
|
@ -314,7 +314,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
inversion_source = inversion["source"]
|
inversion_source = inversion["source"]
|
||||||
inversion_format = inversion.get("format", None)
|
inversion_format = inversion.get("format", None)
|
||||||
inversion_source = fetch_model(
|
inversion_source = fetch_model(
|
||||||
ctx,
|
conversion,
|
||||||
inversion_name,
|
inversion_name,
|
||||||
inversion_source,
|
inversion_source,
|
||||||
dest=inversion_dest,
|
dest=inversion_dest,
|
||||||
|
@ -323,7 +323,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
inversion_weight = inversion.get("weight", 1.0)
|
inversion_weight = inversion.get("weight", 1.0)
|
||||||
|
|
||||||
blend_textual_inversions(
|
blend_textual_inversions(
|
||||||
ctx,
|
conversion,
|
||||||
blend_models["text_encoder"],
|
blend_models["text_encoder"],
|
||||||
blend_models["tokenizer"],
|
blend_models["tokenizer"],
|
||||||
[
|
[
|
||||||
|
@ -355,7 +355,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
lora_name = lora["name"]
|
lora_name = lora["name"]
|
||||||
lora_source = lora["source"]
|
lora_source = lora["source"]
|
||||||
lora_source = fetch_model(
|
lora_source = fetch_model(
|
||||||
ctx,
|
conversion,
|
||||||
f"{name}-lora-{lora_name}",
|
f"{name}-lora-{lora_name}",
|
||||||
lora_source,
|
lora_source,
|
||||||
dest=lora_dest,
|
dest=lora_dest,
|
||||||
|
@ -363,14 +363,14 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
lora_weight = lora.get("weight", 1.0)
|
lora_weight = lora.get("weight", 1.0)
|
||||||
|
|
||||||
blend_loras(
|
blend_loras(
|
||||||
ctx,
|
conversion,
|
||||||
blend_models["text_encoder"],
|
blend_models["text_encoder"],
|
||||||
[(lora_source, lora_weight)],
|
[(lora_source, lora_weight)],
|
||||||
"text_encoder",
|
"text_encoder",
|
||||||
)
|
)
|
||||||
|
|
||||||
blend_loras(
|
blend_loras(
|
||||||
ctx,
|
conversion,
|
||||||
blend_models["unet"],
|
blend_models["unet"],
|
||||||
[(lora_source, lora_weight)],
|
[(lora_source, lora_weight)],
|
||||||
"unet",
|
"unet",
|
||||||
|
@ -413,9 +413,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
source = fetch_model(
|
source = fetch_model(
|
||||||
ctx, name, model["source"], format=model_format
|
conversion, name, model["source"], format=model_format
|
||||||
)
|
)
|
||||||
convert_upscale_resrgan(ctx, model, source)
|
convert_upscale_resrgan(conversion, model, source)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error converting upscaling model %s",
|
"error converting upscaling model %s",
|
||||||
|
@ -433,9 +433,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
model_format = source_format(model)
|
model_format = source_format(model)
|
||||||
try:
|
try:
|
||||||
source = fetch_model(
|
source = fetch_model(
|
||||||
ctx, name, model["source"], format=model_format
|
conversion, name, model["source"], format=model_format
|
||||||
)
|
)
|
||||||
convert_correction_gfpgan(ctx, model, source)
|
convert_correction_gfpgan(conversion, model, source)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error converting correction model %s",
|
"error converting correction model %s",
|
||||||
|
@ -482,21 +482,21 @@ def main() -> int:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info("CLI arguments: %s", args)
|
logger.info("CLI arguments: %s", args)
|
||||||
|
|
||||||
ctx = ConversionContext.from_environ()
|
server = ConversionContext.from_environ()
|
||||||
ctx.half = args.half or "onnx-fp16" in ctx.optimizations
|
server.half = args.half or "onnx-fp16" in server.optimizations
|
||||||
ctx.opset = args.opset
|
server.opset = args.opset
|
||||||
ctx.token = args.token
|
server.token = args.token
|
||||||
logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device)
|
logger.info("converting models in %s using %s", server.model_path, server.training_device)
|
||||||
|
|
||||||
if not path.exists(ctx.model_path):
|
if not path.exists(server.model_path):
|
||||||
logger.info("model path does not existing, creating: %s", ctx.model_path)
|
logger.info("model path does not existing, creating: %s", server.model_path)
|
||||||
makedirs(ctx.model_path)
|
makedirs(server.model_path)
|
||||||
|
|
||||||
logger.info("converting base models")
|
logger.info("converting base models")
|
||||||
convert_models(ctx, args, base_models)
|
convert_models(server, args, base_models)
|
||||||
|
|
||||||
extras = []
|
extras = []
|
||||||
extras.extend(ctx.extra_models)
|
extras.extend(server.extra_models)
|
||||||
extras.extend(args.extras)
|
extras.extend(args.extras)
|
||||||
extras = list(set(extras))
|
extras = list(set(extras))
|
||||||
extras.sort()
|
extras.sort()
|
||||||
|
@ -516,7 +516,7 @@ def main() -> int:
|
||||||
try:
|
try:
|
||||||
validate(data, extra_schema)
|
validate(data, extra_schema)
|
||||||
logger.info("converting extra models")
|
logger.info("converting extra models")
|
||||||
convert_models(ctx, args, data)
|
convert_models(server, args, data)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
logger.exception("invalid data in extras file")
|
logger.exception("invalid data in extras file")
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -12,7 +12,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_correction_gfpgan(
|
def convert_correction_gfpgan(
|
||||||
ctx: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: ModelDict,
|
model: ModelDict,
|
||||||
source: str,
|
source: str,
|
||||||
):
|
):
|
||||||
|
@ -20,7 +20,7 @@ def convert_correction_gfpgan(
|
||||||
source = source or model.get("source")
|
source = source or model.get("source")
|
||||||
scale = model.get("scale")
|
scale = model.get("scale")
|
||||||
|
|
||||||
dest = path.join(ctx.model_path, name + ".onnx")
|
dest = path.join(conversion.model_path, name + ".onnx")
|
||||||
logger.info("converting GFPGAN model: %s -> %s", name, dest)
|
logger.info("converting GFPGAN model: %s -> %s", name, dest)
|
||||||
|
|
||||||
if path.isfile(dest):
|
if path.isfile(dest):
|
||||||
|
@ -37,17 +37,17 @@ def convert_correction_gfpgan(
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_model = torch.load(source, map_location=ctx.map_location)
|
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||||
# TODO: make sure strict=False is safe here
|
# TODO: make sure strict=False is safe here
|
||||||
if "params_ema" in torch_model:
|
if "params_ema" in torch_model:
|
||||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch_model["params"], strict=False)
|
model.load_state_dict(torch_model["params"], strict=False)
|
||||||
|
|
||||||
model.to(ctx.training_device).train(False)
|
model.to(conversion.training_device).train(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
|
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
||||||
input_names = ["data"]
|
input_names = ["data"]
|
||||||
output_names = ["output"]
|
output_names = ["output"]
|
||||||
dynamic_axes = {
|
dynamic_axes = {
|
||||||
|
@ -63,7 +63,7 @@ def convert_correction_gfpgan(
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
opset_version=ctx.opset,
|
opset_version=conversion.opset,
|
||||||
export_params=True,
|
export_params=True,
|
||||||
)
|
)
|
||||||
logger.info("GFPGAN exported to ONNX successfully")
|
logger.info("GFPGAN exported to ONNX successfully")
|
||||||
|
|
|
@ -90,7 +90,7 @@ def onnx_export(
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_diffusers(
|
def convert_diffusion_diffusers(
|
||||||
ctx: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
source: str,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
|
@ -102,10 +102,10 @@ def convert_diffusion_diffusers(
|
||||||
single_vae = model.get("single_vae")
|
single_vae = model.get("single_vae")
|
||||||
replace_vae = model.get("vae")
|
replace_vae = model.get("vae")
|
||||||
|
|
||||||
dtype = ctx.torch_dtype()
|
dtype = conversion.torch_dtype()
|
||||||
logger.debug("using Torch dtype %s for pipeline", dtype)
|
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||||
|
|
||||||
dest_path = path.join(ctx.model_path, name)
|
dest_path = path.join(conversion.model_path, name)
|
||||||
model_index = path.join(dest_path, "model_index.json")
|
model_index = path.join(dest_path, "model_index.json")
|
||||||
|
|
||||||
# diffusers go into a directory rather than .onnx file
|
# diffusers go into a directory rather than .onnx file
|
||||||
|
@ -123,11 +123,11 @@ def convert_diffusion_diffusers(
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
source,
|
source,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
use_auth_token=ctx.token,
|
use_auth_token=conversion.token,
|
||||||
).to(ctx.training_device)
|
).to(conversion.training_device)
|
||||||
output_path = Path(dest_path)
|
output_path = Path(dest_path)
|
||||||
|
|
||||||
optimize_pipeline(ctx, pipeline)
|
optimize_pipeline(conversion, pipeline)
|
||||||
|
|
||||||
# TEXT ENCODER
|
# TEXT ENCODER
|
||||||
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
||||||
|
@ -143,11 +143,11 @@ def convert_diffusion_diffusers(
|
||||||
pipeline.text_encoder,
|
pipeline.text_encoder,
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
model_args=(
|
model_args=(
|
||||||
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32),
|
text_input.input_ids.to(device=conversion.training_device, dtype=torch.int32),
|
||||||
None, # attention mask
|
None, # attention mask
|
||||||
None, # position ids
|
None, # position ids
|
||||||
None, # output attentions
|
None, # output attentions
|
||||||
torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool),
|
torch.tensor(True).to(device=conversion.training_device, dtype=torch.bool),
|
||||||
),
|
),
|
||||||
output_path=output_path / "text_encoder" / ONNX_MODEL,
|
output_path=output_path / "text_encoder" / ONNX_MODEL,
|
||||||
ordered_input_names=["input_ids"],
|
ordered_input_names=["input_ids"],
|
||||||
|
@ -155,8 +155,8 @@ def convert_diffusion_diffusers(
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"input_ids": {0: "batch", 1: "sequence"},
|
"input_ids": {0: "batch", 1: "sequence"},
|
||||||
},
|
},
|
||||||
opset=ctx.opset,
|
opset=conversion.opset,
|
||||||
half=ctx.half,
|
half=conversion.half,
|
||||||
)
|
)
|
||||||
del pipeline.text_encoder
|
del pipeline.text_encoder
|
||||||
|
|
||||||
|
@ -165,11 +165,11 @@ def convert_diffusion_diffusers(
|
||||||
# UNET
|
# UNET
|
||||||
if single_vae:
|
if single_vae:
|
||||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
||||||
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.long)
|
unet_scale = torch.tensor(4).to(device=conversion.training_device, dtype=torch.long)
|
||||||
else:
|
else:
|
||||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||||
unet_scale = torch.tensor(False).to(
|
unet_scale = torch.tensor(False).to(
|
||||||
device=ctx.training_device, dtype=torch.bool
|
device=conversion.training_device, dtype=torch.bool
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_torch_2_0:
|
if is_torch_2_0:
|
||||||
|
@ -182,11 +182,11 @@ def convert_diffusion_diffusers(
|
||||||
pipeline.unet,
|
pipeline.unet,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||||
device=ctx.training_device, dtype=dtype
|
device=conversion.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
torch.randn(2).to(device=ctx.training_device, dtype=dtype),
|
torch.randn(2).to(device=conversion.training_device, dtype=dtype),
|
||||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||||
device=ctx.training_device, dtype=dtype
|
device=conversion.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
unet_scale,
|
unet_scale,
|
||||||
),
|
),
|
||||||
|
@ -199,8 +199,8 @@ def convert_diffusion_diffusers(
|
||||||
"timestep": {0: "batch"},
|
"timestep": {0: "batch"},
|
||||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||||
},
|
},
|
||||||
opset=ctx.opset,
|
opset=conversion.opset,
|
||||||
half=ctx.half,
|
half=conversion.half,
|
||||||
external_data=True,
|
external_data=True,
|
||||||
)
|
)
|
||||||
unet_model_path = str(unet_path.absolute().as_posix())
|
unet_model_path = str(unet_path.absolute().as_posix())
|
||||||
|
@ -238,7 +238,7 @@ def convert_diffusion_diffusers(
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(
|
torch.randn(
|
||||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||||
).to(device=ctx.training_device, dtype=dtype),
|
).to(device=conversion.training_device, dtype=dtype),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae" / ONNX_MODEL,
|
output_path=output_path / "vae" / ONNX_MODEL,
|
||||||
|
@ -247,8 +247,8 @@ def convert_diffusion_diffusers(
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
},
|
},
|
||||||
opset=ctx.opset,
|
opset=conversion.opset,
|
||||||
half=ctx.half,
|
half=conversion.half,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# VAE ENCODER
|
# VAE ENCODER
|
||||||
|
@ -263,7 +263,7 @@ def convert_diffusion_diffusers(
|
||||||
vae_encoder,
|
vae_encoder,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||||
device=ctx.training_device, dtype=dtype
|
device=conversion.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
|
@ -273,7 +273,7 @@ def convert_diffusion_diffusers(
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
},
|
},
|
||||||
opset=ctx.opset,
|
opset=conversion.opset,
|
||||||
half=False, # https://github.com/ssube/onnx-web/issues/290
|
half=False, # https://github.com/ssube/onnx-web/issues/290
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -287,7 +287,7 @@ def convert_diffusion_diffusers(
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(
|
torch.randn(
|
||||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||||
).to(device=ctx.training_device, dtype=dtype),
|
).to(device=conversion.training_device, dtype=dtype),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae_decoder" / ONNX_MODEL,
|
output_path=output_path / "vae_decoder" / ONNX_MODEL,
|
||||||
|
@ -296,8 +296,8 @@ def convert_diffusion_diffusers(
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
},
|
},
|
||||||
opset=ctx.opset,
|
opset=conversion.opset,
|
||||||
half=ctx.half,
|
half=conversion.half,
|
||||||
)
|
)
|
||||||
|
|
||||||
del pipeline.vae
|
del pipeline.vae
|
||||||
|
|
|
@ -55,7 +55,7 @@ def fix_node_name(key: str):
|
||||||
|
|
||||||
|
|
||||||
def blend_loras(
|
def blend_loras(
|
||||||
_context: ServerContext,
|
_conversion: ServerContext,
|
||||||
base_name: Union[str, ModelProto],
|
base_name: Union[str, ModelProto],
|
||||||
loras: List[Tuple[str, float]],
|
loras: List[Tuple[str, float]],
|
||||||
model_type: Literal["text_encoder", "unet"],
|
model_type: Literal["text_encoder", "unet"],
|
||||||
|
|
|
@ -146,7 +146,7 @@ class TrainingConfig:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ctx: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model_name: str = "",
|
model_name: str = "",
|
||||||
scheduler: str = "ddim",
|
scheduler: str = "ddim",
|
||||||
v2: bool = False,
|
v2: bool = False,
|
||||||
|
@ -155,7 +155,7 @@ class TrainingConfig:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
model_name = sanitize_name(model_name)
|
model_name = sanitize_name(model_name)
|
||||||
model_dir = os.path.join(ctx.cache_path, model_name)
|
model_dir = os.path.join(conversion.cache_path, model_name)
|
||||||
working_dir = os.path.join(model_dir, "working")
|
working_dir = os.path.join(model_dir, "working")
|
||||||
|
|
||||||
if not os.path.exists(working_dir):
|
if not os.path.exists(working_dir):
|
||||||
|
@ -1298,7 +1298,7 @@ def download_model(db_config: TrainingConfig, token):
|
||||||
|
|
||||||
|
|
||||||
def get_config_path(
|
def get_config_path(
|
||||||
context: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model_version: str = "v1",
|
model_version: str = "v1",
|
||||||
train_type: str = "default",
|
train_type: str = "default",
|
||||||
config_base_name: str = "training",
|
config_base_name: str = "training",
|
||||||
|
@ -1309,7 +1309,7 @@ def get_config_path(
|
||||||
)
|
)
|
||||||
|
|
||||||
parts = os.path.join(
|
parts = os.path.join(
|
||||||
context.model_path,
|
conversion.model_path,
|
||||||
"configs",
|
"configs",
|
||||||
f"{model_version}-{config_base_name}-{train_type}.yaml",
|
f"{model_version}-{config_base_name}-{train_type}.yaml",
|
||||||
)
|
)
|
||||||
|
@ -1317,7 +1317,7 @@ def get_config_path(
|
||||||
|
|
||||||
|
|
||||||
def get_config_file(
|
def get_config_file(
|
||||||
context: ConversionContext,
|
conversion: ConversionContext,
|
||||||
train_unfrozen=False,
|
train_unfrozen=False,
|
||||||
v2=False,
|
v2=False,
|
||||||
prediction_type="epsilon",
|
prediction_type="epsilon",
|
||||||
|
@ -1343,7 +1343,7 @@ def get_config_file(
|
||||||
model_train_type = train_types["default"]
|
model_train_type = train_types["default"]
|
||||||
|
|
||||||
return get_config_path(
|
return get_config_path(
|
||||||
context,
|
conversion,
|
||||||
model_version_name,
|
model_version_name,
|
||||||
model_train_type,
|
model_train_type,
|
||||||
config_base_name,
|
config_base_name,
|
||||||
|
@ -1352,7 +1352,7 @@ def get_config_file(
|
||||||
|
|
||||||
|
|
||||||
def extract_checkpoint(
|
def extract_checkpoint(
|
||||||
context: ConversionContext,
|
conversion: ConversionContext,
|
||||||
new_model_name: str,
|
new_model_name: str,
|
||||||
checkpoint_file: str,
|
checkpoint_file: str,
|
||||||
scheduler_type="ddim",
|
scheduler_type="ddim",
|
||||||
|
@ -1396,7 +1396,7 @@ def extract_checkpoint(
|
||||||
|
|
||||||
# Create empty config
|
# Create empty config
|
||||||
db_config = TrainingConfig(
|
db_config = TrainingConfig(
|
||||||
context,
|
conversion,
|
||||||
model_name=new_model_name,
|
model_name=new_model_name,
|
||||||
scheduler=scheduler_type,
|
scheduler=scheduler_type,
|
||||||
src=checkpoint_file,
|
src=checkpoint_file,
|
||||||
|
@ -1442,7 +1442,7 @@ def extract_checkpoint(
|
||||||
prediction_type = "epsilon"
|
prediction_type = "epsilon"
|
||||||
|
|
||||||
original_config_file = get_config_file(
|
original_config_file = get_config_file(
|
||||||
context, train_unfrozen, v2, prediction_type, config_file=config_file
|
conversion, train_unfrozen, v2, prediction_type, config_file=config_file
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -1533,7 +1533,7 @@ def extract_checkpoint(
|
||||||
checkpoint, vae_config
|
checkpoint, vae_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vae_file = os.path.join(context.model_path, vae_file)
|
vae_file = os.path.join(conversion.model_path, vae_file)
|
||||||
logger.debug("loading custom VAE: %s", vae_file)
|
logger.debug("loading custom VAE: %s", vae_file)
|
||||||
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||||
|
@ -1658,14 +1658,14 @@ def extract_checkpoint(
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_original(
|
def convert_diffusion_original(
|
||||||
ctx: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: ModelDict,
|
model: ModelDict,
|
||||||
source: str,
|
source: str,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
name = model["name"]
|
name = model["name"]
|
||||||
source = source or model["source"]
|
source = source or model["source"]
|
||||||
|
|
||||||
dest_path = os.path.join(ctx.model_path, name)
|
dest_path = os.path.join(conversion.model_path, name)
|
||||||
dest_index = os.path.join(dest_path, "model_index.json")
|
dest_index = os.path.join(dest_path, "model_index.json")
|
||||||
logger.info(
|
logger.info(
|
||||||
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest_path
|
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest_path
|
||||||
|
@ -1676,8 +1676,8 @@ def convert_diffusion_original(
|
||||||
return (False, dest_path)
|
return (False, dest_path)
|
||||||
|
|
||||||
torch_name = name + "-torch"
|
torch_name = name + "-torch"
|
||||||
torch_path = os.path.join(ctx.cache_path, torch_name)
|
torch_path = os.path.join(conversion.cache_path, torch_name)
|
||||||
working_name = os.path.join(ctx.cache_path, torch_name, "working")
|
working_name = os.path.join(conversion.cache_path, torch_name, "working")
|
||||||
model_index = os.path.join(working_name, "model_index.json")
|
model_index = os.path.join(working_name, "model_index.json")
|
||||||
|
|
||||||
if os.path.exists(torch_path) and os.path.exists(model_index):
|
if os.path.exists(torch_path) and os.path.exists(model_index):
|
||||||
|
@ -1689,7 +1689,7 @@ def convert_diffusion_original(
|
||||||
torch_path,
|
torch_path,
|
||||||
)
|
)
|
||||||
if extract_checkpoint(
|
if extract_checkpoint(
|
||||||
ctx,
|
conversion,
|
||||||
torch_name,
|
torch_name,
|
||||||
source,
|
source,
|
||||||
config_file=model.get("config"),
|
config_file=model.get("config"),
|
||||||
|
@ -1704,9 +1704,9 @@ def convert_diffusion_original(
|
||||||
if "vae" in model:
|
if "vae" in model:
|
||||||
del model["vae"]
|
del model["vae"]
|
||||||
|
|
||||||
result = convert_diffusion_diffusers(ctx, model, working_name)
|
result = convert_diffusion_diffusers(conversion, model, working_name)
|
||||||
|
|
||||||
if "torch" in ctx.prune:
|
if "torch" in conversion.prune:
|
||||||
logger.info("removing intermediate Torch models: %s", torch_path)
|
logger.info("removing intermediate Torch models: %s", torch_path)
|
||||||
shutil.rmtree(torch_path)
|
shutil.rmtree(torch_path)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def blend_textual_inversions(
|
def blend_textual_inversions(
|
||||||
context: ServerContext,
|
server: ServerContext,
|
||||||
text_encoder: ModelProto,
|
text_encoder: ModelProto,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||||
|
@ -161,7 +161,7 @@ def blend_textual_inversions(
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_textual_inversion(
|
def convert_diffusion_textual_inversion(
|
||||||
context: ConversionContext,
|
conversion: ConversionContext,
|
||||||
name: str,
|
name: str,
|
||||||
base_model: str,
|
base_model: str,
|
||||||
inversion: str,
|
inversion: str,
|
||||||
|
@ -169,7 +169,7 @@ def convert_diffusion_textual_inversion(
|
||||||
base_token: Optional[str] = None,
|
base_token: Optional[str] = None,
|
||||||
inversion_weight: Optional[float] = 1.0,
|
inversion_weight: Optional[float] = 1.0,
|
||||||
):
|
):
|
||||||
dest_path = path.join(context.model_path, f"inversion-{name}")
|
dest_path = path.join(conversion.model_path, f"inversion-{name}")
|
||||||
logger.info(
|
logger.info(
|
||||||
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
|
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
|
||||||
)
|
)
|
||||||
|
@ -194,7 +194,7 @@ def convert_diffusion_textual_inversion(
|
||||||
subfolder="tokenizer",
|
subfolder="tokenizer",
|
||||||
)
|
)
|
||||||
text_encoder, tokenizer = blend_textual_inversions(
|
text_encoder, tokenizer = blend_textual_inversions(
|
||||||
context,
|
conversion,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
[(inversion, inversion_weight, base_token, inversion_format)],
|
[(inversion, inversion_weight, base_token, inversion_format)],
|
||||||
|
|
|
@ -13,7 +13,7 @@ TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_upscale_resrgan(
|
def convert_upscale_resrgan(
|
||||||
ctx: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: ModelDict,
|
model: ModelDict,
|
||||||
source: str,
|
source: str,
|
||||||
):
|
):
|
||||||
|
@ -24,7 +24,7 @@ def convert_upscale_resrgan(
|
||||||
source = source or model.get("source")
|
source = source or model.get("source")
|
||||||
scale = model.get("scale")
|
scale = model.get("scale")
|
||||||
|
|
||||||
dest = path.join(ctx.model_path, name + ".onnx")
|
dest = path.join(conversion.model_path, name + ".onnx")
|
||||||
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
|
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
|
||||||
|
|
||||||
if path.isfile(dest):
|
if path.isfile(dest):
|
||||||
|
@ -53,16 +53,16 @@ def convert_upscale_resrgan(
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_model = torch.load(source, map_location=ctx.map_location)
|
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||||
if "params_ema" in torch_model:
|
if "params_ema" in torch_model:
|
||||||
model.load_state_dict(torch_model["params_ema"])
|
model.load_state_dict(torch_model["params_ema"])
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch_model["params"], strict=False)
|
model.load_state_dict(torch_model["params"], strict=False)
|
||||||
|
|
||||||
model.to(ctx.training_device).train(False)
|
model.to(conversion.training_device).train(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
|
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
||||||
input_names = ["data"]
|
input_names = ["data"]
|
||||||
output_names = ["output"]
|
output_names = ["output"]
|
||||||
dynamic_axes = {
|
dynamic_axes = {
|
||||||
|
@ -78,7 +78,7 @@ def convert_upscale_resrgan(
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
opset_version=ctx.opset,
|
opset_version=conversion.opset,
|
||||||
export_params=True,
|
export_params=True,
|
||||||
)
|
)
|
||||||
logger.info("real ESRGAN exported to ONNX successfully")
|
logger.info("real ESRGAN exported to ONNX successfully")
|
||||||
|
|
|
@ -64,7 +64,7 @@ def json_params(
|
||||||
|
|
||||||
|
|
||||||
def make_output_name(
|
def make_output_name(
|
||||||
ctx: ServerContext,
|
server: ServerContext,
|
||||||
mode: str,
|
mode: str,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -92,20 +92,20 @@ def make_output_name(
|
||||||
hash_value(sha, param)
|
hash_value(sha, param)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{ctx.image_format}"
|
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
||||||
for i in range(params.batch)
|
for i in range(params.batch)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
def save_image(server: ServerContext, output: str, image: Image.Image) -> str:
|
||||||
path = base_join(ctx.output_path, output)
|
path = base_join(server.output_path, output)
|
||||||
image.save(path, format=ctx.image_format)
|
image.save(path, format=server.image_format)
|
||||||
logger.debug("saved output image to: %s", path)
|
logger.debug("saved output image to: %s", path)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def save_params(
|
def save_params(
|
||||||
ctx: ServerContext,
|
server: ServerContext,
|
||||||
output: str,
|
output: str,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -113,7 +113,7 @@ def save_params(
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
path = base_join(ctx.output_path, f"{output}.json")
|
path = base_join(server.output_path, f"{output}.json")
|
||||||
json = json_params(
|
json = json_params(
|
||||||
output, params, size, upscale=upscale, border=border, highres=highres
|
output, params, size, upscale=upscale, border=border, highres=highres
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,6 +7,10 @@ from .torch_before_ort import GraphOptimizationLevel, SessionOptions
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
Param = Union[str, int, float]
|
||||||
|
Point = Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
class SizeChart(IntEnum):
|
class SizeChart(IntEnum):
|
||||||
mini = 128 # small tile for very expensive models
|
mini = 128 # small tile for very expensive models
|
||||||
half = 256 # half tile for outpainting
|
half = 256 # half tile for outpainting
|
||||||
|
@ -25,10 +29,6 @@ class TileOrder:
|
||||||
spiral = "spiral"
|
spiral = "spiral"
|
||||||
|
|
||||||
|
|
||||||
Param = Union[str, int, float]
|
|
||||||
Point = Tuple[int, int]
|
|
||||||
|
|
||||||
|
|
||||||
class Border:
|
class Border:
|
||||||
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
|
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
|
||||||
self.left = left
|
self.left = left
|
||||||
|
@ -37,7 +37,7 @@ class Border:
|
||||||
self.bottom = bottom
|
self.bottom = bottom
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom)
|
return "(%s, %s, %s, %s)" % (self.left, self.top, self.right, self.bottom)
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self):
|
||||||
return {
|
return {
|
||||||
|
@ -145,6 +145,7 @@ class DeviceParams:
|
||||||
return sess
|
return sess
|
||||||
|
|
||||||
def torch_str(self) -> str:
|
def torch_str(self) -> str:
|
||||||
|
# TODO: return cuda devices for ROCm as well
|
||||||
if self.device.startswith("cuda"):
|
if self.device.startswith("cuda"):
|
||||||
return self.device
|
return self.device
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -93,7 +93,7 @@ def url_from_rule(rule) -> str:
|
||||||
return url_for(rule.endpoint, **options)
|
return url_for(rule.endpoint, **options)
|
||||||
|
|
||||||
|
|
||||||
def introspect(context: ServerContext, app: Flask):
|
def introspect(server: ServerContext, app: Flask):
|
||||||
return {
|
return {
|
||||||
"name": "onnx-web",
|
"name": "onnx-web",
|
||||||
"routes": [
|
"routes": [
|
||||||
|
@ -103,15 +103,15 @@ def introspect(context: ServerContext, app: Flask):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def list_extra_strings(context: ServerContext):
|
def list_extra_strings(server: ServerContext):
|
||||||
return jsonify(get_extra_strings())
|
return jsonify(get_extra_strings())
|
||||||
|
|
||||||
|
|
||||||
def list_mask_filters(context: ServerContext):
|
def list_mask_filters(server: ServerContext):
|
||||||
return jsonify(list(get_mask_filters().keys()))
|
return jsonify(list(get_mask_filters().keys()))
|
||||||
|
|
||||||
|
|
||||||
def list_models(context: ServerContext):
|
def list_models(server: ServerContext):
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"correction": get_correction_models(),
|
"correction": get_correction_models(),
|
||||||
|
@ -122,30 +122,30 @@ def list_models(context: ServerContext):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def list_noise_sources(context: ServerContext):
|
def list_noise_sources(server: ServerContext):
|
||||||
return jsonify(list(get_noise_sources().keys()))
|
return jsonify(list(get_noise_sources().keys()))
|
||||||
|
|
||||||
|
|
||||||
def list_params(context: ServerContext):
|
def list_params(server: ServerContext):
|
||||||
return jsonify(get_config_params())
|
return jsonify(get_config_params())
|
||||||
|
|
||||||
|
|
||||||
def list_platforms(context: ServerContext):
|
def list_platforms(server: ServerContext):
|
||||||
return jsonify([p.device for p in get_available_platforms()])
|
return jsonify([p.device for p in get_available_platforms()])
|
||||||
|
|
||||||
|
|
||||||
def list_schedulers(context: ServerContext):
|
def list_schedulers(server: ServerContext):
|
||||||
return jsonify(list(get_pipeline_schedulers().keys()))
|
return jsonify(list(get_pipeline_schedulers().keys()))
|
||||||
|
|
||||||
|
|
||||||
def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
source_file = request.files.get("source")
|
source_file = request.files.get("source")
|
||||||
if source_file is None:
|
if source_file is None:
|
||||||
return error_reply("source image is required")
|
return error_reply("source image is required")
|
||||||
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(context)
|
device, params, size = pipeline_from_request(server)
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
strength = get_and_clamp_float(
|
strength = get_and_clamp_float(
|
||||||
|
@ -156,7 +156,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
get_config_value("strength", "min"),
|
get_config_value("strength", "min"),
|
||||||
)
|
)
|
||||||
|
|
||||||
output = make_output_name(context, "img2img", params, size, extras=[strength])
|
output = make_output_name(server, "img2img", params, size, extras=[strength])
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
logger.info("img2img job queued for: %s", job_name)
|
logger.info("img2img job queued for: %s", job_name)
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
run_img2img_pipeline,
|
run_img2img_pipeline,
|
||||||
context,
|
server,
|
||||||
params,
|
params,
|
||||||
output,
|
output,
|
||||||
upscale,
|
upscale,
|
||||||
|
@ -176,19 +176,19 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
device, params, size = pipeline_from_request(context)
|
device, params, size = pipeline_from_request(server)
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
highres = highres_from_request()
|
highres = highres_from_request()
|
||||||
|
|
||||||
output = make_output_name(context, "txt2img", params, size)
|
output = make_output_name(server, "txt2img", params, size)
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
logger.info("txt2img job queued for: %s", job_name)
|
logger.info("txt2img job queued for: %s", job_name)
|
||||||
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
context,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
output,
|
||||||
|
@ -200,7 +200,7 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||||
|
|
||||||
|
|
||||||
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
source_file = request.files.get("source")
|
source_file = request.files.get("source")
|
||||||
if source_file is None:
|
if source_file is None:
|
||||||
return error_reply("source image is required")
|
return error_reply("source image is required")
|
||||||
|
@ -212,7 +212,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(context)
|
device, params, size = pipeline_from_request(server)
|
||||||
expand = border_from_request()
|
expand = border_from_request()
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
|
@ -224,7 +224,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
|
|
||||||
output = make_output_name(
|
output = make_output_name(
|
||||||
context,
|
server,
|
||||||
"inpaint",
|
"inpaint",
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
|
@ -247,7 +247,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
run_inpaint_pipeline,
|
run_inpaint_pipeline,
|
||||||
context,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
output,
|
||||||
|
@ -265,17 +265,17 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
||||||
|
|
||||||
|
|
||||||
def upscale(context: ServerContext, pool: DevicePoolExecutor):
|
def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
source_file = request.files.get("source")
|
source_file = request.files.get("source")
|
||||||
if source_file is None:
|
if source_file is None:
|
||||||
return error_reply("source image is required")
|
return error_reply("source image is required")
|
||||||
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(context)
|
device, params, size = pipeline_from_request(server)
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
output = make_output_name(context, "upscale", params, size)
|
output = make_output_name(server, "upscale", params, size)
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
|
@ -283,7 +283,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
context,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
output,
|
||||||
|
@ -295,7 +295,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
def chain(context: ServerContext, pool: DevicePoolExecutor):
|
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||||
)
|
)
|
||||||
|
@ -311,8 +311,8 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
# get defaults from the regular parameters
|
# get defaults from the regular parameters
|
||||||
device, params, size = pipeline_from_request(context)
|
device, params, size = pipeline_from_request(server)
|
||||||
output = make_output_name(context, "chain", params, size)
|
output = make_output_name(server, "chain", params, size)
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
|
|
||||||
pipeline = ChainPipeline()
|
pipeline = ChainPipeline()
|
||||||
|
@ -371,7 +371,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
pipeline,
|
pipeline,
|
||||||
context,
|
server,
|
||||||
params,
|
params,
|
||||||
empty_source,
|
empty_source,
|
||||||
output=output[0],
|
output=output[0],
|
||||||
|
@ -382,7 +382,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size))
|
return jsonify(json_params(output, params, size))
|
||||||
|
|
||||||
|
|
||||||
def blend(context: ServerContext, pool: DevicePoolExecutor):
|
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
mask_file = request.files.get("mask")
|
mask_file = request.files.get("mask")
|
||||||
if mask_file is None:
|
if mask_file is None:
|
||||||
return error_reply("mask image is required")
|
return error_reply("mask image is required")
|
||||||
|
@ -402,17 +402,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
source = valid_image(source, mask.size, mask.size)
|
source = valid_image(source, mask.size, mask.size)
|
||||||
sources.append(source)
|
sources.append(source)
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(context)
|
device, params, size = pipeline_from_request(server)
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
output = make_output_name(context, "upscale", params, size)
|
output = make_output_name(server, "upscale", params, size)
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
context,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
output,
|
||||||
|
@ -425,17 +425,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
|
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
device, params, size = pipeline_from_request(context)
|
device, params, size = pipeline_from_request(server)
|
||||||
|
|
||||||
output = make_output_name(context, "txt2txt", params, size)
|
output = make_output_name(server, "txt2txt", params, size)
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
run_txt2txt_pipeline,
|
run_txt2txt_pipeline,
|
||||||
context,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
output,
|
||||||
|
@ -445,7 +445,7 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size))
|
return jsonify(json_params(output, params, size))
|
||||||
|
|
||||||
|
|
||||||
def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
output_file = request.args.get("output", None)
|
output_file = request.args.get("output", None)
|
||||||
if output_file is None:
|
if output_file is None:
|
||||||
return error_reply("output name is required")
|
return error_reply("output name is required")
|
||||||
|
@ -456,7 +456,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return ready_reply(cancelled=cancelled)
|
return ready_reply(cancelled=cancelled)
|
||||||
|
|
||||||
|
|
||||||
def ready(context: ServerContext, pool: DevicePoolExecutor):
|
def ready(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
output_file = request.args.get("output", None)
|
output_file = request.args.get("output", None)
|
||||||
if output_file is None:
|
if output_file is None:
|
||||||
return error_reply("output name is required")
|
return error_reply("output name is required")
|
||||||
|
@ -468,7 +468,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return ready_reply(pending=True)
|
return ready_reply(pending=True)
|
||||||
|
|
||||||
if progress is None:
|
if progress is None:
|
||||||
output = base_join(context.output_path, output_file)
|
output = base_join(server.output_path, output_file)
|
||||||
if path.exists(output):
|
if path.exists(output):
|
||||||
return ready_reply(ready=True)
|
return ready_reply(ready=True)
|
||||||
else:
|
else:
|
||||||
|
@ -485,44 +485,44 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def status(context: ServerContext, pool: DevicePoolExecutor):
|
def status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(pool.status())
|
return jsonify(pool.status())
|
||||||
|
|
||||||
|
|
||||||
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
|
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return [
|
return [
|
||||||
app.route("/api")(wrap_route(introspect, context, app=app)),
|
app.route("/api")(wrap_route(introspect, server, app=app)),
|
||||||
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
|
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
|
||||||
app.route("/api/settings/models")(wrap_route(list_models, context)),
|
app.route("/api/settings/models")(wrap_route(list_models, server)),
|
||||||
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
|
app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)),
|
||||||
app.route("/api/settings/params")(wrap_route(list_params, context)),
|
app.route("/api/settings/params")(wrap_route(list_params, server)),
|
||||||
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
|
app.route("/api/settings/platforms")(wrap_route(list_platforms, server)),
|
||||||
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
|
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
|
||||||
app.route("/api/settings/strings")(wrap_route(list_extra_strings, context)),
|
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
|
||||||
app.route("/api/img2img", methods=["POST"])(
|
app.route("/api/img2img", methods=["POST"])(
|
||||||
wrap_route(img2img, context, pool=pool)
|
wrap_route(img2img, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/txt2img", methods=["POST"])(
|
app.route("/api/txt2img", methods=["POST"])(
|
||||||
wrap_route(txt2img, context, pool=pool)
|
wrap_route(txt2img, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/txt2txt", methods=["POST"])(
|
app.route("/api/txt2txt", methods=["POST"])(
|
||||||
wrap_route(txt2txt, context, pool=pool)
|
wrap_route(txt2txt, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/inpaint", methods=["POST"])(
|
app.route("/api/inpaint", methods=["POST"])(
|
||||||
wrap_route(inpaint, context, pool=pool)
|
wrap_route(inpaint, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/upscale", methods=["POST"])(
|
app.route("/api/upscale", methods=["POST"])(
|
||||||
wrap_route(upscale, context, pool=pool)
|
wrap_route(upscale, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/chain", methods=["POST"])(
|
app.route("/api/chain", methods=["POST"])(
|
||||||
wrap_route(chain, context, pool=pool)
|
wrap_route(chain, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/blend", methods=["POST"])(
|
app.route("/api/blend", methods=["POST"])(
|
||||||
wrap_route(blend, context, pool=pool)
|
wrap_route(blend, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/cancel", methods=["PUT"])(
|
app.route("/api/cancel", methods=["PUT"])(
|
||||||
wrap_route(cancel, context, pool=pool)
|
wrap_route(cancel, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
|
app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
|
||||||
app.route("/api/status")(wrap_route(status, context, pool=pool)),
|
app.route("/api/status")(wrap_route(status, server, pool=pool)),
|
||||||
]
|
]
|
||||||
|
|
|
@ -118,13 +118,13 @@ def patch_not_impl():
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
|
def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
|
||||||
cache_path = cache_path_map.get(url, None)
|
cache_path = cache_path_map.get(url, None)
|
||||||
if cache_path is None:
|
if cache_path is None:
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
cache_path = path.basename(parsed.path)
|
cache_path = path.basename(parsed.path)
|
||||||
|
|
||||||
cache_path = path.join(ctx.cache_path, cache_path)
|
cache_path = path.join(server.cache_path, cache_path)
|
||||||
logger.debug("patching download path: %s -> %s", url, cache_path)
|
logger.debug("patching download path: %s -> %s", url, cache_path)
|
||||||
|
|
||||||
if path.exists(cache_path):
|
if path.exists(cache_path):
|
||||||
|
@ -133,27 +133,27 @@ def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
|
||||||
raise FileNotFoundError("missing cache file: %s" % (cache_path))
|
raise FileNotFoundError("missing cache file: %s" % (cache_path))
|
||||||
|
|
||||||
|
|
||||||
def apply_patch_basicsr(ctx: ServerContext):
|
def apply_patch_basicsr(server: ServerContext):
|
||||||
logger.debug("patching BasicSR module")
|
logger.debug("patching BasicSR module")
|
||||||
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)
|
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
|
||||||
|
|
||||||
|
|
||||||
def apply_patch_codeformer(ctx: ServerContext):
|
def apply_patch_codeformer(server: ServerContext):
|
||||||
logger.debug("patching CodeFormer module")
|
logger.debug("patching CodeFormer module")
|
||||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)
|
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
|
||||||
|
|
||||||
|
|
||||||
def apply_patch_facexlib(ctx: ServerContext):
|
def apply_patch_facexlib(server: ServerContext):
|
||||||
logger.debug("patching Facexlib module")
|
logger.debug("patching Facexlib module")
|
||||||
facexlib.utils.load_file_from_url = partial(patch_cache_path, ctx)
|
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
|
||||||
|
|
||||||
|
|
||||||
def apply_patches(ctx: ServerContext):
|
def apply_patches(server: ServerContext):
|
||||||
apply_patch_basicsr(ctx)
|
apply_patch_basicsr(server)
|
||||||
apply_patch_codeformer(ctx)
|
apply_patch_codeformer(server)
|
||||||
apply_patch_facexlib(ctx)
|
apply_patch_facexlib(server)
|
||||||
unload(
|
unload(
|
||||||
[
|
[
|
||||||
"basicsr.utils.download_util",
|
"basicsr.utils.download_util",
|
||||||
|
|
|
@ -115,7 +115,7 @@ def get_config_value(key: str, subkey: str = "default", default=None):
|
||||||
return config_params.get(key, {}).get(subkey, default)
|
return config_params.get(key, {}).get(subkey, default)
|
||||||
|
|
||||||
|
|
||||||
def load_extras(context: ServerContext):
|
def load_extras(server: ServerContext):
|
||||||
"""
|
"""
|
||||||
Load the extras file(s) and collect the relevant parts for the server: labels and strings
|
Load the extras file(s) and collect the relevant parts for the server: labels and strings
|
||||||
"""
|
"""
|
||||||
|
@ -127,7 +127,7 @@ def load_extras(context: ServerContext):
|
||||||
with open("./schemas/extras.yaml", "r") as f:
|
with open("./schemas/extras.yaml", "r") as f:
|
||||||
extra_schema = safe_load(f.read())
|
extra_schema = safe_load(f.read())
|
||||||
|
|
||||||
for file in context.extra_models:
|
for file in server.extra_models:
|
||||||
if file is not None and file != "":
|
if file is not None and file != "":
|
||||||
logger.info("loading extra models from %s", file)
|
logger.info("loading extra models from %s", file)
|
||||||
try:
|
try:
|
||||||
|
@ -209,11 +209,11 @@ IGNORE_EXTENSIONS = [".crdownload", ".lock", ".tmp"]
|
||||||
|
|
||||||
|
|
||||||
def list_model_globs(
|
def list_model_globs(
|
||||||
context: ServerContext, globs: List[str], base_path: Optional[str] = None
|
server: ServerContext, globs: List[str], base_path: Optional[str] = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
models = []
|
models = []
|
||||||
for pattern in globs:
|
for pattern in globs:
|
||||||
pattern_path = path.join(base_path or context.model_path, pattern)
|
pattern_path = path.join(base_path or server.model_path, pattern)
|
||||||
logger.debug("loading models from %s", pattern_path)
|
logger.debug("loading models from %s", pattern_path)
|
||||||
for name in glob(pattern_path):
|
for name in glob(pattern_path):
|
||||||
base = path.basename(name)
|
base = path.basename(name)
|
||||||
|
@ -226,7 +226,7 @@ def list_model_globs(
|
||||||
return unique_models
|
return unique_models
|
||||||
|
|
||||||
|
|
||||||
def load_models(context: ServerContext) -> None:
|
def load_models(server: ServerContext) -> None:
|
||||||
global correction_models
|
global correction_models
|
||||||
global diffusion_models
|
global diffusion_models
|
||||||
global network_models
|
global network_models
|
||||||
|
@ -234,7 +234,7 @@ def load_models(context: ServerContext) -> None:
|
||||||
|
|
||||||
# main categories
|
# main categories
|
||||||
diffusion_models = list_model_globs(
|
diffusion_models = list_model_globs(
|
||||||
context,
|
server,
|
||||||
[
|
[
|
||||||
"diffusion-*",
|
"diffusion-*",
|
||||||
"stable-diffusion-*",
|
"stable-diffusion-*",
|
||||||
|
@ -243,7 +243,7 @@ def load_models(context: ServerContext) -> None:
|
||||||
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
|
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
|
||||||
|
|
||||||
correction_models = list_model_globs(
|
correction_models = list_model_globs(
|
||||||
context,
|
server,
|
||||||
[
|
[
|
||||||
"correction-*",
|
"correction-*",
|
||||||
],
|
],
|
||||||
|
@ -251,7 +251,7 @@ def load_models(context: ServerContext) -> None:
|
||||||
logger.debug("loaded correction models from disk: %s", correction_models)
|
logger.debug("loaded correction models from disk: %s", correction_models)
|
||||||
|
|
||||||
upscaling_models = list_model_globs(
|
upscaling_models = list_model_globs(
|
||||||
context,
|
server,
|
||||||
[
|
[
|
||||||
"upscaling-*",
|
"upscaling-*",
|
||||||
],
|
],
|
||||||
|
@ -260,11 +260,11 @@ def load_models(context: ServerContext) -> None:
|
||||||
|
|
||||||
# additional networks
|
# additional networks
|
||||||
inversion_models = list_model_globs(
|
inversion_models = list_model_globs(
|
||||||
context,
|
server,
|
||||||
[
|
[
|
||||||
"*",
|
"*",
|
||||||
],
|
],
|
||||||
base_path=path.join(context.model_path, "inversion"),
|
base_path=path.join(server.model_path, "inversion"),
|
||||||
)
|
)
|
||||||
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
||||||
network_models.extend(
|
network_models.extend(
|
||||||
|
@ -272,35 +272,35 @@ def load_models(context: ServerContext) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_models = list_model_globs(
|
lora_models = list_model_globs(
|
||||||
context,
|
server,
|
||||||
[
|
[
|
||||||
"*",
|
"*",
|
||||||
],
|
],
|
||||||
base_path=path.join(context.model_path, "lora"),
|
base_path=path.join(server.model_path, "lora"),
|
||||||
)
|
)
|
||||||
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
||||||
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
||||||
|
|
||||||
|
|
||||||
def load_params(context: ServerContext) -> None:
|
def load_params(server: ServerContext) -> None:
|
||||||
global config_params
|
global config_params
|
||||||
|
|
||||||
params_file = path.join(context.params_path, "params.json")
|
params_file = path.join(server.params_path, "params.json")
|
||||||
logger.debug("loading server parameters from file: %s", params_file)
|
logger.debug("loading server parameters from file: %s", params_file)
|
||||||
|
|
||||||
with open(params_file, "r") as f:
|
with open(params_file, "r") as f:
|
||||||
config_params = yaml.safe_load(f)
|
config_params = yaml.safe_load(f)
|
||||||
|
|
||||||
if "platform" in config_params and context.default_platform is not None:
|
if "platform" in config_params and server.default_platform is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"overriding default platform from environment: %s",
|
"overriding default platform from environment: %s",
|
||||||
context.default_platform,
|
server.default_platform,
|
||||||
)
|
)
|
||||||
config_platform = config_params.get("platform", {})
|
config_platform = config_params.get("platform", {})
|
||||||
config_platform["default"] = context.default_platform
|
config_platform["default"] = server.default_platform
|
||||||
|
|
||||||
|
|
||||||
def load_platforms(context: ServerContext) -> None:
|
def load_platforms(server: ServerContext) -> None:
|
||||||
global available_platforms
|
global available_platforms
|
||||||
|
|
||||||
providers = list(get_available_providers())
|
providers = list(get_available_providers())
|
||||||
|
@ -309,7 +309,7 @@ def load_platforms(context: ServerContext) -> None:
|
||||||
for potential in platform_providers:
|
for potential in platform_providers:
|
||||||
if (
|
if (
|
||||||
platform_providers[potential] in providers
|
platform_providers[potential] in providers
|
||||||
and potential not in context.block_platforms
|
and potential not in server.block_platforms
|
||||||
):
|
):
|
||||||
if potential == "cuda":
|
if potential == "cuda":
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
|
@ -317,16 +317,16 @@ def load_platforms(context: ServerContext) -> None:
|
||||||
"device_id": i,
|
"device_id": i,
|
||||||
}
|
}
|
||||||
|
|
||||||
if context.memory_limit is not None:
|
if server.memory_limit is not None:
|
||||||
options["arena_extend_strategy"] = "kSameAsRequested"
|
options["arena_extend_strategy"] = "kSameAsRequested"
|
||||||
options["gpu_mem_limit"] = context.memory_limit
|
options["gpu_mem_limit"] = server.memory_limit
|
||||||
|
|
||||||
available_platforms.append(
|
available_platforms.append(
|
||||||
DeviceParams(
|
DeviceParams(
|
||||||
potential,
|
potential,
|
||||||
platform_providers[potential],
|
platform_providers[potential],
|
||||||
options,
|
options,
|
||||||
context.optimizations,
|
server.optimizations,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -335,18 +335,18 @@ def load_platforms(context: ServerContext) -> None:
|
||||||
potential,
|
potential,
|
||||||
platform_providers[potential],
|
platform_providers[potential],
|
||||||
None,
|
None,
|
||||||
context.optimizations,
|
server.optimizations,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if context.any_platform:
|
if server.any_platform:
|
||||||
# the platform should be ignored when the job is scheduled, but set to CPU just in case
|
# the platform should be ignored when the job is scheduled, but set to CPU just in case
|
||||||
available_platforms.append(
|
available_platforms.append(
|
||||||
DeviceParams(
|
DeviceParams(
|
||||||
"any",
|
"any",
|
||||||
platform_providers["cpu"],
|
platform_providers["cpu"],
|
||||||
None,
|
None,
|
||||||
context.optimizations,
|
server.optimizations,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request(
|
def pipeline_from_request(
|
||||||
context: ServerContext,
|
server: ServerContext,
|
||||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ def pipeline_from_request(
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||||
model_path = get_model_path(context, model)
|
model_path = get_model_path(server, model)
|
||||||
scheduler = get_from_list(
|
scheduler = get_from_list(
|
||||||
request.args, "scheduler", list(pipeline_schedulers.keys())
|
request.args, "scheduler", list(pipeline_schedulers.keys())
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,30 +7,30 @@ from .context import ServerContext
|
||||||
from .utils import wrap_route
|
from .utils import wrap_route
|
||||||
|
|
||||||
|
|
||||||
def serve_bundle_file(context: ServerContext, filename="index.html"):
|
def serve_bundle_file(server: ServerContext, filename="index.html"):
|
||||||
return send_from_directory(path.join("..", context.bundle_path), filename)
|
return send_from_directory(path.join("..", server.bundle_path), filename)
|
||||||
|
|
||||||
|
|
||||||
# non-API routes
|
# non-API routes
|
||||||
def index(context: ServerContext):
|
def index(server: ServerContext):
|
||||||
return serve_bundle_file(context)
|
return serve_bundle_file(server)
|
||||||
|
|
||||||
|
|
||||||
def index_path(context: ServerContext, filename: str):
|
def index_path(server: ServerContext, filename: str):
|
||||||
return serve_bundle_file(context, filename)
|
return serve_bundle_file(server, filename)
|
||||||
|
|
||||||
|
|
||||||
def output(context: ServerContext, filename: str):
|
def output(server: ServerContext, filename: str):
|
||||||
return send_from_directory(
|
return send_from_directory(
|
||||||
path.join("..", context.output_path), filename, as_attachment=False
|
path.join("..", server.output_path), filename, as_attachment=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_static_routes(
|
def register_static_routes(
|
||||||
app: Flask, context: ServerContext, _pool: DevicePoolExecutor
|
app: Flask, server: ServerContext, _pool: DevicePoolExecutor
|
||||||
):
|
):
|
||||||
return [
|
return [
|
||||||
app.route("/")(wrap_route(index, context)),
|
app.route("/")(wrap_route(index, server)),
|
||||||
app.route("/<path:filename>")(wrap_route(index_path, context)),
|
app.route("/<path:filename>")(wrap_route(index_path, server)),
|
||||||
app.route("/output/<path:filename>")(wrap_route(output, context)),
|
app.route("/output/<path:filename>")(wrap_route(output, server)),
|
||||||
]
|
]
|
||||||
|
|
|
@ -9,26 +9,26 @@ from ..worker.pool import DevicePoolExecutor
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
|
|
||||||
|
|
||||||
def check_paths(context: ServerContext) -> None:
|
def check_paths(server: ServerContext) -> None:
|
||||||
if not path.exists(context.model_path):
|
if not path.exists(server.model_path):
|
||||||
raise RuntimeError("model path must exist")
|
raise RuntimeError("model path must exist")
|
||||||
|
|
||||||
if not path.exists(context.output_path):
|
if not path.exists(server.output_path):
|
||||||
makedirs(context.output_path)
|
makedirs(server.output_path)
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(context: ServerContext, model: str):
|
def get_model_path(server: ServerContext, model: str):
|
||||||
return base_join(context.model_path, model)
|
return base_join(server.model_path, model)
|
||||||
|
|
||||||
|
|
||||||
def register_routes(
|
def register_routes(
|
||||||
app: Flask,
|
app: Flask,
|
||||||
context: ServerContext,
|
server: ServerContext,
|
||||||
pool: DevicePoolExecutor,
|
pool: DevicePoolExecutor,
|
||||||
routes: List[Tuple[str, Dict, Callable]],
|
routes: List[Tuple[str, Dict, Callable]],
|
||||||
):
|
):
|
||||||
for route, kwargs, method in routes:
|
for route, kwargs, method in routes:
|
||||||
app.route(route, **kwargs)(wrap_route(method, context, pool=pool))
|
app.route(route, **kwargs)(wrap_route(method, server, pool=pool))
|
||||||
|
|
||||||
|
|
||||||
def wrap_route(func, *args, **kwargs):
|
def wrap_route(func, *args, **kwargs):
|
||||||
|
|
|
@ -26,51 +26,51 @@ MEMORY_ERRORS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def worker_main(context: WorkerContext, server: ServerContext):
|
def worker_main(worker: WorkerContext, server: ServerContext):
|
||||||
apply_patches(server)
|
apply_patches(server)
|
||||||
setproctitle("onnx-web worker: %s" % (context.device.device))
|
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"checking in from worker with providers: %s", get_available_providers()
|
"checking in from worker with providers: %s", get_available_providers()
|
||||||
)
|
)
|
||||||
|
|
||||||
# make leaking workers easier to recycle
|
# make leaking workers easier to recycle
|
||||||
context.progress.cancel_join_thread()
|
worker.progress.cancel_join_thread()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if not context.is_active():
|
if not worker.is_active():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"worker %s has been replaced by %s, exiting",
|
"worker %s has been replaced by %s, exiting",
|
||||||
getpid(),
|
getpid(),
|
||||||
context.get_active(),
|
worker.get_active(),
|
||||||
)
|
)
|
||||||
exit(EXIT_REPLACED)
|
exit(EXIT_REPLACED)
|
||||||
|
|
||||||
# wait briefly for the next job
|
# wait briefly for the next job
|
||||||
job = context.pending.get(timeout=1.0)
|
job = worker.pending.get(timeout=1.0)
|
||||||
logger.info("worker %s got job: %s", context.device.device, job.name)
|
logger.info("worker %s got job: %s", worker.device.device, job.name)
|
||||||
|
|
||||||
# clear flags and save the job name
|
# clear flags and save the job name
|
||||||
context.start(job.name)
|
worker.start(job.name)
|
||||||
logger.info("starting job: %s", job.name)
|
logger.info("starting job: %s", job.name)
|
||||||
|
|
||||||
# reset progress, which does a final check for cancellation
|
# reset progress, which does a final check for cancellation
|
||||||
context.set_progress(0)
|
worker.set_progress(0)
|
||||||
job.fn(context, *job.args, **job.kwargs)
|
job.fn(worker, *job.args, **job.kwargs)
|
||||||
|
|
||||||
# confirm completion of the job
|
# confirm completion of the job
|
||||||
logger.info("job succeeded: %s", job.name)
|
logger.info("job succeeded: %s", job.name)
|
||||||
context.finish()
|
worker.finish()
|
||||||
except Empty:
|
except Empty:
|
||||||
pass
|
pass
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("worker got keyboard interrupt")
|
logger.info("worker got keyboard interrupt")
|
||||||
context.fail()
|
worker.fail()
|
||||||
exit(EXIT_INTERRUPT)
|
exit(EXIT_INTERRUPT)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.exception("value error in worker, exiting: %s")
|
logger.exception("value error in worker, exiting: %s")
|
||||||
context.fail()
|
worker.fail()
|
||||||
exit(EXIT_ERROR)
|
exit(EXIT_ERROR)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e_str = str(e)
|
e_str = str(e)
|
||||||
|
@ -78,11 +78,11 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
||||||
for e_mem in MEMORY_ERRORS:
|
for e_mem in MEMORY_ERRORS:
|
||||||
if e_mem in e_str:
|
if e_mem in e_str:
|
||||||
logger.error("detected out-of-memory error, exiting: %s", e)
|
logger.error("detected out-of-memory error, exiting: %s", e)
|
||||||
context.fail()
|
worker.fail()
|
||||||
exit(EXIT_MEMORY)
|
exit(EXIT_MEMORY)
|
||||||
|
|
||||||
# carry on for other errors
|
# carry on for other errors
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"unrecognized error while running job",
|
"unrecognized error while running job",
|
||||||
)
|
)
|
||||||
context.fail()
|
worker.fail()
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestHashValue(unittest.TestCase):
|
||||||
|
def test_hash_value(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestJSONParams(unittest.TestCase):
|
||||||
|
def test_json_params(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestMakeOutputName(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestSaveImage(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestSaveParams(unittest.TestCase):
|
||||||
|
pass
|
|
@ -0,0 +1,101 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.params import Border, Size
|
||||||
|
|
||||||
|
class BorderTests(unittest.TestCase):
|
||||||
|
def test_json(self):
|
||||||
|
border = Border.even(0)
|
||||||
|
json = border.tojson()
|
||||||
|
|
||||||
|
self.assertIn("left", json)
|
||||||
|
self.assertIn("right", json)
|
||||||
|
self.assertIn("top", json)
|
||||||
|
self.assertIn("bottom", json)
|
||||||
|
|
||||||
|
def test_str(self):
|
||||||
|
border = Border.even(10)
|
||||||
|
bstr = str(border)
|
||||||
|
|
||||||
|
self.assertEqual("(10, 10, 10, 10)", bstr)
|
||||||
|
|
||||||
|
def test_uneven(self):
|
||||||
|
border = Border(1, 2, 3, 4)
|
||||||
|
|
||||||
|
self.assertEqual("(1, 2, 3, 4)", str(border))
|
||||||
|
|
||||||
|
def test_args(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SizeTests(unittest.TestCase):
|
||||||
|
def test_iter(self):
|
||||||
|
size = Size(1, 2)
|
||||||
|
|
||||||
|
self.assertEqual(list(size), [1, 2])
|
||||||
|
|
||||||
|
def test_str(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_border(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_json(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_args(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceParamsTests(unittest.TestCase):
|
||||||
|
def test_str(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_provider(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_options_optimizations(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_options_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_torch_cuda(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_torch_rocm(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImageParamsTests(unittest.TestCase):
|
||||||
|
def test_json(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_args(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StageParamsTests(unittest.TestCase):
|
||||||
|
def test_init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleParamsTests(unittest.TestCase):
|
||||||
|
def test_rescale(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_resize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_json(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_args(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HighresParamsTests(unittest.TestCase):
|
||||||
|
def test_resize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_json(self):
|
||||||
|
pass
|
Loading…
Reference in New Issue