本文是txtai,一个人工智能驱动的语义搜索平台系列教程的一部分。

ONNX 运行时为机器学习模型提供了一种通用的序列化格式。 ONNX 支持多种不同平台/语言并具有内置功能以帮助减少推理时间。

PyTorch 为将 Torch 模型导出到 ONNX 提供了强大的支持。这可以将 Hugging Face Transformer 和/或其他下游模型直接导出到 ONNX。

ONNX 开辟了使用多种语言和平台进行直接推理的途径。例如,模型可以直接在 Android 上运行,以限制发送到第三方服务的数据。 ONNX 是一个令人兴奋的发展,充满希望。微软还发布了Hummingbird,它可以将传统模型(sklearn、决策树、逻辑回归......)导出到 ONNX。

本文将介绍如何使用 txtai 将模型导出到 ONNX。然后这些模型将直接在 Python、JavaScript、Java 和 Rust 中运行。目前,txtai 通过它的 API 支持所有这些语言,这仍然是推荐的方法。

安装依赖
txtai
pip install txtai[pipeline] datasets
使用 ONNX 运行模型

让我们开始吧!以下示例将情绪分析模型导出到 ONNX 并运行推理会话。

import numpy as np

from onnxruntime import InferenceSession, SessionOptions
from transformers import AutoTokenizer
from txtai.pipeline import HFOnnx

# Normalize logits using sigmoid function
sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x))

# Export to ONNX
onnx = HFOnnx()
model = onnx("distilbert-base-uncased-finetuned-sst-2-english", "text-classification")

# Start inference session
options = SessionOptions()
session = InferenceSession(model, options)

# Tokenize
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
tokens = tokenizer(["I am happy", "I am mad"], return_tensors="np")

# Print results
outputs = session.run(None, dict(tokens))
print(sigmoid(outputs[0]))
[[0.01295124 0.9909526 ]
 [0.9874723  0.0297817 ]]

就这样,有结果!文本分类模型使用两个标签来判断情绪,0 表示负面,1 表示正面。上面的结果显示了每个文本片段中每个标签的概率。

ONNX 管道加载模型,将图形转换为 ONNX 并返回。请注意,没有提供输出文件,在这种情况下,ONNX 模型作为字节数组返回。如果提供了输出文件,则此方法返回输出路径。

为文本分类训练和导出模型

接下来,我们将结合 ONNX 管道和 Trainer 管道来创建“训练并导出到 ONNX”工作流。

from datasets import load_dataset
from txtai.pipeline import HFTrainer

trainer = HFTrainer()

# Hugging Face dataset
ds = load_dataset("glue", "sst2")
data = ds["train"].select(range(5000)).flatten_indices()

# Train new model using 5,000 SST2 records (in-memory)
model, tokenizer = trainer("google/electra-base-discriminator", data, columns=("sentence", "label"))

# Export model trained in-memory to ONNX (still in-memory)
output = onnx((model, tokenizer), "text-classification", quantize=True)

# Start inference session
options = SessionOptions()
session = InferenceSession(output, options)

# Tokenize
tokens = tokenizer(["I am happy", "I am mad"], return_tensors="np")

# Print results
outputs = session.run(None, dict(tokens))
print(sigmoid(outputs[0]))
[[0.02424305 0.9557785 ]
 [0.95884305 0.05541185]]

结果与上一步类似,尽管此模型仅在 sst2 数据集的一小部分上进行了训练。让我们保存这个模型以备后用。

text = onnx((model, tokenizer), "text-classification", "text-classify.onnx", quantize=True)
导出一个句子嵌入模型

ONNX 管道还支持导出使用句子转换器包训练的句子嵌入模型。

embeddings = onnx("sentence-transformers/paraphrase-MiniLM-L6-v2", "pooling", "embeddings.onnx", quantize=True)

现在让我们使用 ONNX 运行模型。

from sklearn.metrics.pairwise import cosine_similarity

options = SessionOptions()
session = InferenceSession(embeddings, options)

tokens = tokenizer(["I am happy", "I am glad"], return_tensors="np")

outputs = session.run(None, dict(tokens))[0]

print(cosine_similarity(outputs))
[[1.0000002 0.8430618]
 [0.8430618 1.       ]]

上面的代码标记了两个单独的文本片段(“我很高兴”和“我很高兴”)并通过 ONNX 模型运行它。

这会输出两个嵌入数组,并使用余弦相似度比较这些数组。正如我们所看到的,这两个文本片段具有密切的语义含义。

用txtai加载一个ONNX模型

