1
0
Fork 0

feat(api): pass tile size param to most pipeline stages

This commit is contained in:
Sean Sube 2023-07-02 18:54:10 -05:00
parent c515d25dd4
commit d8ec93a619
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 13 additions and 5 deletions

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..diffusers.load import load_pipeline from ..diffusers.load import load_pipeline
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage from .stage import BaseStage
@ -16,6 +16,8 @@ logger = getLogger(__name__)
class SourceTxt2ImgStage(BaseStage): class SourceTxt2ImgStage(BaseStage):
max_tile = SizeChart.unlimited
def run( def run(
self, self,
job: WorkerContext, job: WorkerContext,

View File

@ -42,7 +42,9 @@ def run_txt2img_pipeline(
) -> None: ) -> None:
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() stage = StageParams(
tile_size=params.tiles,
)
chain.stage( chain.stage(
SourceTxt2ImgStage(), SourceTxt2ImgStage(),
stage, stage,
@ -122,7 +124,9 @@ def run_img2img_pipeline(
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() stage = StageParams(
tile_size=params.tiles,
)
chain.stage( chain.stage(
BlendImg2ImgStage(), BlendImg2ImgStage(),
stage, stage,
@ -219,7 +223,7 @@ def run_inpaint_pipeline(
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(tile_order=tile_order) stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
chain.stage( chain.stage(
UpscaleOutpaintStage(), UpscaleOutpaintStage(),
stage, stage,
@ -286,7 +290,7 @@ def run_upscale_pipeline(
) -> None: ) -> None:
# set up the chain pipeline, no base stage for upscaling # set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() stage = StageParams(tile_size=params.tiles)
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)

View File

@ -14,6 +14,7 @@ Point = Tuple[int, int]
class SizeChart(IntEnum): class SizeChart(IntEnum):
unlimited = 0
mini = 128 # small tile for very expensive models mini = 128 # small tile for very expensive models
half = 256 # half tile for outpainting half = 256 # half tile for outpainting
auto = 512 # auto tile size auto = 512 # auto tile size
@ -22,6 +23,7 @@ class SizeChart(IntEnum):
hd4k = 2**12 hd4k = 2**12
hd8k = 2**13 hd8k = 2**13
hd16k = 2**14 hd16k = 2**14
hd32k = 2**15
hd64k = 2**16 hd64k = 2**16