diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 8f228392..ed2d09a4 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -50,7 +50,7 @@ def needs_tile( source: Optional[Image.Image] = None, ) -> bool: tile = min(max_tile, stage_tile) - logger.debug("") + logger.trace("checking image tile dimensions: %s, %s, %s", tile, source.width > tile or source.height > tile, size.width > tile or size.height > tile) if source is not None: return source.width > tile or source.height > tile diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index e5142756..62cefbf7 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -387,7 +387,9 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): validate(data, schema) # get defaults from the regular parameters - device, base_params, base_size = pipeline_from_request(server, data=data) + device, base_params, base_size = pipeline_from_request(server, data=data.get("defaults", None)) + + # start building the pipeline pipeline = ChainPipeline() for stage_data in data.get("stages", []): stage_class = CHAIN_STAGES[stage_data.get("type")] diff --git a/api/schemas/chain.yaml b/api/schemas/chain.yaml index 211e65ac..e24593f6 100644 --- a/api/schemas/chain.yaml +++ b/api/schemas/chain.yaml @@ -56,10 +56,20 @@ $defs: items: $ref: "#/$defs/request_stage" + request_defaults: + type: object + properties: + txt2img: + $ref: "#/$defs/image_params" + img2img: + $ref: "#/$defs/image_params" + type: object additionalProperties: False required: [stages] properties: + defaults: + $ref: "#/$defs/request_defaults" platform: type: string stages: diff --git a/gui/src/client/types.ts b/gui/src/client/types.ts index 9be13103..dfc240bb 100644 --- a/gui/src/client/types.ts +++ b/gui/src/client/types.ts @@ -162,39 +162,43 @@ export interface HighresParams { highresStrength: number; } +export interface ChainStageParams { + tile_size: number; +} + export interface Txt2ImgStage { name: string; type: 'source-txt2img'; - params: Txt2ImgParams & { - tile_size: number; - }; + params: Partial; } export interface Img2ImgStage { name: string; type: 'blend-img2img'; - params: Img2ImgParams; + params: Partial; } export interface GridStage { name: string; type: 'blend-grid'; - params: { + params: Partial<{ height: number; width: number; - tile_size: number; - }; + } & ChainStageParams>; } export interface OutputStage { name: string; type: 'persist-disk'; - params: { - tile_size: number; - }; + params: Partial; } export interface ChainPipeline { + defaults?: { + txt2img?: Txt2ImgParams; + img2img?: Img2ImgParams; + }; + stages: Array; } diff --git a/gui/src/client/utils.ts b/gui/src/client/utils.ts index e771fc8b..88d79ff1 100644 --- a/gui/src/client/utils.ts +++ b/gui/src/client/utils.ts @@ -41,6 +41,9 @@ export function replacePromptTokens(grid: PipelineGrid, params: Txt2ImgParams, c // eslint-disable-next-line max-params export function buildPipelineForTxt2ImgGrid(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline { const pipeline: ChainPipeline = { + defaults: { + txt2img: params, + }, stages: [], };