13

How to extract the features from a specific layer from a pre-trained PyTorch model (such as ResNet or VGG), without doing a forward pass again?

3 Answers 3

18

New answer

Edit: there's a new feature in torchvision v0.11.0 that allows extracting features.

For example, if you wanna extract features from the layer layer4.2.relu_2, you can do like:

import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import create_feature_extractor

x = torch.rand(1, 3, 224, 224)

model = resnet50()

return_nodes = {
    "layer4.2.relu_2": "layer4"
}
model2 = create_feature_extractor(model, return_nodes=return_nodes)
intermediate_outputs = model2(x)

Old answer

You can register a forward hook on the specific layer you want. Something like:

def some_specific_layer_hook(module, input_, output):
    pass  # the value is in 'output'

model.some_specific_layer.register_forward_hook(some_specific_layer_hook)
    
model(some_input)

For example, to obtain the res5c output in ResNet, you may want to use a nonlocal variable (or global in Python 2):

res5c_output = None

def res5c_hook(module, input_, output):
    nonlocal res5c_output
    res5c_output = output

resnet.layer4.register_forward_hook(res5c_hook)

resnet(some_input)
    
# Then, use `res5c_output`.
Sign up to request clarification or add additional context in comments.

2 Comments

How does the value of output get returned here?? hook functions are not allowed to have a return value, so I don't see how fc1000_output in your code will get the value of output assigned to it. How does the value res5c_output get passed to fc1000_output?
I added a missing nonlocal declaration. It's not that the value of res5c_output gets passed to fc1000_output, it's that the former variable is bound to the outer context.
4

The accepted answer is very helpful! I'm posting a complete example here (using a registered hook as described by @bryant1410) for the lazy ones looking for a working solution:

import torch 
import torchvision.models as models
from torchvision import transforms
from PIL import Image

def get_feat_vector(path_img, model):
    '''
    Input: 
        path_img: string, /path/to/image
        model: a pretrained torch model
    Output:
        my_output: torch.tensor, output of avgpool layer
    '''
    input_image = Image.open(path_img)
    
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    with torch.no_grad():
        my_output = None
        
        def my_hook(module_, input_, output_):
            nonlocal my_output
            my_output = output_

        a_hook = model.avgpool.register_forward_hook(my_hook)        
        model(input_batch)
        a_hook.remove()
        return my_output

There you have your features extraction function, simply call it using the snippet below to obtain features from resnet18.avgpool layer

model = models.resnet18(pretrained=True)
model.eval()
path_ = '/path/to/image'
my_feature = get_feat_vector(path_, model)

Comments

0

Alternative for using register_forward_hook, but with a class rather than a global variables.

Simple example:

class FeatureExtractor:
    def __init__(self):
        self.extracted_features = None

    def __call__(self, module, input_, output):
        self.extracted_features = output

extractor = FeatureExtractor()
model.some_specific_layer.register_forward_hook(extractor)
model(some_input)
extractor.extracted_features

Extracting from multiple layers (storing in a dictionary):

class FeatureExtractor:
    def __init__(self):
        self.extracted_features = dict()

    def extract_features(self, module, input_, output, name):
        self.extracted_features[name] = output

    def get_forward_hook(self, name):
        return functools.partial(self.extract_features, name=name)

model.some_specific_layer.register_forward_hook(extractor.get_forward_hook(layer_name))
model(some_input)
extractor.extracted_features[layer_name]

functools.partial allows us to create a callable that maps to FeatureExtractor.extract_features with a specific argument already passed to the name argument.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.