@@ -322,154 +322,6 @@ impl Into<sys::GraphOptimizationLevel> for GraphOptimizationLevel {
322322 }
323323}
324324
325- // FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
326- // FIXME: Add tests to cover the commented out types
327- /// Enum mapping ONNX Runtime's supported tensor types
328- #[ derive( Debug ) ]
329- #[ cfg_attr( not( windows) , repr( u32 ) ) ]
330- #[ cfg_attr( windows, repr( i32 ) ) ]
331- pub enum TensorElementDataType {
332- /// 32-bit floating point, equivalent to Rust's `f32`
333- Float = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt ,
334- /// Unsigned 8-bit int, equivalent to Rust's `u8`
335- Uint8 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt ,
336- /// Signed 8-bit int, equivalent to Rust's `i8`
337- Int8 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt ,
338- /// Unsigned 16-bit int, equivalent to Rust's `u16`
339- Uint16 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt ,
340- /// Signed 16-bit int, equivalent to Rust's `i16`
341- Int16 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt ,
342- /// Signed 32-bit int, equivalent to Rust's `i32`
343- Int32 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt ,
344- /// Signed 64-bit int, equivalent to Rust's `i64`
345- Int64 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt ,
346- /// String, equivalent to Rust's `String`
347- String = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt ,
348- // /// Boolean, equivalent to Rust's `bool`
349- // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
350- // /// 16-bit floating point, equivalent to Rust's `f16`
351- // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
352- /// 64-bit floating point, equivalent to Rust's `f64`
353- Double = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt ,
354- /// Unsigned 32-bit int, equivalent to Rust's `u32`
355- Uint32 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt ,
356- /// Unsigned 64-bit int, equivalent to Rust's `u64`
357- Uint64 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt ,
358- // /// Complex 64-bit floating point, equivalent to Rust's `???`
359- // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt,
360- // /// Complex 128-bit floating point, equivalent to Rust's `???`
361- // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt,
362- // /// Brain 16-bit floating point
363- // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt,
364- }
365-
366- impl Into < sys:: ONNXTensorElementDataType > for TensorElementDataType {
367- fn into ( self ) -> sys:: ONNXTensorElementDataType {
368- use TensorElementDataType :: * ;
369- match self {
370- Float => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ,
371- Uint8 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 ,
372- Int8 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 ,
373- Uint16 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 ,
374- Int16 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 ,
375- Int32 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 ,
376- Int64 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 ,
377- String => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING ,
378- // Bool => {
379- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
380- // }
381- // Float16 => {
382- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
383- // }
384- Double => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE ,
385- Uint32 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 ,
386- Uint64 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 ,
387- // Complex64 => {
388- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
389- // }
390- // Complex128 => {
391- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
392- // }
393- // Bfloat16 => {
394- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
395- // }
396- }
397- }
398- }
399-
400- /// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
401- pub trait TypeToTensorElementDataType {
402- /// Return the ONNX type for a Rust type
403- fn tensor_element_data_type ( ) -> TensorElementDataType ;
404-
405- /// If the type is `String`, returns `Some` with utf8 contents, else `None`.
406- fn try_utf8_bytes ( & self ) -> Option < & [ u8 ] > ;
407- }
408-
409- macro_rules! impl_type_trait {
410- ( $type_: ty, $variant: ident) => {
411- impl TypeToTensorElementDataType for $type_ {
412- fn tensor_element_data_type( ) -> TensorElementDataType {
413- // unsafe { std::mem::transmute(TensorElementDataType::$variant) }
414- TensorElementDataType :: $variant
415- }
416-
417- fn try_utf8_bytes( & self ) -> Option <& [ u8 ] > {
418- None
419- }
420- }
421- } ;
422- }
423-
424- impl_type_trait ! ( f32 , Float ) ;
425- impl_type_trait ! ( u8 , Uint8 ) ;
426- impl_type_trait ! ( i8 , Int8 ) ;
427- impl_type_trait ! ( u16 , Uint16 ) ;
428- impl_type_trait ! ( i16 , Int16 ) ;
429- impl_type_trait ! ( i32 , Int32 ) ;
430- impl_type_trait ! ( i64 , Int64 ) ;
431- // impl_type_trait!(bool, Bool);
432- // impl_type_trait!(f16, Float16);
433- impl_type_trait ! ( f64 , Double ) ;
434- impl_type_trait ! ( u32 , Uint32 ) ;
435- impl_type_trait ! ( u64 , Uint64 ) ;
436- // impl_type_trait!(, Complex64);
437- // impl_type_trait!(, Complex128);
438- // impl_type_trait!(, Bfloat16);
439-
440- /// Adapter for common Rust string types to Onnx strings.
441- ///
442- /// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
443- /// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
444- /// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
445- /// types (which might implement `AsRef<str>` at some point in the future).
446- pub trait Utf8Data {
447- /// Returns the utf8 contents.
448- fn utf8_bytes ( & self ) -> & [ u8 ] ;
449- }
450-
451- impl Utf8Data for String {
452- fn utf8_bytes ( & self ) -> & [ u8 ] {
453- self . as_bytes ( )
454- }
455- }
456-
457- impl < ' a > Utf8Data for & ' a str {
458- fn utf8_bytes ( & self ) -> & [ u8 ] {
459- self . as_bytes ( )
460- }
461- }
462-
463- impl < T : Utf8Data > TypeToTensorElementDataType for T {
464- fn tensor_element_data_type ( ) -> TensorElementDataType {
465- TensorElementDataType :: String
466- }
467-
468- fn try_utf8_bytes ( & self ) -> Option < & [ u8 ] > {
469- Some ( self . utf8_bytes ( ) )
470- }
471- }
472-
473325/// Allocator type
474326#[ derive( Debug , Clone ) ]
475327#[ repr( i32 ) ]
0 commit comments