3

I made an example diagram of a scaled down version of what I'm trying to implement:

network diagram

So the top two input nodes are only fully connected to the top three output nodes, and the same design applies to the bottom two nodes. So far I've come up with two ways of implementing this in PyTorch, neither of which are optimal.

The first would be to create a nn.ModuleList of many smaller Linear Layers, and during the forward pass, iterate the input through them. For the diagram's example, that would look something like this:

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Module([nn.Linear(2, 3) for i in range(2)])
  
  def forward(self, input):
    output = torch.zeros(2, 3)
    for i in range(2):
      output[i, :] = self.layers[i](input.view(2, 2)[i, :])
    return output.flatten()

So this accomplishes the network in the diagram, the main issue is its very slow. I assume this is because PyTorch has to process the for loop sequentially, and can't process the input tensor in parallel.

To "vectorize" the module such that PyTorch can run it quicker, I have this implementation:

class Module(nn.Module):
  def __init__(self):
    self.layer = nn.Linear(4, 6)
    self.mask = # create mask of ones and zeros to "block" certain layer connections
  
  def forward(self, input):
    prune.custom_from_mask(self.layer, name='weight', mask=self.mask)
    return self.layer(input)

This also accomplishes the diagram's network, by using weight pruning to ensure certain weights in the fully connected layer are always zero (ex. the weight connecting the top input node to the bottom out node will always be zero, so its effectively "disconnected"). This module is much faster than the previous, as there is no for loop. The problem now is this module takes up significantly more memory. This is likely due to the fact that, even though most of the layer's weights will be zero, PyTorch still treats the network as if they are there. This implementation essentially keeps way more weights around than it needs to.

Has anyone encountered this issue before and come up with an efficient solution?

4 Answers 4

2

If weight sharing is ok, then 1D convolutions should solve the problem:

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Conv1d(in_channels=2, out_channels=3, kernel_size=1)
    self._n_splits = 2

  
  def forward(self, input):
    
    B, C = input.shape
    output = self.layers(input.view(B, C//self._n_splits, -1))
    return output.view(B, C)

If weight sharing is NOT ok, then you can use the group convolutions: self.layers = nn.Conv1d(in_channels=4, out_channels=4, kernel_size=1, stride=1, groups=2). However, I am not sure if this can implement an arbitrary number of channel splits, you can check the documentation: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html

A 1D convolution is a fully connected layer on all the channels of the input. A Group convolution will divide the channels into groups and perform separate conv operations on them (which is what you want).

The implementation will look something like:

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Conv1d(in_channels=2, out_channels=4, kernel_size=1, groups=2)

  
  def forward(self, input):
    
    B, C = input.shape
    output = self.layers(input.unsqueeze(-1))
    return output.squeeze()

EDIT:

If you need an odd number of output channels you can combine two group convs.

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Sequence(
         nn.Conv1d(in_channels=2, out_channels=4, kernel_size=1, groups=2),
         nn.Conv1d(in_channels=4, out_channels=3, kernel_size=1, groups=3))


  def forward(self, input):
    
    B, C = input.shape
    output = self.layers(input.unsqueeze(-1))
    return output.squeeze()

That will effectively define the input channels as required in the diagram and allow you for an arbitrary number of output channels. Notice that if the second convolution has groups=1 you will allow for mixing channels and will effectively render the first group conv layer useless.

From a theoretical perspective, there is no need for activation functions in between those two convolutions. We are combining them in a linear matter. However, it is possible that adding an activation function will improve performance.

Sign up to request clarification or add additional context in comments.

2 Comments

Yes weights should NOT be shared. I think you're second solution is close to something I could use, but it may need some tweaking. For example, if I try to create your Conv1d layer, I get a "out_channels must be divisible by groups" error.
That means that if you have 2 groups you need an even number of out channels. Notice that if you want to have an odd number of out channels (for whatever reason) you can combine two groups convolutions. nn.Sequential(nn.Conv1d(in_channels=2, out_channels=4, kernel=1, groups=2), nn.Conv1d(in_channels=4, out_channels=3, kernel=1, groups=3)) This will make one group convolution on half of the input channels (and half of the output channels) and another group convolution channel-wise, to convert from 4 to 3 output channels .
0

If your input is B (an even number equal to 2xb) and your output is C, and if weight sharing is acceptable, you can just reshape to 2 x b and apply a linear layer (b,C) to the last dimension

Comments

0

Assuming the following:

  1. The input of shape (batch_size, d_in) is broken into n_chunks equal sized chunks of shape (batch_size, d_in//n_chunks)
  2. Each chunk is mapped to an output of shape (batch_size, d_out//n_chunks)
  3. Each chunk is mapped with a different set of weights
  4. The output chunks are concatenated to a final shape of (batch_size, d_out)

You can do the following:

class ChunkedLinear(nn.Module):
    def __init__(self, d_in, d_out, n_chunks):
        super().__init__()
        assert d_in % n_chunks == 0, "d_in should be divisible by n_chunks"
        assert d_out % n_chunks == 0, "d_out should be divisible by n_chunks"
        
        self.weight = nn.Parameter(torch.randn(n_chunks, d_out//n_chunks, d_in//n_chunks))
        self.bias = nn.Parameter(torch.randn(n_chunks, d_out//n_chunks))
        self.n_chunks = n_chunks
        
    def forward(self, x):
        batch_size, d_in = x.shape
        assert d_in%self.n_chunks == 0, "x should be divisible by n_chunks"
        
        x = x.reshape(batch_size, self.n_chunks, -1)
        x = torch.einsum('bni,noi->bno', x, self.weight) + self.bias
        x = x.reshape(batch_size, -1)
        return x

For comparison with using multiple linear layers:

# set up sizes
batch_size = 12
d_in = 6
d_out = 8
n_chunks = 2
c_in = d_in//n_chunks
c_out = d_out//n_chunks

# create separate linear layers
layers = [nn.Linear(c_in, c_out, bias=True) for i in range(n_chunks)]

# create chunked layer
chunked_layer = ChunkedLinear(d_in, d_out, n_chunks)

# copy linear layer weights to chunked layer for comparison
weights = torch.stack([i.weight for i in layers])
chunked_layer.weight.data = weights

biases = torch.stack([i.bias for i in layers])
chunked_layer.bias.data = biases

# create input
x = torch.randn(batch_size, d_in)

# compute `y1` by chunking and applying separate linear layers
x_chunked = x.chunk(2, dim=-1)
y_chunked = [layers[i](x_chunked[i]) for i in range(n_chunks)]
y1 = torch.cat(y_chunked, -1)

# compute `y2` with `ChunkedLinear` 
y2 = chunked_layer(x)

# compare outputs
(y1 == y2).all()
> True 

Comments

0

I think you could use this in that case. It is described in the documentation as the way to add non parameter obj to a module

self.register_buffer(name, tensor)

1 Comment

Welcome in Stack Overflow and thanks for the answer. If you have some time you can improve this solution to mention at least a minimal example about to use the method register_bugger() to fix the original question (it's really not clear to me, maybe also for the original person). Also maybe nice to put the official documentation link e.g. I think it's this: pytorch.org/docs/stable/generated/torch.nn.Module.html thanks for your help

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.