gstreamer_analytics/
tensor_meta.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use glib::translate::*;
4use gst::prelude::*;
5
6use crate::ffi;
7use crate::Tensor;
8
9#[repr(transparent)]
10#[doc(alias = "GstTensorMeta")]
11pub struct TensorMeta(ffi::GstTensorMeta);
12
13unsafe impl Send for TensorMeta {}
14unsafe impl Sync for TensorMeta {}
15
16impl TensorMeta {
17    #[doc(alias = "gst_buffer_add_tensor_meta")]
18    pub fn add(buffer: &mut gst::BufferRef) -> gst::MetaRefMut<Self, gst::meta::Standalone> {
19        skip_assert_initialized!();
20
21        unsafe {
22            let meta_ptr = ffi::gst_buffer_add_tensor_meta(buffer.as_mut_ptr());
23            Self::from_mut_ptr(buffer, meta_ptr)
24        }
25    }
26
27    #[doc(alias = "gst_tensor_meta_set")]
28    pub fn set(&mut self, tensors: glib::Slice<Tensor>) {
29        unsafe {
30            ffi::gst_tensor_meta_set(self.as_mut_ptr(), tensors.len() as u32, tensors.into_raw());
31        }
32    }
33
34    #[doc(alias = "gst_tensor_meta_get_index_from_id")]
35    pub fn index_from_id(&self, id: glib::Quark) -> i32 {
36        unsafe { ffi::gst_tensor_meta_get_index_from_id(self.as_mut_ptr(), id.into_glib()) }
37    }
38
39    pub fn as_slice(&self) -> &[Tensor] {
40        unsafe { glib::Slice::from_glib_borrow_num(self.0.tensors, self.0.num_tensors) }
41    }
42
43    pub fn as_mut_slice(&mut self) -> &mut [Tensor] {
44        unsafe { glib::Slice::from_glib_borrow_num_mut(self.0.tensors, self.0.num_tensors) }
45    }
46
47    unsafe fn as_mut_ptr(&self) -> *mut ffi::GstTensorMeta {
48        mut_override(&self.0)
49    }
50}
51
52unsafe impl MetaAPI for TensorMeta {
53    type GstType = ffi::GstTensorMeta;
54
55    #[doc(alias = "gst_tensor_meta_api_get_type")]
56    #[inline]
57    fn meta_api() -> glib::Type {
58        unsafe { from_glib(ffi::gst_tensor_meta_api_get_type()) }
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use crate::*;
65
66    #[test]
67    fn build_tensor_meta() {
68        gst::init().unwrap();
69
70        let mut buf = gst::Buffer::new();
71
72        let mut tmeta = TensorMeta::add(buf.make_mut());
73
74        let tensor = Tensor::new_simple(
75            glib::Quark::from_str("me"),
76            TensorDataType::Int16,
77            gst::Buffer::with_size(2 * 3 * 4 * 5).unwrap(),
78            TensorDimOrder::RowMajor,
79            &[3, 4, 5],
80        );
81
82        let tptr = tensor.as_ptr();
83
84        tmeta.set([tensor].into());
85
86        let tensors = tmeta.as_slice();
87
88        assert_eq!(tensors.len(), 1);
89
90        // Check that it's the same tensor
91        assert_eq!(tptr, tensors[0].as_ptr());
92        assert_eq!(tensors[0].dims_order(), TensorDimOrder::RowMajor);
93        assert_eq!(tensors[0].dims().len(), 3);
94        assert_eq!(tensors[0].dims()[0], 3);
95
96        assert_eq!(tmeta.as_slice().len(), 1);
97
98        tmeta.as_mut_slice();
99    }
100}