lint(api): use more accurate worker name throughout
This commit is contained in:
parent
ada482c183
commit
b31f546516
|
@ -70,14 +70,14 @@ class ChainPipeline:
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: List[Image.Image],
|
||||||
callback: Optional[ProgressCallback],
|
callback: Optional[ProgressCallback],
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> List[Image.Image]:
|
) -> List[Image.Image]:
|
||||||
return self(job, server, params, sources=sources, callback=callback, **kwargs)
|
return self(worker, server, params, sources=sources, callback=callback, **kwargs)
|
||||||
|
|
||||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||||
self.stages.append((callback, params, kwargs))
|
self.stages.append((callback, params, kwargs))
|
||||||
|
@ -85,7 +85,7 @@ class ChainPipeline:
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: List[Image.Image],
|
||||||
|
@ -151,10 +151,10 @@ class ChainPipeline:
|
||||||
tile_mask: Image.Image,
|
tile_mask: Image.Image,
|
||||||
dims: Tuple[int, int, int],
|
dims: Tuple[int, int, int],
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
for i in range(job.retries):
|
for i in range(worker.retries):
|
||||||
try:
|
try:
|
||||||
output_tile = stage_pipe.run(
|
output_tile = stage_pipe.run(
|
||||||
job,
|
worker,
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
params,
|
params,
|
||||||
|
@ -175,8 +175,8 @@ class ChainPipeline:
|
||||||
i,
|
i,
|
||||||
)
|
)
|
||||||
server.cache.clear()
|
server.cache.clear()
|
||||||
run_gc([job.get_device()])
|
run_gc([worker.get_device()])
|
||||||
job.retries = job.retries - (i + 1)
|
worker.retries = worker.retries - (i + 1)
|
||||||
|
|
||||||
raise RetryException("exhausted retries on tile")
|
raise RetryException("exhausted retries on tile")
|
||||||
|
|
||||||
|
@ -193,10 +193,10 @@ class ChainPipeline:
|
||||||
stage_sources = stage_outputs
|
stage_sources = stage_outputs
|
||||||
else:
|
else:
|
||||||
logger.debug("image within tile size of %s, running stage", tile)
|
logger.debug("image within tile size of %s, running stage", tile)
|
||||||
for i in range(job.retries):
|
for i in range(worker.retries):
|
||||||
try:
|
try:
|
||||||
stage_outputs = stage_pipe.run(
|
stage_outputs = stage_pipe.run(
|
||||||
job,
|
worker,
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
params,
|
params,
|
||||||
|
@ -213,10 +213,10 @@ class ChainPipeline:
|
||||||
"error while running stage pipeline, retry %s of 3", i
|
"error while running stage pipeline, retry %s of 3", i
|
||||||
)
|
)
|
||||||
server.cache.clear()
|
server.cache.clear()
|
||||||
run_gc([job.get_device()])
|
run_gc([worker.get_device()])
|
||||||
job.retries = job.retries - (i + 1)
|
worker.retries = worker.retries - (i + 1)
|
||||||
|
|
||||||
if job.retries <= 0:
|
if worker.retries <= 0:
|
||||||
raise RetryException("exhausted retries on stage")
|
raise RetryException("exhausted retries on stage")
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
@ -20,7 +20,7 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
@ -45,7 +45,7 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
job.get_device(),
|
worker.get_device(),
|
||||||
inversions=inversions,
|
inversions=inversions,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
||||||
class BlendLinearStage(BaseStage):
|
class BlendLinearStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
||||||
class BlendMaskStage(BaseStage):
|
class BlendMaskStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
||||||
class CorrectCodeformerStage(BaseStage):
|
class CorrectCodeformerStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
@ -30,6 +30,6 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
|
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
device = job.get_device()
|
device = worker.get_device()
|
||||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||||
return [pipe(source) for source in sources]
|
return [pipe(source) for source in sources]
|
||||||
|
|
|
@ -15,7 +15,7 @@ logger = getLogger(__name__)
|
||||||
class PersistDiskStage(BaseStage):
|
class PersistDiskStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
||||||
class PersistS3Stage(BaseStage):
|
class PersistS3Stage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
||||||
class ReduceCropStage(BaseStage):
|
class ReduceCropStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
||||||
class ReduceThumbnailStage(BaseStage):
|
class ReduceThumbnailStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
||||||
class SourceNoiseStage(BaseStage):
|
class SourceNoiseStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
||||||
class SourceS3Stage(BaseStage):
|
class SourceS3Stage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -25,7 +25,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
@ -67,7 +67,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
job.get_device(),
|
worker.get_device(),
|
||||||
inversions=inversions,
|
inversions=inversions,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
||||||
class SourceURLStage(BaseStage):
|
class SourceURLStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -12,7 +12,7 @@ class BaseStage:
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -50,7 +50,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
@ -67,7 +67,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
|
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
|
||||||
device = job.get_device()
|
device = worker.get_device()
|
||||||
bsrgan = self.load(server, stage, upscale, device)
|
bsrgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
||||||
class UpscaleHighresStage(BaseStage):
|
class UpscaleHighresStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
@ -35,7 +35,7 @@ class UpscaleHighresStage(BaseStage):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
chain(
|
chain(
|
||||||
job,
|
worker,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
source,
|
source,
|
||||||
|
|
|
@ -28,7 +28,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
@ -55,7 +55,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
job.get_device(),
|
worker.get_device(),
|
||||||
inversions=inversions,
|
inversions=inversions,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
)
|
)
|
||||||
|
|
|
@ -73,7 +73,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
@ -89,7 +89,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
for source in sources:
|
for source in sources:
|
||||||
output = np.array(source)
|
output = np.array(source)
|
||||||
upsampler = self.load(
|
upsampler = self.load(
|
||||||
server, upscale, job.get_device(), tile=stage.tile_size
|
server, upscale, worker.get_device(), tile=stage.tile_size
|
||||||
)
|
)
|
||||||
|
|
||||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||||
|
|
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
||||||
class UpscaleSimpleStage(BaseStage):
|
class UpscaleSimpleStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_job: WorkerContext,
|
_worker: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -18,7 +18,7 @@ logger = getLogger(__name__)
|
||||||
class UpscaleStableDiffusionStage(BaseStage):
|
class UpscaleStableDiffusionStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
@ -43,7 +43,7 @@ class UpscaleStableDiffusionStage(BaseStage):
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
"upscale",
|
"upscale",
|
||||||
job.get_device(),
|
worker.get_device(),
|
||||||
model=path.join(server.model_path, upscale.upscale_model),
|
model=path.join(server.model_path, upscale.upscale_model),
|
||||||
)
|
)
|
||||||
generator = torch.manual_seed(params.seed)
|
generator = torch.manual_seed(params.seed)
|
||||||
|
|
|
@ -50,7 +50,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
@ -67,7 +67,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
||||||
device = job.get_device()
|
device = worker.get_device()
|
||||||
swinir = self.load(server, stage, upscale, device)
|
swinir = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
|
@ -34,7 +34,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_txt2img_pipeline(
|
def run_txt2img_pipeline(
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -83,8 +83,8 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = job.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(job, server, params, [], callback=progress, latents=latents)
|
images = chain.run(worker, server, params, [], callback=progress, latents=latents)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
run_gc([job.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished txt2img job: {dest}")
|
show_system_toast(f"finished txt2img job: {dest}")
|
||||||
|
@ -110,7 +110,7 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
|
|
||||||
def run_img2img_pipeline(
|
def run_img2img_pipeline(
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
outputs: List[str],
|
outputs: List[str],
|
||||||
|
@ -175,8 +175,8 @@ def run_img2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# run and append the filtered source
|
# run and append the filtered source
|
||||||
progress = job.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(job, server, params, [source], callback=progress)
|
images = chain(worker, server, params, [source], callback=progress)
|
||||||
|
|
||||||
if source_filter is not None and source_filter != "none":
|
if source_filter is not None and source_filter != "none":
|
||||||
images.append(source)
|
images.append(source)
|
||||||
|
@ -199,7 +199,7 @@ def run_img2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
run_gc([job.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished img2img job: {dest}")
|
show_system_toast(f"finished img2img job: {dest}")
|
||||||
|
@ -207,7 +207,7 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
|
|
||||||
def run_inpaint_pipeline(
|
def run_inpaint_pipeline(
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -353,8 +353,8 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = job.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(job, server, params, [source], callback=progress, latents=latents)
|
images = chain(worker, server, params, [source], callback=progress, latents=latents)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
@ -378,7 +378,7 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
del image
|
del image
|
||||||
run_gc([job.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished inpaint job: {dest}")
|
show_system_toast(f"finished inpaint job: {dest}")
|
||||||
|
@ -386,7 +386,7 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
|
|
||||||
def run_upscale_pipeline(
|
def run_upscale_pipeline(
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -427,8 +427,8 @@ def run_upscale_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = job.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(job, server, params, [source], callback=progress)
|
images = chain(worker, server, params, [source], callback=progress)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
@ -445,7 +445,7 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
del image
|
del image
|
||||||
run_gc([job.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished upscale job: {dest}")
|
show_system_toast(f"finished upscale job: {dest}")
|
||||||
|
@ -453,7 +453,7 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
|
|
||||||
def run_blend_pipeline(
|
def run_blend_pipeline(
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -477,15 +477,15 @@ def run_blend_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = job.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(job, server, params, sources, callback=progress)
|
images = chain(worker, server, params, sources, callback=progress)
|
||||||
|
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
dest = save_image(server, output, image, params, size, upscale=upscale)
|
dest = save_image(server, output, image, params, size, upscale=upscale)
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
del image
|
del image
|
||||||
run_gc([job.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished blend job: {dest}")
|
show_system_toast(f"finished blend job: {dest}")
|
||||||
|
|
|
@ -8,7 +8,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_txt2txt_pipeline(
|
def run_txt2txt_pipeline(
|
||||||
job: WorkerContext,
|
worker: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
_size: Size,
|
_size: Size,
|
||||||
|
@ -20,7 +20,7 @@ def run_txt2txt_pipeline(
|
||||||
model = "EleutherAI/gpt-j-6B"
|
model = "EleutherAI/gpt-j-6B"
|
||||||
tokens = 1024
|
tokens = 1024
|
||||||
|
|
||||||
device = job.get_device()
|
device = worker.get_device()
|
||||||
|
|
||||||
pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str())
|
pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str())
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
|
Loading…
Reference in New Issue