add default parameters to chain pipeline
This commit is contained in:
parent
f9acf9b50f
commit
7d8819ef87
|
@ -50,7 +50,7 @@ def needs_tile(
|
||||||
source: Optional[Image.Image] = None,
|
source: Optional[Image.Image] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
tile = min(max_tile, stage_tile)
|
tile = min(max_tile, stage_tile)
|
||||||
logger.debug("")
|
logger.trace("checking image tile dimensions: %s, %s, %s", tile, source.width > tile or source.height > tile, size.width > tile or size.height > tile)
|
||||||
|
|
||||||
if source is not None:
|
if source is not None:
|
||||||
return source.width > tile or source.height > tile
|
return source.width > tile or source.height > tile
|
||||||
|
|
|
@ -387,7 +387,9 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
# get defaults from the regular parameters
|
# get defaults from the regular parameters
|
||||||
device, base_params, base_size = pipeline_from_request(server, data=data)
|
device, base_params, base_size = pipeline_from_request(server, data=data.get("defaults", None))
|
||||||
|
|
||||||
|
# start building the pipeline
|
||||||
pipeline = ChainPipeline()
|
pipeline = ChainPipeline()
|
||||||
for stage_data in data.get("stages", []):
|
for stage_data in data.get("stages", []):
|
||||||
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
||||||
|
|
|
@ -56,10 +56,20 @@ $defs:
|
||||||
items:
|
items:
|
||||||
$ref: "#/$defs/request_stage"
|
$ref: "#/$defs/request_stage"
|
||||||
|
|
||||||
|
request_defaults:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
txt2img:
|
||||||
|
$ref: "#/$defs/image_params"
|
||||||
|
img2img:
|
||||||
|
$ref: "#/$defs/image_params"
|
||||||
|
|
||||||
type: object
|
type: object
|
||||||
additionalProperties: False
|
additionalProperties: False
|
||||||
required: [stages]
|
required: [stages]
|
||||||
properties:
|
properties:
|
||||||
|
defaults:
|
||||||
|
$ref: "#/$defs/request_defaults"
|
||||||
platform:
|
platform:
|
||||||
type: string
|
type: string
|
||||||
stages:
|
stages:
|
||||||
|
|
|
@ -162,39 +162,43 @@ export interface HighresParams {
|
||||||
highresStrength: number;
|
highresStrength: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ChainStageParams {
|
||||||
|
tile_size: number;
|
||||||
|
}
|
||||||
|
|
||||||
export interface Txt2ImgStage {
|
export interface Txt2ImgStage {
|
||||||
name: string;
|
name: string;
|
||||||
type: 'source-txt2img';
|
type: 'source-txt2img';
|
||||||
params: Txt2ImgParams & {
|
params: Partial<Txt2ImgParams & ChainStageParams>;
|
||||||
tile_size: number;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Img2ImgStage {
|
export interface Img2ImgStage {
|
||||||
name: string;
|
name: string;
|
||||||
type: 'blend-img2img';
|
type: 'blend-img2img';
|
||||||
params: Img2ImgParams;
|
params: Partial<Img2ImgParams & ChainStageParams>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface GridStage {
|
export interface GridStage {
|
||||||
name: string;
|
name: string;
|
||||||
type: 'blend-grid';
|
type: 'blend-grid';
|
||||||
params: {
|
params: Partial<{
|
||||||
height: number;
|
height: number;
|
||||||
width: number;
|
width: number;
|
||||||
tile_size: number;
|
} & ChainStageParams>;
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface OutputStage {
|
export interface OutputStage {
|
||||||
name: string;
|
name: string;
|
||||||
type: 'persist-disk';
|
type: 'persist-disk';
|
||||||
params: {
|
params: Partial<ChainStageParams>;
|
||||||
tile_size: number;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChainPipeline {
|
export interface ChainPipeline {
|
||||||
|
defaults?: {
|
||||||
|
txt2img?: Txt2ImgParams;
|
||||||
|
img2img?: Img2ImgParams;
|
||||||
|
};
|
||||||
|
|
||||||
stages: Array<Txt2ImgStage | Img2ImgStage | GridStage | OutputStage>;
|
stages: Array<Txt2ImgStage | Img2ImgStage | GridStage | OutputStage>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,9 @@ export function replacePromptTokens(grid: PipelineGrid, params: Txt2ImgParams, c
|
||||||
// eslint-disable-next-line max-params
|
// eslint-disable-next-line max-params
|
||||||
export function buildPipelineForTxt2ImgGrid(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline {
|
export function buildPipelineForTxt2ImgGrid(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline {
|
||||||
const pipeline: ChainPipeline = {
|
const pipeline: ChainPipeline = {
|
||||||
|
defaults: {
|
||||||
|
txt2img: params,
|
||||||
|
},
|
||||||
stages: [],
|
stages: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue