1
0
Fork 0

feat(scripts): add env debug script (#191)

This commit is contained in:
Sean Sube 2023-03-18 13:16:59 -05:00
parent 17e4fd7b06
commit 84718e5928
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 195 additions and 121 deletions

View File

@ -373,7 +373,7 @@ have added your own.
### Test the models ### Test the models
You should verify that all of the steps up to this point have worked correctly by attempting to run the You should verify that all of the steps up to this point have worked correctly by attempting to run the
`api/test-diffusers.py` script, which is a slight variation on the original txt2img script. `api/scripts/test-diffusers.py` script, which is a slight variation on the original txt2img script.
If the script works, there will be an image of an astronaut in `outputs/test.png`. If the script works, there will be an image of an astronaut in `outputs/test.png`.

View File

@ -95,7 +95,9 @@ def blend_textual_inversions(
# add the tokens to the tokenizer # add the tokens to the tokenizer
logger.debug( logger.debug(
"found embeddings for %s tokens: %s", len(embeds.keys()), list(embeds.keys()) "found embeddings for %s tokens: %s",
len(embeds.keys()),
list(embeds.keys()),
) )
num_added_tokens = tokenizer.add_tokens(list(embeds.keys())) num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0: if num_added_tokens == 0:

62
api/scripts/check-env.py Normal file
View File

@ -0,0 +1,62 @@
from importlib.metadata import version
from typing import List
REQUIRED_MODULES = ["onnx", "diffusers", "safetensors", "torch"]
REQUIRED_RUNTIME = [
"onnxruntime",
"onnxruntime_gpu",
"onnxruntime_rocm",
"onnxruntime_training",
]
def check_modules() -> List[str]:
results = []
for name in REQUIRED_MODULES:
try:
__import__(name)
module_version = version(name)
results.append(
f"required module {name} is present at version {module_version}"
)
except ImportError as e:
results.append(f"unable to import required module {name}: {e}")
return results
def check_providers() -> List[str]:
results = []
try:
import onnxruntime
import torch
available = onnxruntime.get_available_providers()
for provider in onnxruntime.get_all_providers():
if provider in available:
results.append(f"onnxruntime provider {provider} is available")
else:
results.append(f"onnxruntime provider {provider} is missing")
except Exception as e:
results.append(f"unable to check runtime providers: {e}")
return results
ALL_CHECKS = [
check_modules,
check_providers,
]
def check_all():
results = []
for check in ALL_CHECKS:
results.extend(check())
print(results)
if __name__ == "__main__":
check_all()

124
api/scripts/check-model.py Normal file
View File

@ -0,0 +1,124 @@
from collections import Counter
from logging import getLogger
from os import path
from sys import argv
from typing import Callable, List
import torch
from onnx import load_model
from safetensors import safe_open
import onnx_web
logger = getLogger(__name__)
CheckFunc = Callable[[str], List[str]]
def check_file_extension(filename: str) -> List[str]:
"""
Check the file extension
"""
_name, ext = path.splitext(filename)
ext = ext.removeprefix(".")
if ext != "":
return [f"format:{ext}"]
return []
def check_file_diffusers(filename: str) -> List[str]:
"""
Check for a diffusers directory with model_index.json
"""
if path.isdir(filename) and path.exists(path.join(filename, "model_index.json")):
return ["model:diffusion"]
return []
def check_parser_safetensor(filename: str) -> List[str]:
"""
Attempt to load as a safetensor
"""
try:
if path.isfile(filename):
# only try to parse files
safe_open(filename, framework="pt")
return ["format:safetensor"]
except Exception as e:
logger.debug("error parsing as safetensor: %s", e)
return []
def check_parser_torch(filename: str) -> List[str]:
"""
Attempt to load as a torch tensor
"""
try:
if path.isfile(filename):
# only try to parse files
torch.load(filename)
return ["format:torch"]
except Exception as e:
logger.debug("error parsing as torch tensor: %s", e)
return []
def check_parser_onnx(filename: str) -> List[str]:
"""
Attempt to load as an ONNX model
"""
try:
if path.isfile(filename):
load_model(filename)
return ["format:onnx"]
except Exception as e:
logger.debug("error parsing as ONNX model: %s", e)
return []
def check_network_lora(filename: str) -> List[str]:
"""
TODO: Check for LoRA keys
"""
return []
def check_network_inversion(filename: str) -> List[str]:
"""
TODO: Check for Textual Inversion keys
"""
return []
ALL_CHECKS: List[CheckFunc] = [
check_file_diffusers,
check_file_extension,
check_network_inversion,
check_network_lora,
check_parser_onnx,
check_parser_safetensor,
check_parser_torch,
]
def check_file(filename: str) -> Counter:
logger.info("checking file: %s", filename)
counts = Counter()
for check in ALL_CHECKS:
logger.info("running check: %s", check.__name__)
counts.update(check(filename))
common = counts.most_common()
logger.info("file %s is most likely: %s", filename, common)
if __name__ == "__main__":
for file in argv[1:]:
check_file(file)

View File

@ -1,114 +0,0 @@
from collections import Counter
from logging import getLogger
from onnx import load_model
from os import path
from safetensors import safe_open
from sys import argv
from typing import Callable, List
import onnx_web
import torch
logger = getLogger(__name__)
CheckFunc = Callable[[str], List[str]]
def check_file_extension(filename: str) -> List[str]:
"""
Check the file extension
"""
_name, ext = path.splitext(filename)
ext = ext.removeprefix(".")
if ext != "":
return [f"format:{ext}"]
return []
def check_file_diffusers(filename: str) -> List[str]:
"""
Check for a diffusers directory with model_index.json
"""
if path.isdir(filename) and path.exists(path.join(filename, "model_index.json")):
return ["model:diffusion"]
return []
def check_parser_safetensor(filename: str) -> List[str]:
"""
Attempt to load as a safetensor
"""
try:
if path.isfile(filename):
# only try to parse files
safe_open(filename, framework="pt")
return ["format:safetensor"]
except Exception as e:
logger.debug("error parsing as safetensor: %s", e)
return []
def check_parser_torch(filename: str) -> List[str]:
"""
Attempt to load as a torch tensor
"""
try:
if path.isfile(filename):
# only try to parse files
torch.load(filename)
return ["format:torch"]
except Exception as e:
logger.debug("error parsing as torch tensor: %s", e)
return []
def check_parser_onnx(filename: str) -> List[str]:
"""
Attempt to load as an ONNX model
"""
try:
if path.isfile(filename):
load_model(filename)
return ["format:onnx"]
except Exception as e:
logger.debug("error parsing as ONNX model: %s", e)
return []
def check_network_lora(filename: str) -> List[str]:
"""
TODO: Check for LoRA keys
"""
return []
def check_network_inversion(filename: str) -> List[str]:
"""
TODO: Check for Textual Inversion keys
"""
return []
ALL_CHECKS: List[CheckFunc] = [
check_file_diffusers,
check_file_extension,
check_network_inversion,
check_network_lora,
check_parser_onnx,
check_parser_safetensor,
check_parser_torch,
]
def check_file(filename: str) -> Counter:
logger.info("checking file: %s", filename)
counts = Counter()
for check in ALL_CHECKS:
logger.info("running check: %s", check.__name__)
counts.update(check(filename))
common = counts.most_common()
logger.info("file %s is most likely: %s", filename, common)
if __name__ == "__main__":
for file in argv[1:]:
check_file(file)

View File

@ -12,9 +12,9 @@ steps = 22
height = 512 height = 512
width = 512 width = 512
model = path.join('..', 'models', 'stable-diffusion-onnx-v1-5') model = path.join('..', '..', 'models', 'stable-diffusion-onnx-v1-5')
prompt = 'an astronaut eating a hamburger' prompt = 'an astronaut eating a hamburger'
output = path.join('..', 'outputs', 'test.png') output = path.join('..', '..', 'outputs', 'test.png')
print('generating test image...') print('generating test image...')
pipe = OnnxStableDiffusionPipeline.from_pretrained(model, provider='DmlExecutionProvider', safety_checker=None) pipe = OnnxStableDiffusionPipeline.from_pretrained(model, provider='DmlExecutionProvider', safety_checker=None)

View File

@ -11,9 +11,9 @@ steps = 22
height = 512 height = 512
width = 512 width = 512
esrgan = path.join('..', 'models', 'RealESRGAN_x4plus.onnx') esrgan = path.join('..', '..', 'models', 'RealESRGAN_x4plus.onnx')
output = path.join('..', 'outputs', 'test.png') output = path.join('..', '..', 'outputs', 'test.png')
upscale = path.join('..', 'outputs', 'test-large.png') upscale = path.join('..', '..', 'outputs', 'test-large.png')
print('upscaling test image...') print('upscaling test image...')
session = ort.InferenceSession(esrgan, providers=['DmlExecutionProvider']) session = ort.InferenceSession(esrgan, providers=['DmlExecutionProvider'])