chalkydri_apriltags/
lib.rs

1#![feature(
2    portable_simd,
3    alloc_layout_extra,
4    slice_as_chunks,
5    sync_unsafe_cell,
6    array_chunks
7)]
8#![warn(clippy::infinite_loop)]
9
10#[macro_use]
11extern crate statrs;
12#[cfg(feature = "multi-thread")]
13extern crate rayon;
14
15//mod decode;
16// mod pose_estimation;
17pub mod utils;
18
19use libblur::{BlurImage, BlurImageMut};
20use nalgebra::Vector;
21use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
22// use pose_estimation::pose_estimation;
23use ril::{Line, Rgb};
24use statrs::statistics::{Data, Distribution, Max, Median, Min, OrderStatistics};
25// TODO: ideally we'd use alloc here and only pull in libstd for sync::atomic when the multi-thread feature is enabled
26use std::{
27    alloc::{alloc, alloc_zeroed, dealloc, Layout},
28    sync::{
29        atomic::{AtomicUsize, Ordering},
30        Arc, Mutex,
31    },
32    time::Instant,
33};
34
35#[cfg(feature = "multi-thread")]
36use rayon::iter::{ParallelBridge, ParallelIterator};
37
38use crate::utils::*;
39
40/// Union-Find data structure for connected components
41#[derive(Debug, Clone)]
42pub struct UnionFind {
43    parent: *mut usize,
44    cluster_sizes: *mut usize,
45    len: usize,
46}
47
48impl UnionFind {
49    pub fn new(len: usize) -> Self {
50        unsafe {
51            //let st = Instant::now();
52            let uf = UnionFind {
53                parent: alloc(Layout::array::<usize>(len).unwrap()) as *mut usize,
54                cluster_sizes: alloc(Layout::array::<usize>(len).unwrap()) as *mut usize,
55                len,
56            };
57            for i in 0..len {
58                *uf.parent.add(i) = i;
59                *uf.cluster_sizes.add(i) = 1;
60            }
61            //dbg!(st.elapsed());
62            uf
63        }
64    }
65
66    #[inline(always)]
67    pub fn find(&mut self, id: usize) -> usize {
68        unsafe {
69            let curr = self.parent.add(id);
70            if *curr != id {
71                *curr = self.find(*curr);
72            }
73            *curr
74        }
75    }
76
77    #[inline(always)]
78    pub fn union(&mut self, id1: usize, id2: usize) {
79        let root1 = self.find(id1);
80        let root2 = self.find(id2);
81
82        if root1 != root2 {
83            unsafe {
84                if self.get_size(root1) < self.get_size(root2) {
85                    *self.parent.add(root1) = root2;
86                    *self.cluster_sizes.add(root2) += self.get_size(root1);
87                } else {
88                    *self.parent.add(root2) = root1;
89                    *self.cluster_sizes.add(root1) += self.get_size(root2);
90                }
91            }
92        }
93    }
94
95    #[inline(always)]
96    pub fn get_size(&self, id: usize) -> usize {
97        unsafe { *self.cluster_sizes.add(id) }
98    }
99}
100impl Drop for UnionFind {
101    fn drop(&mut self) {
102        unsafe {
103            dealloc(
104                self.parent as *mut u8,
105                Layout::array::<usize>(self.len).unwrap(),
106            );
107            dealloc(
108                self.cluster_sizes as *mut u8,
109                Layout::array::<usize>(self.len).unwrap(),
110            );
111        }
112    }
113}
114
115#[derive(Clone, Default, Debug)]
116struct ClusterHash {
117    hash: u32,
118    id: u64,
119    data: Vec<(usize, usize)>,
120}
121
122fn u64hash_2(x: u64) -> u32 {
123    ((x >> 32) ^ x) as u32
124}
125
126/// Raw buffers used by a [`detector`](Detector)
127///
128/// We need a separate struct for this so the compiler will treat them as thread-safe.
129/// Interacting with raw buffers is typically lower overhead, but unsafe.
130struct DetectorBufs {
131    /// The thresholded image buffer
132    buf: *mut Color,
133    /// Detected corners
134    points: *mut (usize, usize),
135}
136unsafe impl Send for DetectorBufs {}
137unsafe impl Sync for DetectorBufs {}
138
139/// An AprilTag detector
140///
141/// This is the main entrypoint.
142pub struct Detector {
143    /// Raw buffers used by the detector
144    bufs: DetectorBufs,
145    valid_tags: &'static [usize],
146    points_len: AtomicUsize,
147    /// Checked edges (x1, y1, x2, y2)
148    lines: Vec<(usize, usize, usize, usize)>,
149    /// Width of input frames
150    width: usize,
151    /// Height of input frames
152    height: usize,
153}
154impl Detector {
155    /// Initialize a new detector for the specified dimensions
156    ///
157    /// `valid_tags` is required for optimization and error resistance.
158    pub fn new(
159        width: usize,
160        height: usize,
161        valid_tags: &'static [usize],
162        //intrinsics: IntrinsicParametersPerspective<f32>,
163    ) -> Self {
164        unsafe {
165            // Allocate raw buffers
166            let buf: *mut Color =
167                alloc_zeroed(Layout::array::<Color>(width * height).unwrap()).cast();
168            let points: *mut (usize, usize) =
169                alloc_zeroed(Layout::array::<(usize, usize)>(width * height).unwrap()).cast();
170            let points_len = AtomicUsize::new(0);
171
172            Self {
173                bufs: DetectorBufs { buf, points },
174                valid_tags,
175                points_len,
176                lines: Vec::new(),
177                width,
178                height,
179            }
180        }
181    }
182
183    /// Calculate otsu value
184    ///
185    /// [Otsu's method](https://en.wikipedia.org/wiki/Otsu%27s_method) is an adaptive thresholding
186    /// algorithm. In English: it turns a grayscale image into binary (foreground/background,
187    /// black/white).
188    ///
189    /// We should investigate combining the variations for unbalanced images and triclass
190    /// thresholding.
191    pub fn calc_otsu(&mut self, input: &mut [u8]) {
192        //libblur::fast_gaussian(&mut BlurImageMut::borrow(input, self.width as u32, self.height as u32, libblur::FastBlurChannels::Channels3), 7, libblur::ThreadingPolicy::Adaptive, libblur::EdgeMode::Clamp).unwrap();
193
194        // Calculate histogram
195        for y in 0..self.height {
196            for x in 0..self.width {
197                let mut pixels = Vec::new();
198
199                const BLOCK_SIZE: usize = 5;
200                const C: u8 = 4;
201
202                unsafe {
203                    let x_min = x.saturating_sub(BLOCK_SIZE.saturating_div(2));
204                    let x_max = x
205                        .unchecked_add(BLOCK_SIZE.saturating_div(2))
206                        .min(self.width - 1);
207                    let y_min = y.saturating_sub(BLOCK_SIZE.saturating_div(2));
208                    let y_max = y
209                        .saturating_add(BLOCK_SIZE.saturating_div(2))
210                        .min(self.height - 1);
211
212                    //let mut total = 0u16;
213                    for x in x_min..=x_max {
214                        for y in y_min..=y_max {
215                            // Red, green, and blue are each represent with 1 byte
216                            let i = px(x, y, self.width);
217                            let gray = grayscale(input.get_unchecked((i * 3)..(i * 3) + 3));
218
219                            pixels.push(gray as f64);
220                            //total = total.unchecked_add(gray as u16);
221                        }
222                    }
223                    //let total_pixels = y_max.unchecked_sub(y_min) * x_max.saturating_sub(x_min);
224                    //let mean = (total / (total_pixels as u16)) as u8;
225                    let mut data = Data::new(pixels);
226                    //let mean = data.mean().unwrap();
227                    let i = px(x, y, self.width);
228                    let p = grayscale(input.get_unchecked((i * 3)..(i * 3) + 3));
229
230                    //if p > mean + C {
231                    //    *self.bufs.buf.add(i) = Color::White;
232                    //} else if p < mean - C {
233                    //    *self.bufs.buf.add(i) = Color::Black;
234                    //} else {
235                    //    *self.bufs.buf.add(i) = Color::Other;
236                    //}
237                    if (y > 0 && x > 0) && (data.max() - data.min()) < 5.0 {
238                        let gray = data.median();
239                        if gray < 60.0 {
240                            *self.bufs.buf.add(i) = Color::Black;
241                        } else if gray > 160.0 {
242                            *self.bufs.buf.add(i) = Color::White;
243                        } else {
244                            *self.bufs.buf.add(i) = Color::Other;
245                        }
246                        //*self.bufs.buf.add(i) = *self.bufs.buf.add(px(x - 1, y - 1, self.width));
247                    } else {
248                        if p >= data.upper_quartile() as u8 {
249                            *self.bufs.buf.add(i) = Color::White;
250                        } else if p <= data.lower_quartile() as u8 {
251                            *self.bufs.buf.add(i) = Color::Black;
252                        } else {
253                            *self.bufs.buf.add(i) = Color::Other;
254                        }
255                    }
256                }
257            }
258        }
259    }
260
261    /// Process an RGB frame
262    ///
263    /// FAST needs a 3x3 circle around each pixel, so we only process pixels within a 3x3 pixel
264    /// padding.
265    pub fn process_frame(&mut self, input: &[u8]) {
266        // Check that the input is RGB
267        assert_eq!(input.len(), self.width * self.height * 3);
268
269        let mut copy = input.to_vec();
270        //let adaptive_thresh = Instant::now();
271        self.calc_otsu(&mut copy);
272        //dbg!(adaptive_thresh.elapsed());
273
274        //unsafe {
275        //    self.thresh(input);
276        //}
277        // Reset points_len to 0
278        self.points_len.store(0, Ordering::SeqCst);
279        // Clear the lines Vec
280        self.lines.clear();
281
282        self.detect_corners();
283
284        self.check_edges();
285
286        //pose_estimation(intrinsics);
287    }
288
289    /// Run corner detection
290    #[inline(always)]
291    pub fn detect_corners(&mut self) {
292        #[cfg(not(feature = "multi-thread"))]
293        for x in 3..=self.width - 3 {
294            for y in 3..=self.height - 3 {
295                unsafe {
296                    self.process_pixel(x, y);
297                }
298            }
299        }
300
301        #[cfg(feature = "multi-thread")]
302        (3..=self.width - 3).par_bridge().for_each(|x| {
303            for y in 3..=self.height - 3 {
304                unsafe {
305                    self.process_pixel(x, y);
306                }
307            }
308        });
309    }
310
311    /// Threshold an input RGB buffer
312    ///
313    /// TODO: This needs to use [Self::calc_otsu].
314    ///
315    /// # Safety
316    /// `input` is treated as an RGB buffer, even if it isn't.
317    /// The caller should check that `input` is an RGB buffer.
318    #[inline(always)]
319    pub unsafe fn thresh(&self, input: &[u8]) {
320        // This is mainly memory-bound, so multi-threading probably isn't worth it.
321        for i in 0..self.width * self.height {
322            // Red, green, and blue are each represent with 1 byte
323            let gray = grayscale(input.get_unchecked((i * 3)..(i * 3) + 3));
324
325            // 60 is a "kinda works" value because I haven't implemented the algorithm
326            if gray < 60 {
327                *self.bufs.buf.add(i) = Color::Black;
328            } else if gray > 160 {
329                *self.bufs.buf.add(i) = Color::White;
330            } else {
331                *self.bufs.buf.add(i) = Color::Other;
332            }
333        }
334    }
335
336    /// Process a pixel
337    ///
338    /// This should have as little overhead as possible, as it must be run hundreds of thousands of
339    /// times for each frame.
340    ///
341    /// # Safety
342    /// (`x`, `y`) is assumed to be a valid pixel coord.
343    /// The caller must make sure of this.
344    #[inline(always)]
345    unsafe fn process_pixel(&self, x: usize, y: usize) {
346        // Pull out frame width and frame buffer for cleaner looking code
347        // TODO: is this optimized down into a noop?
348        let width = self.width;
349        let buf = self.bufs.buf;
350
351        // Get binary value of pixel at (x,y)
352        let p = *buf.add(px(x, y, width));
353
354        if p.is_black() {
355            // Get pixels that are diagonal neighbors of p
356            let (up_left, up_right, down_left, down_right) = (
357                *buf.add(px(x - 1, y - 1, width)),
358                *buf.add(px(x + 1, y - 1, width)),
359                *buf.add(px(x - 1, y + 1, width)),
360                *buf.add(px(x + 1, y + 1, width)),
361            );
362
363            // Only one can be black
364            // The carrot is Rust's exclusive or (XOR) operation
365            let clean = up_left.is_black()
366                ^ up_right.is_black()
367                ^ down_left.is_black()
368                ^ down_right.is_black();
369
370            if clean {
371                // Furthest top right
372                let p3 = *buf.add(px(x + 3, y - 3, width));
373                // Furthest bottom right
374                let p7 = *buf.add(px(x + 3, y + 3, width));
375                // Furthest bottom left
376                let p11 = *buf.add(px(x - 3, y + 3, width));
377                // Furthest top left
378                let p15 = *buf.add(px(x - 3, y - 3, width));
379
380                if (p3.is_good() && p7.is_good() && p11.is_good() && p15.is_good())
381                    && (p3.is_black() ^ p7.is_black() ^ p11.is_black() ^ p15.is_black())
382                {
383                    // Furthest top center
384                    let p1 = *buf.add(px(x, y - 3, width));
385                    // Furthest middle right
386                    let p5 = *buf.add(px(x + 3, y, width));
387                    // Furthest bottom center
388                    let p9 = *buf.add(px(x, y + 3, width));
389                    // Furthest middle left
390                    let p13 = *buf.add(px(x - 3, y, width));
391
392                    // Add p to the corner buffer
393                    *self
394                        .bufs
395                        .points
396                        .add(self.points_len.fetch_add(1, Ordering::SeqCst)) = (x, y);
397                }
398            }
399        }
400    }
401
402    /// Check a single edge (imaginary line between two corners)
403    ///
404    /// See [Self::check_edges].
405    ///
406    /// # Safety
407    /// (`x1`, `y1`) and (`x2`, `y2`) are assumed to be a valid pixel coords.
408    /// The caller must make sure of this.
409    unsafe fn check_edge(&mut self, x1: usize, y1: usize, x2: usize, y2: usize) {
410        // idk how to describe this one
411        const CHECK_OFFSET: usize = 5;
412        let width = self.width;
413        let buf = self.bufs.buf;
414
415        // calculate & store midpoint
416        let midpoint_x = (x1 + x2) / 2;
417        let midpoint_y = (y1 + y2) / 2;
418
419        // Figure out if edge is closer to horizontal/vertical
420        let (xdiff, ydiff) = (x1.max(x2) - x1.min(x2), y1.max(y2) - y1.min(y2));
421        let is_vertical_line = x1 == x2 || xdiff < ydiff;
422        let is_horizontal_line = y1 == y2 || ydiff < xdiff;
423
424        // Calculate and store the coords for the midway points
425        let (mw1x, mw1y) = ((midpoint_x + x1) / 2, (midpoint_y + y1) / 2);
426        let (mw2x, mw2y) = ((midpoint_x + x2) / 2, (midpoint_y + y2) / 2);
427
428        if is_vertical_line {
429            // edge is closer to a vertical line instead of a diagonal
430            let mw1right = *buf.add(px(mw1x + CHECK_OFFSET, mw1y, width));
431            let mw2right = *buf.add(px(mw2x + CHECK_OFFSET, mw2y, width));
432
433            let mw1left = *buf.add(px(mw1x - CHECK_OFFSET, mw1y, width));
434            let mw2left = *buf.add(px(mw2x - CHECK_OFFSET, mw2y, width));
435
436            // Check that all of the checking points are valid
437            if mw1left.is_good() && mw2left.is_good() && mw1right.is_good() && mw2right.is_good() {
438                // Check that only one side of the edge is black (the other should be white)
439                if (mw1left.is_black() ^ mw2right.is_black())
440                    && (mw2left.is_black() ^ mw1right.is_black())
441                    && (mw1left == mw2left)
442                {
443                    // midway one has black pixels on both sides
444                    self.lines.push((x1, y1, x2, y2));
445                }
446            }
447        }
448
449        if is_horizontal_line {
450            // edge is closer to a horizontal line instead of a diagonal
451
452            // XXX: Checking midway 1 then midway 2 *might* have marginally better performance,
453            // but likely not worth it for the more complex code.
454
455            // create the point to the right of the two midways
456            let mw1top = *buf.add(px(mw1x, mw1y - CHECK_OFFSET, width));
457            let mw2top = *buf.add(px(mw2x, mw2y - CHECK_OFFSET, width));
458
459            // create the point ot the left of the two midways
460            let mw1bottom = *buf.add(px(mw1x, mw1y + CHECK_OFFSET, width));
461            let mw2bottom = *buf.add(px(mw2x, mw2y + CHECK_OFFSET, width));
462
463            // check if the midways are black pixels,
464            // and if the pixels to the right and left of these midways are black pixels as well.
465
466            if mw1top.is_good() && mw2top.is_good() && mw1bottom.is_good() && mw2bottom.is_good() {
467                if (mw1top.is_black() ^ mw2bottom.is_black())
468                    && (mw2top.is_black() ^ mw1bottom.is_black())
469                    && (mw1top == mw2top)
470                {
471                    // midway one has black pixels on both sides
472                    self.lines.push((x1, y1, x2, y2));
473                }
474            }
475        }
476    }
477
478    /// Perform edge checking on all detected corners
479    #[inline(always)]
480    pub fn check_edges(&mut self) {
481        // Turn the raw buffer into a Rust slice
482        let points = unsafe {
483            core::slice::from_raw_parts(
484                self.bufs.points as *const _,
485                self.points_len.load(Ordering::SeqCst),
486            )
487        };
488
489        // Iterate over every detected corner
490        // TODO: this might benefit from multi-threading
491        for &(x1, y1) in points.iter() {
492            // Iterate over every detected corner in reverse, checking for edges
493            for &(x2, y2) in points.iter().rev() {
494                unsafe {
495                    self.check_edge(x1, y1, x2, y2);
496                }
497            }
498        }
499    }
500
501    pub fn connected_components(&self) -> UnionFind {
502        let width = self.width;
503        let buf = self.bufs.buf;
504
505        let mut uf = UnionFind::new(self.width * self.height);
506        for y in 0..self.height {
507            for x in 1..width - 1 {
508                unsafe {
509                    let i = px(x, y, self.width);
510
511                    let p = *buf.add(i);
512                    if !p.is_good() {
513                        continue;
514                    }
515
516                    if x > 0 {
517                        let left_i = px(x - 1, y, width);
518                        if *buf.add(left_i) == p {
519                            uf.union(i, left_i);
520                        }
521                    }
522
523                    if y > 0 {
524                        let top_i = px(x, y - 1, width);
525                        if *buf.add(top_i) == p {
526                            uf.union(i, top_i);
527                        }
528
529                        if p.is_white() {
530                            if x > 0 {
531                                let top_left_i = px(x - 1, y - 1, width);
532                                if *buf.add(top_left_i) == p {
533                                    uf.union(i, top_left_i);
534                                }
535                            }
536                            if x < width - 1 {
537                                let top_right_i = px(x + 1, y - 1, width);
538                                if *buf.add(top_right_i) == p {
539                                    uf.union(i, top_right_i);
540                                }
541                            }
542                        }
543                    }
544                }
545            }
546        }
547
548        uf
549    }
550
551    //pub fn do_cluster(&self, cluster_map_size: usize) {
552    //    //
553    //}
554
555    //pub fn cluster(&self, threads: usize) -> Vec<Vec<(usize, usize)>> {
556    //    let mut uf = self.connected_components();
557    //    let cluster_map_size = (0.2 * self.width as f64 * self.height as f64) as usize;
558
559    //    if self.height <= 2 {
560    //        let mut uf = uf.clone();
561    //        return self.do_cluster(cluster_map_size);
562    //    }
563
564    //    let sz = self.height - 1;
565    //    let pool = rayon::ThreadPoolBuilder::new().num_threads(threads).build().unwrap();
566    //    let chunk_sz = (1 + sz / (8 * pool.current_num_threads())) as usize;
567
568    //    let clusters = Arc::new(Mutex::new(Vec::new()));
569    //
570    //    (1..sz).into_par_iter().chunks(chunk_sz).for_each(|i| {
571    //        let local_cluster_map_sz = cluster_map_size / (sz / chunk_sz + 1);
572    //        let local_clusters = self.do_cluster(cluster_map_size);
573    //        let mut clusters = clusters.lock().unwrap();
574    //        clusters.push(local_clusters);
575    //    });
576
577    //    // Merge results from all threads
578    //let mut clusters_list = clusters.lock().unwrap().clone();
579    //if clusters_list.is_empty() {
580    //    return Vec::new();
581    //}
582
583    //let mut length = clusters_list.len();
584    //// Combine clusters in a tree-like manner for efficiency
585    //while length > 1 {
586    //    let mut write = 0;
587    //    for i in (0..length - 1).step_by(2) {
588    //        clusters_list[write] = self.merge_clusters(
589    //            clusters_list[i].clone(),
590    //            clusters_list[i + 1].clone()
591    //        );
592    //        write += 1;
593    //    }
594
595    //    if length % 2 != 0 {
596    //        clusters_list[write] = clusters_list[length - 1].clone();
597    //        write += 1;
598    //    }
599
600    //    length = write;
601    //}
602
603    //// Safety check to prevent out-of-bounds access
604    //if clusters_list.is_empty() {
605    //    return Vec::new();
606    //}
607
608    //// Convert cluster hashes to vector of points
609    //clusters_list.remove(0)
610    //    .into_iter()
611    //    .map(|cluster_hash| cluster_hash.data)
612    //    .collect()
613    //}
614
615    pub fn draw(&self) {
616        let mut img = ril::Image::new(self.width as u32, self.height as u32, Rgb::black());
617        let mut conn_comp = self.connected_components();
618        for x in 0..self.width {
619            for y in 0..self.height {
620                img.set_pixel(
621                    x as u32,
622                    y as u32,
623                    match unsafe { *self.bufs.buf.add(px(x, y, self.width)) } {
624                        Color::Black => Rgb::black(),
625                        Color::White => Rgb::white(),
626                        Color::Other => Rgb::from_hex("777777").unwrap(),
627                    },
628                );
629            }
630        }
631        for (x1, y1, x2, y2) in self.lines.clone() {
632            //img.draw(
633            //    &ril::draw::Ellipse::circle(x1 as u32, y1 as u32, 2)
634            //        .with_fill(Rgb::from_hex("ffa500").unwrap()),
635            //);
636            //img.draw(
637            //    &ril::draw::Ellipse::circle(x2 as u32, y2 as u32, 2)
638            //        .with_fill(Rgb::from_hex("ff0000").unwrap()),
639            //);
640            //img.draw(&Line::new(
641            //    (x1 as u32, y1 as u32),
642            //    (x2 as u32, y2 as u32),
643            //    Rgb::from_hex("ff0000").unwrap(),
644            //));
645            if conn_comp.find(unsafe { px(x1, y1, self.width) })
646                == conn_comp.find(unsafe { px(x2, y2, self.width) })
647            {
648                img.draw(&Line::new(
649                    (x1 as u32, y1 as u32),
650                    (x2 as u32, y2 as u32),
651                    Rgb::from_hex("00ff00").unwrap(),
652                ));
653            }
654            //            let (parent_x, parent_y) = (parent_id % self.width, parent_id / self.width);
655            //            img.draw(
656            //                &ril::draw::Ellipse::circle(parent_x as u32, parent_y as u32, 2)
657            //                    .with_fill(Rgb::from_hex("0000ff").unwrap()),
658            //            );
659        }
660        img.save(ril::ImageFormat::Png, "lines.png").unwrap();
661    }
662}
663impl Clone for Detector {
664    fn clone(&self) -> Self {
665        Self::new(self.width, self.height, &[])
666    }
667}
668impl Drop for Detector {
669    fn drop(&mut self) {
670        unsafe {
671            dealloc(
672                self.bufs.buf as *mut _,
673                Layout::array::<bool>(self.width * self.height).unwrap(),
674            );
675            dealloc(
676                self.bufs.points as *mut _,
677                Layout::array::<(usize, usize)>(self.width * self.height).unwrap(),
678            );
679        }
680    }
681}