1
0
Fork 0

feat: add parameter for ControlNet selection

This commit is contained in:
Sean Sube 2023-04-12 08:43:15 -05:00
parent fbf576746c
commit 9e017ee35d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
11 changed files with 68 additions and 15 deletions

View File

@ -28,7 +28,7 @@ def blend_controlnet(
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
source = stage_source or source source = stage_source or source
logger.info( logger.info(
"blending image using controlnet, %s steps: %s", params.steps, params.prompt "blending image using ControlNet, %s steps: %s", params.steps, params.prompt
) )
pipe = load_pipeline( pipe = load_pipeline(

View File

@ -174,7 +174,6 @@ def convert_diffusion_diffusers(
if is_torch_2_0: if is_torch_2_0:
pipe_cnet.set_attn_processor(CrossAttnProcessor()) pipe_cnet.set_attn_processor(CrossAttnProcessor())
cnet_path = output_path / "cnet" / ONNX_MODEL cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export( onnx_export(
pipe_cnet, pipe_cnet,

View File

@ -116,14 +116,13 @@ def load_pipeline(
scheduler_name: str, scheduler_name: str,
device: DeviceParams, device: DeviceParams,
lpw: bool, lpw: bool,
control: Optional[str] = None,
inversions: Optional[List[Tuple[str, float]]] = None, inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None,
): ):
inversions = inversions or [] inversions = inversions or []
loras = loras or [] loras = loras or []
controlnet = "canny" # TODO; from params
torch_dtype = ( torch_dtype = (
torch.float16 if "torch-fp16" in server.optimizations else torch.float32 torch.float16 if "torch-fp16" in server.optimizations else torch.float32
) )
@ -186,6 +185,8 @@ def load_pipeline(
} }
text_encoder = None text_encoder = None
# Textual Inversion blending
if inversions is not None and len(inversions) > 0: if inversions is not None and len(inversions) > 0:
logger.debug("blending Textual Inversions from %s", inversions) logger.debug("blending Textual Inversions from %s", inversions)
inversion_names, inversion_weights = zip(*inversions) inversion_names, inversion_weights = zip(*inversions)
@ -225,7 +226,7 @@ def load_pipeline(
) )
) )
# test LoRA blending # LoRA blending
if loras is not None and len(loras) > 0: if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras) lora_names, lora_weights = zip(*loras)
lora_models = [ lora_models = [
@ -278,8 +279,12 @@ def load_pipeline(
) )
) )
if controlnet is not None: if control is not None:
components["controlnet"] = OnnxRuntimeModel.from_pretrained(controlnet) components["controlnet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(
control,
provider=device.ort_provider(),
sess_options=device.sess_options(),
))
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,
@ -360,7 +365,7 @@ class UNetWrapper(object):
self.server = server self.server = server
self.wrapped = wrapped self.wrapped = wrapped
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None): def __call__(self, sample=None, timestep=None, encoder_hidden_states=None, **kwargs):
global timestep_dtype global timestep_dtype
timestep_dtype = timestep.dtype timestep_dtype = timestep.dtype
@ -382,6 +387,7 @@ class UNetWrapper(object):
sample=sample, sample=sample,
timestep=timestep, timestep=timestep,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
**kwargs,
) )
def __getattr__(self, attr): def __getattr__(self, attr):
@ -393,7 +399,7 @@ class VAEWrapper(object):
self.server = server self.server = server
self.wrapped = wrapped self.wrapped = wrapped
def __call__(self, latent_sample=None): def __call__(self, latent_sample=None, **kwargs):
global timestep_dtype global timestep_dtype
logger.trace("VAE parameter types: %s", latent_sample.dtype) logger.trace("VAE parameter types: %s", latent_sample.dtype)
@ -401,7 +407,7 @@ class VAEWrapper(object):
logger.info("converting VAE sample dtype") logger.info("converting VAE sample dtype")
latent_sample = latent_sample.astype(timestep_dtype) latent_sample = latent_sample.astype(timestep_dtype)
return self.wrapped(latent_sample=latent_sample) return self.wrapped(latent_sample=latent_sample, **kwargs)
def __getattr__(self, attr): def __getattr__(self, attr):
return getattr(self.wrapped, attr) return getattr(self.wrapped, attr)

View File

@ -239,8 +239,9 @@ def run_img2img_pipeline(
params.scheduler, params.scheduler,
job.get_device(), job.get_device(),
params.lpw, params.lpw,
inversions, control=params.control,
loras, inversions=inversions,
loras=loras,
) )
progress = job.get_progress_callback() progress = job.get_progress_callback()
if params.lpw: if params.lpw:

View File

