1use core::alloc::Layout;
2use core::mem::size_of;
3
4use crate::Error;
5use alloc::alloc::alloc;
6use alloc::vec::Vec;
7use core::marker::PhantomData;
8
9use crate::ffi::*;
10
11pub(crate) trait TensorInputOrOutput {}
13
14pub struct Input;
16impl TensorInputOrOutput for Input {}
17
18pub struct Output;
20impl TensorInputOrOutput for Output {}
21
22pub struct Tensor<IO, T>
44where
45 IO: TensorInputOrOutput,
46 T: InnerTensorData,
47{
48 pub(crate) ptr: *mut TfLiteTensor,
49 pub(crate) _marker: PhantomData<(IO, T)>,
50}
51impl<IO, T> Tensor<IO, T>
52where
53 IO: TensorInputOrOutput,
54 T: InnerTensorData,
55{
56 pub fn kind(&self) -> TfLiteType {
58 unsafe { TfLiteTensorType(self.ptr) }
59 }
60 pub fn num_dims(&self) -> Option<u32> {
62 let i = unsafe { TfLiteTensorNumDims(self.ptr) };
63
64 if i == -1 {
65 return None;
66 }
67
68 Some(i as u32)
69 }
70 pub fn dim(&self, id: u32) -> i32 {
72 unsafe { TfLiteTensorDim(self.ptr, id as i32) }
73 }
74}
75impl<T: InnerTensorData> Tensor<Input, T> {
78 pub fn write(&mut self, data: &[u8]) -> Result<(), Error> {
88 unsafe {
89 let ret = TfLiteTensorCopyFromBuffer(self.ptr, data.as_ptr() as *const _, data.len());
90
91 Error::from(ret)
92 }
93 }
94}
95impl<T: InnerTensorData> Tensor<Output, T> {
96 pub fn read<const N: usize>(&self) -> Vec<[T; N]> {
103 unsafe {
104 let ct = TfLiteTensorByteSize(self.ptr) / size_of::<[T; N]>();
106
107 let ptr = TfLiteTensorData(self.ptr);
109
110 core::slice::from_raw_parts::<[T; N]>(ptr as *const _, ct).to_vec()
112 }
113 }
114}
115
116pub(crate) trait InnerTensorData: Clone + Copy + Sized {
117 const TFLITE_KIND: TfLiteType;
118}
119
120impl InnerTensorData for i8 {
121 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteInt8;
122}
123impl InnerTensorData for i16 {
124 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteInt16;
125}
126impl InnerTensorData for i32 {
127 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteInt32;
128}
129impl InnerTensorData for i64 {
130 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteInt64;
131}
132
133impl InnerTensorData for u8 {
134 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteUInt8;
135}
136impl InnerTensorData for u16 {
137 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteUInt16;
138}
139impl InnerTensorData for u32 {
140 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteUInt32;
141}
142impl InnerTensorData for u64 {
143 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteUInt64;
144}
145
146impl InnerTensorData for f32 {
147 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteFloat32;
148}
149impl InnerTensorData for f64 {
150 const TFLITE_KIND: TfLiteType = TfLiteType::kTfLiteFloat64;
151}
152
153pub struct TensorData<T: InnerTensorData> {
154 size: usize,
155 ptr: *mut T,
156}
157impl<T: InnerTensorData> TensorData<T> {
158 pub fn new(size: usize) -> Self {
164 unsafe {
165 let ptr = alloc(Layout::array::<T>(size).unwrap()) as *mut T;
166
167 Self { ptr, size }
168 }
169 }
170 pub fn as_slice(&self) -> &[T] {
172 unsafe { core::slice::from_raw_parts(self.ptr.cast_const(), self.size) }
173 }
174}