diff --git a/api/serve.py b/api/serve.py index 435f7e54..684c3ff5 100644 --- a/api/serve.py +++ b/api/serve.py @@ -30,6 +30,12 @@ max_width = 512 model_path = environ.get('ONNX_WEB_MODEL_PATH', "../models/stable-diffusion-onnx-v1-5") output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs") +# platforms +platform_providers = { + 'amd': 'DmlExecutionProvider', + 'cpu': 'CPUExecutionProvider', +} + # schedulers pipeline_schedulers = { 'ddim': DDIMScheduler.from_pretrained(model_path, subfolder="scheduler"), @@ -85,6 +91,7 @@ def txt2img(): user = request.remote_addr prompt = request.args.get('prompt', default_prompt) + provider = get_from_map(request.args, 'provider', platform_providers, 'amd') scheduler = get_from_map(request.args, 'scheduler', pipeline_schedulers, 'euler-a') cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg, 0) @@ -103,7 +110,7 @@ def txt2img(): pipe = OnnxStableDiffusionPipeline.from_pretrained( model_path, - provider="DmlExecutionProvider", + provider=provider, safety_checker=None, scheduler=scheduler )