tfledge/
interpreter.rs

1use core::ptr::null_mut;
2
3use crate::error::error_reporter;
4use crate::tensor::InnerTensorData;
5use crate::{CoralDevice, Error, Input, Model, Output, Tensor};
6use core::marker::PhantomData;
7
8use crate::ffi::*;
9
10/// The core structure for making inferences with TFLite
11///
12/// # Examples
13///
14/// ```
15/// # use tfledge::{Interpreter, Model, Error, list_devices};
16/// # fn main() -> Result<(), Error> {
17/// // Load the TFLite model
18/// let model = Model::from_file("model.tflite")?;
19///
20/// // Get a Coral device
21/// let device = list_devices().next().unwrap();
22///
23/// // Create a new interpreter
24/// let mut interpreter = Interpreter::new(model, device)?;
25///
26/// // ... perform inference ...
27///
28/// # Ok(())
29/// # }
30/// ```
31pub struct Interpreter {
32    ptr: *mut TfLiteInterpreter,
33}
34impl Interpreter {
35    /// Create a new [Interpreter] and allocate tensors for a given [Model]
36    ///
37    /// # Arguments
38    ///
39    /// * `model` - The TensorFlow Lite model to use for inference.
40    /// * `dev` - The Coral device to use for acceleration.
41    ///
42    /// # Errors
43    ///
44    /// Returns an error if the interpreter cannot be created or if the tensors cannot be
45    /// allocated.
46    pub fn new(model: Model, dev: CoralDevice) -> Result<Self, Error> {
47        unsafe {
48            // Build the interpreter options
49            let opts: *mut TfLiteInterpreterOptions = TfLiteInterpreterOptionsCreate();
50            TfLiteInterpreterOptionsSetErrorReporter(opts, Some(error_reporter), null_mut());
51            TfLiteInterpreterOptionsAddDelegate(opts, dev.create_delegate());
52
53            // Create the interpreter
54            let ptr = TfLiteInterpreterCreate(model.ptr, opts);
55            if ptr.is_null() {
56                return Err(Error::FailedToCreateInterpreter);
57            }
58
59            // Allocate tensors
60            Self::allocate_tensors(&mut Self { ptr })?;
61
62            Ok(Self { ptr })
63        }
64    }
65
66    /// Allocate tensors for the interpreter
67    ///
68    /// This method is called by [`Interpreter::new`] to allocate the tensors required by the
69    /// model.
70    fn allocate_tensors(&mut self) -> Result<(), Error> {
71        unsafe {
72            let ret = TfLiteInterpreterAllocateTensors(self.ptr);
73            Error::from(ret)
74        }
75    }
76
77    /// Get an input tensor
78    ///
79    /// # Arguments
80    ///
81    /// * `id` - The index of the input tensor to get.
82    ///
83    /// # Type Parameters
84    ///
85    /// * `T` - The data type of the tensor. Must implement [`InnerTensorData`].
86    pub fn input_tensor<T: InnerTensorData>(&self, id: u32) -> Tensor<Input, T> {
87        unsafe {
88            let ptr = TfLiteInterpreterGetInputTensor(self.ptr, id as i32);
89
90            Tensor::<Input, T> {
91                ptr,
92                _marker: PhantomData,
93            }
94        }
95    }
96
97    /// Get an output tensor
98    ///
99    /// # Arguments
100    ///
101    /// * `id` - The index of the output tensor to get.
102    ///
103    /// # Type Parameters
104    ///
105    /// * `T` - The data type of the tensor. Must implement [`InnerTensorData`].
106    pub fn output_tensor<T: InnerTensorData>(&self, id: u32) -> Tensor<Output, T> {
107        unsafe {
108            let ptr = TfLiteInterpreterGetOutputTensor(self.ptr, id as i32);
109
110            Tensor::<Output, T> {
111                ptr: ptr as *mut _,
112                _marker: PhantomData,
113            }
114        }
115    }
116
117    /// Run inference
118    ///
119    /// This basically just processes data from the input tensors, using the model, into the ourput
120    /// tensors.
121    ///
122    /// # Errors
123    ///
124    /// Returns an error if inference fails.
125    pub fn invoke(&mut self) -> Result<(), Error> {
126        unsafe {
127            let ret = TfLiteInterpreterInvoke(self.ptr);
128
129            Error::from(ret)
130        }
131    }
132}
133impl Drop for Interpreter {
134    fn drop(&mut self) {
135        unsafe {
136            TfLiteInterpreterDelete(self.ptr);
137        }
138    }
139}