txtai 内置了对 ONNX 模型的支持。加载 ONNX 模型是无缝的,并且嵌入和管道支持它。以下部分展示了如何加载由 ONNX 支持的分类管道和嵌入模型。

from txtai.embeddings import Embeddings
from txtai.pipeline import Labels

labels = Labels(("text-classify.onnx", "google/electra-base-discriminator"), dynamic=False)
print(labels(["I am happy", "I am mad"]))

embeddings = Embeddings({"path": "embeddings.onnx", "tokenizer": "sentence-transformers/paraphrase-MiniLM-L6-v2"})
print(embeddings.similarity("I am happy", ["I am glad"]))
[[(1, 0.9988517761230469), (0, 0.0011482156114652753)], [(0, 0.997488260269165), (1, 0.0025116782635450363)]]
[(0, 0.8581848740577698)]
JavaScript

到目前为止,我们已经将模型导出到 ONNX 并通过 Python 运行它们。这已经具有很多优势,包括快速推理时间、量化和更少的软件依赖性。但是当我们在其他语言/平台上运行一个用 Python 训练的模型时,ONNX 真的很出色。

让我们尝试在 JavaScript 中运行上面训练的模型。第一步是获取 Node.js 环境和依赖项设置。

import os

!mkdir js
os.chdir("/content/js")
{
  "name": "onnx-test",
  "private": true,
  "version": "1.0.0",
  "description": "ONNX Runtime Node.js test",
  "main": "index.js",
  "dependencies": {
    "onnxruntime-node": ">=1.8.0",
    "tokenizers": "file:tokenizers/bindings/node"
  }
}
# Copy ONNX models
!cp ../text-classify.onnx .
!cp ../embeddings.onnx .

# Save copy of Bert Tokenizer
tokenizer.save_pretrained("bert")

# Get tokenizers project
!git clone https://github.com/huggingface/tokenizers.git

os.chdir("/content/js/tokenizers/bindings/node")

# Install Rust
!apt-get install rustc

# Build tokenizers project locally as version on NPM isn't working properly for latest version of Node.js
!npm install --also=dev
!npm run dev

# Install all dependencies
os.chdir("/content/js")
!npm install

接下来,我们将用 JavaScript 将推理代码写入 index.js 文件。

const ort = require('onnxruntime-node');
const { promisify } = require('util');
const { Tokenizer } = require("tokenizers/dist/bindings/tokenizer");

function sigmoid(data) {
    return data.map(x => 1 / (1 + Math.exp(-x)))
}

function softmax(data) { 
    return data.map(x => Math.exp(x) / (data.map(y => Math.exp(y))).reduce((a,b) => a+b)) 
}

function similarity(v1, v2) {
    let dot = 0.0;
    let norm1 = 0.0;
    let norm2 = 0.0;

    for (let x = 0; x < v1.length; x++) {
        dot += v1[x] * v2[x];
        norm1 += Math.pow(v1[x], 2);
        norm2 += Math.pow(v2[x], 2);
    }

    return dot / (Math.sqrt(norm1) * Math.sqrt(norm2));
}

function tokenizer(path) {
    let tokenizer = Tokenizer.fromFile(path);
    return promisify(tokenizer.encode.bind(tokenizer));
}

async function predict(session, text) {
    try {
        // Tokenize input
        let encode = tokenizer("bert/tokenizer.json");
        let output = await encode(text);

        let ids = output.getIds().map(x => BigInt(x))
        let mask = output.getAttentionMask().map(x => BigInt(x))
        let tids = output.getTypeIds().map(x => BigInt(x))

        // Convert inputs to tensors    
        let tensorIds = new ort.Tensor('int64', BigInt64Array.from(ids), [1, ids.length]);
        let tensorMask = new ort.Tensor('int64', BigInt64Array.from(mask), [1, mask.length]);
        let tensorTids = new ort.Tensor('int64', BigInt64Array.from(tids), [1, tids.length]);

        let inputs = null;
        if (session.inputNames.length > 2) {
            inputs = { input_ids: tensorIds, attention_mask: tensorMask, token_type_ids: tensorTids};
        }
        else {
            inputs = { input_ids: tensorIds, attention_mask: tensorMask};
        }

        return await session.run(inputs);
    } catch (e) {
        console.error(`failed to inference ONNX model: ${e}.`);
    }
}

