PytorchでModuleDictをフィールドに持つnn.Moduleをロードできない
Pytorchで以下のようなクラスを作成しました。
class AModule(nn.Module):
def __init__(self):
self.modules = nn.ModuleDict()
これをAModule.save_state_dict
で保存すると、self.modules
の中にあるモジュールのパラメータが保存できます。
次にAModule.load_state_dict()
でロードすると
Unexpected key(s) in state_dict
というエラーでロードできません。解決方法などありますでしょうか。クラスごと保存するしかないでしょうか。