diff --git a/Cargo.lock b/Cargo.lock index c156f405b4fb..eeec65a975d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -371,7 +371,7 @@ dependencies = [ "cap-primitives", "cap-std", "rustix 1.0.8", - "smallvec", + "smallvec 1.15.1", ] [[package]] @@ -789,7 +789,7 @@ dependencies = [ "serde_derive", "sha2", "similar", - "smallvec", + "smallvec 1.15.1", "souper-ir", "target-lexicon", "wasmtime-internal-math", @@ -850,7 +850,7 @@ dependencies = [ "serde", "serde_derive", "similar", - "smallvec", + "smallvec 1.15.1", "target-lexicon", "thiserror 2.0.17", "toml", @@ -867,7 +867,7 @@ dependencies = [ "hashbrown 0.15.2", "log", "similar", - "smallvec", + "smallvec 1.15.1", "target-lexicon", ] @@ -893,7 +893,7 @@ dependencies = [ "cranelift-reader", "libm", "log", - "smallvec", + "smallvec 1.15.1", "thiserror 2.0.17", ] @@ -970,7 +970,7 @@ version = "0.127.0" dependencies = [ "anyhow", "cranelift-codegen", - "smallvec", + "smallvec 1.15.1", "target-lexicon", ] @@ -1121,6 +1121,16 @@ dependencies = [ "uuid", ] +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.3.11" @@ -1213,7 +1223,7 @@ dependencies = [ "instant", "log", "once_cell", - "smallvec", + "smallvec 1.15.1", "symbolic_expressions", ] @@ -1821,7 +1831,7 @@ dependencies = [ "itoa", "pin-project-lite", "pin-utils", - "smallvec", + "smallvec 1.15.1", "tokio", "want", ] @@ -1905,7 +1915,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.15.1", "utf16_iter", "utf8_iter", "write16", @@ -1980,7 +1990,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.15.1", "utf8_iter", ] @@ -2670,24 +2680,22 @@ dependencies = [ [[package]] name = "ort" -version = "2.0.0-rc.2" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" dependencies = [ - "js-sys", "ort-sys", - "thiserror 1.0.65", - "tracing", - "web-sys", + "smallvec 2.0.0-alpha.10", ] [[package]] name = "ort-sys" -version = "2.0.0-rc.2" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" dependencies = [ "flate2", + "pkg-config", "sha2", "tar", "ureq", @@ -2722,6 +2730,15 @@ dependencies = [ "sha2", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -2752,9 +2769,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.29" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "postcard" @@ -3019,7 +3036,7 @@ dependencies = [ "log", "rustc-hash", "serde", - "smallvec", + "smallvec 1.15.1", ] [[package]] @@ -3177,26 +3194,14 @@ dependencies = [ ] [[package]] -name = "rustls" -version = "0.23.7" +name = "rustls-pki-types" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebbbdb961df0ad3f2652da8f3fdc4b36122f568f968f45ad3316f26c025c677b" +checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" dependencies = [ - "log", - "once_cell", - "ring", - "rustls-pki-types", - "rustls-webpki", - "subtle", "zeroize", ] -[[package]] -name = "rustls-pki-types" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8" - [[package]] name = "rustls-webpki" version = "0.102.2" @@ -3452,6 +3457,12 @@ dependencies = [ "serde", ] +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + [[package]] name = "socket2" version = "0.6.1" @@ -3462,6 +3473,17 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "souper-ir" version = "2.1.0" @@ -3835,7 +3857,7 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" dependencies = [ - "rustls 0.22.4", + "rustls", "rustls-pki-types", "tokio", ] @@ -4030,17 +4052,32 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.10.0" +version = "3.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72139d247e5f97a3eff96229a7ae85ead5328a39efe76f8bf5a06313d505b6ea" +checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" dependencies = [ "base64", + "der", "log", - "once_cell", - "rustls 0.23.7", + "native-tls", + "percent-encoding", "rustls-pki-types", - "url", - "webpki-roots", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs", +] + +[[package]] +name = "ureq-proto" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b4531c118335662134346048ddb0e54cc86bd7e81866757873055f0e38f5d2" +dependencies = [ + "base64", + "http", + "httparse", + "log", ] [[package]] @@ -4055,6 +4092,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf16_iter" version = "1.0.5" @@ -4325,7 +4368,7 @@ dependencies = [ "serde", "serde_derive", "serde_yaml", - "smallvec", + "smallvec 1.15.1", "wasm-encoder", "wasmparser 0.241.2", "wat", @@ -4514,7 +4557,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", - "smallvec", + "smallvec 1.15.1", "target-lexicon", "tempfile", "tokio", @@ -4695,7 +4738,7 @@ dependencies = [ "semver", "serde", "serde_derive", - "smallvec", + "smallvec 1.15.1", "target-lexicon", "wasm-encoder", "wasmparser 0.241.2", @@ -4742,7 +4785,7 @@ dependencies = [ "quote", "rand 0.8.5", "rand 0.9.2", - "smallvec", + "smallvec 1.15.1", "target-lexicon", "wasmparser 0.241.2", "wasmtime", @@ -4764,7 +4807,7 @@ dependencies = [ "rayon", "serde", "serde_json", - "smallvec", + "smallvec 1.15.1", "target-lexicon", "tempfile", "v8", @@ -4851,7 +4894,7 @@ dependencies = [ "log", "object 0.37.3", "pulley-interpreter", - "smallvec", + "smallvec 1.15.1", "target-lexicon", "thiserror 2.0.17", "wasmparser 0.241.2", @@ -5062,7 +5105,7 @@ dependencies = [ "http-body", "http-body-util", "hyper", - "rustls 0.22.4", + "rustls", "sha2", "tempfile", "test-log", @@ -5141,7 +5184,7 @@ dependencies = [ "anyhow", "bytes", "futures", - "rustls 0.22.4", + "rustls", "test-programs-artifacts", "tokio", "tokio-rustls", @@ -5232,13 +5275,12 @@ dependencies = [ ] [[package]] -name = "web-sys" -version = "0.3.57" +name = "webpki-root-certs" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b17e741662c70c8bd24ac5c5b18de314a2c26c32bf8346ee1e6f53de919c283" +checksum = "ee3e3b5f5e80bc89f30ce8d0343bf4e5f12341c51f3e26cbeecbc7c85443e85b" dependencies = [ - "js-sys", - "wasm-bindgen", + "rustls-pki-types", ] [[package]] @@ -5354,7 +5396,7 @@ dependencies = [ "cranelift-codegen", "gimli 0.32.3", "regalloc2", - "smallvec", + "smallvec 1.15.1", "target-lexicon", "thiserror 2.0.17", "wasmparser 0.241.2", diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index 93bc6c965fcf..43e1e6275545 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -31,7 +31,7 @@ wasmtime = { workspace = true, features = [ tracing = { workspace = true } thiserror = { workspace = true } -ort = { version = "2.0.0-rc.2", default-features = false, features = [ +ort = { version = "2.0.0-rc.10", default-features = false, features = [ "copy-dylibs", ], optional = true } tch = { version = "0.17.0", default-features = false, optional = true} @@ -71,6 +71,8 @@ openvino = ["dep:openvino"] onnx = ["dep:ort"] # Use prebuilt ONNX Runtime binaries from ort. onnx-download = ["onnx", "ort/download-binaries"] +# CUDA execution provider for NVIDIA GPU support (requires CUDA toolkit) +onnx-cuda = ["onnx", "ort/cuda"] # WinML is only available on Windows 10 1809 and later. winml = ["dep:windows"] # PyTorch is available on all platforms; requires Libtorch to be installed diff --git a/crates/wasi-nn/examples/classification-component-onnx/README.md b/crates/wasi-nn/examples/classification-component-onnx/README.md index 9105aa96793f..119d82aa8131 100644 --- a/crates/wasi-nn/examples/classification-component-onnx/README.md +++ b/crates/wasi-nn/examples/classification-component-onnx/README.md @@ -3,35 +3,82 @@ This example demonstrates how to use the `wasi-nn` crate to run a classification using the [ONNX Runtime](https://onnxruntime.ai/) backend from a WebAssembly component. +It supports CPU and GPU (Nvidia CUDA) execution targets. + +**Note:** +GPU execution target only supports Nvidia CUDA (onnx-cuda) as execution provider (EP) for now. + ## Build + In this directory, run the following command to build the WebAssembly component: ```console cargo component build ``` -In the Wasmtime root directory, run the following command to build the Wasmtime CLI and run the WebAssembly component: +## Running the Example + +### Building Wasmtime + +#### For CPU-only execution: ```sh -# build wasmtime with component-model and WASI-NN with ONNX runtime support cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-download +``` + +#### For GPU (Nvidia CUDA) support: +```sh +# This will automatically download onnxruntime dynamic shared library from cdn.pyke.io +cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-cuda,wasmtime-wasi-nn/onnx-download +``` + +### Running with Different Execution Targets + +The execution target is controlled by passing a single argument to the WASM module. + +Arguments: +- No argument or `cpu` - Use CPU execution +- `gpu` or `cuda` - Use GPU/CUDA execution -# run the component with wasmtime +#### CPU Execution (default): +```sh ./target/debug/wasmtime run \ -Snn \ --dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \ ./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip1/debug/classification-component-onnx.wasm ``` -You should get the following output: +#### GPU (CUDA) Execution: +```sh +# path to `libonnxruntime_providers_cuda.so` downloaded by `ort-sys` +export LD_LIBRARY_PATH={wasmtime_workspace}/target/debug + +./target/debug/wasmtime run \ + -Snn \ + --dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \ + ./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip1/debug/classification-component-onnx.wasm \ + gpu + +``` + +## Expected Output + +You should get output similar to: ```txt +No execution target specified, defaulting to CPU Read ONNX model, size in bytes: 4956208 -Loaded graph into wasi-nn +Loaded graph into wasi-nn with Cpu target Created wasi-nn execution context. Read ONNX Labels, # of labels: 1000 -Set input tensor Executed graph inference -Getting inferencing output Retrieved output data with length: 4000 Index: n02099601 golden retriever - Probability: 0.9948673 Index: n02088094 Afghan hound, Afghan - Probability: 0.002528982 Index: n02102318 cocker spaniel, English cocker spaniel, cocker - Probability: 0.0010986356 ``` + +When using GPU target, the first line will indicate the selected execution target. +You can monitor GPU usage using cmd `watch -n 1 nvidia-smi`. + +## Prerequisites for GPU(CUDA) Support +- NVIDIA GPU with CUDA support +- CUDA Toolkit 12.x with cuDNN 9.x +- Build wasmtime with `wasmtime-wasi-nn/onnx-cuda` feature diff --git a/crates/wasi-nn/examples/classification-component-onnx/src/main.rs b/crates/wasi-nn/examples/classification-component-onnx/src/main.rs index c02fc1ed8da2..affa61681557 100644 --- a/crates/wasi-nn/examples/classification-component-onnx/src/main.rs +++ b/crates/wasi-nn/examples/classification-component-onnx/src/main.rs @@ -17,14 +17,46 @@ use self::wasi::nn::{ tensor::{Tensor, TensorData, TensorDimensions, TensorType}, }; +/// Determine execution target from command-line argument +/// Usage: wasm_module [cpu|gpu|cuda] +fn get_execution_target() -> ExecutionTarget { + let args: Vec = std::env::args().collect(); + + // First argument (index 0) is the program name, second (index 1) is the target + // Ignore any arguments after index 1 + if args.len() >= 2 { + match args[1].to_lowercase().as_str() { + "gpu" | "cuda" => { + println!("Using GPU (CUDA) execution target from argument"); + return ExecutionTarget::Gpu; + } + "cpu" => { + println!("Using CPU execution target from argument"); + return ExecutionTarget::Cpu; + } + _ => { + println!("Unknown execution target '{}', defaulting to CPU", args[1]); + } + } + } else { + println!("No execution target specified, defaulting to CPU"); + println!("Usage: [cpu|gpu|cuda]"); + } + + ExecutionTarget::Cpu +} + fn main() { // Load the ONNX model - SqueezeNet 1.1-7 // Full details: https://github.com/onnx/models/tree/main/vision/classification/squeezenet let model: GraphBuilder = fs::read("fixture/models/squeezenet1.1-7.onnx").unwrap(); println!("Read ONNX model, size in bytes: {}", model.len()); - let graph = load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).unwrap(); - println!("Loaded graph into wasi-nn"); + // Determine execution target + let execution_target = get_execution_target(); + + let graph = load(&[model], GraphEncoding::Onnx, execution_target).unwrap(); + println!("Loaded graph into wasi-nn with {:?} target", execution_target); let exec_context = Graph::init_execution_context(&graph).unwrap(); println!("Created wasi-nn execution context."); diff --git a/crates/wasi-nn/src/backend/onnx.rs b/crates/wasi-nn/src/backend/onnx.rs index aa033cfca16d..41c9592c7919 100644 --- a/crates/wasi-nn/src/backend/onnx.rs +++ b/crates/wasi-nn/src/backend/onnx.rs @@ -6,11 +6,17 @@ use super::{ use crate::backend::{Id, read}; use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; use crate::{ExecutionContext, Graph}; -use anyhow::Context; -use ort::{GraphOptimizationLevel, Session, inputs}; +use ort::execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch}; +use ort::session::builder::GraphOptimizationLevel; +use ort::session::{Input, Output, Session, SessionInputValue}; +use ort::tensor::TensorElementType; +use ort::value::{Tensor as OrtTensor, ValueType}; use std::path::Path; use std::sync::{Arc, Mutex}; +#[cfg(feature = "onnx-cuda")] +use ort::execution_providers::CUDAExecutionProvider; + #[derive(Default)] pub struct OnnxBackend(); unsafe impl Send for OnnxBackend {} @@ -26,7 +32,11 @@ impl BackendInner for OnnxBackend { return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into()); } + // Configure execution providers based on target + let execution_providers = configure_execution_providers(target)?; + let session = Session::builder()? + .with_execution_providers(execution_providers)? .with_optimization_level(GraphOptimizationLevel::Level3)? .commit_from_memory(builders[0])?; @@ -40,6 +50,36 @@ impl BackendInner for OnnxBackend { } } +/// Configure execution providers based on the target +fn configure_execution_providers( + target: ExecutionTarget, +) -> Result, BackendError> { + match target { + ExecutionTarget::Cpu => { + // Use CPU execution provider with default configuration + tracing::debug!("Using CPU execution provider"); + Ok(vec![CPUExecutionProvider::default().build()]) + } + ExecutionTarget::Gpu => { + #[cfg(feature = "onnx-cuda")] + { + // Use CUDA execution provider for GPU acceleration + tracing::debug!("Configuring ONNX Nvidia CUDA execution provider for GPU target"); + Ok(vec![CUDAExecutionProvider::default().build()]) + } + #[cfg(not(feature = "onnx-cuda"))] + { + Err(BackendError::BackendAccess(anyhow::anyhow!( + "GPU execution target is requested, but 'onnx-cuda' feature is not enabled" + ))) + } + } + ExecutionTarget::Tpu => { + unimplemented!("TPU execution target is not supported for ONNX backend yet"); + } + } +} + impl BackendFromDir for OnnxBackend { fn load_from_dir( &mut self, @@ -177,18 +217,19 @@ impl BackendExecutionContext for OnnxExecutionContext { input_slot.tensor.replace(input.tensor.clone()); } - let mut session_inputs: Vec> = vec![]; - for i in &self.inputs { - session_inputs.extend(to_input_value(i)?); - } - let session = self.session.lock().unwrap(); - let session_outputs = session.run(session_inputs.as_slice())?; + let session_inputs: Vec> = self + .inputs + .iter() + .map(|i| to_input_value(i)) + .collect::, _>>()?; + let mut session = self.session.lock().unwrap(); + let session_outputs = session.run(&session_inputs[..])?; let mut output_tensors = Vec::new(); for i in 0..self.outputs.len() { // TODO: fix preexisting gap--this only handles f32 tensors. - let raw: (Vec, &[f32]) = session_outputs[i].try_extract_raw_tensor()?; - let f32s = raw.1.to_vec(); + let (_shape, data) = session_outputs[i].try_extract_tensor::()?; + let f32s = data.to_vec(); let output = &mut self.outputs[i]; let tensor = Tensor { dimensions: output.shape.dimensions_as_u32()?, @@ -206,16 +247,17 @@ impl BackendExecutionContext for OnnxExecutionContext { // WITX None => { - let mut session_inputs: Vec> = vec![]; - for i in &self.inputs { - session_inputs.extend(to_input_value(i)?); - } - let session = self.session.lock().unwrap(); - let session_outputs = session.run(session_inputs.as_slice())?; + let session_inputs: Vec> = self + .inputs + .iter() + .map(|i| to_input_value(i)) + .collect::, _>>()?; + let mut session = self.session.lock().unwrap(); + let session_outputs = session.run(&session_inputs[..])?; for i in 0..self.outputs.len() { // TODO: fix preexisting gap--this only handles f32 tensors. - let raw: (Vec, &[f32]) = session_outputs[i].try_extract_raw_tensor()?; - let f32s = raw.1.to_vec(); + let (_shape, data) = session_outputs[i].try_extract_tensor::()?; + let f32s = data.to_vec(); let output = &mut self.outputs[i]; output.tensor.replace(Tensor { dimensions: output.shape.dimensions_as_u32()?, @@ -244,7 +286,7 @@ impl BackendExecutionContext for OnnxExecutionContext { impl From for BackendError { fn from(e: ort::Error) -> Self { - BackendError::BackendAccess(e.into()) + BackendError::BackendAccess(anyhow::anyhow!("{}", e)) } } @@ -265,7 +307,7 @@ struct Shape { } impl Shape { - fn from_onnx_input(input: &ort::Input) -> Result { + fn from_onnx_input(input: &Input) -> Result { let name = input.name.clone(); let (dimensions, ty) = convert_value_type(&input.input_type)?; Ok(Self { @@ -275,7 +317,7 @@ impl Shape { }) } - fn from_onnx_output(output: &ort::Output) -> Result { + fn from_onnx_output(output: &Output) -> Result { let name = output.name.clone(); let (dimensions, ty) = convert_value_type(&output.output_type)?; Ok(Self { @@ -322,10 +364,10 @@ impl Shape { } } -fn convert_value_type(vt: &ort::ValueType) -> Result<(Vec, TensorType), BackendError> { +fn convert_value_type(vt: &ValueType) -> Result<(Vec, TensorType), BackendError> { match vt { - ort::ValueType::Tensor { ty, dimensions } => { - let dims = dimensions.clone(); + ValueType::Tensor { ty, shape, .. } => { + let dims = shape.to_vec(); let ty = (*ty).try_into()?; Ok((dims, ty)) } @@ -341,15 +383,15 @@ fn convert_i64(i: &i64) -> Result { }) } -impl TryFrom for TensorType { +impl TryFrom for TensorType { type Error = BackendError; - fn try_from(ty: ort::TensorElementType) -> Result { + fn try_from(ty: TensorElementType) -> Result { match ty { - ort::TensorElementType::Float32 => Ok(TensorType::Fp32), - ort::TensorElementType::Float64 => Ok(TensorType::Fp64), - ort::TensorElementType::Uint8 => Ok(TensorType::U8), - ort::TensorElementType::Int32 => Ok(TensorType::I32), - ort::TensorElementType::Int64 => Ok(TensorType::I64), + TensorElementType::Float32 => Ok(TensorType::Fp32), + TensorElementType::Float64 => Ok(TensorType::Fp64), + TensorElementType::Uint8 => Ok(TensorType::U8), + TensorElementType::Int32 => Ok(TensorType::I32), + TensorElementType::Int64 => Ok(TensorType::I64), _ => Err(BackendError::BackendAccess(anyhow::anyhow!( "unsupported tensor type: {ty:?}" ))), @@ -357,18 +399,16 @@ impl TryFrom for TensorType { } } -fn to_input_value(slot: &TensorSlot) -> Result<[ort::SessionInputValue<'_>; 1], BackendError> { +fn to_input_value(slot: &TensorSlot) -> Result, BackendError> { match &slot.tensor { Some(tensor) => match tensor.ty { TensorType::Fp32 => { let data = bytes_to_f32_vec(tensor.data.to_vec()); - let dimensions = tensor - .dimensions - .iter() - .map(|d| *d as i64) // TODO: fewer conversions - .collect::>(); - Ok(inputs![(dimensions, Arc::new(data.into_boxed_slice()))] - .context("failed to create ONNX session input")?) + let dimensions: Vec = + tensor.dimensions.iter().map(|d| *d as usize).collect(); + // Create an ort::Tensor and convert to SessionInputValue + let ort_tensor = OrtTensor::from_array((&dimensions[..], data.into_boxed_slice()))?; + Ok(ort_tensor.into()) } _ => { unimplemented!("{:?} not supported by ONNX", tensor.ty);