diff --git a/README.md b/README.md index a1148944..e45b1688 100644 --- a/README.md +++ b/README.md @@ -373,7 +373,7 @@ have added your own. ### Test the models You should verify that all of the steps up to this point have worked correctly by attempting to run the -`api/test-diffusers.py` script, which is a slight variation on the original txt2img script. +`api/scripts/test-diffusers.py` script, which is a slight variation on the original txt2img script. If the script works, there will be an image of an astronaut in `outputs/test.png`. diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 447e7dda..d0b5d2fb 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -95,7 +95,9 @@ def blend_textual_inversions( # add the tokens to the tokenizer logger.debug( - "found embeddings for %s tokens: %s", len(embeds.keys()), list(embeds.keys()) + "found embeddings for %s tokens: %s", + len(embeds.keys()), + list(embeds.keys()), ) num_added_tokens = tokenizer.add_tokens(list(embeds.keys())) if num_added_tokens == 0: diff --git a/api/scripts/check-env.py b/api/scripts/check-env.py new file mode 100644 index 00000000..774ffd6d --- /dev/null +++ b/api/scripts/check-env.py @@ -0,0 +1,62 @@ +from importlib.metadata import version +from typing import List + +REQUIRED_MODULES = ["onnx", "diffusers", "safetensors", "torch"] + +REQUIRED_RUNTIME = [ + "onnxruntime", + "onnxruntime_gpu", + "onnxruntime_rocm", + "onnxruntime_training", +] + + +def check_modules() -> List[str]: + results = [] + for name in REQUIRED_MODULES: + try: + __import__(name) + module_version = version(name) + results.append( + f"required module {name} is present at version {module_version}" + ) + except ImportError as e: + results.append(f"unable to import required module {name}: {e}") + + return results + + +def check_providers() -> List[str]: + results = [] + try: + import onnxruntime + import torch + + available = onnxruntime.get_available_providers() + for provider in onnxruntime.get_all_providers(): + if provider in available: + results.append(f"onnxruntime provider {provider} is available") + else: + results.append(f"onnxruntime provider {provider} is missing") + except Exception as e: + results.append(f"unable to check runtime providers: {e}") + + return results + + +ALL_CHECKS = [ + check_modules, + check_providers, +] + + +def check_all(): + results = [] + for check in ALL_CHECKS: + results.extend(check()) + + print(results) + + +if __name__ == "__main__": + check_all() diff --git a/api/scripts/check-model.py b/api/scripts/check-model.py new file mode 100644 index 00000000..ddf9f007 --- /dev/null +++ b/api/scripts/check-model.py @@ -0,0 +1,124 @@ +from collections import Counter +from logging import getLogger +from os import path +from sys import argv +from typing import Callable, List + +import torch +from onnx import load_model +from safetensors import safe_open + +import onnx_web + +logger = getLogger(__name__) + +CheckFunc = Callable[[str], List[str]] + + +def check_file_extension(filename: str) -> List[str]: + """ + Check the file extension + """ + _name, ext = path.splitext(filename) + ext = ext.removeprefix(".") + + if ext != "": + return [f"format:{ext}"] + + return [] + + +def check_file_diffusers(filename: str) -> List[str]: + """ + Check for a diffusers directory with model_index.json + """ + if path.isdir(filename) and path.exists(path.join(filename, "model_index.json")): + return ["model:diffusion"] + + return [] + + +def check_parser_safetensor(filename: str) -> List[str]: + """ + Attempt to load as a safetensor + """ + try: + if path.isfile(filename): + # only try to parse files + safe_open(filename, framework="pt") + return ["format:safetensor"] + except Exception as e: + logger.debug("error parsing as safetensor: %s", e) + + return [] + + +def check_parser_torch(filename: str) -> List[str]: + """ + Attempt to load as a torch tensor + """ + try: + if path.isfile(filename): + # only try to parse files + torch.load(filename) + return ["format:torch"] + except Exception as e: + logger.debug("error parsing as torch tensor: %s", e) + + return [] + + +def check_parser_onnx(filename: str) -> List[str]: + """ + Attempt to load as an ONNX model + """ + try: + if path.isfile(filename): + load_model(filename) + return ["format:onnx"] + except Exception as e: + logger.debug("error parsing as ONNX model: %s", e) + + return [] + + +def check_network_lora(filename: str) -> List[str]: + """ + TODO: Check for LoRA keys + """ + return [] + + +def check_network_inversion(filename: str) -> List[str]: + """ + TODO: Check for Textual Inversion keys + """ + return [] + + +ALL_CHECKS: List[CheckFunc] = [ + check_file_diffusers, + check_file_extension, + check_network_inversion, + check_network_lora, + check_parser_onnx, + check_parser_safetensor, + check_parser_torch, +] + + +def check_file(filename: str) -> Counter: + logger.info("checking file: %s", filename) + + counts = Counter() + for check in ALL_CHECKS: + logger.info("running check: %s", check.__name__) + counts.update(check(filename)) + + common = counts.most_common() + logger.info("file %s is most likely: %s", filename, common) + + +if __name__ == "__main__": + for file in argv[1:]: + check_file(file) diff --git a/api/scripts/model-guess.py b/api/scripts/model-guess.py deleted file mode 100644 index 81f18869..00000000 --- a/api/scripts/model-guess.py +++ /dev/null @@ -1,114 +0,0 @@ -from collections import Counter -from logging import getLogger -from onnx import load_model -from os import path -from safetensors import safe_open -from sys import argv -from typing import Callable, List - -import onnx_web -import torch - -logger = getLogger(__name__) - -CheckFunc = Callable[[str], List[str]] - -def check_file_extension(filename: str) -> List[str]: - """ - Check the file extension - """ - _name, ext = path.splitext(filename) - ext = ext.removeprefix(".") - - if ext != "": - return [f"format:{ext}"] - - return [] - -def check_file_diffusers(filename: str) -> List[str]: - """ - Check for a diffusers directory with model_index.json - """ - if path.isdir(filename) and path.exists(path.join(filename, "model_index.json")): - return ["model:diffusion"] - - return [] - -def check_parser_safetensor(filename: str) -> List[str]: - """ - Attempt to load as a safetensor - """ - try: - if path.isfile(filename): - # only try to parse files - safe_open(filename, framework="pt") - return ["format:safetensor"] - except Exception as e: - logger.debug("error parsing as safetensor: %s", e) - - return [] - -def check_parser_torch(filename: str) -> List[str]: - """ - Attempt to load as a torch tensor - """ - try: - if path.isfile(filename): - # only try to parse files - torch.load(filename) - return ["format:torch"] - except Exception as e: - logger.debug("error parsing as torch tensor: %s", e) - - return [] - -def check_parser_onnx(filename: str) -> List[str]: - """ - Attempt to load as an ONNX model - """ - try: - if path.isfile(filename): - load_model(filename) - return ["format:onnx"] - except Exception as e: - logger.debug("error parsing as ONNX model: %s", e) - - return [] - -def check_network_lora(filename: str) -> List[str]: - """ - TODO: Check for LoRA keys - """ - return [] - -def check_network_inversion(filename: str) -> List[str]: - """ - TODO: Check for Textual Inversion keys - """ - return [] - - -ALL_CHECKS: List[CheckFunc] = [ - check_file_diffusers, - check_file_extension, - check_network_inversion, - check_network_lora, - check_parser_onnx, - check_parser_safetensor, - check_parser_torch, -] - -def check_file(filename: str) -> Counter: - logger.info("checking file: %s", filename) - - counts = Counter() - for check in ALL_CHECKS: - logger.info("running check: %s", check.__name__) - counts.update(check(filename)) - - common = counts.most_common() - logger.info("file %s is most likely: %s", filename, common) - -if __name__ == "__main__": - for file in argv[1:]: - check_file(file) \ No newline at end of file diff --git a/api/test-diffusers.py b/api/scripts/test-diffusers.py similarity index 81% rename from api/test-diffusers.py rename to api/scripts/test-diffusers.py index 5852eb6b..27b018f5 100644 --- a/api/test-diffusers.py +++ b/api/scripts/test-diffusers.py @@ -12,9 +12,9 @@ steps = 22 height = 512 width = 512 -model = path.join('..', 'models', 'stable-diffusion-onnx-v1-5') +model = path.join('..', '..', 'models', 'stable-diffusion-onnx-v1-5') prompt = 'an astronaut eating a hamburger' -output = path.join('..', 'outputs', 'test.png') +output = path.join('..', '..', 'outputs', 'test.png') print('generating test image...') pipe = OnnxStableDiffusionPipeline.from_pretrained(model, provider='DmlExecutionProvider', safety_checker=None) diff --git a/api/test-resrgan.py b/api/scripts/test-resrgan.py similarity index 86% rename from api/test-resrgan.py rename to api/scripts/test-resrgan.py index efebb80b..afcbe7ea 100644 --- a/api/test-resrgan.py +++ b/api/scripts/test-resrgan.py @@ -11,9 +11,9 @@ steps = 22 height = 512 width = 512 -esrgan = path.join('..', 'models', 'RealESRGAN_x4plus.onnx') -output = path.join('..', 'outputs', 'test.png') -upscale = path.join('..', 'outputs', 'test-large.png') +esrgan = path.join('..', '..', 'models', 'RealESRGAN_x4plus.onnx') +output = path.join('..', '..', 'outputs', 'test.png') +upscale = path.join('..', '..', 'outputs', 'test-large.png') print('upscaling test image...') session = ort.InferenceSession(esrgan, providers=['DmlExecutionProvider'])