tfledge/
model.rs

1use core::ptr::null_mut;
2
3use crate::error::error_reporter;
4use crate::ffi::*;
5use crate::Error;
6
7use alloc::string::String;
8
9/// A machine learning model which can be loaded onto a device to make inferences
10pub struct Model {
11    pub(crate) ptr: *mut TfLiteModel,
12}
13impl Model {
14    /// Load a model from a byte slice
15    ///
16    /// # Arguments
17    ///
18    /// * `bytes` - A byte slice containing the raw model data.
19    ///
20    /// # Errors
21    ///
22    /// Returns an error if the model cannot be loaded from the provided byte slice. This
23    /// usually means the data is invalid or corrupted.
24    ///
25    /// # Examples
26    ///
27    /// ```
28    /// # use tfledge::{Model, Error};
29    /// # fn main() -> Result<(), Error> {
30    /// let model_data: &[u8] = include_bytes!("model.tflite"); // Replace with your model path
31    /// let model = Model::from_bytes(model_data)?;
32    /// # Ok(())
33    /// # }
34    /// ```
35    pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
36        unsafe {
37            // Create the model, passing in the error reporter
38            let ptr = TfLiteModelCreateWithErrorReporter(
39                bytes.as_ptr() as *const _,
40                bytes.len(),
41                Some(error_reporter),
42                null_mut(),
43            );
44
45            // Check if the model creation was successful
46            if ptr.is_null() {
47                return Err(Error::FailedToLoadModel);
48            }
49
50            Ok(Self { ptr })
51        }
52    }
53    /// Load a model from a file
54    ///
55    /// # Arguments
56    ///
57    /// * `path` - A string slice representing the path to the model file.
58    ///
59    /// # Errors
60    ///
61    /// Returns an error if the model cannot be loaded from the provided file path. This could
62    /// be due to a file not being found, invalid permissions, or a corrupted model file.
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// # use tfledge::{Model, Error};
68    /// # fn main() -> Result<(), Error> {
69    /// let model = Model::from_file("model.tflite")?;
70    /// # Ok(())
71    /// # }
72    /// ```
73    pub fn from_file(path: impl Into<String>) -> Result<Self, Error> {
74        unsafe {
75            // Convert the path to a C string
76            let path = path.into();
77            let cpath = [path.as_bytes(), b"\0"].concat();
78
79            // Create the model, passing in the error reporter
80            let ptr = TfLiteModelCreateFromFileWithErrorReporter(
81                cpath.as_ptr() as *const _,
82                Some(error_reporter),
83                null_mut(),
84            );
85
86            // Check if the model creation was successful
87            if ptr.is_null() {
88                return Err(Error::FailedToLoadModel);
89            }
90
91            Ok(Self { ptr })
92        }
93    }
94}
95impl Drop for Model {
96    fn drop(&mut self) {
97        unsafe {
98            TfLiteModelDelete(self.ptr);
99        }
100    }
101}