3

I have some complicated model on PyTorch. How can I print names of layers (or IDs) which connected to layer's input. For start I want to find it for Concat layer. See example code below:

class Concat(nn.Module):
    def __init__(self, dimension=1):
        super().__init__()
        self.d = dimension

    def forward(self, x):
        return torch.cat(x, self.d)


class SomeModel(nn.Module):
    def __init__(self):
        super(SomeModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.conc = Concat(1)
        self.linear = nn.Linear(8192, 1)

    def forward(self, x):
        out1 = F.relu(self.bn1(self.conv1(x)))
        out2 = F.relu(self.conv2(x))
        out = self.conc([out1, out2])
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


if __name__ == '__main__':
    model = SomeModel()
    print(model)
    y = model(torch.randn(1, 3, 32, 32))
    print(y.size())
    for name, m in model.named_modules():
        if 'Concat' in m.__class__.__name__:
            print(name, m, m.__class__.__name__)
            # Here print names of all input layers for Concat

2 Answers 2

2
+50

You can use type(module).__name__ to get the nn.Module class name:

>>> model = SomeModel()
>>> y = model(torch.randn(1, 3, 32, 32))
>>> for name, m in model.named_modules():
...     if 'Concat' == type(m).__name__:
...         print(name, m)
conc Concat()

Edit: You can actually manage to get the list of operators used to compute the inputs of Concat. However, I don't think you can actually get the attribute names of the nn.Module associated with these operators. This kind of information is not available - and needed - at model inference.

This solution requires you to register a forward hook on the layer with nn.Module.register_forward_hook. Then perform one inference to trigger it, then you can remove the hook. In the forward hook, you have access to the list of inputs and extract the name of the operator from the grad_fn attribute callback. Using nn.Module.register_forward_pre_hook here would be more appropriate since we are only looking at the inputs, and do not need the output.

>>> def op_name(x)
...     return type(x.grad_fn).__name__.replace('Backward0', '')

>>> def forward_hook(module, ins):
...     print([op_name(x) for x in ins[0]])

Attach the hook on model.conc, trigger it and then clean up:

>>> handle = model.conc.register_forward_pre_hook(forward_hook)
>>> model(torch.empty(2, 3, 10, 10, requires_grad=True))
['Relu', 'Relu']

>>> handle.remove()
Sign up to request clarification or add additional context in comments.

7 Comments

I need name of inputs for this layer. I know how to print name of current layer.
@ZFTurbo What do you mean by "name"? What would be the expected output in your example?
Example: ['module.conv1', 'module.conv2']
@ZFTurbo why is the expected result not ['module.bn1', 'module.conv2'], since bn1 is applied after conv1 before being passed to Concat? Do you want to return the previous closests conv2d layer that were used on each input of Concat?
@ZFTurbo, I have edited my answer.
|
1

How can I print names of layers (or IDs) which connected to layer's input.

You can't do it based on a module itself, because "connected to" doesn't exist as a static property of a nn.Module:

Instead of

out = self.conc([out1, out2])

you could just as easily have written

out = self.conc([out1, random.choice([out1, out2]))

What's connected to self.conc now?

However, you can inspect the "graph" of a value (If torchviz.make_dot doesn't directly do what you want, you may want to check its implementation)

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.