1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-06-05 23:18:13 -05:00
parent 395a632946
commit 6cb7ed58be
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 4 additions and 2 deletions

View File

@ -41,7 +41,9 @@ class VAEWrapper(object):
def __call__(self, latent_sample=None, sample=None, **kwargs):
# set timestep dtype to input type
inputs = self.wrapped.model.graph.input
sample_input = [i for i in inputs if i.name == "sample" or i.name == "latent_sample"][0]
sample_input = [
i for i in inputs if i.name == "sample" or i.name == "latent_sample"
][0]
sample_dtype = tensor_dtype_to_np_dtype(sample_input.type.tensor_type.elem_type)
logger.trace(

View File

@ -451,7 +451,7 @@ def run_test(
mse = find_mse(result, ref)
if mse < test.mse_threshold:
if mse < (test.mse_threshold * 10):
logger.info("MSE within threshold: %.5f < %.5f", mse, test.mse_threshold)
else:
logger.warning("MSE above threshold: %.5f > %.5f", mse, test.mse_threshold)