tfledge/interpreter.rs
1use core::ptr::null_mut;
2
3use crate::error::error_reporter;
4use crate::tensor::InnerTensorData;
5use crate::{CoralDevice, Error, Input, Model, Output, Tensor};
6use core::marker::PhantomData;
7
8use crate::ffi::*;
9
10/// The core structure for making inferences with TFLite
11///
12/// # Examples
13///
14/// ```
15/// # use tfledge::{Interpreter, Model, Error, list_devices};
16/// # fn main() -> Result<(), Error> {
17/// // Load the TFLite model
18/// let model = Model::from_file("model.tflite")?;
19///
20/// // Get a Coral device
21/// let device = list_devices().next().unwrap();
22///
23/// // Create a new interpreter
24/// let mut interpreter = Interpreter::new(model, device)?;
25///
26/// // ... perform inference ...
27///
28/// # Ok(())
29/// # }
30/// ```
31pub struct Interpreter {
32 ptr: *mut TfLiteInterpreter,
33}
34impl Interpreter {
35 /// Create a new [Interpreter] and allocate tensors for a given [Model]
36 ///
37 /// # Arguments
38 ///
39 /// * `model` - The TensorFlow Lite model to use for inference.
40 /// * `dev` - The Coral device to use for acceleration.
41 ///
42 /// # Errors
43 ///
44 /// Returns an error if the interpreter cannot be created or if the tensors cannot be
45 /// allocated.
46 pub fn new(model: Model, dev: CoralDevice) -> Result<Self, Error> {
47 unsafe {
48 // Build the interpreter options
49 let opts: *mut TfLiteInterpreterOptions = TfLiteInterpreterOptionsCreate();
50 TfLiteInterpreterOptionsSetErrorReporter(opts, Some(error_reporter), null_mut());
51 TfLiteInterpreterOptionsAddDelegate(opts, dev.create_delegate());
52
53 // Create the interpreter
54 let ptr = TfLiteInterpreterCreate(model.ptr, opts);
55 if ptr.is_null() {
56 return Err(Error::FailedToCreateInterpreter);
57 }
58
59 // Allocate tensors
60 Self::allocate_tensors(&mut Self { ptr })?;
61
62 Ok(Self { ptr })
63 }
64 }
65
66 /// Allocate tensors for the interpreter
67 ///
68 /// This method is called by [`Interpreter::new`] to allocate the tensors required by the
69 /// model.
70 fn allocate_tensors(&mut self) -> Result<(), Error> {
71 unsafe {
72 let ret = TfLiteInterpreterAllocateTensors(self.ptr);
73 Error::from(ret)
74 }
75 }
76
77 /// Get an input tensor
78 ///
79 /// # Arguments
80 ///
81 /// * `id` - The index of the input tensor to get.
82 ///
83 /// # Type Parameters
84 ///
85 /// * `T` - The data type of the tensor. Must implement [`InnerTensorData`].
86 pub fn input_tensor<T: InnerTensorData>(&self, id: u32) -> Tensor<Input, T> {
87 unsafe {
88 let ptr = TfLiteInterpreterGetInputTensor(self.ptr, id as i32);
89
90 Tensor::<Input, T> {
91 ptr,
92 _marker: PhantomData,
93 }
94 }
95 }
96
97 /// Get an output tensor
98 ///
99 /// # Arguments
100 ///
101 /// * `id` - The index of the output tensor to get.
102 ///
103 /// # Type Parameters
104 ///
105 /// * `T` - The data type of the tensor. Must implement [`InnerTensorData`].
106 pub fn output_tensor<T: InnerTensorData>(&self, id: u32) -> Tensor<Output, T> {
107 unsafe {
108 let ptr = TfLiteInterpreterGetOutputTensor(self.ptr, id as i32);
109
110 Tensor::<Output, T> {
111 ptr: ptr as *mut _,
112 _marker: PhantomData,
113 }
114 }
115 }
116
117 /// Run inference
118 ///
119 /// This basically just processes data from the input tensors, using the model, into the ourput
120 /// tensors.
121 ///
122 /// # Errors
123 ///
124 /// Returns an error if inference fails.
125 pub fn invoke(&mut self) -> Result<(), Error> {
126 unsafe {
127 let ret = TfLiteInterpreterInvoke(self.ptr);
128
129 Error::from(ret)
130 }
131 }
132}
133impl Drop for Interpreter {
134 fn drop(&mut self) {
135 unsafe {
136 TfLiteInterpreterDelete(self.ptr);
137 }
138 }
139}