add filtered source to outputs
This commit is contained in:
parent
f986b75c85
commit
06d55b0f1f
|
@ -285,7 +285,11 @@ def run_img2img_pipeline(
|
|||
**pipe_params,
|
||||
)
|
||||
|
||||
for image, output in zip(result.images, outputs):
|
||||
images = result.images
|
||||
if source_filter is not None:
|
||||
images.append(source)
|
||||
|
||||
for image, output in zip(images, outputs):
|
||||
image = run_upscale_correction(
|
||||
job,
|
||||
server,
|
||||
|
|
|
@ -183,6 +183,7 @@ def source_filter_canny(
|
|||
|
||||
image = cv2.Canny(pil_to_cv2(source), low_threshold, high_threshold)
|
||||
image = Image.fromarray(image)
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
@ -192,7 +193,7 @@ def source_filter_openpose(server: ServerContext, source: Image.Image) -> Image.
|
|||
|
||||
model = OpenposeDetector.from_pretrained(
|
||||
"lllyasviel/ControlNet",
|
||||
cache_dir=server.cache_dir,
|
||||
cache_dir=server.cache_path,
|
||||
)
|
||||
image = model(source)
|
||||
|
||||
|
|
|
@ -69,7 +69,9 @@ def make_output_name(
|
|||
params: ImageParams,
|
||||
size: Size,
|
||||
extras: Optional[List[Optional[Param]]] = None,
|
||||
count: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
count = count or params.batch
|
||||
now = int(time())
|
||||
sha = sha256()
|
||||
|
||||
|
@ -93,7 +95,7 @@ def make_output_name(
|
|||
|
||||
return [
|
||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
||||
for i in range(params.batch)
|
||||
for i in range(count)
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -176,8 +176,12 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
get_config_value("strength", "min"),
|
||||
)
|
||||
|
||||
# TODO: add filtered source to outputs
|
||||
output = make_output_name(server, "img2img", params, size, extras=[strength])
|
||||
output_count = params.batch
|
||||
if source_filter is not None:
|
||||
output_count += 1
|
||||
|
||||
output = make_output_name(server, "img2img", params, size, extras=[strength], count=output_count)
|
||||
|
||||
job_name = output[0]
|
||||
logger.info("img2img job queued for: %s", job_name)
|
||||
|
||||
|
|
Loading…
Reference in New Issue