@ -160,6 +160,18 @@ class DeviceParams:
class ImageParams: class ImageParams:
model: str
scheduler: str
prompt: str
cfg: float
steps: int
seed: int
negative_prompt: Optional[str]
lpw: bool
eta: float
batch: int
control: Optional[str]
def __init__( def __init__(
self, self,
model: str, model: str,
@ -172,6 +184,7 @@ class ImageParams:
lpw: bool = False, lpw: bool = False,
eta: float = 0.0, eta: float = 0.0,
batch: int = 1, batch: int = 1,
control: Optional[str] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.scheduler = scheduler self.scheduler = scheduler
@ -183,6 +196,7 @@ class ImageParams:
self.lpw = lpw or False self.lpw = lpw or False
self.eta = eta self.eta = eta
self.batch = batch self.batch = batch
self.control = control
def tojson(self) -> Dict[str, Optional[Param]]: def tojson(self) -> Dict[str, Optional[Param]]:
return { return {
@ -196,6 +210,7 @@ class ImageParams:
"lpw": self.lpw, "lpw": self.lpw,
"eta": self.eta, "eta": self.eta,
"batch": self.batch, "batch": self.batch,
"control": self.control,
} }
def with_args(self, **kwargs): def with_args(self, **kwargs):
@ -210,6 +225,7 @@ class ImageParams:
kwargs.get("lpw", self.lpw), kwargs.get("lpw", self.lpw),
kwargs.get("eta", self.eta), kwargs.get("eta", self.eta),
kwargs.get("batch", self.batch), kwargs.get("batch", self.batch),
kwargs.get("control", self.control),
) )

View File

@ -42,6 +42,7 @@ def pipeline_from_request(
device = platform device = platform
# pipeline stuff # pipeline stuff
control = get_not_empty(request.args, "control", get_config_value("control"))
lpw = get_not_empty(request.args, "lpw", "false") == "true" lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model")) model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(server, model) model_path = get_model_path(server, model)
@ -132,6 +133,7 @@ def pipeline_from_request(
lpw=lpw, lpw=lpw,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
batch=batch, batch=batch,
control=control,
) )
size = Size(width, height) size = Size(width, height)
return (device, params, size) return (device, params, size)

View File

@ -18,6 +18,10 @@
"max": 30, "max": 30,
"step": 0.1 "step": 0.1
}, },
"control": {
"default": "",
"keys": []
},
"correction": { "correction": {
"default": "", "default": "",
"keys": [] "keys": []

View File

@ -22,6 +22,10 @@
"max": 30, "max": 30,
"step": 0.1 "step": 0.1
}, },
"control": {
"default": "",
"keys": []
},
"correction": { "correction": {
"default": "", "default": "",
"keys": [] "keys": []

View File

@ -32,6 +32,11 @@ export interface ModelParams {
* Use the long prompt weighting pipeline. * Use the long prompt weighting pipeline.
*/ */
lpw: boolean; lpw: boolean;
/**
* ControlNet to be used.
*/
control: string;
} }
/** /**
@ -191,7 +196,7 @@ export interface ReadyResponse {
export interface NetworkModel { export interface NetworkModel {
name: string; name: string;
type: 'inversion' | 'lora'; type: 'control' | 'inversion' | 'lora';
// TODO: add token // TODO: add token
// TODO: add layer/token count // TODO: add layer/token count
} }
@ -392,6 +397,7 @@ export function appendModelToURL(url: URL, params: ModelParams) {
url.searchParams.append('upscaling', params.upscaling); url.searchParams.append('upscaling', params.upscaling);
url.searchParams.append('correction', params.correction); url.searchParams.append('correction', params.correction);
url.searchParams.append('lpw', String(params.lpw)); url.searchParams.append('lpw', String(params.lpw));
url.searchParams.append('control', params.control);
} }
/** /**

View File

@ -116,6 +116,20 @@ export function ModelControl() {
}); });
}} }}
/> />
<QueryMenu
id='control'
labelKey='model'
name={t('modelType.control')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'control').map((network) => network.name),
}}
onSelect={(control) => {
setModel({
control,
});
}}
/>
</Stack> </Stack>
<Stack direction='row' spacing={2}> <Stack direction='row' spacing={2}>
<FormControlLabel <FormControlLabel

View File

@ -500,11 +500,12 @@ export function createStateSlices(server: ServerParams) {
const createModelSlice: Slice<ModelSlice> = (set) => ({ const createModelSlice: Slice<ModelSlice> = (set) => ({
model: { model: {
control: server.control.default,
correction: server.correction.default,
lpw: false,
model: server.model.default, model: server.model.default,
platform: server.platform.default, platform: server.platform.default,
upscaling: server.upscaling.default, upscaling: server.upscaling.default,
correction: server.correction.default,
lpw: false,
}, },
setModel(params) { setModel(params) {
set((prev) => ({ set((prev) => ({