gstreamer_analytics/
tensor.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use crate::ffi;
4use crate::*;
5use glib::translate::*;
6
7glib::wrapper! {
8    /// Hold tensor data
9    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
10    #[doc(alias = "GstTensor")]
11    pub struct Tensor(Boxed<ffi::GstTensor>);
12
13    match fn {
14        copy => |ptr| ffi::gst_tensor_copy(ptr),
15        free => |ptr| ffi::gst_tensor_free(ptr),
16        type_ => || ffi::gst_tensor_get_type(),
17    }
18}
19
20unsafe impl Send for Tensor {}
21unsafe impl Sync for Tensor {}
22
23impl Tensor {
24    /// Allocates a new [`Tensor`][crate::Tensor] of `dims_order` ROW_MAJOR or COLUMN_MAJOR and
25    /// with an interleaved layout
26    /// ## `id`
27    /// semantically identify the contents of the tensor
28    /// ## `data_type`
29    /// [`TensorDataType`][crate::TensorDataType] of tensor data
30    /// ## `data`
31    /// [`gst::Buffer`][crate::gst::Buffer] holding tensor data
32    /// ## `dims_order`
33    /// Indicate tensor dimension indexing order
34    /// ## `dims`
35    /// tensor dimensions. Value of 0 mean the
36    /// dimension is dynamic.
37    ///
38    /// # Returns
39    ///
40    /// A newly allocated [`Tensor`][crate::Tensor]
41    #[doc(alias = "gst_tensor_new_simple")]
42    pub fn new_simple(
43        id: glib::Quark,
44        data_type: TensorDataType,
45        data: gst::Buffer,
46        dims_order: TensorDimOrder,
47        dims: &[usize],
48    ) -> Tensor {
49        skip_assert_initialized!();
50        unsafe {
51            from_glib_full(ffi::gst_tensor_new_simple(
52                id.into_glib(),
53                data_type.into_glib(),
54                data.into_glib_ptr(),
55                dims_order.into_glib(),
56                dims.len(),
57                dims.as_ptr() as *mut _,
58            ))
59        }
60    }
61
62    /// Gets the dimensions of the tensor.
63    ///
64    /// # Returns
65    ///
66    /// The dims array form the tensor
67    #[doc(alias = "gst_tensor_get_dims")]
68    #[doc(alias = "get_dims")]
69    pub fn dims(&self) -> &[usize] {
70        let mut num_dims: usize = 0;
71        unsafe {
72            let dims = ffi::gst_tensor_get_dims(self.as_ptr(), &mut num_dims);
73            std::slice::from_raw_parts(dims as *const _, num_dims)
74        }
75    }
76
77    #[inline]
78    pub fn id(&self) -> glib::Quark {
79        unsafe { from_glib(self.inner.id) }
80    }
81
82    #[inline]
83    pub fn data_type(&self) -> TensorDataType {
84        unsafe { from_glib(self.inner.data_type) }
85    }
86
87    #[inline]
88    pub fn data(&self) -> &gst::BufferRef {
89        unsafe { gst::BufferRef::from_ptr(self.inner.data) }
90    }
91
92    #[inline]
93    pub fn data_mut(&mut self) -> &mut gst::BufferRef {
94        unsafe {
95            self.inner.data = gst::ffi::gst_mini_object_make_writable(self.inner.data as _) as _;
96            gst::BufferRef::from_mut_ptr(self.inner.data)
97        }
98    }
99
100    #[inline]
101    pub fn dims_order(&self) -> TensorDimOrder {
102        unsafe { from_glib(self.inner.dims_order) }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use crate::*;
109
110    #[test]
111    fn create_tensor() {
112        gst::init().unwrap();
113
114        let buf = gst::Buffer::with_size(2 * 3 * 4 * 5).unwrap();
115        assert_eq!(buf.size(), 2 * 3 * 4 * 5);
116
117        let mut tensor = Tensor::new_simple(
118            glib::Quark::from_str("me"),
119            TensorDataType::Int16,
120            buf,
121            TensorDimOrder::RowMajor,
122            &[3, 4, 5],
123        );
124
125        assert_eq!(tensor.id(), glib::Quark::from_str("me"));
126        assert_eq!(tensor.data_type(), TensorDataType::Int16);
127        assert_eq!(tensor.dims_order(), TensorDimOrder::RowMajor);
128        assert_eq!(tensor.dims()[0], 3);
129        assert_eq!(tensor.dims()[1], 4);
130        assert_eq!(tensor.dims()[2], 5);
131        assert_eq!(tensor.data().size(), 2 * 3 * 4 * 5);
132
133        tensor.data();
134        tensor.data_mut();
135    }
136}