网站链接: element-ui dtcms
当前位置: 首页 > 技术博文  > 技术博文

Pytorch 模型转化为torchscript 模型——用于C++部署

2021/6/27 8:12:24 人评论

import torch import torchvision import torch.nn as nn #model torchvision.models.resnet50(pretrainedTrue) modeltorchvision.models.resnet18() # 加载模型 num_ftrs model.fc.in_features model.fc nn.Linear(num_ftrs, 196) # make the change model.load_state_di…

import torch
import torchvision
import torch.nn as nn
#model = torchvision.models.resnet50(pretrained=True)
model=torchvision.models.resnet18() # 加载模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 196)  # make the change
model.load_state_dict(torch.load("resnet18.pth"))

model.eval()  # 注意,将模型设置为eval模式,再保存

example = torch.rand(1, 3, 224, 224) # 输入模型的尺寸
traced_script_module = torch.jit.trace(model, example)
# 保存模型
traced_script_module.save("resnet18.pt")

相关资讯

    暂无相关的数据...

共有条评论 网友评论

验证码: 看不清楚?