diff --git a/emblem5/ai/common.py b/emblem5/ai/common.py index f5e26e0..3d4c036 100644 --- a/emblem5/ai/common.py +++ b/emblem5/ai/common.py @@ -150,7 +150,7 @@ def make_stripe_img(left, right, nstripes): return ret def predict_multi(model, transforms, images, ncells=1): - results_per_img = ncells * ncells + results_per_img = ncells * ncells * 2 ret = [] with torch.no_grad(): tensors = []