Skip to content

保存和加载模型

保存并加载模型

  • 保存并加载模型权重
  • 带形状保存和加载模型

保存并加载模型权重

  • 在本节中,我们将了解如何通过保存、加载和运行模型预测来保持模型状态。

保存并加载模型权重——需要掌握3个函数

  • torch.save: 将一个序列化的对象保存到磁盘。这个函数使用 Pythonpickle工具进行序列化。模型 (model)、张量 (tensor) 和各种对象的字典 (dict) 都可以用这个函数保存。
  • torch.load:将 pickled 对象文件反序列化到内存,也便于将数据加载到设备中。
  • torch.nn.Module.load_state_dict():加载模型的参数。

保存并加载模型权重——state_dict

  • Pytorch模型将学习到的参数存储在名为state_dict的内部状态字典中。这些可以通过torch.save方法保存。

保存并加载模型权重——state_dict

  • PyTorch 中,torch.nn.Module里面的可学习的参数 都放在model.parameters()里面
  • state_dict 是一个 Python dictionary object,将每一层映射到它的 parameter tensor 上
  • 只有含有可学习参数的层 ,或者含有registered buffers 的层才有state_dict
  • 优化器的对象 (torch.optim) 也有 state_dict,存储了优化器的状态和它的超参数

保存并加载模型权重

import torch
import torchvision.models as models

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# 打印模型的state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

保存和加载模型权重

  • 要加载模型权重,首先需要创建同一个模型的实例,然后使用load_state_dict()方法加载参数。
  • 在推理之前调用model.eval()方法,以将丢弃和批量归一化图层设置为评估模式。不这样做将产生不一致的推理结果。
model = models.vgg16() #不指定pretrained=True,即不加载默认权重
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() #评估模式
print(model)

带形状保存和加载模型

  • 当加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。我们可能希望将该类的结构与模型一起保存,在这种情况下,我们可以将model传递给保存函数:
#保存整个模型
torch.save(model, 'model.pth')
#加载模型
model = torch.load('model.pth')
print(model)

带形状保存和加载模型

  • 这种方法在序列化模型时使用Python pickle模块,因此它依赖于加载模型时可用的实际类定义。