Skip to content

第4章:模型保存和加载

在PyTorch中,保存和加载模型是一个重要的步骤,它允许你保存训练好的模型并在以后重新使用或部署。

4.1 保存PyTorch模型

4.1.1 保存整个模型

保存整个模型,包括模型结构和参数。

python
torch.save(model, 'model.pth')

4.1.2 仅保存模型参数

仅保存模型参数,不包括模型结构。

python
torch.save(model.state_dict(), 'model_parameters.pth')

4.2 加载预训练模型

4.2.1 加载整个模型

加载整个模型,包括模型结构和参数。

python
model = torch.load('model.pth')

4.2.2 加载模型参数

加载模型参数到指定的模型结构中。

python
model = CustomModel()
model.load_state_dict(torch.load('model_parameters.pth'))

4.3 模型转换和部署

4.3.1 模型转换

将模型转换为其他格式,如ONNX,以便在不同的平台上部署。

python
torch.onnx.export(model, dummy_input, 'model.onnx')

4.3.2 模型部署

部署模型到生产环境,如使用TorchScript或ONNX。

python
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_model.pth')

4.4 本章小结

本章介绍了如何在PyTorch中保存和加载模型,包括保存整个模型和仅保存模型参数,以及模型的转换和部署。这些技能对于模型的持久化和应用至关重要。