async function main() {
    let args = process.argv.slice(2);
    if (args.length > 1) {
        // Run sentence embeddings
        const session = await ort.InferenceSession.create('./embeddings.onnx');

        let v1 = await predict(session, args[0]);
        let v2 = await predict(session, args[1]);

        // Unpack results
        v1 = v1.embeddings.data;
        v2 = v2.embeddings.data;

        // Print similarity
        console.log(similarity(Array.from(v1), Array.from(v2)));
    }
    else {
        // Run text classifier
        const session = await ort.InferenceSession.create('./text-classify.onnx');
        let results = await predict(session, args[0]);

        // Normalize results using softmax and print
        console.log(softmax(results.logits.data));
    }
}

main();

使用 ONNX 在 JavaScript 中运行文本分类

!node . "I am happy"
!node . "I am mad"
Float32Array(2) [ 0.001104647060856223, 0.9988954067230225 ]
Float32Array(2) [ 0.9976443648338318, 0.00235558208078146 ]

首先,不得不说这是🔥🔥🔥!令人惊讶的是,这个模型可以完全在 JavaScript 中运行。现在是进入 NLP 的好时机!

上述步骤安装了一个带有依赖项的 JavaScript 环境,以运行 ONNX 并在 JavaScript 中标记数据。之前创建的文本分类模型被加载到 JavaScript ONNX 运行时并运行推理。

提醒一下,文本分类模型使用两个标签来判断情绪,0 表示负面,1 表示正面。上面的结果显示了每个文本片段中每个标签的概率。

构建句子嵌入并比较 JavaScript 与 ONNX 中的相似性

!node . "I am happy", "I am glad"
0.8414919420066624

再一次......哇!!句子嵌入模型生成可用于比较语义相似度的向量,-1 表示最不相似,1 表示最相似。

虽然结果与导出的模型不完全匹配,但非常接近。再次值得一提的是,这是 100% JavaScript,没有 API 或远程调用,都在 node.js 中。

爪哇

让我们用 Java 尝试同样的事情。以下部分初始化 Java 构建环境并写出运行 ONNX 推理所需的代码。

import os

os.chdir("/content")
!mkdir java
os.chdir("/content/java")

# Copy ONNX models
!cp ../text-classify.onnx .
!cp ../embeddings.onnx .

# Save copy of Bert Tokenizer
tokenizer.save_pretrained("bert")

!mkdir -p src/main/java

# Install gradle
!wget https://services.gradle.org/distributions/gradle-7.2-bin.zip
!unzip -o gradle-7.2-bin.zip
!gradle-7.2/bin/gradle wrapper
apply plugin: "java"

repositories {
    mavenCentral()
}

dependencies {
    implementation "com.robrua.nlp:easy-bert:1.0.3"
    implementation "com.microsoft.onnxruntime:onnxruntime:1.8.1"
}

java {
    toolchain {
        languageVersion = JavaLanguageVersion.of(8)
    }
}

jar {
    archiveBaseName = "onnxjava"
}

task onnx(type: JavaExec) {
    description = "Runs ONNX demo"
    classpath = sourceSets.main.runtimeClasspath
    main = "OnnxDemo"
}
import java.io.File;

import java.nio.LongBuffer;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;

import com.robrua.nlp.bert.FullTokenizer;

class Tokens {
    public long[] ids;
    public long[] mask;
    public long[] types;
}

class Tokenizer {
    private FullTokenizer tokenizer;

    public Tokenizer(String path) {
        File vocab = new File(path);
        this.tokenizer = new FullTokenizer(vocab, true);
    }

    public Tokens tokenize(String text) {
        // Build list of tokens
        List<String> tokensList = new ArrayList();
        tokensList.add("[CLS]"); 
        tokensList.addAll(Arrays.asList(tokenizer.tokenize(text)));
        tokensList.add("[SEP]");

        int[] ids = tokenizer.convert(tokensList.toArray(new String[0]));

        Tokens tokens = new Tokens();

        // input ids    
        tokens.ids = Arrays.stream(ids).mapToLong(i -> i).toArray();

        // attention mask
        tokens.mask = new long[ids.length];
        Arrays.fill(tokens.mask, 1);

        // token type ids
        tokens.types = new long[ids.length];
        Arrays.fill(tokens.types, 0);

        return tokens;
    }
}

class Inference {
    private Tokenizer tokenizer;
    private OrtEnvironment env;
    private OrtSession session;

    public Inference(String model) throws Exception {
        this.tokenizer = new Tokenizer("bert/vocab.txt");
        this.env = OrtEnvironment.getEnvironment();
        this.session = env.createSession(model, new OrtSession.SessionOptions());
    }

