feat: add parameter for ControlNet selection
This commit is contained in:
parent
fbf576746c
commit
9e017ee35d
|
@ -28,7 +28,7 @@ def blend_controlnet(
|
|||
params = params.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
logger.info(
|
||||
"blending image using controlnet, %s steps: %s", params.steps, params.prompt
|
||||
"blending image using ControlNet, %s steps: %s", params.steps, params.prompt
|
||||
)
|
||||
|
||||
pipe = load_pipeline(
|
||||
|
|
|
@ -174,7 +174,6 @@ def convert_diffusion_diffusers(
|
|||
if is_torch_2_0:
|
||||
pipe_cnet.set_attn_processor(CrossAttnProcessor())
|
||||
|
||||
|
||||
cnet_path = output_path / "cnet" / ONNX_MODEL
|
||||
onnx_export(
|
||||
pipe_cnet,
|
||||
|
|
|
@ -116,14 +116,13 @@ def load_pipeline(
|
|||
scheduler_name: str,
|
||||
device: DeviceParams,
|
||||
lpw: bool,
|
||||
control: Optional[str] = None,
|
||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
):
|
||||
inversions = inversions or []
|
||||
loras = loras or []
|
||||
|
||||
controlnet = "canny" # TODO; from params
|
||||
|
||||
torch_dtype = (
|
||||
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
||||
)
|
||||
|
@ -186,6 +185,8 @@ def load_pipeline(
|
|||
}
|
||||
|
||||
text_encoder = None
|
||||
|
||||
# Textual Inversion blending
|
||||
if inversions is not None and len(inversions) > 0:
|
||||
logger.debug("blending Textual Inversions from %s", inversions)
|
||||
inversion_names, inversion_weights = zip(*inversions)
|
||||
|
@ -225,7 +226,7 @@ def load_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# test LoRA blending
|
||||
# LoRA blending
|
||||
if loras is not None and len(loras) > 0:
|
||||
lora_names, lora_weights = zip(*loras)
|
||||
lora_models = [
|
||||
|
@ -278,8 +279,12 @@ def load_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
if controlnet is not None:
|
||||
components["controlnet"] = OnnxRuntimeModel.from_pretrained(controlnet)
|
||||
if control is not None:
|
||||
components["controlnet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(
|
||||
control,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
))
|
||||
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
|
@ -360,7 +365,7 @@ class UNetWrapper(object):
|
|||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
|
||||
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None, **kwargs):
|
||||
global timestep_dtype
|
||||
timestep_dtype = timestep.dtype
|
||||
|
||||
|
@ -382,6 +387,7 @@ class UNetWrapper(object):
|
|||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
|
@ -393,7 +399,7 @@ class VAEWrapper(object):
|
|||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __call__(self, latent_sample=None):
|
||||
def __call__(self, latent_sample=None, **kwargs):
|
||||
global timestep_dtype
|
||||
|
||||
logger.trace("VAE parameter types: %s", latent_sample.dtype)
|
||||
|
@ -401,7 +407,7 @@ class VAEWrapper(object):
|
|||
logger.info("converting VAE sample dtype")
|
||||
latent_sample = latent_sample.astype(timestep_dtype)
|
||||
|
||||
return self.wrapped(latent_sample=latent_sample)
|
||||
return self.wrapped(latent_sample=latent_sample, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
|
|
@ -239,8 +239,9 @@ def run_img2img_pipeline(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
inversions,
|
||||
loras,
|
||||
control=params.control,
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
progress = job.get_progress_callback()
|
||||
if params.lpw:
|
||||
|
|
|
@ -160,6 +160,18 @@ class DeviceParams:
|
|||
|
||||
|
||||
class ImageParams:
|
||||
model: str
|
||||
scheduler: str
|
||||
prompt: str
|
||||
cfg: float
|
||||
steps: int
|
||||
seed: int
|
||||
negative_prompt: Optional[str]
|
||||
lpw: bool
|
||||
eta: float
|
||||
batch: int
|
||||
control: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -172,6 +184,7 @@ class ImageParams:
|
|||
lpw: bool = False,
|
||||
eta: float = 0.0,
|
||||
batch: int = 1,
|
||||
control: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.scheduler = scheduler
|
||||
|
@ -183,6 +196,7 @@ class ImageParams:
|
|||
self.lpw = lpw or False
|
||||
self.eta = eta
|
||||
self.batch = batch
|
||||
self.control = control
|
||||
|
||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||
return {
|
||||
|
@ -196,6 +210,7 @@ class ImageParams:
|
|||
"lpw": self.lpw,
|
||||
"eta": self.eta,
|
||||
"batch": self.batch,
|
||||
"control": self.control,
|
||||
}
|
||||
|
||||
def with_args(self, **kwargs):
|
||||
|
@ -210,6 +225,7 @@ class ImageParams:
|
|||
kwargs.get("lpw", self.lpw),
|
||||
kwargs.get("eta", self.eta),
|
||||
kwargs.get("batch", self.batch),
|
||||
kwargs.get("control", self.control),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ def pipeline_from_request(
|
|||
device = platform
|
||||
|
||||
# pipeline stuff
|
||||
control = get_not_empty(request.args, "control", get_config_value("control"))
|
||||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||
model_path = get_model_path(server, model)
|
||||
|
@ -132,6 +133,7 @@ def pipeline_from_request(
|
|||
lpw=lpw,
|
||||
negative_prompt=negative_prompt,
|
||||
batch=batch,
|
||||
control=control,
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
"max": 30,
|
||||
"step": 0.1
|
||||
},
|
||||
"control": {
|
||||
"default": "",
|
||||
"keys": []
|
||||
},
|
||||
"correction": {
|
||||
"default": "",
|
||||
"keys": []
|
||||
|
|
|
@ -22,6 +22,10 @@
|
|||
"max": 30,
|
||||
"step": 0.1
|
||||
},
|
||||
"control": {
|
||||
"default": "",
|
||||
"keys": []
|
||||
},
|
||||
"correction": {
|
||||
"default": "",
|
||||
"keys": []
|
||||
|
|
|
@ -32,6 +32,11 @@ export interface ModelParams {
|
|||
* Use the long prompt weighting pipeline.
|
||||
*/
|
||||
lpw: boolean;
|
||||
|
||||
/**
|
||||
* ControlNet to be used.
|
||||
*/
|
||||
control: string;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -191,7 +196,7 @@ export interface ReadyResponse {
|
|||
|
||||
export interface NetworkModel {
|
||||
name: string;
|
||||
type: 'inversion' | 'lora';
|
||||
type: 'control' | 'inversion' | 'lora';
|
||||
// TODO: add token
|
||||
// TODO: add layer/token count
|
||||
}
|
||||
|
@ -392,6 +397,7 @@ export function appendModelToURL(url: URL, params: ModelParams) {
|
|||
url.searchParams.append('upscaling', params.upscaling);
|
||||
url.searchParams.append('correction', params.correction);
|
||||
url.searchParams.append('lpw', String(params.lpw));
|
||||
url.searchParams.append('control', params.control);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -116,6 +116,20 @@ export function ModelControl() {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
<QueryMenu
|
||||
id='control'
|
||||
labelKey='model'
|
||||
name={t('modelType.control')}
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.networks.filter((network) => network.type === 'control').map((network) => network.name),
|
||||
}}
|
||||
onSelect={(control) => {
|
||||
setModel({
|
||||
control,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
<Stack direction='row' spacing={2}>
|
||||
<FormControlLabel
|
||||
|
|
|
@ -500,11 +500,12 @@ export function createStateSlices(server: ServerParams) {
|
|||
|
||||
const createModelSlice: Slice<ModelSlice> = (set) => ({
|
||||
model: {
|
||||
control: server.control.default,
|
||||
correction: server.correction.default,
|
||||
lpw: false,
|
||||
model: server.model.default,
|
||||
platform: server.platform.default,
|
||||
upscaling: server.upscaling.default,
|
||||
correction: server.correction.default,
|
||||
lpw: false,
|
||||
},
|
||||
setModel(params) {
|
||||
set((prev) => ({
|
||||
|
|
Loading…
Reference in New Issue