feat(api): pass tile size param to most pipeline stages
This commit is contained in:
parent
c515d25dd4
commit
d8ec93a619
|
@ -7,7 +7,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..diffusers.load import load_pipeline
|
from ..diffusers.load import load_pipeline
|
||||||
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
|
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .stage import BaseStage
|
||||||
|
@ -16,6 +16,8 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SourceTxt2ImgStage(BaseStage):
|
class SourceTxt2ImgStage(BaseStage):
|
||||||
|
max_tile = SizeChart.unlimited
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
job: WorkerContext,
|
||||||
|
|
|
@ -42,7 +42,9 @@ def run_txt2img_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
stage = StageParams(
|
||||||
|
tile_size=params.tiles,
|
||||||
|
)
|
||||||
chain.stage(
|
chain.stage(
|
||||||
SourceTxt2ImgStage(),
|
SourceTxt2ImgStage(),
|
||||||
stage,
|
stage,
|
||||||
|
@ -122,7 +124,9 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
stage = StageParams(
|
||||||
|
tile_size=params.tiles,
|
||||||
|
)
|
||||||
chain.stage(
|
chain.stage(
|
||||||
BlendImg2ImgStage(),
|
BlendImg2ImgStage(),
|
||||||
stage,
|
stage,
|
||||||
|
@ -219,7 +223,7 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(tile_order=tile_order)
|
stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
|
||||||
chain.stage(
|
chain.stage(
|
||||||
UpscaleOutpaintStage(),
|
UpscaleOutpaintStage(),
|
||||||
stage,
|
stage,
|
||||||
|
@ -286,7 +290,7 @@ def run_upscale_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
# set up the chain pipeline, no base stage for upscaling
|
# set up the chain pipeline, no base stage for upscaling
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
stage = StageParams(tile_size=params.tiles)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
|
|
|
@ -14,6 +14,7 @@ Point = Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
class SizeChart(IntEnum):
|
class SizeChart(IntEnum):
|
||||||
|
unlimited = 0
|
||||||
mini = 128 # small tile for very expensive models
|
mini = 128 # small tile for very expensive models
|
||||||
half = 256 # half tile for outpainting
|
half = 256 # half tile for outpainting
|
||||||
auto = 512 # auto tile size
|
auto = 512 # auto tile size
|
||||||
|
@ -22,6 +23,7 @@ class SizeChart(IntEnum):
|
||||||
hd4k = 2**12
|
hd4k = 2**12
|
||||||
hd8k = 2**13
|
hd8k = 2**13
|
||||||
hd16k = 2**14
|
hd16k = 2**14
|
||||||
|
hd32k = 2**15
|
||||||
hd64k = 2**16
|
hd64k = 2**16
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue