TensorflowからONNXにエクスポートするには、tf2onnxを使用します。
インストール
まず、tf.graph_util.convert_variables_to_constantsを実行することで、Tensorflowのgraphのvariableをconstantに変換しておきます。これをfrozenと呼びます。
ただし、convert_variables_to_constantsにはBatchNormalizationでmoving_varianceのvariableをfrozenできない問題があるため、事前に問題の起こるノードを修正しておきます。(Unable to import frozen graph with batchnorm)
次に、tf2onnx.tfonnx.process_tf_graphを実行することで、ONNXに変換します。引数には入力と出力に対応するノードを指定します。
出力されたONNXが正しいかどうかを、ONNXRuntimeを使用して検証します。
なお、Placeholderがtf.contrib.layers.batch_normのis_trainingに接続されているなどすると、Ifを含むONNXが出力されるため、必要に応じてFalseなどの直値を設定しておきます。
インストール
pip3 install tf2onnx
まず、tf.graph_util.convert_variables_to_constantsを実行することで、Tensorflowのgraphのvariableをconstantに変換しておきます。これをfrozenと呼びます。
ただし、convert_variables_to_constantsにはBatchNormalizationでmoving_varianceのvariableをfrozenできない問題があるため、事前に問題の起こるノードを修正しておきます。(Unable to import frozen graph with batchnorm)
# fix batch norm nodes gd = sess.graph.as_graph_def() for node in gd.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in range(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] # Freeze the graph output_node_names=["upscale/mul","hourglass/hg_2/after/hmap/conv/BiasAdd","radius/out/fc/BiasAdd"] frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, gd, output_node_names )
次に、tf2onnx.tfonnx.process_tf_graphを実行することで、ONNXに変換します。引数には入力と出力に対応するノードを指定します。
# Convert to onnx input_names=["import/eye:0"] output_names=["import/upscale/mul:0","import/hourglass/hg_2/after/hmap/conv/BiasAdd:0","import/radius/out/fc/BiasAdd:0"] graph1 = tf.Graph() with graph1.as_default(): tf.import_graph_def(frozen_graph_def) onnx_graph = tf2onnx.tfonnx.process_tf_graph(graph1, input_names=input_names, output_names=output_names) model_proto = onnx_graph.make_model("sample") with open("sample.onnx", "wb") as f: f.write(model_proto.SerializeToString())
出力されたONNXが正しいかどうかを、ONNXRuntimeを使用して検証します。
# Inference import numpy import onnxruntime as rt onnx_sess = rt.InferenceSession("sample.onnx") for node in onnx_sess.get_inputs(): print(node.name) print(node.shape) print(node.type) X = numpy.random.random((2, 36, 60, 1)).astype(numpy.float32) pred_onnx = onnx_sess.run(None, {"import/eye:0":X}) print(pred_onnx)
なお、Placeholderがtf.contrib.layers.batch_normのis_trainingに接続されているなどすると、Ifを含むONNXが出力されるため、必要に応じてFalseなどの直値を設定しておきます。
コメント
コメント一覧 (2)
Tegs: душевая кабина river nara 90 26 сантехлидер https://santehlider.com/catalog/santekhnika/dushevye_kabiny_i_ograzhdeniya/dushevye_kabiny/48710/
<u>душевая кабина 120х80 santehlider </u>
<i>душевая кабина 120х80 deto santehlider </i>
<b>душевая кабина 120 90 santehlider </b>