1
0
Fork 0

feat: add a way to select textual inversions

This commit is contained in:
Sean Sube 2023-02-21 23:08:13 -06:00
parent 45f5fca383
commit 2e7de16778
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
12 changed files with 55 additions and 2 deletions

View File

@ -36,6 +36,7 @@ def blend_img2img(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
if params.lpw:
logger.debug("using LPW pipeline for img2img")

View File

@ -75,6 +75,7 @@ def blend_inpaint(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
if params.lpw:

View File

@ -42,6 +42,7 @@ def source_txt2img(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
if params.lpw:

View File

@ -79,6 +79,7 @@ def upscale_outpaint(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
if params.lpw:
logger.debug("using LPW pipeline for inpaint")

View File

@ -18,6 +18,7 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
OnnxRuntimeModel,
)
try:
@ -138,8 +139,9 @@ def load_pipeline(
scheduler_type: Any,
device: DeviceParams,
lpw: bool,
inversion: Optional[str],
):
pipe_key = (pipeline, model, device.device, device.provider, lpw)
pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion)
scheduler_key = (scheduler_type, model)
cache_pipe = server.cache.get("diffusion", pipe_key)
@ -182,6 +184,17 @@ def load_pipeline(
sess_options=device.sess_options(),
subfolder="scheduler",
)
text_encoder = None
if inversion is not None:
logger.debug("loading text encoder from %s", inversion)
text_encoder = OnnxRuntimeModel.from_pretrained(
inversion,
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="text_encoder",
)
pipe = pipeline.from_pretrained(
model,
custom_pipeline=custom_pipeline,
@ -190,6 +203,7 @@ def load_pipeline(
revision="onnx",
safety_checker=None,
scheduler=scheduler,
text_encoder=text_encoder,
)
if not server.show_progress:

View File

@ -36,6 +36,7 @@ def run_txt2img_pipeline(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
progress = job.get_progress_callback()
@ -109,6 +110,7 @@ def run_img2img_pipeline(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
progress = job.get_progress_callback()
if params.lpw:

View File

@ -19,6 +19,8 @@ logger = getLogger(__name__)
def hash_value(sha, param: Param):
if param is None:
return
elif isinstance(param, bool):
sha.update(bytearray(pack("!B", param)))
elif isinstance(param, float):
sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int):
@ -73,8 +75,12 @@ def make_output_name(
hash_value(sha, params.prompt)
hash_value(sha, params.negative_prompt)
hash_value(sha, params.cfg)
hash_value(sha, params.steps)
hash_value(sha, params.seed)
hash_value(sha, params.steps)
hash_value(sha, params.lpw)
hash_value(sha, params.eta)
hash_value(sha, params.batch)
hash_value(sha, params.inversion)
hash_value(sha, size.width)
hash_value(sha, size.height)

View File

@ -157,6 +157,7 @@ class ImageParams:
lpw: bool = False,
eta: float = 0.0,
batch: int = 1,
inversion: str = None,
) -> None:
self.model = model
self.scheduler = scheduler
@ -168,6 +169,7 @@ class ImageParams:
self.lpw = lpw or False
self.eta = eta
self.batch = batch
self.inversion = inversion
def tojson(self) -> Dict[str, Optional[Param]]:
return {
@ -181,6 +183,7 @@ class ImageParams:
"lpw": self.lpw,
"eta": self.eta,
"batch": self.batch,
"inversion": self.inversion,
}
def with_args(self, **kwargs):
@ -195,6 +198,7 @@ class ImageParams:
kwargs.get("lpw", self.lpw),
kwargs.get("eta", self.eta),
kwargs.get("batch", self.batch),
kwargs.get("inversion", self.inversion),
)

View File

@ -159,6 +159,8 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
scheduler = get_from_map(
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
)
inversion = get_not_empty(request.args, "inversion", get_config_value("inversion"))
inversion_path = get_model_path(inversion)
# image params
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
@ -240,6 +242,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
lpw=lpw,
negative_prompt=negative_prompt,
batch=batch,
inversion=inversion_path,
)
size = Size(width, height)
return (device, params, size)

View File

@ -32,6 +32,8 @@ export interface ModelParams {
* Use the long prompt weighting pipeline.
*/
lpw: boolean;
inversion: string;
}
/**
@ -183,6 +185,7 @@ export interface ReadyResponse {
export interface ModelsResponse {
diffusion: Array<string>;
correction: Array<string>;
inversion: Array<string>;
upscaling: Array<string>;
}
@ -325,6 +328,7 @@ export function appendModelToURL(url: URL, params: ModelParams) {
url.searchParams.append('upscaling', params.upscaling);
url.searchParams.append('correction', params.correction);
url.searchParams.append('lpw', String(params.lpw));
url.searchParams.append('inversion', params.inversion);
}
/**

View File

@ -54,6 +54,21 @@ export function ModelControl() {
});
}}
/>
<QueryList
id='inversion'
labels={MODEL_LABELS}
name='Textual Inversion'
query={{
result: models,
selector: (result) => result.inversion,
}}
value={params.inversion}
onChange={(inversion) => {
setModel({
inversion,
});
}}
/>
<QueryList
id='upscaling'
labels={MODEL_LABELS}

View File

@ -487,6 +487,7 @@ export function createStateSlices(server: ServerParams) {
platform: server.platform.default,
upscaling: server.upscaling.default,
correction: server.correction.default,
inversion: server.inversion.default,
lpw: false,
},
setModel(params) {