1
0
Fork 0

feat: show model in image card, use labels for model and scheduler (#104)

This commit is contained in:
Sean Sube 2023-02-12 09:51:35 -06:00
parent 9c5043e9d0
commit 27a3fa8f51
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 55 additions and 36 deletions

View File

@ -2,7 +2,21 @@ from logging import getLogger
from typing import Any, Optional, Tuple
import numpy as np
from diffusers import DiffusionPipeline
from diffusers import (
DDIMScheduler,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ..params import DeviceParams, Size
from ..utils import run_gc
@ -22,6 +36,28 @@ last_pipeline_scheduler: Any = None
latent_channels = 4
latent_factor = 8
pipeline_schedulers = {
"ddim": DDIMScheduler,
"ddpm": DDPMScheduler,
"dpm-multi": DPMSolverMultistepScheduler,
"dpm-single": DPMSolverSinglestepScheduler,
"euler": EulerDiscreteScheduler,
"euler-a": EulerAncestralDiscreteScheduler,
"heun": HeunDiscreteScheduler,
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
"k-dpm-2": KDPM2DiscreteScheduler,
"karras-ve": KarrasVeScheduler,
"lms-discrete": LMSDiscreteScheduler,
"pndm": PNDMScheduler,
}
def get_scheduler_name(scheduler: Any) -> Optional[str]:
for k, v in pipeline_schedulers.items():
if scheduler == v or scheduler == v.__name__:
return k
return None
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
"""

View File

@ -1,12 +1,14 @@
from hashlib import sha256
from json import dumps
from logging import getLogger
from os import path
from struct import pack
from time import time
from typing import Any, Optional, Tuple
from PIL import Image
from .diffusion.load import get_scheduler_name
from .params import Border, ImageParams, Param, Size, UpscaleParams
from .utils import ServerContext, base_join
@ -38,6 +40,9 @@ def json_params(
"params": params.tojson(),
}
json["params"]["model"] = path.basename(params.model)
json["params"]["scheduler"] = get_scheduler_name(params.scheduler)
if upscale is not None and border is not None:
size = upscale.resize(size.add_border(border))

View File

@ -9,28 +9,12 @@ from typing import Dict, List, Tuple, Union
import numpy as np
import torch
import yaml
from diffusers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
from flask_cors import CORS
from jsonschema import validate
from onnxruntime import get_available_providers
from PIL import Image
from onnx_web.hacks import apply_patches
from .chain import (
ChainPipeline,
blend_img2img,
@ -48,12 +32,14 @@ from .chain import (
upscale_stable_diffusion,
)
from .device_pool import DevicePoolExecutor
from .diffusion.load import pipeline_schedulers
from .diffusion.run import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .hacks import apply_patches
from .image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
@ -99,20 +85,7 @@ platform_providers = {
"directml": "DmlExecutionProvider",
"rocm": "ROCMExecutionProvider",
}
pipeline_schedulers = {
"ddim": DDIMScheduler,
"ddpm": DDPMScheduler,
"dpm-multi": DPMSolverMultistepScheduler,
"dpm-single": DPMSolverSinglestepScheduler,
"euler": EulerDiscreteScheduler,
"euler-a": EulerAncestralDiscreteScheduler,
"heun": HeunDiscreteScheduler,
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
"k-dpm-2": KDPM2DiscreteScheduler,
"karras-ve": KarrasVeScheduler,
"lms-discrete": LMSDiscreteScheduler,
"pndm": PNDMScheduler,
}
noise_sources = {
"fill-edge": noise_source_fill_edge,
"fill-mask": noise_source_fill_mask,

View File

@ -137,7 +137,7 @@ export interface ImageResponse {
key: string;
url: string;
};
params: Required<BaseImgParams>;
params: Required<BaseImgParams> & Required<ModelParams>;
size: {
width: number;
height: number;

View File

@ -1,4 +1,4 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Brush, ContentCopy, Delete, Download } from '@mui/icons-material';
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Paper, Tooltip } from '@mui/material';
import * as React from 'react';
@ -8,6 +8,7 @@ import { useStore } from 'zustand';
import { ImageResponse } from '../client.js';
import { ConfigContext, StateContext } from '../state.js';
import { MODEL_LABELS, SCHEDULER_LABELS } from '../strings.js';
export interface ImageCardProps {
value: ImageResponse;
@ -64,6 +65,9 @@ export function ImageCard(props: ImageCardProps) {
window.open(output.url, '_blank');
}
const model = mustDefault(MODEL_LABELS[params.model], params.model);
const scheduler = mustDefault(SCHEDULER_LABELS[params.scheduler], params.scheduler);
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
<CardMedia sx={{ height: config.params.height.default }}
component='img'
@ -73,11 +77,12 @@ export function ImageCard(props: ImageCardProps) {
<CardContent>
<Box textAlign='center'>
<Grid container spacing={2}>
<GridItem xs={4}>Model: {model}</GridItem>
<GridItem xs={4}>Scheduler: {scheduler}</GridItem>
<GridItem xs={4}>Seed: {params.seed}</GridItem>
<GridItem xs={4}>CFG: {params.cfg}</GridItem>
<GridItem xs={4}>Steps: {params.steps}</GridItem>
<GridItem xs={4}>Size: {size.width}x{size.height}</GridItem>
<GridItem xs={4}>Seed: {params.seed}</GridItem>
<GridItem xs={8}>Scheduler: {params.scheduler}</GridItem>
<GridItem xs={12}>
<Box textAlign='left'>{params.prompt}</Box>
</GridItem>

View File

@ -1,5 +1,5 @@
// TODO: set up i18next
export const MODEL_LABELS = {
export const MODEL_LABELS: Record<string, string> = {
'stable-diffusion-onnx-v1-4': 'Stable Diffusion v1.4',
'stable-diffusion-onnx-v1-5': 'Stable Diffusion v1.5',
'stable-diffusion-onnx-v1-inpainting': 'SD Inpainting v1',