    public float[][] predict(String text) throws Exception {
        Tokens tokens = this.tokenizer.tokenize(text);

        Map<String, OnnxTensor> inputs = new HashMap<String, OnnxTensor>();
        inputs.put("input_ids", OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.ids),  new long[]{1, tokens.ids.length}));
        inputs.put("attention_mask", OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.mask),  new long[]{1, tokens.mask.length}));
        inputs.put("token_type_ids", OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.types),  new long[]{1, tokens.types.length}));

        return (float[][])session.run(inputs).get(0).getValue();
    }
}

class Vectors {
    public static double similarity(float[] v1, float[] v2) {
        double dot = 0.0;
        double norm1 = 0.0;
        double norm2 = 0.0;

        for (int x = 0; x < v1.length; x++) {
            dot += v1[x] * v2[x];
            norm1 += Math.pow(v1[x], 2);
            norm2 += Math.pow(v2[x], 2);
        }

        return dot / (Math.sqrt(norm1) * Math.sqrt(norm2));
    }

    public static float[] softmax(float[] input) {
        double[] t = new double[input.length];
        double sum = 0.0;

        for (int x = 0; x < input.length; x++) {
            double val = Math.exp(input[x]);
            sum += val;
            t[x] = val;
        }

        float[] output = new float[input.length];
        for (int x = 0; x < output.length; x++) {
            output[x] = (float) (t[x] / sum);
        }

        return output;
    }
}

public class OnnxDemo {
    public static void main(String[] args) {
        try {
            if (args.length < 2) {
              Inference inference = new Inference("text-classify.onnx");

              float[][] v1 = inference.predict(args[0]);

              System.out.println(Arrays.toString(Vectors.softmax(v1[0])));
            }
            else {
              Inference inference = new Inference("embeddings.onnx");
              float[][] v1 = inference.predict(args[0]);
              float[][] v2 = inference.predict(args[1]);

              System.out.println(Vectors.similarity(v1[0], v2[0]));
            }
        }
        catch (Exception ex) {
            ex.printStackTrace();
        }
    }
}

使用 ONNX 在 Java 中运行文本分类

!./gradlew -q --console=plain onnx --args='"I am happy"' 2> /dev/null
!./gradlew -q --console=plain onnx --args='"I am mad"' 2> /dev/null
[0.0011046471, 0.99889535]
[0.9976444, 0.002355582]

上面的命令标记输入并使用先前使用 Java ONNX 推理会话创建的文本分类模型运行推理。

提醒一下,文本分类模型使用两个标签来判断情绪,0 表示负面,1 表示正面。上面的结果显示了每个文本片段中每个标签的概率。

构建句子嵌入并比较 Java 和 ONNX 中的相似性

!./gradlew -q --console=plain onnx --args='"I am happy" "I am glad"' 2> /dev/null
0.8581848568615768

句子嵌入模型生成可用于比较语义相似度的向量,-1 表示最不相似,1 表示最相似。

这是 100% Java,没有 API 或远程调用,全部在 JVM 中。还是觉得很神奇!

生锈

最后但同样重要的是,让我们试试 Rust。以下部分初始化 Rust 构建环境并写出运行 ONNX 推理所需的代码。

import os

os.chdir("/content")
!mkdir rust
os.chdir("/content/rust")

# Copy ONNX models
!cp ../text-classify.onnx .
!cp ../embeddings.onnx .

# Save copy of Bert Tokenizer
tokenizer.save_pretrained("bert")

# Install Rust
!apt-get install rustc

!mkdir -p src
[package]
name = "onnx-test"
version = "1.0.0"
description = """
ONNX Runtime Rust test
"""
edition = "2018"

[dependencies]
onnxruntime = { version = "0.0.14"}
tokenizers = { version = "0.10.1"}
use onnxruntime::environment::Environment;
use onnxruntime::GraphOptimizationLevel;
use onnxruntime::ndarray::{Array2, Axis};
use onnxruntime::tensor::OrtOwnedTensor;

use std::env;

use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
use tokenizers::models::wordpiece::WordPiece;
use tokenizers::normalizers::bert::BertNormalizer;
use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
use tokenizers::processors::bert::BertProcessing;
use tokenizers::tokenizer::{Result, Tokenizer, EncodeInput};

