Pytorchで以下のようなクラスを作成しました。

class AModule(nn.Module):
    def __init__(self):
        self.modules = nn.ModuleDict()

これをAModule.save_state_dictで保存すると、self.modulesの中にあるモジュールのパラメータが保存できます。
次にAModule.load_state_dict()でロードすると

Unexpected key(s) in state_dict

というエラーでロードできません。解決方法などありますでしょうか。クラスごと保存するしかないでしょうか。