torchvision加载ResNet除全连接层的权重
简单贴下如何加载torchvision中预训练权重,要不总是忘。
import torch
import torch.nn as nn
import torchvision
class ResNet(nn.Module):
def __init__(self):
super(ResNet, self).__init__()
pass
# 往ResNet里面添加权重
def init_weights(self, pretrained = True):
"""
Args:
self: 模型本身
pretrained (bool)
"""
if pretrained == True:
# 获取ResNet34的预训练权重
共有 0 条评论