clean up lingering provider refs
This commit is contained in:
parent
04a2faffd9
commit
cc12cb0fcf
|
@ -43,7 +43,8 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def upscale_outpaint(
|
||||
ctx: ServerContext,
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
|
@ -76,9 +77,9 @@ def upscale_outpaint(
|
|||
full_latents = get_latents_from_seed(params.seed, full_size)
|
||||
|
||||
if is_debug():
|
||||
save_image(ctx, 'last-source.png', source_image)
|
||||
save_image(ctx, 'last-mask.png', mask_image)
|
||||
save_image(ctx, 'last-noise.png', noise_image)
|
||||
save_image(server, 'last-source.png', source_image)
|
||||
save_image(server, 'last-mask.png', mask_image)
|
||||
save_image(server, 'last-noise.png', noise_image)
|
||||
|
||||
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
|
@ -86,11 +87,11 @@ def upscale_outpaint(
|
|||
mask = mask_image.crop((left, top, left + tile, top + tile))
|
||||
|
||||
if is_debug():
|
||||
save_image(ctx, 'tile-source.png', image)
|
||||
save_image(ctx, 'tile-mask.png', mask)
|
||||
save_image(server, 'tile-source.png', image)
|
||||
save_image(server, 'tile-mask.png', mask)
|
||||
|
||||
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
params.model, params.scheduler, job.get_device())
|
||||
|
||||
latents = get_tile_latents(full_latents, dims)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
|
|
@ -140,6 +140,7 @@ def run_inpaint_pipeline(
|
|||
stage = StageParams()
|
||||
|
||||
image = upscale_outpaint(
|
||||
job,
|
||||
server,
|
||||
stage,
|
||||
params,
|
||||
|
|
|
@ -74,7 +74,6 @@ def make_output_name(
|
|||
|
||||
hash_value(sha, mode)
|
||||
hash_value(sha, params.model)
|
||||
hash_value(sha, params.provider)
|
||||
hash_value(sha, params.scheduler.__name__)
|
||||
hash_value(sha, params.prompt)
|
||||
hash_value(sha, params.negative_prompt)
|
||||
|
|
|
@ -481,7 +481,7 @@ def img2img():
|
|||
@app.route('/api/txt2img', methods=['POST'])
|
||||
def txt2img():
|
||||
device, params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request(params.provider)
|
||||
upscale = upscale_from_request()
|
||||
|
||||
output = make_output_name(
|
||||
context,
|
||||
|
@ -512,7 +512,7 @@ def inpaint():
|
|||
|
||||
device, params, size = pipeline_from_request()
|
||||
expand = border_from_request()
|
||||
upscale = upscale_from_request(params.provider)
|
||||
upscale = upscale_from_request()
|
||||
|
||||
fill_color = get_not_empty(request.args, 'fillColor', 'white')
|
||||
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
|
||||
|
@ -573,7 +573,7 @@ def upscale():
|
|||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||
|
||||
device, params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request(params.provider)
|
||||
upscale = upscale_from_request()
|
||||
|
||||
output = make_output_name(
|
||||
context,
|
||||
|
@ -629,7 +629,7 @@ def chain():
|
|||
kwargs['border'] = border
|
||||
|
||||
if 'upscale' in kwargs:
|
||||
upscale = UpscaleParams(kwargs.get('upscale'), params.provider)
|
||||
upscale = UpscaleParams(kwargs.get('upscale'))
|
||||
kwargs['upscale'] = upscale
|
||||
|
||||
stage_source_name = 'source:%s' % (stage.name)
|
||||
|
|
Loading…
Reference in New Issue