1
0
Fork 0

lint(api): name context params consistently (#278)

This commit is contained in:
Sean Sube 2023-04-09 20:33:03 -05:00
parent fea9185707
commit 9698e29268
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
19 changed files with 363 additions and 238 deletions

View File

@ -23,7 +23,7 @@ class StageCallback(Protocol):
def __call__(
self,
job: WorkerContext,
ctx: ServerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,

View File

@ -140,7 +140,7 @@ base_models: Models = {
def fetch_model(
ctx: ConversionContext,
conversion: ConversionContext,
name: str,
source: str,
dest: Optional[str] = None,
@ -148,7 +148,7 @@ def fetch_model(
hf_hub_fetch: bool = False,
hf_hub_filename: Optional[str] = None,
) -> str:
cache_path = dest or ctx.cache_path
cache_path = dest or conversion.cache_path
cache_name = path.join(cache_path, name)
# add an extension if possible, some of the conversion code checks for it
@ -201,7 +201,7 @@ def fetch_model(
return source
def convert_models(ctx: ConversionContext, args, models: Models):
def convert_models(conversion: ConversionContext, args, models: Models):
if args.sources and "sources" in models:
for model in models.get("sources"):
model = tuple_to_source(model)
@ -214,7 +214,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
source = model["source"]
try:
dest = fetch_model(ctx, name, source, format=model_format)
dest = fetch_model(conversion, name, source, format=model_format)
logger.info("finished downloading source: %s -> %s", source, dest)
except Exception:
logger.exception("error fetching source %s", name)
@ -234,20 +234,20 @@ def convert_models(ctx: ConversionContext, args, models: Models):
try:
if network_type == "inversion" and network_model == "concept":
dest = fetch_model(
ctx,
conversion,
name,
source,
dest=path.join(ctx.model_path, network_type),
dest=path.join(conversion.model_path, network_type),
format=network_format,
hf_hub_fetch=True,
hf_hub_filename="learned_embeds.bin",
)
else:
dest = fetch_model(
ctx,
conversion,
name,
source,
dest=path.join(ctx.model_path, network_type),
dest=path.join(conversion.model_path, network_type),
format=network_format,
)
@ -267,19 +267,19 @@ def convert_models(ctx: ConversionContext, args, models: Models):
try:
source = fetch_model(
ctx, name, model["source"], format=model_format
conversion, name, model["source"], format=model_format
)
converted = False
if model_format in model_formats_original:
converted, dest = convert_diffusion_original(
ctx,
conversion,
model,
source,
)
else:
converted, dest = convert_diffusion_diffusers(
ctx,
conversion,
model,
source,
)
@ -289,8 +289,8 @@ def convert_models(ctx: ConversionContext, args, models: Models):
# keep track of which models have been blended
blend_models = {}
inversion_dest = path.join(ctx.model_path, "inversion")
lora_dest = path.join(ctx.model_path, "lora")
inversion_dest = path.join(conversion.model_path, "inversion")
lora_dest = path.join(conversion.model_path, "lora")
for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
@ -314,7 +314,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
inversion_source = inversion["source"]
inversion_format = inversion.get("format", None)
inversion_source = fetch_model(
ctx,
conversion,
inversion_name,
inversion_source,
dest=inversion_dest,
@ -323,7 +323,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
inversion_weight = inversion.get("weight", 1.0)
blend_textual_inversions(
ctx,
conversion,
blend_models["text_encoder"],
blend_models["tokenizer"],
[
@ -355,7 +355,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
lora_name = lora["name"]
lora_source = lora["source"]
lora_source = fetch_model(
ctx,
conversion,
f"{name}-lora-{lora_name}",
lora_source,
dest=lora_dest,
@ -363,14 +363,14 @@ def convert_models(ctx: ConversionContext, args, models: Models):
lora_weight = lora.get("weight", 1.0)
blend_loras(
ctx,
conversion,
blend_models["text_encoder"],
[(lora_source, lora_weight)],
"text_encoder",
)
blend_loras(
ctx,
conversion,
blend_models["unet"],
[(lora_source, lora_weight)],
"unet",
@ -413,9 +413,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
try:
source = fetch_model(
ctx, name, model["source"], format=model_format
conversion, name, model["source"], format=model_format
)
convert_upscale_resrgan(ctx, model, source)
convert_upscale_resrgan(conversion, model, source)
except Exception:
logger.exception(
"error converting upscaling model %s",
@ -433,9 +433,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
model_format = source_format(model)
try:
source = fetch_model(
ctx, name, model["source"], format=model_format
conversion, name, model["source"], format=model_format
)
convert_correction_gfpgan(ctx, model, source)
convert_correction_gfpgan(conversion, model, source)
except Exception:
logger.exception(
"error converting correction model %s",
@ -482,21 +482,21 @@ def main() -> int:
args = parser.parse_args()
logger.info("CLI arguments: %s", args)
ctx = ConversionContext.from_environ()
ctx.half = args.half or "onnx-fp16" in ctx.optimizations
ctx.opset = args.opset
ctx.token = args.token
logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device)
server = ConversionContext.from_environ()
server.half = args.half or "onnx-fp16" in server.optimizations
server.opset = args.opset
server.token = args.token
logger.info("converting models in %s using %s", server.model_path, server.training_device)
if not path.exists(ctx.model_path):
logger.info("model path does not existing, creating: %s", ctx.model_path)
makedirs(ctx.model_path)
if not path.exists(server.model_path):
logger.info("model path does not existing, creating: %s", server.model_path)
makedirs(server.model_path)
logger.info("converting base models")
convert_models(ctx, args, base_models)
convert_models(server, args, base_models)
extras = []
extras.extend(ctx.extra_models)
extras.extend(server.extra_models)
extras.extend(args.extras)
extras = list(set(extras))
extras.sort()
@ -516,7 +516,7 @@ def main() -> int:
try:
validate(data, extra_schema)
logger.info("converting extra models")
convert_models(ctx, args, data)
convert_models(server, args, data)
except ValidationError:
logger.exception("invalid data in extras file")
except Exception:

View File

@ -12,7 +12,7 @@ logger = getLogger(__name__)
@torch.no_grad()
def convert_correction_gfpgan(
ctx: ConversionContext,
conversion: ConversionContext,
model: ModelDict,
source: str,
):
@ -20,7 +20,7 @@ def convert_correction_gfpgan(
source = source or model.get("source")
scale = model.get("scale")
dest = path.join(ctx.model_path, name + ".onnx")
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest)
if path.isfile(dest):
@ -37,17 +37,17 @@ def convert_correction_gfpgan(
scale=scale,
)
torch_model = torch.load(source, map_location=ctx.map_location)
torch_model = torch.load(source, map_location=conversion.map_location)
# TODO: make sure strict=False is safe here
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
else:
model.load_state_dict(torch_model["params"], strict=False)
model.to(ctx.training_device).train(False)
model.to(conversion.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
@ -63,7 +63,7 @@ def convert_correction_gfpgan(
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=ctx.opset,
opset_version=conversion.opset,
export_params=True,
)
logger.info("GFPGAN exported to ONNX successfully")

View File

@ -90,7 +90,7 @@ def onnx_export(
@torch.no_grad()
def convert_diffusion_diffusers(
ctx: ConversionContext,
conversion: ConversionContext,
model: Dict,
source: str,
) -> Tuple[bool, str]:
@ -102,10 +102,10 @@ def convert_diffusion_diffusers(
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
dtype = ctx.torch_dtype()
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
dest_path = path.join(ctx.model_path, name)
dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
# diffusers go into a directory rather than .onnx file
@ -123,11 +123,11 @@ def convert_diffusion_diffusers(
pipeline = StableDiffusionPipeline.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=ctx.token,
).to(ctx.training_device)
use_auth_token=conversion.token,
).to(conversion.training_device)
output_path = Path(dest_path)
optimize_pipeline(ctx, pipeline)
optimize_pipeline(conversion, pipeline)
# TEXT ENCODER
num_tokens = pipeline.text_encoder.config.max_position_embeddings
@ -143,11 +143,11 @@ def convert_diffusion_diffusers(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32),
text_input.input_ids.to(device=conversion.training_device, dtype=torch.int32),
None, # attention mask
None, # position ids
None, # output attentions
torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool),
torch.tensor(True).to(device=conversion.training_device, dtype=torch.bool),
),
output_path=output_path / "text_encoder" / ONNX_MODEL,
ordered_input_names=["input_ids"],
@ -155,8 +155,8 @@ def convert_diffusion_diffusers(
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=ctx.opset,
half=ctx.half,
opset=conversion.opset,
half=conversion.half,
)
del pipeline.text_encoder
@ -165,11 +165,11 @@ def convert_diffusion_diffusers(
# UNET
if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.long)
unet_scale = torch.tensor(4).to(device=conversion.training_device, dtype=torch.long)
else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(
device=ctx.training_device, dtype=torch.bool
device=conversion.training_device, dtype=torch.bool
)
if is_torch_2_0:
@ -182,11 +182,11 @@ def convert_diffusion_diffusers(
pipeline.unet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=ctx.training_device, dtype=dtype
device=conversion.training_device, dtype=dtype
),
torch.randn(2).to(device=ctx.training_device, dtype=dtype),
torch.randn(2).to(device=conversion.training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(
device=ctx.training_device, dtype=dtype
device=conversion.training_device, dtype=dtype
),
unet_scale,
),
@ -199,8 +199,8 @@ def convert_diffusion_diffusers(
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
},
opset=ctx.opset,
half=ctx.half,
opset=conversion.opset,
half=conversion.half,
external_data=True,
)
unet_model_path = str(unet_path.absolute().as_posix())
@ -238,7 +238,7 @@ def convert_diffusion_diffusers(
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=ctx.training_device, dtype=dtype),
).to(device=conversion.training_device, dtype=dtype),
False,
),
output_path=output_path / "vae" / ONNX_MODEL,
@ -247,8 +247,8 @@ def convert_diffusion_diffusers(
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=ctx.opset,
half=ctx.half,
opset=conversion.opset,
half=conversion.half,
)
else:
# VAE ENCODER
@ -263,7 +263,7 @@ def convert_diffusion_diffusers(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=ctx.training_device, dtype=dtype
device=conversion.training_device, dtype=dtype
),
False,
),
@ -273,7 +273,7 @@ def convert_diffusion_diffusers(
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=ctx.opset,
opset=conversion.opset,
half=False, # https://github.com/ssube/onnx-web/issues/290
)
@ -287,7 +287,7 @@ def convert_diffusion_diffusers(
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=ctx.training_device, dtype=dtype),
).to(device=conversion.training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / ONNX_MODEL,
@ -296,8 +296,8 @@ def convert_diffusion_diffusers(
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=ctx.opset,
half=ctx.half,
opset=conversion.opset,
half=conversion.half,
)
del pipeline.vae

View File

@ -55,7 +55,7 @@ def fix_node_name(key: str):
def blend_loras(
_context: ServerContext,
_conversion: ServerContext,
base_name: Union[str, ModelProto],
loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"],

View File

@ -146,7 +146,7 @@ class TrainingConfig:
def __init__(
self,
ctx: ConversionContext,
conversion: ConversionContext,
model_name: str = "",
scheduler: str = "ddim",
v2: bool = False,
@ -155,7 +155,7 @@ class TrainingConfig:
**kwargs,
):
model_name = sanitize_name(model_name)
model_dir = os.path.join(ctx.cache_path, model_name)
model_dir = os.path.join(conversion.cache_path, model_name)
working_dir = os.path.join(model_dir, "working")
if not os.path.exists(working_dir):
@ -1298,7 +1298,7 @@ def download_model(db_config: TrainingConfig, token):
def get_config_path(
context: ConversionContext,
conversion: ConversionContext,
model_version: str = "v1",
train_type: str = "default",
config_base_name: str = "training",
@ -1309,7 +1309,7 @@ def get_config_path(
)
parts = os.path.join(
context.model_path,
conversion.model_path,
"configs",
f"{model_version}-{config_base_name}-{train_type}.yaml",
)
@ -1317,7 +1317,7 @@ def get_config_path(
def get_config_file(
context: ConversionContext,
conversion: ConversionContext,
train_unfrozen=False,
v2=False,
prediction_type="epsilon",
@ -1343,7 +1343,7 @@ def get_config_file(
model_train_type = train_types["default"]
return get_config_path(
context,
conversion,
model_version_name,
model_train_type,
config_base_name,
@ -1352,7 +1352,7 @@ def get_config_file(
def extract_checkpoint(
context: ConversionContext,
conversion: ConversionContext,
new_model_name: str,
checkpoint_file: str,
scheduler_type="ddim",
@ -1396,7 +1396,7 @@ def extract_checkpoint(
# Create empty config
db_config = TrainingConfig(
context,
conversion,
model_name=new_model_name,
scheduler=scheduler_type,
src=checkpoint_file,
@ -1442,7 +1442,7 @@ def extract_checkpoint(
prediction_type = "epsilon"
original_config_file = get_config_file(
context, train_unfrozen, v2, prediction_type, config_file=config_file
conversion, train_unfrozen, v2, prediction_type, config_file=config_file
)
logger.info(
@ -1533,7 +1533,7 @@ def extract_checkpoint(
checkpoint, vae_config
)
else:
vae_file = os.path.join(context.model_path, vae_file)
vae_file = os.path.join(conversion.model_path, vae_file)
logger.debug("loading custom VAE: %s", vae_file)
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
@ -1658,14 +1658,14 @@ def extract_checkpoint(
@torch.no_grad()
def convert_diffusion_original(
ctx: ConversionContext,
conversion: ConversionContext,
model: ModelDict,
source: str,
) -> Tuple[bool, str]:
name = model["name"]
source = source or model["source"]
dest_path = os.path.join(ctx.model_path, name)
dest_path = os.path.join(conversion.model_path, name)
dest_index = os.path.join(dest_path, "model_index.json")
logger.info(
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest_path
@ -1676,8 +1676,8 @@ def convert_diffusion_original(
return (False, dest_path)
torch_name = name + "-torch"
torch_path = os.path.join(ctx.cache_path, torch_name)
working_name = os.path.join(ctx.cache_path, torch_name, "working")
torch_path = os.path.join(conversion.cache_path, torch_name)
working_name = os.path.join(conversion.cache_path, torch_name, "working")
model_index = os.path.join(working_name, "model_index.json")
if os.path.exists(torch_path) and os.path.exists(model_index):
@ -1689,7 +1689,7 @@ def convert_diffusion_original(
torch_path,
)
if extract_checkpoint(
ctx,
conversion,
torch_name,
source,
config_file=model.get("config"),
@ -1704,9 +1704,9 @@ def convert_diffusion_original(
if "vae" in model:
del model["vae"]
result = convert_diffusion_diffusers(ctx, model, working_name)
result = convert_diffusion_diffusers(conversion, model, working_name)
if "torch" in ctx.prune:
if "torch" in conversion.prune:
logger.info("removing intermediate Torch models: %s", torch_path)
shutil.rmtree(torch_path)

View File

@ -16,7 +16,7 @@ logger = getLogger(__name__)
@torch.no_grad()
def blend_textual_inversions(
context: ServerContext,
server: ServerContext,
text_encoder: ModelProto,
tokenizer: CLIPTokenizer,
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
@ -161,7 +161,7 @@ def blend_textual_inversions(
@torch.no_grad()
def convert_diffusion_textual_inversion(
context: ConversionContext,
conversion: ConversionContext,
name: str,
base_model: str,
inversion: str,
@ -169,7 +169,7 @@ def convert_diffusion_textual_inversion(
base_token: Optional[str] = None,
inversion_weight: Optional[float] = 1.0,
):
dest_path = path.join(context.model_path, f"inversion-{name}")
dest_path = path.join(conversion.model_path, f"inversion-{name}")
logger.info(
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
)
@ -194,7 +194,7 @@ def convert_diffusion_textual_inversion(
subfolder="tokenizer",
)
text_encoder, tokenizer = blend_textual_inversions(
context,
conversion,
text_encoder,
tokenizer,
[(inversion, inversion_weight, base_token, inversion_format)],

View File

@ -13,7 +13,7 @@ TAG_X4_V3 = "real-esrgan-x4-v3"
@torch.no_grad()
def convert_upscale_resrgan(
ctx: ConversionContext,
conversion: ConversionContext,
model: ModelDict,
source: str,
):
@ -24,7 +24,7 @@ def convert_upscale_resrgan(
source = source or model.get("source")
scale = model.get("scale")
dest = path.join(ctx.model_path, name + ".onnx")
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
if path.isfile(dest):
@ -53,16 +53,16 @@ def convert_upscale_resrgan(
scale=scale,
)
torch_model = torch.load(source, map_location=ctx.map_location)
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
else:
model.load_state_dict(torch_model["params"], strict=False)
model.to(ctx.training_device).train(False)
model.to(conversion.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
@ -78,7 +78,7 @@ def convert_upscale_resrgan(
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=ctx.opset,
opset_version=conversion.opset,
export_params=True,
)
logger.info("real ESRGAN exported to ONNX successfully")

View File

@ -64,7 +64,7 @@ def json_params(
def make_output_name(
ctx: ServerContext,
server: ServerContext,
mode: str,
params: ImageParams,
size: Size,
@ -92,20 +92,20 @@ def make_output_name(
hash_value(sha, param)
return [
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{ctx.image_format}"
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
for i in range(params.batch)
]
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
path = base_join(ctx.output_path, output)
image.save(path, format=ctx.image_format)
def save_image(server: ServerContext, output: str, image: Image.Image) -> str:
path = base_join(server.output_path, output)
image.save(path, format=server.image_format)
logger.debug("saved output image to: %s", path)
return path
def save_params(
ctx: ServerContext,
server: ServerContext,
output: str,
params: ImageParams,
size: Size,
@ -113,7 +113,7 @@ def save_params(
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
) -> str:
path = base_join(ctx.output_path, f"{output}.json")
path = base_join(server.output_path, f"{output}.json")
json = json_params(
output, params, size, upscale=upscale, border=border, highres=highres
)

View File

@ -7,6 +7,10 @@ from .torch_before_ort import GraphOptimizationLevel, SessionOptions
logger = getLogger(__name__)
Param = Union[str, int, float]
Point = Tuple[int, int]
class SizeChart(IntEnum):
mini = 128 # small tile for very expensive models
half = 256 # half tile for outpainting
@ -25,10 +29,6 @@ class TileOrder:
spiral = "spiral"
Param = Union[str, int, float]
Point = Tuple[int, int]
class Border:
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
self.left = left
@ -37,7 +37,7 @@ class Border:
self.bottom = bottom
def __str__(self) -> str:
return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom)
return "(%s, %s, %s, %s)" % (self.left, self.top, self.right, self.bottom)
def tojson(self):
return {
@ -145,6 +145,7 @@ class DeviceParams:
return sess
def torch_str(self) -> str:
# TODO: return cuda devices for ROCm as well
if self.device.startswith("cuda"):
return self.device
else:

View File

@ -93,7 +93,7 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options)
def introspect(context: ServerContext, app: Flask):
def introspect(server: ServerContext, app: Flask):
return {
"name": "onnx-web",
"routes": [
@ -103,15 +103,15 @@ def introspect(context: ServerContext, app: Flask):
}
def list_extra_strings(context: ServerContext):
def list_extra_strings(server: ServerContext):
return jsonify(get_extra_strings())
def list_mask_filters(context: ServerContext):
def list_mask_filters(server: ServerContext):
return jsonify(list(get_mask_filters().keys()))
def list_models(context: ServerContext):
def list_models(server: ServerContext):
return jsonify(
{
"correction": get_correction_models(),
@ -122,30 +122,30 @@ def list_models(context: ServerContext):
)
def list_noise_sources(context: ServerContext):
def list_noise_sources(server: ServerContext):
return jsonify(list(get_noise_sources().keys()))
def list_params(context: ServerContext):
def list_params(server: ServerContext):
return jsonify(get_config_params())
def list_platforms(context: ServerContext):
def list_platforms(server: ServerContext):
return jsonify([p.device for p in get_available_platforms()])
def list_schedulers(context: ServerContext):
def list_schedulers(server: ServerContext):
return jsonify(list(get_pipeline_schedulers().keys()))
def img2img(context: ServerContext, pool: DevicePoolExecutor):
def img2img(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context)
device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
strength = get_and_clamp_float(
@ -156,7 +156,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
get_config_value("strength", "min"),
)
output = make_output_name(context, "img2img", params, size, extras=[strength])
output = make_output_name(server, "img2img", params, size, extras=[strength])
job_name = output[0]
logger.info("img2img job queued for: %s", job_name)
@ -164,7 +164,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
pool.submit(
job_name,
run_img2img_pipeline,
context,
server,
params,
output,
upscale,
@ -176,19 +176,19 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale))
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context)
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
highres = highres_from_request()
output = make_output_name(context, "txt2img", params, size)
output = make_output_name(server, "txt2img", params, size)
job_name = output[0]
logger.info("txt2img job queued for: %s", job_name)
pool.submit(
job_name,
run_txt2img_pipeline,
context,
server,
params,
size,
output,
@ -200,7 +200,7 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required")
@ -212,7 +212,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context)
device, params, size = pipeline_from_request(server)
expand = border_from_request()
upscale = upscale_from_request()
@ -224,7 +224,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
)
output = make_output_name(
context,
server,
"inpaint",
params,
size,
@ -247,7 +247,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
pool.submit(
job_name,
run_inpaint_pipeline,
context,
server,
params,
size,
output,
@ -265,17 +265,17 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
def upscale(context: ServerContext, pool: DevicePoolExecutor):
def upscale(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context)
device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size)
output = make_output_name(server, "upscale", params, size)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
@ -283,7 +283,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor):
pool.submit(
job_name,
run_upscale_pipeline,
context,
server,
params,
size,
output,
@ -295,7 +295,7 @@ def upscale(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale))
def chain(context: ServerContext, pool: DevicePoolExecutor):
def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.debug(
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
@ -311,8 +311,8 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
validate(data, schema)
# get defaults from the regular parameters
device, params, size = pipeline_from_request(context)
output = make_output_name(context, "chain", params, size)
device, params, size = pipeline_from_request(server)
output = make_output_name(server, "chain", params, size)
job_name = output[0]
pipeline = ChainPipeline()
@ -371,7 +371,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
pool.submit(
job_name,
pipeline,
context,
server,
params,
empty_source,
output=output[0],
@ -382,7 +382,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size))
def blend(context: ServerContext, pool: DevicePoolExecutor):
def blend(server: ServerContext, pool: DevicePoolExecutor):
mask_file = request.files.get("mask")
if mask_file is None:
return error_reply("mask image is required")
@ -402,17 +402,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
source = valid_image(source, mask.size, mask.size)
sources.append(source)
device, params, size = pipeline_from_request(context)
device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size)
output = make_output_name(server, "upscale", params, size)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
pool.submit(
job_name,
run_blend_pipeline,
context,
server,
params,
size,
output,
@ -425,17 +425,17 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale))
def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context)
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server)
output = make_output_name(context, "txt2txt", params, size)
output = make_output_name(server, "txt2txt", params, size)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
pool.submit(
job_name,
run_txt2txt_pipeline,
context,
server,
params,
size,
output,
@ -445,7 +445,7 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size))
def cancel(context: ServerContext, pool: DevicePoolExecutor):
def cancel(server: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None)
if output_file is None:
return error_reply("output name is required")
@ -456,7 +456,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
return ready_reply(cancelled=cancelled)
def ready(context: ServerContext, pool: DevicePoolExecutor):
def ready(server: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None)
if output_file is None:
return error_reply("output name is required")
@ -468,7 +468,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
return ready_reply(pending=True)
if progress is None:
output = base_join(context.output_path, output_file)
output = base_join(server.output_path, output_file)
if path.exists(output):
return ready_reply(ready=True)
else:
@ -485,44 +485,44 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
)
def status(context: ServerContext, pool: DevicePoolExecutor):
def status(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(pool.status())
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api")(wrap_route(introspect, context, app=app)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
app.route("/api/settings/models")(wrap_route(list_models, context)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
app.route("/api/settings/params")(wrap_route(list_params, context)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
app.route("/api/settings/strings")(wrap_route(list_extra_strings, context)),
app.route("/api")(wrap_route(introspect, server, app=app)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
app.route("/api/settings/models")(wrap_route(list_models, server)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)),
app.route("/api/settings/params")(wrap_route(list_params, server)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, server)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, context, pool=pool)
wrap_route(img2img, server, pool=pool)
),
app.route("/api/txt2img", methods=["POST"])(
wrap_route(txt2img, context, pool=pool)
wrap_route(txt2img, server, pool=pool)
),
app.route("/api/txt2txt", methods=["POST"])(
wrap_route(txt2txt, context, pool=pool)
wrap_route(txt2txt, server, pool=pool)
),
app.route("/api/inpaint", methods=["POST"])(
wrap_route(inpaint, context, pool=pool)
wrap_route(inpaint, server, pool=pool)
),
app.route("/api/upscale", methods=["POST"])(
wrap_route(upscale, context, pool=pool)
wrap_route(upscale, server, pool=pool)
),
app.route("/api/chain", methods=["POST"])(
wrap_route(chain, context, pool=pool)
wrap_route(chain, server, pool=pool)
),
app.route("/api/blend", methods=["POST"])(
wrap_route(blend, context, pool=pool)
wrap_route(blend, server, pool=pool)
),
app.route("/api/cancel", methods=["PUT"])(
wrap_route(cancel, context, pool=pool)
wrap_route(cancel, server, pool=pool)
),
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
app.route("/api/status")(wrap_route(status, context, pool=pool)),
app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
app.route("/api/status")(wrap_route(status, server, pool=pool)),
]

View File

@ -118,13 +118,13 @@ def patch_not_impl():
raise NotImplementedError()
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
cache_path = cache_path_map.get(url, None)
if cache_path is None:
parsed = urlparse(url)
cache_path = path.basename(parsed.path)
cache_path = path.join(ctx.cache_path, cache_path)
cache_path = path.join(server.cache_path, cache_path)
logger.debug("patching download path: %s -> %s", url, cache_path)
if path.exists(cache_path):
@ -133,27 +133,27 @@ def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
raise FileNotFoundError("missing cache file: %s" % (cache_path))
def apply_patch_basicsr(ctx: ServerContext):
def apply_patch_basicsr(server: ServerContext):
logger.debug("patching BasicSR module")
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
def apply_patch_codeformer(ctx: ServerContext):
def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module")
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
def apply_patch_facexlib(ctx: ServerContext):
def apply_patch_facexlib(server: ServerContext):
logger.debug("patching Facexlib module")
facexlib.utils.load_file_from_url = partial(patch_cache_path, ctx)
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
def apply_patches(ctx: ServerContext):
apply_patch_basicsr(ctx)
apply_patch_codeformer(ctx)
apply_patch_facexlib(ctx)
def apply_patches(server: ServerContext):
apply_patch_basicsr(server)
apply_patch_codeformer(server)
apply_patch_facexlib(server)
unload(
[
"basicsr.utils.download_util",

View File

@ -115,7 +115,7 @@ def get_config_value(key: str, subkey: str = "default", default=None):
return config_params.get(key, {}).get(subkey, default)
def load_extras(context: ServerContext):
def load_extras(server: ServerContext):
"""
Load the extras file(s) and collect the relevant parts for the server: labels and strings
"""
@ -127,7 +127,7 @@ def load_extras(context: ServerContext):
with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read())
for file in context.extra_models:
for file in server.extra_models:
if file is not None and file != "":
logger.info("loading extra models from %s", file)
try:
@ -209,11 +209,11 @@ IGNORE_EXTENSIONS = [".crdownload", ".lock", ".tmp"]
def list_model_globs(
context: ServerContext, globs: List[str], base_path: Optional[str] = None
server: ServerContext, globs: List[str], base_path: Optional[str] = None
) -> List[str]:
models = []
for pattern in globs:
pattern_path = path.join(base_path or context.model_path, pattern)
pattern_path = path.join(base_path or server.model_path, pattern)
logger.debug("loading models from %s", pattern_path)
for name in glob(pattern_path):
base = path.basename(name)
@ -226,7 +226,7 @@ def list_model_globs(
return unique_models
def load_models(context: ServerContext) -> None:
def load_models(server: ServerContext) -> None:
global correction_models
global diffusion_models
global network_models
@ -234,7 +234,7 @@ def load_models(context: ServerContext) -> None:
# main categories
diffusion_models = list_model_globs(
context,
server,
[
"diffusion-*",
"stable-diffusion-*",
@ -243,7 +243,7 @@ def load_models(context: ServerContext) -> None:
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
correction_models = list_model_globs(
context,
server,
[
"correction-*",
],
@ -251,7 +251,7 @@ def load_models(context: ServerContext) -> None:
logger.debug("loaded correction models from disk: %s", correction_models)
upscaling_models = list_model_globs(
context,
server,
[
"upscaling-*",
],
@ -260,11 +260,11 @@ def load_models(context: ServerContext) -> None:
# additional networks
inversion_models = list_model_globs(
context,
server,
[
"*",
],
base_path=path.join(context.model_path, "inversion"),
base_path=path.join(server.model_path, "inversion"),
)
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
network_models.extend(
@ -272,35 +272,35 @@ def load_models(context: ServerContext) -> None:
)
lora_models = list_model_globs(
context,
server,
[
"*",
],
base_path=path.join(context.model_path, "lora"),
base_path=path.join(server.model_path, "lora"),
)
logger.debug("loaded LoRA models from disk: %s", lora_models)
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
def load_params(context: ServerContext) -> None:
def load_params(server: ServerContext) -> None:
global config_params
params_file = path.join(context.params_path, "params.json")
params_file = path.join(server.params_path, "params.json")
logger.debug("loading server parameters from file: %s", params_file)
with open(params_file, "r") as f:
config_params = yaml.safe_load(f)
if "platform" in config_params and context.default_platform is not None:
if "platform" in config_params and server.default_platform is not None:
logger.info(
"overriding default platform from environment: %s",
context.default_platform,
server.default_platform,
)
config_platform = config_params.get("platform", {})
config_platform["default"] = context.default_platform
config_platform["default"] = server.default_platform
def load_platforms(context: ServerContext) -> None:
def load_platforms(server: ServerContext) -> None:
global available_platforms
providers = list(get_available_providers())
@ -309,7 +309,7 @@ def load_platforms(context: ServerContext) -> None:
for potential in platform_providers:
if (
platform_providers[potential] in providers
and potential not in context.block_platforms
and potential not in server.block_platforms
):
if potential == "cuda":
for i in range(torch.cuda.device_count()):
@ -317,16 +317,16 @@ def load_platforms(context: ServerContext) -> None:
"device_id": i,
}
if context.memory_limit is not None:
if server.memory_limit is not None:
options["arena_extend_strategy"] = "kSameAsRequested"
options["gpu_mem_limit"] = context.memory_limit
options["gpu_mem_limit"] = server.memory_limit
available_platforms.append(
DeviceParams(
potential,
platform_providers[potential],
options,
context.optimizations,
server.optimizations,
)
)
else:
@ -335,18 +335,18 @@ def load_platforms(context: ServerContext) -> None:
potential,
platform_providers[potential],
None,
context.optimizations,
server.optimizations,
)
)
if context.any_platform:
if server.any_platform:
# the platform should be ignored when the job is scheduled, but set to CPU just in case
available_platforms.append(
DeviceParams(
"any",
platform_providers["cpu"],
None,
context.optimizations,
server.optimizations,
)
)

View File

@ -28,7 +28,7 @@ logger = getLogger(__name__)
def pipeline_from_request(
context: ServerContext,
server: ServerContext,
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
@ -44,7 +44,7 @@ def pipeline_from_request(
# pipeline stuff
lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(context, model)
model_path = get_model_path(server, model)
scheduler = get_from_list(
request.args, "scheduler", list(pipeline_schedulers.keys())
)

View File

@ -7,30 +7,30 @@ from .context import ServerContext
from .utils import wrap_route
def serve_bundle_file(context: ServerContext, filename="index.html"):
return send_from_directory(path.join("..", context.bundle_path), filename)
def serve_bundle_file(server: ServerContext, filename="index.html"):
return send_from_directory(path.join("..", server.bundle_path), filename)
# non-API routes
def index(context: ServerContext):
return serve_bundle_file(context)
def index(server: ServerContext):
return serve_bundle_file(server)
def index_path(context: ServerContext, filename: str):
return serve_bundle_file(context, filename)
def index_path(server: ServerContext, filename: str):
return serve_bundle_file(server, filename)
def output(context: ServerContext, filename: str):
def output(server: ServerContext, filename: str):
return send_from_directory(
path.join("..", context.output_path), filename, as_attachment=False
path.join("..", server.output_path), filename, as_attachment=False
)
def register_static_routes(
app: Flask, context: ServerContext, _pool: DevicePoolExecutor
app: Flask, server: ServerContext, _pool: DevicePoolExecutor
):
return [
app.route("/")(wrap_route(index, context)),
app.route("/<path:filename>")(wrap_route(index_path, context)),
app.route("/output/<path:filename>")(wrap_route(output, context)),
app.route("/")(wrap_route(index, server)),
app.route("/<path:filename>")(wrap_route(index_path, server)),
app.route("/output/<path:filename>")(wrap_route(output, server)),
]

View File

@ -9,26 +9,26 @@ from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
def check_paths(context: ServerContext) -> None:
if not path.exists(context.model_path):
def check_paths(server: ServerContext) -> None:
if not path.exists(server.model_path):
raise RuntimeError("model path must exist")
if not path.exists(context.output_path):
makedirs(context.output_path)
if not path.exists(server.output_path):
makedirs(server.output_path)
def get_model_path(context: ServerContext, model: str):
return base_join(context.model_path, model)
def get_model_path(server: ServerContext, model: str):
return base_join(server.model_path, model)
def register_routes(
app: Flask,
context: ServerContext,
server: ServerContext,
pool: DevicePoolExecutor,
routes: List[Tuple[str, Dict, Callable]],
):
for route, kwargs, method in routes:
app.route(route, **kwargs)(wrap_route(method, context, pool=pool))
app.route(route, **kwargs)(wrap_route(method, server, pool=pool))
def wrap_route(func, *args, **kwargs):

View File

@ -26,51 +26,51 @@ MEMORY_ERRORS = [
]
def worker_main(context: WorkerContext, server: ServerContext):
def worker_main(worker: WorkerContext, server: ServerContext):
apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device))
setproctitle("onnx-web worker: %s" % (worker.device.device))
logger.trace(
"checking in from worker with providers: %s", get_available_providers()
)
# make leaking workers easier to recycle
context.progress.cancel_join_thread()
worker.progress.cancel_join_thread()
while True:
try:
if not context.is_active():
if not worker.is_active():
logger.warning(
"worker %s has been replaced by %s, exiting",
getpid(),
context.get_active(),
worker.get_active(),
)
exit(EXIT_REPLACED)
# wait briefly for the next job
job = context.pending.get(timeout=1.0)
logger.info("worker %s got job: %s", context.device.device, job.name)
job = worker.pending.get(timeout=1.0)
logger.info("worker %s got job: %s", worker.device.device, job.name)
# clear flags and save the job name
context.start(job.name)
worker.start(job.name)
logger.info("starting job: %s", job.name)
# reset progress, which does a final check for cancellation
context.set_progress(0)
job.fn(context, *job.args, **job.kwargs)
worker.set_progress(0)
job.fn(worker, *job.args, **job.kwargs)
# confirm completion of the job
logger.info("job succeeded: %s", job.name)
context.finish()
worker.finish()
except Empty:
pass
except KeyboardInterrupt:
logger.info("worker got keyboard interrupt")
context.fail()
worker.fail()
exit(EXIT_INTERRUPT)
except ValueError:
logger.exception("value error in worker, exiting: %s")
context.fail()
worker.fail()
exit(EXIT_ERROR)
except Exception as e:
e_str = str(e)
@ -78,11 +78,11 @@ def worker_main(context: WorkerContext, server: ServerContext):
for e_mem in MEMORY_ERRORS:
if e_mem in e_str:
logger.error("detected out-of-memory error, exiting: %s", e)
context.fail()
worker.fail()
exit(EXIT_MEMORY)
# carry on for other errors
logger.exception(
"unrecognized error while running job",
)
context.fail()
worker.fail()

23
api/tests/test_output.py Normal file
View File

@ -0,0 +1,23 @@
import unittest
class TestHashValue(unittest.TestCase):
def test_hash_value(self):
pass
class TestJSONParams(unittest.TestCase):
def test_json_params(self):
pass
class TestMakeOutputName(unittest.TestCase):
pass
class TestSaveImage(unittest.TestCase):
pass
class TestSaveParams(unittest.TestCase):
pass

101
api/tests/test_params.py Normal file
View File

@ -0,0 +1,101 @@
import unittest
from onnx_web.params import Border, Size
class BorderTests(unittest.TestCase):
def test_json(self):
border = Border.even(0)
json = border.tojson()
self.assertIn("left", json)
self.assertIn("right", json)
self.assertIn("top", json)
self.assertIn("bottom", json)
def test_str(self):
border = Border.even(10)
bstr = str(border)
self.assertEqual("(10, 10, 10, 10)", bstr)
def test_uneven(self):
border = Border(1, 2, 3, 4)
self.assertEqual("(1, 2, 3, 4)", str(border))
def test_args(self):
pass
class SizeTests(unittest.TestCase):
def test_iter(self):
size = Size(1, 2)
self.assertEqual(list(size), [1, 2])
def test_str(self):
pass
def test_border(self):
pass
def test_json(self):
pass
def test_args(self):
pass
class DeviceParamsTests(unittest.TestCase):
def test_str(self):
pass
def test_provider(self):
pass
def test_options_optimizations(self):
pass
def test_options_cache(self):
pass
def test_torch_cuda(self):
pass
def test_torch_rocm(self):
pass
class ImageParamsTests(unittest.TestCase):
def test_json(self):
pass
def test_args(self):
pass
class StageParamsTests(unittest.TestCase):
def test_init(self):
pass
class UpscaleParamsTests(unittest.TestCase):
def test_rescale(self):
pass
def test_resize(self):
pass
def test_json(self):
pass
def test_args(self):
pass
class HighresParamsTests(unittest.TestCase):
def test_resize(self):
pass
def test_json(self):
pass