在项目中,加载模型时遇到了命名空间相关的问题,如类是在__main__中的,但在导入的时候,类在某个命名空间中。因此在加载模型时,从__main__中找不到相应的类。
File "xxx\python\python39\lib\site-packages\torch\serialization.py", line 595, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "xxx\python\python39\lib\site-packages\torch\serialization.py", line 774, in _legacy_load
result = unpickler.load()
AttributeError: Can't get attribute 'SomeClass' on <module '__main__' from 'xxx\\python\\Python39\\Scripts\\celery.exe\\__main__.py'>
AttributeError: Can't get attribute 'SomeClass' on <module '__main__' from 'xxx\\django\\proj\\manage.py'>
有两种方法保存 PyTorch 训练后的模型,一种是用 torch.save(the_model, PATH) 保存模型然后使用 the_model = torch.load(PATH) 加载模型。另一种分别使用 torch.save(the_model.state_dict(), PATH) , the_model=SomeClass().load_state_dict(torch.load(PATH)) 和 the_model.eval() 保存和加载模型。
需要注意的是,前者仅保存和加载模型参数,而后者会保存和加载整个模型。而如果采用后者,当在其他工程中使用模型时,可能会破坏原有的结构树,从而产生类似的错误。
最后,使用第一种方法单独加载模型,再用第二种方法保存模型后,再在其他地方使用。