1
0
Fork 0

clean up lingering provider refs

This commit is contained in:
Sean Sube 2023-02-04 21:23:34 -06:00
parent 04a2faffd9
commit cc12cb0fcf
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 13 additions and 12 deletions

View File

@ -43,7 +43,8 @@ logger = getLogger(__name__)
def upscale_outpaint( def upscale_outpaint(
ctx: ServerContext, job: JobContext,
server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -76,9 +77,9 @@ def upscale_outpaint(
full_latents = get_latents_from_seed(params.seed, full_size) full_latents = get_latents_from_seed(params.seed, full_size)
if is_debug(): if is_debug():
save_image(ctx, 'last-source.png', source_image) save_image(server, 'last-source.png', source_image)
save_image(ctx, 'last-mask.png', mask_image) save_image(server, 'last-mask.png', mask_image)
save_image(ctx, 'last-noise.png', noise_image) save_image(server, 'last-noise.png', noise_image)
def outpaint(image: Image.Image, dims: Tuple[int, int, int]): def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims left, top, tile = dims
@ -86,11 +87,11 @@ def upscale_outpaint(
mask = mask_image.crop((left, top, left + tile, top + tile)) mask = mask_image.crop((left, top, left + tile, top + tile))
if is_debug(): if is_debug():
save_image(ctx, 'tile-source.png', image) save_image(server, 'tile-source.png', image)
save_image(ctx, 'tile-mask.png', mask) save_image(server, 'tile-mask.png', mask)
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.provider, params.scheduler) params.model, params.scheduler, job.get_device())
latents = get_tile_latents(full_latents, dims) latents = get_tile_latents(full_latents, dims)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)

View File

@ -140,6 +140,7 @@ def run_inpaint_pipeline(
stage = StageParams() stage = StageParams()
image = upscale_outpaint( image = upscale_outpaint(
job,
server, server,
stage, stage,
params, params,

View File

@ -74,7 +74,6 @@ def make_output_name(
hash_value(sha, mode) hash_value(sha, mode)
hash_value(sha, params.model) hash_value(sha, params.model)
hash_value(sha, params.provider)
hash_value(sha, params.scheduler.__name__) hash_value(sha, params.scheduler.__name__)
hash_value(sha, params.prompt) hash_value(sha, params.prompt)
hash_value(sha, params.negative_prompt) hash_value(sha, params.negative_prompt)

View File

@ -481,7 +481,7 @@ def img2img():
@app.route('/api/txt2img', methods=['POST']) @app.route('/api/txt2img', methods=['POST'])
def txt2img(): def txt2img():
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request(params.provider) upscale = upscale_from_request()
output = make_output_name( output = make_output_name(
context, context,
@ -512,7 +512,7 @@ def inpaint():
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
expand = border_from_request() expand = border_from_request()
upscale = upscale_from_request(params.provider) upscale = upscale_from_request()
fill_color = get_not_empty(request.args, 'fillColor', 'white') fill_color = get_not_empty(request.args, 'fillColor', 'white')
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none') mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
@ -573,7 +573,7 @@ def upscale():
source_image = Image.open(BytesIO(source_file.read())).convert('RGB') source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request(params.provider) upscale = upscale_from_request()
output = make_output_name( output = make_output_name(
context, context,
@ -629,7 +629,7 @@ def chain():
kwargs['border'] = border kwargs['border'] = border
if 'upscale' in kwargs: if 'upscale' in kwargs:
upscale = UpscaleParams(kwargs.get('upscale'), params.provider) upscale = UpscaleParams(kwargs.get('upscale'))
kwargs['upscale'] = upscale kwargs['upscale'] = upscale
stage_source_name = 'source:%s' % (stage.name) stage_source_name = 'source:%s' % (stage.name)