ONNX是开放式神经网络(Open Neural Network Exchange)的简称,主要由微软和合作伙伴社区创建和维护。很多深度学习训练框架(如Tensorflow, PyTorch, Scikit-learn, MXNet等)的模型都可以导出或转换为标准的ONNX格式,采用ONNX格式作为统一的界面,各种嵌入式平台就可以只需要解析ONNX格式的模型而不用支持多种多样的训练框架,本文主要介绍如何通过代码或JSON文件的形式来构造一个ONNX单算子模型或者整个graph,以及使用ONNX Runtime进行推理得到算子或模型的计算结果。

一. ONNX文件格式

        ONNX文件是基于Protobuf进行序列化。了解Protobuf协议的同学应该知道,Protobuf都会有一个*.proto的文件定义协议,ONNX的该协议定义在onnx/onnx.proto3 at main · onnx/onnx · GitHub 文件中。

        从onnx.proto3协议中我们需要重点知道的数据结构如下:

  • ModelProto:模型的定义,包含版本信息,生产者和GraphProto。
  • GraphProto: 包含很多重复的NodeProto, initializer, ValueInfoProto等,这些元素共同构成一个计算图,在GraphProto中,这些元素都是以列表的方式存储,连接关系是通过Node之间的输入输出进行表达的。
  • NodeProto: onnx的计算图是一个有向无环图(DAG),NodeProto定义算子类型,节点的输入输出,还包含属性。
  • ValueInforProto: 定义输入输出这类变量的类型。
  • TensorProto: 序列化的权重数据,包含数据的数据类型,shape等。
  • AttributeProto: 具有名字的属性,可以存储基本的数据类型(int, float, string, vector等)也可以存储onnx定义的数据结构(TENSOR, GRAPH等)。
二. Python API

2.1 搭建ONNX模型

        ONNX是用DAG来描述网络结构的,也就是一个网络(Graph)由节点(Node)和边(Tensor)组成,ONNX提供的helper类中有很多API可以用来构建一个ONNX网络模型,比如make_node, make_graph, make_tensor等,下面是一个单个Conv2d的网络构造示例:

import onnx
from onnx import helper
from onnx import TensorProto
import numpy as np
weight = np.random.randn(36)
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2, 4, 4])
W = helper.make_tensor('W', TensorProto.FLOAT, [2, 2, 3, 3], weight)
B = helper.make_tensor('B', TensorProto.FLOAT, [2], [1.0, 2.0])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 2, 2, 2])
node_def = helper.make_node(
'Conv', # node name
['X', 'W', 'B'],
['Y'], # outputs
# attributes
strides=[2,2],
)
graph_def = helper.make_graph(
[node_def],
'test_conv_mode',
[X], # graph inputs
[Y], # graph outputs
initializer=[W, B],
)
mode_def = helper.make_model(graph_def, producer_name='onnx-example')
onnx.checker.check_model(mode_def)
onnx.save(mode_def, "./Conv.onnx")

        搭建的这个Conv算子模型使用netron可视化如下图所示:

         这个示例演示了如何使用helper的make_tensor_value_info, make_mode, make_graph, make_model等方法来搭建一个onnx模型。

        相比于PyTorch或其它框架,这些API看起来仍然显得比较繁琐,一般我们也不会用ONNX来搭建一个大型的网络模型,而是通过其它框架转换得到一个ONNX模型。

2.2 Shape Inference

        很多时候我们从pytorch, tensorflow或其他框架转换过来的onnx模型中间节点并没有shape信息,如下图所示:

 

         我们经常希望能直接看到网络中某些node的shape信息,shape_inference模块可以推导出所有node的shape信息,这样可视化模型时将会更友好:

import onnx
from onnx import shape_inference
onnx_model = onnx.load("./test_data/mobilenetv2-1.0.onnx")
onnx_model = shape_inference.infer_shapes(onnx_model)
onnx.save(onnx_model, "./test_data/mobilenetv2-1.0_shaped.onnx")

        可视化经过shape_inference之后的模型如下图:

2.3 ONNX Optimizer

        ONNX的optimizer模块提供部分图优化的功能,例如最常用的:fuse_bn_into_conv,fuse_pad_into_conv等等。

        查看onnx支持的优化方法:

from onnx import optimizer
all_passes = optimizer.get_available_passes()
print("Available optimization passes:")
for p in all_passes:
print(p)
print()

        应用图优化到onnx模型上进行变换:

passes = ['fuse_bn_into_conv']
# Apply the optimization on the original model
optimized_model = optimizer.optimize(onnx_model, passes)

        将mobile net v2应用fuse_bn_into_conv之后,BatchNormalization的参数合并到了Conv的weight和bias参数中,如下图所示:

 

三. ONNX Runtime计算ONNX模型

        onnx本身只是一个协议,定义算子与模型结构等,不涉及具体的计算。onnx runtime是类似JVM一样将ONNX格式的模型运行起来的解释器,包括对模型的解析、图优化、后端运行等。

        安装onnx runtime:

python3 -m pip install onnxruntime

        推理:

import onnx
import onnxruntime as ort
import numpy as np
import cv2
def preprocess(img_data):
mean_vec = np.array([0.485, 0.456, 0.406])
stddev_vec = np.array([0.229, 0.224, 0.225])
norm_img_data = np.zeros(img_data.shape).astype('float32')
for i in range(img_data.shape[0]):
# for each pixel in each channel, divide the value by 255 to get value between [0, 1] and then normalize
norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
return norm_img_data
img = cv2.imread("test_data/dog.jpeg")
img = cv2.resize(img, (224,224), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_data = np.transpose(img, (2, 0, 1))
input_data = preprocess(input_data)
input_data = input_data.reshape([1, 3, 224, 224])
sess = ort.InferenceSession("test_data/mobilenetv2-1.0.onnx")
input_name = sess.get_inputs()[0].name
result = sess.run([], {input_name: input_data})
result = np.reshape(result, [1, -1])
index = np.argmax(result)
print("max index:", index)