tfledge/
tensor.rs

1use core::alloc::Layout;
2use core::mem::size_of;
3
4use crate::Error;
5use alloc::alloc::alloc;
6use alloc::vec::Vec;
7use core::marker::PhantomData;
8
9use crate::ffi::*;
10
11/// Marker trait for tensor input/output types
12pub(crate) trait TensorInputOrOutput {}
13
14/// Marker struct for input tensors
15pub struct Input;
16impl TensorInputOrOutput for Input {}
17
18/// Marker struct for output tensors
19pub struct Output;
20impl TensorInputOrOutput for Output {}
21
22/// A TensorFlow Lite tensor
23///
24/// This structure represents a tensor in a TensorFlow Lite model. It can be either an input
25/// tensor or an output tensor, depending on the type parameter `IO`.
26///
27/// # Examples
28///
29/// ```
30/// # use tfledge::{Interpreter, Model, Error, list_devices, Tensor};
31/// # fn main() -> Result<(), Error> {
32/// # let model = Model::from_file("model.tflite")?;
33/// # let device = list_devices().next().unwrap();
34/// # let mut interpreter = Interpreter::new(model, device)?;
35/// // Get the first input tensor as a tensor of f32 values
36/// let input_tensor: Tensor<Input, f32> = interpreter.input_tensor(0);
37///
38/// // Check the data type of the tensor
39/// assert_eq!(input_tensor.kind(), crate::ffi::TfLiteType::kTfLiteFloat32);
40/// # Ok(())
41/// # }
42/// ```
43pub struct Tensor<IO, T>
44where
45    IO: TensorInputOrOutput,
46    T: InnerTensorData,
47{
48    pub(crate) ptr: *mut TfLiteTensor,
49    pub(crate) _marker: PhantomData<(IO, T)>,
50}
51impl<IO, T> Tensor<IO, T>
52where
53    IO: TensorInputOrOutput,
54    T: InnerTensorData,
55{
56    /// Data type of tensor
57    pub fn kind(&self) -> TfLiteType {
58        unsafe { TfLiteTensorType(self.ptr) }
59    }
60    /// Number of dimensions ([None] if opaque)
61    pub fn num_dims(&self) -> Option<u32> {
62        let i = unsafe { TfLiteTensorNumDims(self.ptr) };
63
64        if i == -1 {
65            return None;
66        }
67
68        Some(i as u32)
69    }
70    /// Length of tensor for a given dimension
71    pub fn dim(&self, id: u32) -> i32 {
72        unsafe { TfLiteTensorDim(self.ptr, id as i32) }
73    }
74}
75// read is lower cost than write
76// TODO: Need to specify that here somewhere
77impl<T: InnerTensorData> Tensor<Input, T> {
78    /// Write data to the tensor.
79    ///
80    /// # Arguments
81    ///
82    /// * `data` - The data to write to the tensor.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if the data cannot be written to the tensor.
87    pub fn write(&mut self, data: &[u8]) -> Result<(), Error> {
88        unsafe {
89            let ret = TfLiteTensorCopyFromBuffer(self.ptr, data.as_ptr() as *const _, data.len());
90
91            Error::from(ret)
92        }
93    }
94}
95impl<T: InnerTensorData> Tensor<Output, T> {
96    /// Read data from the tensor.
97    ///
98    /// # Type Parameters
99    ///
100    /// * `const N: usize` - The number of elements to read from the tensor at a time. For
101    ///   example, if `N` is 4, then the method will return a vector of 4-element arrays.
102    pub fn read<const N: usize>(&self) -> Vec<[T; N]> {
103        unsafe {
104            // Calculate the number of chunks to read from the tensor
105            let ct = TfLiteTensorByteSize(self.ptr) / size_of::<[T; N]>();
106
107            // Get a pointer to the tensor data
108            let ptr = TfLiteTensorData(self.ptr);
109
110            // Read the data from the tensor
111            core::slice::from_raw_parts::<[T; N]>(ptr as *const _, ct).to_vec()
112        }
113    }
114}
115
116pub(crate) trait InnerTensorData: Clone + Copy + Sized {
117    const TFLITE_KIND: TfLiteType;
118}
119
120impl InnerTensorData for i8 {
121    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteInt8;
122}
123impl InnerTensorData for i16 {
124    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteInt16;
125}
126impl InnerTensorData for i32 {
127    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteInt32;
128}
129impl InnerTensorData for i64 {
130    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteInt64;
131}
132
133impl InnerTensorData for u8 {
134    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteUInt8;
135}
136impl InnerTensorData for u16 {
137    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteUInt16;
138}
139impl InnerTensorData for u32 {
140    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteUInt32;
141}
142impl InnerTensorData for u64 {
143    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteUInt64;
144}
145
146impl InnerTensorData for f32 {
147    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteFloat32;
148}
149impl InnerTensorData for f64 {
150    const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteFloat64;
151}
152
153pub struct TensorData<T: InnerTensorData> {
154    size: usize,
155    ptr: *mut T,
156}
157impl<T: InnerTensorData> TensorData<T> {
158    /// Create a new [`TensorData`] buffer
159    ///
160    /// # Arguments
161    ///
162    /// * `size` - The size of the buffer in bytes.
163    pub fn new(size: usize) -> Self {
164        unsafe {
165            let ptr = alloc(Layout::array::<T>(size).unwrap()) as *mut T;
166
167            Self { ptr, size }
168        }
169    }
170    /// Tensor data as a raw slice
171    pub fn as_slice(&self) -> &[T] {
172        unsafe { core::slice::from_raw_parts(self.ptr.cast_const(), self.size) }
173    }
174}