0

I am using the MixStyle methodology for domain adaptation, and it involves using a custom layer that is inserted after every encoder stage. However, it is causing VRAM to grow linearly, which causes an OOM error. No memory leak occurs on disabling the layer, so I am sure that this particular layer is causing the issue. Any idea why this is happening?

This is the layer for reference -

class MixStyle(nn.Module):
    """MixStyle.
    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix='random'):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
          mix (str): how to mix.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha
        self.mix = mix
        self._activated = True

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})'

    def set_activation_status(self, status=True):
        self._activated = status

    def update_mix_method(self, mix='random'):
        self.mix = mix

    def forward(self, x):
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1, 1))
        lmda = lmda.to(x.device)

        if self.mix == 'random':
            # random shuffle
            perm = torch.randperm(B)

        elif self.mix == 'crossdomain':
            # split into two halves and swap the order
            perm = torch.arange(B - 1, -1, -1) # inverse index
            perm_b, perm_a = perm.chunk(2)
            perm_b = perm_b[torch.randperm(B // 2)]
            perm_a = perm_a[torch.randperm(B // 2)]
            perm = torch.cat([perm_b, perm_a], 0)

        else:
            raise NotImplementedError

        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed*sig_mix + mu_mix

I tried deleting the intermediate tensors and then calling torch.cuda.empty_cache(), but the VRAM consumption is growing.

2
  • 1
    can you turn that into a minimal reproducible example? then it'd be suitable as evidence for a bug report. Commented Sep 28 at 21:29
  • I was having a same problem when working on my custom NN layer. Based on your code what I guess is the computational graph from previous forward passes is not being freed. torch.distributions.Beta.sample() returns a tensor that tracks gradients if used in the middle of a differentiable graph. Even though you only use it for mixing, it’s still part of the autograd graph. Just wrap the sampling and mixing operations inside a torch.no_grad() block. I can give you a sample code for your question. Commented Oct 15 at 12:19

1 Answer 1

1
  1. torch.distributions.Beta.sample() returns a tensor that tracks gradients if used in the middle of a differentiable graph.

  2. Even though you only use it for mixing, it’s still part of the autograd graph.

  3. Every iteration creates new lmda tensors whose computation history references the distribution parameters (which are registered as buffers inside the module).PyTorch accumulates these in the autograd graph, causing VRAM to grow linearly.

    Hope this solution help you.

with torch.no_grad():
    lmda = self.beta.sample((B, 1, 1, 1)).to(x.device)
    if self.mix == 'random':
        perm = torch.randperm(B)
    elif self.mix == 'crossdomain':
        perm = torch.arange(B - 1, -1, -1)
        perm_b, perm_a = perm.chunk(2)
        perm_b = perm_b[torch.randperm(B // 2)]
        perm_a = perm_a[torch.randperm(B // 2)]
        perm = torch.cat([perm_b, perm_a], 0)
    else:
        raise NotImplementedError
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.