tfledge/model.rs
1use core::ptr::null_mut;
2
3use crate::error::error_reporter;
4use crate::ffi::*;
5use crate::Error;
6
7use alloc::string::String;
8
9/// A machine learning model which can be loaded onto a device to make inferences
10pub struct Model {
11 pub(crate) ptr: *mut TfLiteModel,
12}
13impl Model {
14 /// Load a model from a byte slice
15 ///
16 /// # Arguments
17 ///
18 /// * `bytes` - A byte slice containing the raw model data.
19 ///
20 /// # Errors
21 ///
22 /// Returns an error if the model cannot be loaded from the provided byte slice. This
23 /// usually means the data is invalid or corrupted.
24 ///
25 /// # Examples
26 ///
27 /// ```
28 /// # use tfledge::{Model, Error};
29 /// # fn main() -> Result<(), Error> {
30 /// let model_data: &[u8] = include_bytes!("model.tflite"); // Replace with your model path
31 /// let model = Model::from_bytes(model_data)?;
32 /// # Ok(())
33 /// # }
34 /// ```
35 pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
36 unsafe {
37 // Create the model, passing in the error reporter
38 let ptr = TfLiteModelCreateWithErrorReporter(
39 bytes.as_ptr() as *const _,
40 bytes.len(),
41 Some(error_reporter),
42 null_mut(),
43 );
44
45 // Check if the model creation was successful
46 if ptr.is_null() {
47 return Err(Error::FailedToLoadModel);
48 }
49
50 Ok(Self { ptr })
51 }
52 }
53 /// Load a model from a file
54 ///
55 /// # Arguments
56 ///
57 /// * `path` - A string slice representing the path to the model file.
58 ///
59 /// # Errors
60 ///
61 /// Returns an error if the model cannot be loaded from the provided file path. This could
62 /// be due to a file not being found, invalid permissions, or a corrupted model file.
63 ///
64 /// # Examples
65 ///
66 /// ```
67 /// # use tfledge::{Model, Error};
68 /// # fn main() -> Result<(), Error> {
69 /// let model = Model::from_file("model.tflite")?;
70 /// # Ok(())
71 /// # }
72 /// ```
73 pub fn from_file(path: impl Into<String>) -> Result<Self, Error> {
74 unsafe {
75 // Convert the path to a C string
76 let path = path.into();
77 let cpath = [path.as_bytes(), b"\0"].concat();
78
79 // Create the model, passing in the error reporter
80 let ptr = TfLiteModelCreateFromFileWithErrorReporter(
81 cpath.as_ptr() as *const _,
82 Some(error_reporter),
83 null_mut(),
84 );
85
86 // Check if the model creation was successful
87 if ptr.is_null() {
88 return Err(Error::FailedToLoadModel);
89 }
90
91 Ok(Self { ptr })
92 }
93 }
94}
95impl Drop for Model {
96 fn drop(&mut self) {
97 unsafe {
98 TfLiteModelDelete(self.ptr);
99 }
100 }
101}