1
0
Fork 0

fix(api): set CUDA device in ORT session

This commit is contained in:
Sean Sube 2023-02-04 21:17:39 -06:00
parent d636ce3eef
commit 04a2faffd9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
21 changed files with 116 additions and 52 deletions

6
api/dev-requirements.txt Normal file
View File

@ -0,0 +1,6 @@
mypy
types-Flask-Cors
types-jsonschema
types-Pillow
types-PyYAML

View File

@ -28,6 +28,7 @@ logger = getLogger(__name__)
class StageCallback(Protocol): class StageCallback(Protocol):
def __call__( def __call__(
self, self,
job: JobContext,
ctx: ServerContext, ctx: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
@ -83,7 +84,7 @@ class ChainPipeline:
stage_params.tile_size) stage_params.tile_size)
def stage_tile(tile: Image.Image, _dims) -> Image.Image: def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(server, stage_params, params, tile, tile = stage_pipe(job, server, stage_params, params, tile,
**kwargs) **kwargs)
if is_debug(): if is_debug():
@ -95,7 +96,7 @@ class ChainPipeline:
image, stage_params.tile_size, stage_params.outscale, [stage_tile]) image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else: else:
logger.info('image within tile size, running stage') logger.info('image within tile size, running stage')
image = stage_pipe(server, stage_params, params, image, image = stage_pipe(job, server, stage_params, params, image,
**kwargs) **kwargs)
logger.info('finished stage %s, result size: %sx%s', logger.info('finished stage %s, result size: %sx%s',

View File

@ -4,6 +4,9 @@ from diffusers import (
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.load import ( from ..diffusion.load import (
load_pipeline, load_pipeline,
) )
@ -21,7 +24,8 @@ logger = getLogger(__name__)
def blend_img2img( def blend_img2img(
_ctx: ServerContext, job: JobContext,
_server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -34,7 +38,7 @@ def blend_img2img(
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt) logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler) params.model, params.scheduler, job.get_device())
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)

View File

@ -5,6 +5,9 @@ from logging import getLogger
from PIL import Image from PIL import Image
from typing import Callable, Tuple from typing import Callable, Tuple
from ..device_pool import (
JobContext,
)
from ..diffusion.load import ( from ..diffusion.load import (
get_latents_from_seed, get_latents_from_seed,
load_pipeline, load_pipeline,
@ -38,7 +41,8 @@ logger = getLogger(__name__)
def blend_inpaint( def blend_inpaint(
ctx: ServerContext, job: JobContext,
server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -65,9 +69,9 @@ def blend_inpaint(
mask_filter=mask_filter) mask_filter=mask_filter)
if is_debug(): if is_debug():
save_image(ctx, 'last-source.png', source_image) save_image(server, 'last-source.png', source_image)
save_image(ctx, 'last-mask.png', mask_image) save_image(server, 'last-mask.png', mask_image)
save_image(ctx, 'last-noise.png', noise_image) save_image(server, 'last-noise.png', noise_image)
def outpaint(image: Image.Image, dims: Tuple[int, int, int]): def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims left, top, tile = dims
@ -75,11 +79,11 @@ def blend_inpaint(
mask = mask_image.crop((left, top, left + tile, top + tile)) mask = mask_image.crop((left, top, left + tile, top + tile))
if is_debug(): if is_debug():
save_image(ctx, 'tile-source.png', image) save_image(server, 'tile-source.png', image)
save_image(ctx, 'tile-mask.png', mask) save_image(server, 'tile-mask.png', mask)
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.provider, params.scheduler) params.model, params.scheduler, job.get_device())
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)

View File

@ -5,6 +5,9 @@ from PIL import Image
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from typing import Optional from typing import Optional
from ..device_pool import (
JobContext,
)
from ..params import ( from ..params import (
ImageParams, ImageParams,
StageParams, StageParams,
@ -60,7 +63,8 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[
def correct_gfpgan( def correct_gfpgan(
ctx: ServerContext, _job: JobContext,
server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -74,7 +78,7 @@ def correct_gfpgan(
return source_image return source_image
logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model) logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model)
gfpgan = load_gfpgan(ctx, upscale, upsampler=upsampler) gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
output = np.array(source_image) output = np.array(source_image)
_, _, output = gfpgan.enhance( _, _, output = gfpgan.enhance(

View File

@ -1,6 +1,9 @@
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import ( from ..params import (
ImageParams, ImageParams,
StageParams, StageParams,
@ -16,6 +19,7 @@ logger = getLogger(__name__)
def persist_disk( def persist_disk(
_job: JobContext,
ctx: ServerContext, ctx: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -5,6 +5,9 @@ from io import BytesIO
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import ( from ..params import (
ImageParams, ImageParams,
StageParams, StageParams,

View File

@ -1,6 +1,9 @@
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import ( from ..params import (
ImageParams, ImageParams,
Size, Size,

View File

@ -1,6 +1,9 @@
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import ( from ..params import (
ImageParams, ImageParams,
Size, Size,

View File

@ -2,6 +2,9 @@ from logging import getLogger
from PIL import Image from PIL import Image
from typing import Callable from typing import Callable
from ..device_pool import (
JobContext,
)
from ..params import ( from ..params import (
ImageParams, ImageParams,
Size, Size,

View File

@ -4,6 +4,9 @@ from diffusers import (
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.load import ( from ..diffusion.load import (
get_latents_from_seed, get_latents_from_seed,
load_pipeline, load_pipeline,
@ -23,7 +26,8 @@ logger = getLogger(__name__)
def source_txt2img( def source_txt2img(
ctx: ServerContext, job: JobContext,
server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -39,7 +43,7 @@ def source_txt2img(
logger.warn('a source image was passed to a txt2img stage, but will be discarded') logger.warn('a source image was passed to a txt2img stage, but will be discarded')
pipe = load_pipeline(OnnxStableDiffusionPipeline, pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler) params.model, params.scheduler, job.get_device())
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)

View File

@ -5,6 +5,9 @@ from logging import getLogger
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from typing import Callable, Tuple from typing import Callable, Tuple
from ..device_pool import (
JobContext,
)
from ..diffusion.load import ( from ..diffusion.load import (
get_latents_from_seed, get_latents_from_seed,
get_tile_latents, get_tile_latents,

View File

@ -4,10 +4,14 @@ from os import path
from PIL import Image from PIL import Image
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from ..device_pool import (
JobContext,
)
from ..onnx import ( from ..onnx import (
OnnxNet, OnnxNet,
) )
from ..params import ( from ..params import (
DeviceParams,
ImageParams, ImageParams,
StageParams, StageParams,
UpscaleParams, UpscaleParams,
@ -25,7 +29,7 @@ last_pipeline_instance = None
last_pipeline_params = (None, None) last_pipeline_params = (None, None)
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_params global last_pipeline_params
@ -41,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
# use ONNX acceleration, if available # use ONNX acceleration, if available
if params.format == 'onnx': if params.format == 'onnx':
model = OnnxNet(ctx, model_file, provider=params.provider) model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options)
elif params.format == 'pth': elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale) num_block=23, num_grow_ch=32, scale=params.scale)
@ -76,7 +80,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
def upscale_resrgan( def upscale_resrgan(
ctx: ServerContext, job: JobContext,
server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -87,7 +92,7 @@ def upscale_resrgan(
logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale) logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale)
output = np.array(source_image) output = np.array(source_image)
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size) upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale) output, _ = upsampler.enhance(output, outscale=upscale.outscale)

View File

@ -5,10 +5,14 @@ from logging import getLogger
from os import path from os import path
from PIL import Image from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
) )
from ..params import ( from ..params import (
DeviceParams,
ImageParams, ImageParams,
StageParams, StageParams,
UpscaleParams, UpscaleParams,
@ -27,7 +31,7 @@ last_pipeline_instance = None
last_pipeline_params = (None, None) last_pipeline_params = (None, None)
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams): def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_params global last_pipeline_params
@ -39,11 +43,11 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
return last_pipeline_instance return last_pipeline_instance
if upscale.format == 'onnx': if upscale.format == 'onnx':
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, upscale.provider) logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider) pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
else: else:
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, upscale.provider) logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider) pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
last_pipeline_instance = pipeline last_pipeline_instance = pipeline
last_pipeline_params = cache_params last_pipeline_params = cache_params
@ -53,7 +57,8 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
def upscale_stable_diffusion( def upscale_stable_diffusion(
ctx: ServerContext, job: JobContext,
server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, source: Image.Image,
@ -65,7 +70,7 @@ def upscale_stable_diffusion(
prompt = prompt or params.prompt prompt = prompt or params.prompt
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt) logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
pipeline = load_stable_diffusion(ctx, upscale) pipeline = load_stable_diffusion(server, upscale, job.get_device())
generator = torch.manual_seed(params.seed) generator = torch.manual_seed(params.seed)
return pipeline( return pipeline(

View File

@ -133,6 +133,9 @@ class DevicePoolExecutor:
return True return True
else: else:
job.set_cancel() job.set_cancel()
return True
return False
def done(self, key: str) -> Tuple[bool, int]: def done(self, key: str) -> Tuple[bool, int]:
for job in self.jobs: for job in self.jobs:

View File

@ -5,6 +5,7 @@ from logging import getLogger
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from ..params import ( from ..params import (
DeviceParams,
Size, Size,
) )
from ..utils import ( from ..utils import (
@ -45,12 +46,12 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np
return full_latents[:, :, y:yt, x:xt] return full_latents[:, :, y:yt, x:xt]
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None): def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_scheduler global last_pipeline_scheduler
global last_pipeline_options global last_pipeline_options
options = (pipeline, model, provider) options = (pipeline, model, device.provider)
if last_pipeline_instance != None and last_pipeline_options == options: if last_pipeline_instance != None and last_pipeline_options == options:
logger.debug('reusing existing diffusion pipeline') logger.debug('reusing existing diffusion pipeline')
pipe = last_pipeline_instance pipe = last_pipeline_instance
@ -61,11 +62,18 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
run_gc() run_gc()
logger.debug('loading new diffusion pipeline from %s', model) logger.debug('loading new diffusion pipeline from %s', model)
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
sess_options=device.options,
subfolder='scheduler',
)
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,
provider=provider, provider=device.provider,
safety_checker=None, safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler') scheduler=scheduler,
sess_options=device.options,
) )
if device is not None and hasattr(pipe, 'to'): if device is not None and hasattr(pipe, 'to'):
@ -78,7 +86,11 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
if last_pipeline_scheduler != scheduler: if last_pipeline_scheduler != scheduler:
logger.debug('loading new diffusion scheduler') logger.debug('loading new diffusion scheduler')
scheduler = scheduler.from_pretrained( scheduler = scheduler.from_pretrained(
model, subfolder='scheduler') model,
provider=device.provider,
sess_options=device.options,
subfolder='scheduler',
)
if device is not None and hasattr(scheduler, 'to'): if device is not None and hasattr(scheduler, 'to'):
scheduler = scheduler.to(device) scheduler = scheduler.to(device)

View File

@ -48,9 +48,8 @@ def run_txt2img_pipeline(
output: str, output: str,
upscale: UpscaleParams upscale: UpscaleParams
) -> None: ) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionPipeline, pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, device.provider, params.scheduler, device=device.torch_device()) params.model, params.scheduler, job.get_device())
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -90,9 +89,8 @@ def run_img2img_pipeline(
source_image: Image.Image, source_image: Image.Image,
strength: float, strength: float,
) -> None: ) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, device.provider, params.scheduler, device=device.torch_device()) params.model, params.scheduler, job.get_device())
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)

View File

@ -1,6 +1,6 @@
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
from os import path from os import path
from typing import Any from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
@ -43,13 +43,19 @@ class OnnxNet():
Provides the RRDBNet interface using an ONNX session for DirectML acceleration. Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
''' '''
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None: def __init__(
self,
server: ServerContext,
model: str,
provider: str = 'DmlExecutionProvider',
sess_options: Optional[dict] = None,
) -> None:
''' '''
TODO: get platform provider from request params TODO: get platform provider from request params
''' '''
model_path = path.join(ctx.model_path, model) model_path = path.join(server.model_path, model)
self.session = InferenceSession( self.session = InferenceSession(
model_path, providers=[provider]) model_path, providers=[provider], sess_options=sess_options)
def __call__(self, image: Any) -> Any: def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name input_name = self.session.get_inputs()[0].name

View File

@ -79,7 +79,6 @@ class ImageParams:
def __init__( def __init__(
self, self,
model: str, model: str,
provider: str,
scheduler: Any, scheduler: Any,
prompt: str, prompt: str,
negative_prompt: Optional[str], negative_prompt: Optional[str],
@ -88,7 +87,6 @@ class ImageParams:
seed: int seed: int
) -> None: ) -> None:
self.model = model self.model = model
self.provider = provider
self.scheduler = scheduler self.scheduler = scheduler
self.prompt = prompt self.prompt = prompt
self.negative_prompt = negative_prompt self.negative_prompt = negative_prompt
@ -96,10 +94,9 @@ class ImageParams:
self.steps = steps self.steps = steps
self.seed = seed self.seed = seed
def tojson(self) -> Dict[str, Param]: def tojson(self) -> Dict[str, Optional[Param]]:
return { return {
'model': self.model, 'model': self.model,
'provider': self.provider,
'scheduler': self.scheduler.__name__, 'scheduler': self.scheduler.__name__,
'seed': self.seed, 'seed': self.seed,
'prompt': self.prompt, 'prompt': self.prompt,
@ -130,7 +127,6 @@ class UpscaleParams():
def __init__( def __init__(
self, self,
upscale_model: str, upscale_model: str,
provider: str,
correction_model: Optional[str] = None, correction_model: Optional[str] = None,
denoise: float = 0.5, denoise: float = 0.5,
faces=True, faces=True,
@ -143,7 +139,6 @@ class UpscaleParams():
tile_pad: int = 10, tile_pad: int = 10,
) -> None: ) -> None:
self.upscale_model = upscale_model self.upscale_model = upscale_model
self.provider = provider
self.correction_model = correction_model self.correction_model = correction_model
self.denoise = denoise self.denoise = denoise
self.faces = faces self.faces = faces
@ -158,7 +153,6 @@ class UpscaleParams():
def rescale(self, scale: int): def rescale(self, scale: int):
return UpscaleParams( return UpscaleParams(
self.upscale_model, self.upscale_model,
self.provider,
correction_model=self.correction_model, correction_model=self.correction_model,
denoise=self.denoise, denoise=self.denoise,
faces=self.faces, faces=self.faces,

View File

@ -226,7 +226,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt) user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
params = ImageParams(model_path, device.provider, scheduler, prompt, params = ImageParams(model_path, scheduler, prompt,
negative_prompt, cfg, steps, seed) negative_prompt, cfg, steps, seed)
size = Size(width, height) size = Size(width, height)
return (device, params, size) return (device, params, size)
@ -245,7 +245,7 @@ def border_from_request() -> Border:
return Border(left, right, top, bottom) return Border(left, right, top, bottom)
def upscale_from_request(provider: str) -> UpscaleParams: def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0) denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1) scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1) outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
@ -257,7 +257,6 @@ def upscale_from_request(provider: str) -> UpscaleParams:
return UpscaleParams( return UpscaleParams(
upscaling, upscaling,
provider,
correction_model=correction, correction_model=correction,
denoise=denoise, denoise=denoise,
faces=faces, faces=faces,
@ -327,7 +326,7 @@ def load_platforms():
if platform_providers[potential] in providers and potential not in context.block_platforms: if platform_providers[potential] in providers and potential not in context.block_platforms:
if potential == 'cuda': if potential == 'cuda':
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
available_platforms.append(DeviceParams('%s:%s' % (potential, i), platform_providers[potential], { available_platforms.append(DeviceParams(potential, platform_providers[potential], {
'device_id': i, 'device_id': i,
})) }))
else: else:
@ -344,7 +343,6 @@ def load_platforms():
return -1 return -1
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last)) available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
logger.info('available acceleration platforms: %s', logger.info('available acceleration platforms: %s',
@ -456,7 +454,7 @@ def img2img():
source_image = Image.open(BytesIO(source_file.read())).convert('RGB') source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request(params.provider) upscale = upscale_from_request()
strength = get_and_clamp_float( strength = get_and_clamp_float(
request.args, request.args,

View File

@ -54,6 +54,7 @@
"scandir", "scandir",
"scipy", "scipy",
"scrollback", "scrollback",
"sess",
"Singlestep", "Singlestep",
"spacy", "spacy",
"spinalcase", "spinalcase",