PytorchからONNXへの変換時に、Pytorch0.3系で学習したモデルの場合、BatchNorm2dもしくはInstanceNorm2dにおいて
object has no attribute ‘track_running_stats’
というエラーが発生します。
この問題を解決するには、BatchNorm2dもしくはInstanceNorm2dを探索し、track_running_statsに値を設定します。
上記対処をすると、ONNXへの書き出しに成功します。
参考:‘BatchNorm2d’ object has no attribute ‘track_running_stats’
object has no attribute ‘track_running_stats’
というエラーが発生します。
この問題を解決するには、BatchNorm2dもしくはInstanceNorm2dを探索し、track_running_statsに値を設定します。
def recursion_change_bn(module): if isinstance(module, torch.nn.BatchNorm2d) or isinstance(module, torch.nn.InstanceNorm2d): module.track_running_stats = 1 else: for i, (name, module1) in enumerate(module._modules.items()): module1 = recursion_change_bn(module1) return module for i, (name, module) in enumerate(model._modules.items()): module = recursion_change_bn(model) model.eval()
上記対処をすると、ONNXへの書き出しに成功します。
x = Variable(torch.randn(1, 3, 64, 64)) torch.onnx.export(model, x, 'output.onnx', verbose=True)
参考:‘BatchNorm2d’ object has no attribute ‘track_running_stats’
コメント