1
0
Fork 0

fix(api): mark all convert methods as no_grad

This commit is contained in:
Sean Sube 2023-03-01 08:26:40 -06:00
parent b44e644f9e
commit 21fc7c5968
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 10 additions and 0 deletions

View File

@ -1650,6 +1650,7 @@ def extract_checkpoint(
logger.info(result_status)
@torch.no_grad()
def convert_diffusion_original(
ctx: ConversionContext,
model: ModelDict,

View File

@ -11,6 +11,7 @@ from ..utils import ConversionContext
logger = getLogger(__name__)
@torch.no_grad()
def convert_diffusion_textual_inversion(
context: ConversionContext, name: str, base_model: str, inversion: str
):

View File

@ -74,14 +74,22 @@ class WorkerContext:
class JobStatus:
name: str
device: str
progress: int
cancelled: bool
finished: bool
def __init__(
self,
name: str,
device: DeviceParams,
progress: int = 0,
cancelled: bool = False,
finished: bool = False,
) -> None:
self.name = name
self.device = device.device
self.progress = progress
self.cancelled = cancelled
self.finished = finished