3

Currently, I load pretrained torchvision model using following code:

import torchvision
torchvision.models.resnet101(pretrained=True)

However, I'd love to have model name as string parameter and then load the pretrained model using that string. A pseudo-code that would do so would be something like:

model_name = 'resnet101'
torchvision.models.get(model_name)(pretrained=True)

Is there a way to accomplish this in a rather simple manner?

1

2 Answers 2

4

You can use torch.hub:

model_str = 'resnet50'
model = torch.hub.load('pytorch/vision', model_str, pretrained=True)

All the available models by strings can be found via:

torch.hub.list('pytorch/vision', force_reload=True)

output:

['alexnet',
 'deeplabv3_mobilenet_v3_large',
 'deeplabv3_resnet101',
 'deeplabv3_resnet50',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'fcn_resnet101',
 'fcn_resnet50',
 'googlenet',
 'inception_v3',
 'lraspp_mobilenet_v3_large',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext50_32x4d',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'squeezenet1_0',
 'squeezenet1_1',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'wide_resnet101_2',
 'wide_resnet50_2']
Sign up to request clarification or add additional context in comments.

Comments

3

You can use getattr

getattr(torchvision.models, 'resnet101')(pretrained=True)

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.