fn tokenize(text: String, inputs: usize) -> Vec<Array2<i64>> {
    // Load tokenizer
    let mut tokenizer = Tokenizer::new(Box::new(
        WordPiece::from_files("bert/vocab.txt")
            .build()
            .expect("Vocab file not found"),
    ));

    tokenizer.with_normalizer(Box::new(BertNormalizer::default()));
    tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer));
    tokenizer.with_decoder(Box::new(WordPieceDecoder::default()));
    tokenizer.with_post_processor(Box::new(BertProcessing::new(
        (
            String::from("[SEP]"),
            tokenizer.get_model().token_to_id("[SEP]").unwrap(),
        ),
        (
            String::from("[CLS]"),
            tokenizer.get_model().token_to_id("[CLS]").unwrap(),
        ),
    )));

    // Encode input text
    let encoding = tokenizer.encode(EncodeInput::Single(text), true).unwrap();

    let v1: Vec<i64> = encoding.get_ids().to_vec().into_iter().map(|x| x as i64).collect();
    let v2: Vec<i64> = encoding.get_attention_mask().to_vec().into_iter().map(|x| x as i64).collect();
    let v3: Vec<i64> = encoding.get_type_ids().to_vec().into_iter().map(|x| x as i64).collect();

    let ids = Array2::from_shape_vec((1, v1.len()), v1).unwrap();
    let mask = Array2::from_shape_vec((1, v2.len()), v2).unwrap();
    let tids = Array2::from_shape_vec((1, v3.len()), v3).unwrap();

    return if inputs > 2 { vec![ids, mask, tids] } else { vec![ids, mask] };
}

fn predict(text: String, softmax: bool) -> Vec<f32> {
    // Start onnx session
    let environment = Environment::builder()
        .with_name("test")
        .build().unwrap();

    // Derive model path
    let model = if softmax { "text-classify.onnx" } else { "embeddings.onnx" };

    let mut session = environment
        .new_session_builder().unwrap()
        .with_optimization_level(GraphOptimizationLevel::Basic).unwrap()
        .with_number_threads(1).unwrap()
        .with_model_from_file(model).unwrap();

    let inputs = tokenize(text, session.inputs.len());

    // Run inference and print result
    let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(inputs).unwrap();
    let output: &OrtOwnedTensor<f32, _> = &outputs[0];

    let probabilities: Vec<f32>;
    if softmax {
        probabilities = output
            .softmax(Axis(1))
            .iter()
            .copied()
            .collect::<Vec<_>>();
    }
    else {
        probabilities= output
            .iter()
            .copied()
            .collect::<Vec<_>>();
    }

    return probabilities;
}

fn similarity(v1: &Vec<f32>, v2: &Vec<f32>) -> f64 {
    let mut dot = 0.0;
    let mut norm1 = 0.0;
    let mut norm2 = 0.0;

    for x in 0..v1.len() {
        dot += v1[x] * v2[x];
        norm1 += v1[x].powf(2.0);
        norm2 += v2[x].powf(2.0);
    }

    return dot as f64 / (norm1.sqrt() * norm2.sqrt()) as f64
}

fn main() -> Result<()> {
    // Tokenize input string
    let args: Vec<String> = env::args().collect();

    if args.len() <= 2 {
      let v1 = predict(args[1].to_string(), true);
      println!("{:?}", v1);
    }
    else {
      let v1 = predict(args[1].to_string(), false);
      let v2 = predict(args[2].to_string(), false);
      println!("{:?}", similarity(&v1, &v2));
    }

    Ok(())
}

使用 ONNX 在 Rust 中运行文本分类

!cargo run "I am happy" 2> /dev/null
!cargo run "I am mad" 2> /dev/null
[0.0011003953, 0.99889964]
[0.9976444, 0.0023555849]

上面的命令对输入进行标记,并使用之前使用 Rust ONNX 推理会话创建的文本分类模型运行推理。

提醒一下,文本分类模型使用两个标签来判断情绪,0 表示负面,1 表示正面。上面的结果显示了每个文本片段中每个标签的概率。

构建句子嵌入并比较 Rust 与 ONNX 中的相似性

!cargo run "I am happy" "I am glad" 2> /dev/null
0.8583641740656903

句子嵌入模型生成可用于比较语义相似度的向量,-1 表示最不相似,1 表示最相似。

再一次,这是 100% Rust,没有 API 或远程调用。是的,仍然认为这很神奇!

收尾

本笔记本介绍了如何使用 txtai 将模型导出到 ONNX。然后这些模型在 Python、JavaScript、Java 和 Rust 中运行。 Golang 也进行了评估,但目前似乎没有足够稳定的 ONNX 运行时可用。

此方法提供了一种在多个平台上使用多种编程语言来训练和运行机器学习模型的方法。

以下是用例的非详尽列表。

  • 为移动/边缘设备构建本地执行模型

  • 当团队不想将 Python 添加到混合中时,使用 Java/JavaScript/Rust 开发堆栈运行模型

  • 将模型导出到 ONNX 以进行 Python 推理以提高 CPU 性能和/或减少软件依赖项的数量