feat: add parameter for ControlNet selection
This commit is contained in:
parent
fbf576746c
commit
9e017ee35d
|
@ -28,7 +28,7 @@ def blend_controlnet(
|
||||||
params = params.with_args(**kwargs)
|
params = params.with_args(**kwargs)
|
||||||
source = stage_source or source
|
source = stage_source or source
|
||||||
logger.info(
|
logger.info(
|
||||||
"blending image using controlnet, %s steps: %s", params.steps, params.prompt
|
"blending image using ControlNet, %s steps: %s", params.steps, params.prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
|
|
@ -174,7 +174,6 @@ def convert_diffusion_diffusers(
|
||||||
if is_torch_2_0:
|
if is_torch_2_0:
|
||||||
pipe_cnet.set_attn_processor(CrossAttnProcessor())
|
pipe_cnet.set_attn_processor(CrossAttnProcessor())
|
||||||
|
|
||||||
|
|
||||||
cnet_path = output_path / "cnet" / ONNX_MODEL
|
cnet_path = output_path / "cnet" / ONNX_MODEL
|
||||||
onnx_export(
|
onnx_export(
|
||||||
pipe_cnet,
|
pipe_cnet,
|
||||||
|
|
|
@ -116,14 +116,13 @@ def load_pipeline(
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
lpw: bool,
|
lpw: bool,
|
||||||
|
control: Optional[str] = None,
|
||||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||||
loras: Optional[List[Tuple[str, float]]] = None,
|
loras: Optional[List[Tuple[str, float]]] = None,
|
||||||
):
|
):
|
||||||
inversions = inversions or []
|
inversions = inversions or []
|
||||||
loras = loras or []
|
loras = loras or []
|
||||||
|
|
||||||
controlnet = "canny" # TODO; from params
|
|
||||||
|
|
||||||
torch_dtype = (
|
torch_dtype = (
|
||||||
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
||||||
)
|
)
|
||||||
|
@ -186,6 +185,8 @@ def load_pipeline(
|
||||||
}
|
}
|
||||||
|
|
||||||
text_encoder = None
|
text_encoder = None
|
||||||
|
|
||||||
|
# Textual Inversion blending
|
||||||
if inversions is not None and len(inversions) > 0:
|
if inversions is not None and len(inversions) > 0:
|
||||||
logger.debug("blending Textual Inversions from %s", inversions)
|
logger.debug("blending Textual Inversions from %s", inversions)
|
||||||
inversion_names, inversion_weights = zip(*inversions)
|
inversion_names, inversion_weights = zip(*inversions)
|
||||||
|
@ -225,7 +226,7 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# test LoRA blending
|
# LoRA blending
|
||||||
if loras is not None and len(loras) > 0:
|
if loras is not None and len(loras) > 0:
|
||||||
lora_names, lora_weights = zip(*loras)
|
lora_names, lora_weights = zip(*loras)
|
||||||
lora_models = [
|
lora_models = [
|
||||||
|
@ -278,8 +279,12 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if controlnet is not None:
|
if control is not None:
|
||||||
components["controlnet"] = OnnxRuntimeModel.from_pretrained(controlnet)
|
components["controlnet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(
|
||||||
|
control,
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
))
|
||||||
|
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
|
@ -360,7 +365,7 @@ class UNetWrapper(object):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
|
|
||||||
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
|
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None, **kwargs):
|
||||||
global timestep_dtype
|
global timestep_dtype
|
||||||
timestep_dtype = timestep.dtype
|
timestep_dtype = timestep.dtype
|
||||||
|
|
||||||
|
@ -382,6 +387,7 @@ class UNetWrapper(object):
|
||||||
sample=sample,
|
sample=sample,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
|
@ -393,7 +399,7 @@ class VAEWrapper(object):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
|
|
||||||
def __call__(self, latent_sample=None):
|
def __call__(self, latent_sample=None, **kwargs):
|
||||||
global timestep_dtype
|
global timestep_dtype
|
||||||
|
|
||||||
logger.trace("VAE parameter types: %s", latent_sample.dtype)
|
logger.trace("VAE parameter types: %s", latent_sample.dtype)
|
||||||
|
@ -401,7 +407,7 @@ class VAEWrapper(object):
|
||||||
logger.info("converting VAE sample dtype")
|
logger.info("converting VAE sample dtype")
|
||||||
latent_sample = latent_sample.astype(timestep_dtype)
|
latent_sample = latent_sample.astype(timestep_dtype)
|
||||||
|
|
||||||
return self.wrapped(latent_sample=latent_sample)
|
return self.wrapped(latent_sample=latent_sample, **kwargs)
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
return getattr(self.wrapped, attr)
|
return getattr(self.wrapped, attr)
|
||||||
|
|
|
@ -239,8 +239,9 @@ def run_img2img_pipeline(
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
params.lpw,
|
params.lpw,
|
||||||
inversions,
|
control=params.control,
|
||||||
loras,
|
inversions=inversions,
|
||||||
|
loras=loras,
|
||||||
)
|
)
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
if params.lpw:
|
if params.lpw:
|
||||||
|
|
|
@ -160,6 +160,18 @@ class DeviceParams:
|
||||||
|
|
||||||
|
|
||||||
class ImageParams:
|
class ImageParams:
|
||||||
|
model: str
|
||||||
|
scheduler: str
|
||||||
|
prompt: str
|
||||||
|
cfg: float
|
||||||
|
steps: int
|
||||||
|
seed: int
|
||||||
|
negative_prompt: Optional[str]
|
||||||
|
lpw: bool
|
||||||
|
eta: float
|
||||||
|
batch: int
|
||||||
|
control: Optional[str]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -172,6 +184,7 @@ class ImageParams:
|
||||||
lpw: bool = False,
|
lpw: bool = False,
|
||||||
eta: float = 0.0,
|
eta: float = 0.0,
|
||||||
batch: int = 1,
|
batch: int = 1,
|
||||||
|
control: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
|
@ -183,6 +196,7 @@ class ImageParams:
|
||||||
self.lpw = lpw or False
|
self.lpw = lpw or False
|
||||||
self.eta = eta
|
self.eta = eta
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
|
self.control = control
|
||||||
|
|
||||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||||
return {
|
return {
|
||||||
|
@ -196,6 +210,7 @@ class ImageParams:
|
||||||
"lpw": self.lpw,
|
"lpw": self.lpw,
|
||||||
"eta": self.eta,
|
"eta": self.eta,
|
||||||
"batch": self.batch,
|
"batch": self.batch,
|
||||||
|
"control": self.control,
|
||||||
}
|
}
|
||||||
|
|
||||||
def with_args(self, **kwargs):
|
def with_args(self, **kwargs):
|
||||||
|
@ -210,6 +225,7 @@ class ImageParams:
|
||||||
kwargs.get("lpw", self.lpw),
|
kwargs.get("lpw", self.lpw),
|
||||||
kwargs.get("eta", self.eta),
|
kwargs.get("eta", self.eta),
|
||||||
kwargs.get("batch", self.batch),
|
kwargs.get("batch", self.batch),
|
||||||
|
kwargs.get("control", self.control),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,7 @@ def pipeline_from_request(
|
||||||
device = platform
|
device = platform
|
||||||
|
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
|
control = get_not_empty(request.args, "control", get_config_value("control"))
|
||||||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||||
model_path = get_model_path(server, model)
|
model_path = get_model_path(server, model)
|
||||||
|
@ -132,6 +133,7 @@ def pipeline_from_request(
|
||||||
lpw=lpw,
|
lpw=lpw,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
control=control,
|
||||||
)
|
)
|
||||||
size = Size(width, height)
|
size = Size(width, height)
|
||||||
return (device, params, size)
|
return (device, params, size)
|
||||||
|
|
|
@ -18,6 +18,10 @@
|
||||||
"max": 30,
|
"max": 30,
|
||||||
"step": 0.1
|
"step": 0.1
|
||||||
},
|
},
|
||||||
|
"control": {
|
||||||
|
"default": "",
|
||||||
|
"keys": []
|
||||||
|
},
|
||||||
"correction": {
|
"correction": {
|
||||||
"default": "",
|
"default": "",
|
||||||
"keys": []
|
"keys": []
|
||||||
|
|
|
@ -22,6 +22,10 @@
|
||||||
"max": 30,
|
"max": 30,
|
||||||
"step": 0.1
|
"step": 0.1
|
||||||
},
|
},
|
||||||
|
"control": {
|
||||||
|
"default": "",
|
||||||
|
"keys": []
|
||||||
|
},
|
||||||
"correction": {
|
"correction": {
|
||||||
"default": "",
|
"default": "",
|
||||||
"keys": []
|
"keys": []
|
||||||
|
|
|
@ -32,6 +32,11 @@ export interface ModelParams {
|
||||||
* Use the long prompt weighting pipeline.
|
* Use the long prompt weighting pipeline.
|
||||||
*/
|
*/
|
||||||
lpw: boolean;
|
lpw: boolean;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ControlNet to be used.
|
||||||
|
*/
|
||||||
|
control: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -191,7 +196,7 @@ export interface ReadyResponse {
|
||||||
|
|
||||||
export interface NetworkModel {
|
export interface NetworkModel {
|
||||||
name: string;
|
name: string;
|
||||||
type: 'inversion' | 'lora';
|
type: 'control' | 'inversion' | 'lora';
|
||||||
// TODO: add token
|
// TODO: add token
|
||||||
// TODO: add layer/token count
|
// TODO: add layer/token count
|
||||||
}
|
}
|
||||||
|
@ -392,6 +397,7 @@ export function appendModelToURL(url: URL, params: ModelParams) {
|
||||||
url.searchParams.append('upscaling', params.upscaling);
|
url.searchParams.append('upscaling', params.upscaling);
|
||||||
url.searchParams.append('correction', params.correction);
|
url.searchParams.append('correction', params.correction);
|
||||||
url.searchParams.append('lpw', String(params.lpw));
|
url.searchParams.append('lpw', String(params.lpw));
|
||||||
|
url.searchParams.append('control', params.control);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -116,6 +116,20 @@ export function ModelControl() {
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
<QueryMenu
|
||||||
|
id='control'
|
||||||
|
labelKey='model'
|
||||||
|
name={t('modelType.control')}
|
||||||
|
query={{
|
||||||
|
result: models,
|
||||||
|
selector: (result) => result.networks.filter((network) => network.type === 'control').map((network) => network.name),
|
||||||
|
}}
|
||||||
|
onSelect={(control) => {
|
||||||
|
setModel({
|
||||||
|
control,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
</Stack>
|
</Stack>
|
||||||
<Stack direction='row' spacing={2}>
|
<Stack direction='row' spacing={2}>
|
||||||
<FormControlLabel
|
<FormControlLabel
|
||||||
|
|
|
@ -500,11 +500,12 @@ export function createStateSlices(server: ServerParams) {
|
||||||
|
|
||||||
const createModelSlice: Slice<ModelSlice> = (set) => ({
|
const createModelSlice: Slice<ModelSlice> = (set) => ({
|
||||||
model: {
|
model: {
|
||||||
|
control: server.control.default,
|
||||||
|
correction: server.correction.default,
|
||||||
|
lpw: false,
|
||||||
model: server.model.default,
|
model: server.model.default,
|
||||||
platform: server.platform.default,
|
platform: server.platform.default,
|
||||||
upscaling: server.upscaling.default,
|
upscaling: server.upscaling.default,
|
||||||
correction: server.correction.default,
|
|
||||||
lpw: false,
|
|
||||||
},
|
},
|
||||||
setModel(params) {
|
setModel(params) {
|
||||||
set((prev) => ({
|
set((prev) => ({
|
||||||
|
|
Loading…
Reference in New Issue