I am using the MixStyle methodology for domain adaptation, and it involves using a custom layer that is inserted after every encoder stage. However, it is causing VRAM to grow linearly, which causes an OOM error. No memory leak occurs on disabling the layer, so I am sure that this particular layer is causing the issue. Any idea why this is happening?
This is the layer for reference -
class MixStyle(nn.Module):
"""MixStyle.
Reference:
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix='random'):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})'
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix='random'):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B = x.size(0)
mu = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], keepdim=True)
sig = (var + self.eps).sqrt()
mu, sig = mu.detach(), sig.detach()
x_normed = (x-mu) / sig
lmda = self.beta.sample((B, 1, 1, 1))
lmda = lmda.to(x.device)
if self.mix == 'random':
# random shuffle
perm = torch.randperm(B)
elif self.mix == 'crossdomain':
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(B // 2)]
perm_a = perm_a[torch.randperm(B // 2)]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
mu2, sig2 = mu[perm], sig[perm]
mu_mix = mu*lmda + mu2 * (1-lmda)
sig_mix = sig*lmda + sig2 * (1-lmda)
return x_normed*sig_mix + mu_mix
I tried deleting the intermediate tensors and then calling torch.cuda.empty_cache(), but the VRAM consumption is growing.