feat: show model in image card, use labels for model and scheduler (#104)
This commit is contained in:
parent
9c5043e9d0
commit
27a3fa8f51
|
@ -2,7 +2,21 @@ from logging import getLogger
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import (
|
||||||
|
DDIMScheduler,
|
||||||
|
DDPMScheduler,
|
||||||
|
DiffusionPipeline,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
DPMSolverSinglestepScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
HeunDiscreteScheduler,
|
||||||
|
KarrasVeScheduler,
|
||||||
|
KDPM2AncestralDiscreteScheduler,
|
||||||
|
KDPM2DiscreteScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
from ..params import DeviceParams, Size
|
from ..params import DeviceParams, Size
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
@ -22,6 +36,28 @@ last_pipeline_scheduler: Any = None
|
||||||
latent_channels = 4
|
latent_channels = 4
|
||||||
latent_factor = 8
|
latent_factor = 8
|
||||||
|
|
||||||
|
pipeline_schedulers = {
|
||||||
|
"ddim": DDIMScheduler,
|
||||||
|
"ddpm": DDPMScheduler,
|
||||||
|
"dpm-multi": DPMSolverMultistepScheduler,
|
||||||
|
"dpm-single": DPMSolverSinglestepScheduler,
|
||||||
|
"euler": EulerDiscreteScheduler,
|
||||||
|
"euler-a": EulerAncestralDiscreteScheduler,
|
||||||
|
"heun": HeunDiscreteScheduler,
|
||||||
|
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
||||||
|
"k-dpm-2": KDPM2DiscreteScheduler,
|
||||||
|
"karras-ve": KarrasVeScheduler,
|
||||||
|
"lms-discrete": LMSDiscreteScheduler,
|
||||||
|
"pndm": PNDMScheduler,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
||||||
|
for k, v in pipeline_schedulers.items():
|
||||||
|
if scheduler == v or scheduler == v.__name__:
|
||||||
|
return k
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from json import dumps
|
from json import dumps
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
from struct import pack
|
from struct import pack
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from .diffusion.load import get_scheduler_name
|
||||||
from .params import Border, ImageParams, Param, Size, UpscaleParams
|
from .params import Border, ImageParams, Param, Size, UpscaleParams
|
||||||
from .utils import ServerContext, base_join
|
from .utils import ServerContext, base_join
|
||||||
|
|
||||||
|
@ -38,6 +40,9 @@ def json_params(
|
||||||
"params": params.tojson(),
|
"params": params.tojson(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json["params"]["model"] = path.basename(params.model)
|
||||||
|
json["params"]["scheduler"] = get_scheduler_name(params.scheduler)
|
||||||
|
|
||||||
if upscale is not None and border is not None:
|
if upscale is not None and border is not None:
|
||||||
size = upscale.resize(size.add_border(border))
|
size = upscale.resize(size.add_border(border))
|
||||||
|
|
||||||
|
|
|
@ -9,28 +9,12 @@ from typing import Dict, List, Tuple, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from diffusers import (
|
|
||||||
DDIMScheduler,
|
|
||||||
DDPMScheduler,
|
|
||||||
DPMSolverMultistepScheduler,
|
|
||||||
DPMSolverSinglestepScheduler,
|
|
||||||
EulerAncestralDiscreteScheduler,
|
|
||||||
EulerDiscreteScheduler,
|
|
||||||
HeunDiscreteScheduler,
|
|
||||||
KarrasVeScheduler,
|
|
||||||
KDPM2AncestralDiscreteScheduler,
|
|
||||||
KDPM2DiscreteScheduler,
|
|
||||||
LMSDiscreteScheduler,
|
|
||||||
PNDMScheduler,
|
|
||||||
)
|
|
||||||
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
from onnxruntime import get_available_providers
|
from onnxruntime import get_available_providers
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.hacks import apply_patches
|
|
||||||
|
|
||||||
from .chain import (
|
from .chain import (
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
blend_img2img,
|
blend_img2img,
|
||||||
|
@ -48,12 +32,14 @@ from .chain import (
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
)
|
)
|
||||||
from .device_pool import DevicePoolExecutor
|
from .device_pool import DevicePoolExecutor
|
||||||
|
from .diffusion.load import pipeline_schedulers
|
||||||
from .diffusion.run import (
|
from .diffusion.run import (
|
||||||
run_img2img_pipeline,
|
run_img2img_pipeline,
|
||||||
run_inpaint_pipeline,
|
run_inpaint_pipeline,
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
|
from .hacks import apply_patches
|
||||||
from .image import ( # mask filters; noise sources
|
from .image import ( # mask filters; noise sources
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
mask_filter_gaussian_screen,
|
mask_filter_gaussian_screen,
|
||||||
|
@ -99,20 +85,7 @@ platform_providers = {
|
||||||
"directml": "DmlExecutionProvider",
|
"directml": "DmlExecutionProvider",
|
||||||
"rocm": "ROCMExecutionProvider",
|
"rocm": "ROCMExecutionProvider",
|
||||||
}
|
}
|
||||||
pipeline_schedulers = {
|
|
||||||
"ddim": DDIMScheduler,
|
|
||||||
"ddpm": DDPMScheduler,
|
|
||||||
"dpm-multi": DPMSolverMultistepScheduler,
|
|
||||||
"dpm-single": DPMSolverSinglestepScheduler,
|
|
||||||
"euler": EulerDiscreteScheduler,
|
|
||||||
"euler-a": EulerAncestralDiscreteScheduler,
|
|
||||||
"heun": HeunDiscreteScheduler,
|
|
||||||
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
|
||||||
"k-dpm-2": KDPM2DiscreteScheduler,
|
|
||||||
"karras-ve": KarrasVeScheduler,
|
|
||||||
"lms-discrete": LMSDiscreteScheduler,
|
|
||||||
"pndm": PNDMScheduler,
|
|
||||||
}
|
|
||||||
noise_sources = {
|
noise_sources = {
|
||||||
"fill-edge": noise_source_fill_edge,
|
"fill-edge": noise_source_fill_edge,
|
||||||
"fill-mask": noise_source_fill_mask,
|
"fill-mask": noise_source_fill_mask,
|
||||||
|
|
|
@ -137,7 +137,7 @@ export interface ImageResponse {
|
||||||
key: string;
|
key: string;
|
||||||
url: string;
|
url: string;
|
||||||
};
|
};
|
||||||
params: Required<BaseImgParams>;
|
params: Required<BaseImgParams> & Required<ModelParams>;
|
||||||
size: {
|
size: {
|
||||||
width: number;
|
width: number;
|
||||||
height: number;
|
height: number;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||||
import { Brush, ContentCopy, Delete, Download } from '@mui/icons-material';
|
import { Brush, ContentCopy, Delete, Download } from '@mui/icons-material';
|
||||||
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Paper, Tooltip } from '@mui/material';
|
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Paper, Tooltip } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
|
@ -8,6 +8,7 @@ import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ImageResponse } from '../client.js';
|
import { ImageResponse } from '../client.js';
|
||||||
import { ConfigContext, StateContext } from '../state.js';
|
import { ConfigContext, StateContext } from '../state.js';
|
||||||
|
import { MODEL_LABELS, SCHEDULER_LABELS } from '../strings.js';
|
||||||
|
|
||||||
export interface ImageCardProps {
|
export interface ImageCardProps {
|
||||||
value: ImageResponse;
|
value: ImageResponse;
|
||||||
|
@ -64,6 +65,9 @@ export function ImageCard(props: ImageCardProps) {
|
||||||
window.open(output.url, '_blank');
|
window.open(output.url, '_blank');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const model = mustDefault(MODEL_LABELS[params.model], params.model);
|
||||||
|
const scheduler = mustDefault(SCHEDULER_LABELS[params.scheduler], params.scheduler);
|
||||||
|
|
||||||
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
|
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
|
||||||
<CardMedia sx={{ height: config.params.height.default }}
|
<CardMedia sx={{ height: config.params.height.default }}
|
||||||
component='img'
|
component='img'
|
||||||
|
@ -73,11 +77,12 @@ export function ImageCard(props: ImageCardProps) {
|
||||||
<CardContent>
|
<CardContent>
|
||||||
<Box textAlign='center'>
|
<Box textAlign='center'>
|
||||||
<Grid container spacing={2}>
|
<Grid container spacing={2}>
|
||||||
|
<GridItem xs={4}>Model: {model}</GridItem>
|
||||||
|
<GridItem xs={4}>Scheduler: {scheduler}</GridItem>
|
||||||
|
<GridItem xs={4}>Seed: {params.seed}</GridItem>
|
||||||
<GridItem xs={4}>CFG: {params.cfg}</GridItem>
|
<GridItem xs={4}>CFG: {params.cfg}</GridItem>
|
||||||
<GridItem xs={4}>Steps: {params.steps}</GridItem>
|
<GridItem xs={4}>Steps: {params.steps}</GridItem>
|
||||||
<GridItem xs={4}>Size: {size.width}x{size.height}</GridItem>
|
<GridItem xs={4}>Size: {size.width}x{size.height}</GridItem>
|
||||||
<GridItem xs={4}>Seed: {params.seed}</GridItem>
|
|
||||||
<GridItem xs={8}>Scheduler: {params.scheduler}</GridItem>
|
|
||||||
<GridItem xs={12}>
|
<GridItem xs={12}>
|
||||||
<Box textAlign='left'>{params.prompt}</Box>
|
<Box textAlign='left'>{params.prompt}</Box>
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
// TODO: set up i18next
|
// TODO: set up i18next
|
||||||
export const MODEL_LABELS = {
|
export const MODEL_LABELS: Record<string, string> = {
|
||||||
'stable-diffusion-onnx-v1-4': 'Stable Diffusion v1.4',
|
'stable-diffusion-onnx-v1-4': 'Stable Diffusion v1.4',
|
||||||
'stable-diffusion-onnx-v1-5': 'Stable Diffusion v1.5',
|
'stable-diffusion-onnx-v1-5': 'Stable Diffusion v1.5',
|
||||||
'stable-diffusion-onnx-v1-inpainting': 'SD Inpainting v1',
|
'stable-diffusion-onnx-v1-inpainting': 'SD Inpainting v1',
|
||||||
|
|
Loading…
Reference in New Issue