Pytorch两种保存加载模型方法

在项目中,加载模型时遇到了命名空间相关的问题,如类是在__main__中的,但在导入的时候,类在某个命名空间中。因此在加载模型时,从__main__中找不到相应的类。

  File "xxx\python\python39\lib\site-packages\torch\serialization.py", line 595, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "xxx\python\python39\lib\site-packages\torch\serialization.py", line 774, in _legacy_load
    result = unpickler.load()
AttributeError: Can't get attribute 'SomeClass' on <module '__main__' from 'xxx\\python\\Python39\\Scripts\\celery.exe\\__main__.py'>
AttributeError: Can't get attribute 'SomeClass' on <module '__main__' from 'xxx\\django\\proj\\manage.py'>

有两种方法保存 PyTorch 训练后的模型,一种是用 torch.save(the_model, PATH) 保存模型然后使用 the_model = torch.load(PATH) 加载模型。另一种分别使用 torch.save(the_model.state_dict(), PATH) the_model=SomeClass().load_state_dict(torch.load(PATH)) 保存和加载模型。

需要注意的是,前者仅保存和加载模型参数,而后者会保存和加载整个模型。而如果采用后者,当在其他工程中使用模型时,可能会破坏原有的结构树,从而产生类似的错误。

最后,使用第一种方法单独加载模型,再用第二种方法保存模型后,再在其他地方使用。

参考链接

Best way to save a trained model in PyTorch?

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注

Back to Top