1
0
Fork 0

add default parameters to chain pipeline

This commit is contained in:
Sean Sube 2023-09-12 19:06:13 -05:00
parent f9acf9b50f
commit 7d8819ef87
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 31 additions and 12 deletions

View File

@ -50,7 +50,7 @@ def needs_tile(
source: Optional[Image.Image] = None, source: Optional[Image.Image] = None,
) -> bool: ) -> bool:
tile = min(max_tile, stage_tile) tile = min(max_tile, stage_tile)
logger.debug("") logger.trace("checking image tile dimensions: %s, %s, %s", tile, source.width > tile or source.height > tile, size.width > tile or size.height > tile)
if source is not None: if source is not None:
return source.width > tile or source.height > tile return source.width > tile or source.height > tile

View File

@ -387,7 +387,9 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
validate(data, schema) validate(data, schema)
# get defaults from the regular parameters # get defaults from the regular parameters
device, base_params, base_size = pipeline_from_request(server, data=data) device, base_params, base_size = pipeline_from_request(server, data=data.get("defaults", None))
# start building the pipeline
pipeline = ChainPipeline() pipeline = ChainPipeline()
for stage_data in data.get("stages", []): for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")] stage_class = CHAIN_STAGES[stage_data.get("type")]

View File

@ -56,10 +56,20 @@ $defs:
items: items:
$ref: "#/$defs/request_stage" $ref: "#/$defs/request_stage"
request_defaults:
type: object
properties:
txt2img:
$ref: "#/$defs/image_params"
img2img:
$ref: "#/$defs/image_params"
type: object type: object
additionalProperties: False additionalProperties: False
required: [stages] required: [stages]
properties: properties:
defaults:
$ref: "#/$defs/request_defaults"
platform: platform:
type: string type: string
stages: stages:

View File

@ -162,39 +162,43 @@ export interface HighresParams {
highresStrength: number; highresStrength: number;
} }
export interface ChainStageParams {
tile_size: number;
}
export interface Txt2ImgStage { export interface Txt2ImgStage {
name: string; name: string;
type: 'source-txt2img'; type: 'source-txt2img';
params: Txt2ImgParams & { params: Partial<Txt2ImgParams & ChainStageParams>;
tile_size: number;
};
} }
export interface Img2ImgStage { export interface Img2ImgStage {
name: string; name: string;
type: 'blend-img2img'; type: 'blend-img2img';
params: Img2ImgParams; params: Partial<Img2ImgParams & ChainStageParams>;
} }
export interface GridStage { export interface GridStage {
name: string; name: string;
type: 'blend-grid'; type: 'blend-grid';
params: { params: Partial<{
height: number; height: number;
width: number; width: number;
tile_size: number; } & ChainStageParams>;
};
} }
export interface OutputStage { export interface OutputStage {
name: string; name: string;
type: 'persist-disk'; type: 'persist-disk';
params: { params: Partial<ChainStageParams>;
tile_size: number;
};
} }
export interface ChainPipeline { export interface ChainPipeline {
defaults?: {
txt2img?: Txt2ImgParams;
img2img?: Img2ImgParams;
};
stages: Array<Txt2ImgStage | Img2ImgStage | GridStage | OutputStage>; stages: Array<Txt2ImgStage | Img2ImgStage | GridStage | OutputStage>;
} }

View File

@ -41,6 +41,9 @@ export function replacePromptTokens(grid: PipelineGrid, params: Txt2ImgParams, c
// eslint-disable-next-line max-params // eslint-disable-next-line max-params
export function buildPipelineForTxt2ImgGrid(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline { export function buildPipelineForTxt2ImgGrid(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline {
const pipeline: ChainPipeline = { const pipeline: ChainPipeline = {
defaults: {
txt2img: params,
},
stages: [], stages: [],
}; };