/* eslint-disable no-null/no-null */ import { Maybe } from '@apextoaster/js-utils'; import { createContext } from 'react'; import { StateCreator, StoreApi } from 'zustand'; import { ApiClient, ApiResponse, BaseImgParams, BrushParams, Img2ImgParams, InpaintParams, OutpaintPixels, paramsFromConfig, Txt2ImgParams, UpscaleParams, } from './api/client.js'; import { ConfigFiles, ConfigParams, ConfigState } from './config.js'; type TabState = ConfigFiles> & ConfigState>; interface Txt2ImgSlice { txt2img: TabState; setTxt2Img(params: Partial): void; resetTxt2Img(): void; } interface Img2ImgSlice { img2img: TabState; setImg2Img(params: Partial): void; resetImg2Img(): void; } interface InpaintSlice { inpaint: TabState; setInpaint(params: Partial): void; resetInpaint(): void; } interface HistorySlice { history: Array; limit: number; loading: Maybe; pushHistory(image: ApiResponse): void; removeHistory(image: ApiResponse): void; setLimit(limit: number): void; setLoading(image: Maybe): void; } interface DefaultSlice { defaults: TabState; setDefaults(param: Partial): void; } interface OutpaintSlice { outpaint: OutpaintPixels; setOutpaint(pixels: Partial): void; } interface BrushSlice { brush: BrushParams; setBrush(brush: Partial): void; } interface UpscaleSlice { upscale: UpscaleParams; setUpscale(upscale: Partial): void; } export type OnnxState = BrushSlice & DefaultSlice & HistorySlice & Img2ImgSlice & InpaintSlice & OutpaintSlice & Txt2ImgSlice & UpscaleSlice; export function createStateSlices(base: ConfigParams) { const defaults = paramsFromConfig(base); const createTxt2ImgSlice: StateCreator = (set) => ({ txt2img: { ...defaults, width: base.width.default, height: base.height.default, }, setTxt2Img(params) { set((prev) => ({ txt2img: { ...prev.txt2img, ...params, }, })); }, resetTxt2Img() { set({ txt2img: { ...defaults, width: base.width.default, height: base.height.default, }, }); }, }); const createImg2ImgSlice: StateCreator = (set) => ({ img2img: { ...defaults, source: null, strength: base.strength.default, }, setImg2Img(params) { set((prev) => ({ img2img: { ...prev.img2img, ...params, }, })); }, resetImg2Img() { set({ img2img: { ...defaults, source: null, strength: base.strength.default, }, }); }, }); const createInpaintSlice: StateCreator = (set) => ({ inpaint: { ...defaults, filter: 'none', mask: null, noise: 'histogram', source: null, }, setInpaint(params) { set((prev) => ({ inpaint: { ...prev.inpaint, ...params, }, })); }, resetInpaint() { set({ inpaint: { ...defaults, filter: 'none', mask: null, noise: 'histogram', source: null, }, }); }, }); const createHistorySlice: StateCreator = (set) => ({ history: [], limit: 4, loading: null, pushHistory(image) { set((prev) => ({ ...prev, history: [ image, ...prev.history, ], loading: null, })); }, removeHistory(image) { set((prev) => ({ ...prev, history: prev.history.filter((it) => it.output !== image.output), })); }, setLimit(limit) { set((prev) => ({ ...prev, limit, })); }, setLoading(loading) { set((prev) => ({ ...prev, loading, })); }, }); const createOutpaintSlice: StateCreator = (set) => ({ outpaint: { enabled: false, left: 0, right: 0, top: 0, bottom: 0, }, setOutpaint(pixels) { set((prev) => ({ outpaint: { ...prev.outpaint, ...pixels, } })); }, }); const createBrushSlice: StateCreator = (set) => ({ brush: { color: 255, size: 8, strength: 0.5, }, setBrush(brush) { set((prev) => ({ brush: { ...prev.brush, ...brush, }, })); }, }); const createUpscaleSlice: StateCreator = (set) => ({ upscale: { denoise: 0.5, enabled: false, faces: false, scale: 1, }, setUpscale(upscale) { set((prev) => ({ upscale: { ...prev.upscale, ...upscale, } })); }, }); const createDefaultSlice: StateCreator = (set) => ({ defaults: { ...defaults, }, setDefaults(params) { set((prev) => ({ defaults: { ...prev.defaults, ...params, } })); }, }); return { createBrushSlice, createDefaultSlice, createHistorySlice, createImg2ImgSlice, createInpaintSlice, createOutpaintSlice, createTxt2ImgSlice, createUpscaleSlice, }; } export const ClientContext = createContext>(undefined); export const StateContext = createContext>>(undefined);