Skip to main content

gstreamer_analytics/
group.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use glib::translate::*;
4use std::marker::PhantomData;
5
6use crate::{AnalyticsKeypointDimensions, AnalyticsKeypointPosition, ffi, relation_meta::*};
7
8#[derive(Debug)]
9pub enum AnalyticsGroupMtd {}
10
11mod sealed {
12    pub trait Sealed {}
13    impl<T: super::AnalyticsRelationMetaGroupExt> Sealed for T {}
14}
15
16pub trait AnalyticsRelationMetaGroupExt: sealed::Sealed {
17    fn add_group_mtd(
18        &mut self,
19        pre_alloc_size: usize,
20    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError>;
21
22    fn add_group_mtd_with_size(
23        &mut self,
24        group_size: usize,
25    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError>;
26
27    fn add_keypoints_group(
28        &mut self,
29        semantic_tag: &str,
30        dimension: AnalyticsKeypointDimensions,
31        positions: &[i32],
32        confidences: Option<&[f32]>,
33        visibilities: Option<&[u8]>,
34        skeleton_pairs: &[i32],
35    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError>;
36
37    fn add_keypoints_group_from_positions(
38        &mut self,
39        semantic_tag: &str,
40        positions: &[AnalyticsKeypointPosition],
41        confidences: Option<&[f32]>,
42        visibilities: Option<&[u8]>,
43        skeleton_pairs: &[i32],
44    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError> {
45        if positions.is_empty() {
46            return Err(glib::bool_error!("No keypoint positions provided"));
47        }
48
49        let dimension = positions[0].dimension;
50
51        if positions
52            .iter()
53            .any(|position| position.dimension != dimension)
54        {
55            return Err(glib::bool_error!(
56                "All keypoint positions must use the same dimension"
57            ));
58        }
59
60        let coords_per_keypoint = match dimension {
61            AnalyticsKeypointDimensions::_2d => 2,
62            AnalyticsKeypointDimensions::_3d => 3,
63            _ => {
64                return Err(glib::bool_error!(
65                    "Unsupported keypoint dimension for positions"
66                ));
67            }
68        };
69
70        let mut flattened_positions = Vec::with_capacity(positions.len() * coords_per_keypoint);
71        for position in positions {
72            flattened_positions.push(position.x);
73            flattened_positions.push(position.y);
74            if coords_per_keypoint == 3 {
75                flattened_positions.push(position.z);
76            }
77        }
78
79        self.add_keypoints_group(
80            semantic_tag,
81            dimension,
82            &flattened_positions,
83            confidences,
84            visibilities,
85            skeleton_pairs,
86        )
87    }
88}
89
90impl AnalyticsRelationMetaGroupExt
91    for gst::MetaRefMut<'_, AnalyticsRelationMeta, gst::meta::Standalone>
92{
93    #[doc(alias = "gst_analytics_relation_meta_add_group_mtd")]
94    fn add_group_mtd(
95        &mut self,
96        pre_alloc_size: usize,
97    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError> {
98        unsafe {
99            let mut mtd = std::mem::MaybeUninit::uninit();
100            let ret = from_glib(ffi::gst_analytics_relation_meta_add_group_mtd(
101                self.as_mut_ptr(),
102                pre_alloc_size,
103                mtd.as_mut_ptr(),
104            ));
105            let id = mtd.assume_init().id;
106
107            if ret {
108                Ok(AnalyticsMtdRef::from_meta(self.as_ref(), id))
109            } else {
110                Err(glib::bool_error!("Couldn't add group metadata"))
111            }
112        }
113    }
114
115    #[doc(alias = "gst_analytics_relation_meta_add_group_mtd_with_size")]
116    fn add_group_mtd_with_size(
117        &mut self,
118        group_size: usize,
119    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError> {
120        unsafe {
121            let mut mtd = std::mem::MaybeUninit::uninit();
122            let ret = from_glib(ffi::gst_analytics_relation_meta_add_group_mtd_with_size(
123                self.as_mut_ptr(),
124                group_size,
125                mtd.as_mut_ptr(),
126            ));
127            let id = mtd.assume_init().id;
128
129            if ret {
130                Ok(AnalyticsMtdRef::from_meta(self.as_ref(), id))
131            } else {
132                Err(glib::bool_error!("Couldn't add group metadata"))
133            }
134        }
135    }
136
137    #[doc(alias = "gst_analytics_relation_meta_add_keypoints_group")]
138    fn add_keypoints_group(
139        &mut self,
140        semantic_tag: &str,
141        dimension: AnalyticsKeypointDimensions,
142        positions: &[i32],
143        confidences: Option<&[f32]>,
144        visibilities: Option<&[u8]>,
145        skeleton_pairs: &[i32],
146    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError> {
147        let coords_per_keypoint = match dimension {
148            AnalyticsKeypointDimensions::_2d => 2,
149            AnalyticsKeypointDimensions::_3d => 3,
150            _ => {
151                return Err(glib::bool_error!(
152                    "Unsupported keypoint dimension for positions"
153                ));
154            }
155        };
156
157        if positions.is_empty() {
158            return Err(glib::bool_error!("No keypoint positions provided"));
159        }
160
161        if !positions.len().is_multiple_of(coords_per_keypoint) {
162            return Err(glib::bool_error!(
163                "Positions length must match the keypoint dimension"
164            ));
165        }
166
167        let keypoint_count = positions.len() / coords_per_keypoint;
168
169        if let Some(confidences) = confidences
170            && confidences.len() != keypoint_count
171        {
172            return Err(glib::bool_error!(
173                "Confidences length must match keypoint count"
174            ));
175        }
176
177        if let Some(visibilities) = visibilities
178            && visibilities.len() != keypoint_count
179        {
180            return Err(glib::bool_error!(
181                "Visibilities length must match keypoint count"
182            ));
183        }
184
185        unsafe {
186            let mut mtd = std::mem::MaybeUninit::uninit();
187            let ret = from_glib(ffi::gst_analytics_relation_meta_add_keypoints_group(
188                self.as_mut_ptr(),
189                semantic_tag.to_glib_none().0,
190                dimension.into_glib(),
191                positions.len(),
192                positions.as_ptr(),
193                keypoint_count,
194                confidences.map_or(std::ptr::null(), |confidences| confidences.as_ptr()),
195                visibilities.map_or(std::ptr::null(), |visibilities| visibilities.as_ptr()),
196                skeleton_pairs.len(),
197                skeleton_pairs.as_ptr(),
198                mtd.as_mut_ptr(),
199            ));
200            let id = mtd.assume_init().id;
201
202            if ret {
203                Ok(AnalyticsMtdRef::from_meta(self.as_ref(), id))
204            } else {
205                Err(glib::bool_error!("Couldn't add keypoints group metadata"))
206            }
207        }
208    }
209}
210
211impl AnalyticsMtdRef<'_, AnalyticsGroupMtd> {
212    #[doc(alias = "gst_analytics_group_mtd_has_semantic_tag")]
213    pub fn has_semantic_tag(&self, tag: &str) -> Result<bool, glib::BoolError> {
214        unsafe {
215            let mtd = ffi::GstAnalyticsMtd::unsafe_from(self);
216            Ok(from_glib(ffi::gst_analytics_group_mtd_has_semantic_tag(
217                &mtd as *const _ as *const ffi::GstAnalyticsGroupMtd,
218                tag.to_glib_none().0,
219            )))
220        }
221    }
222
223    #[doc(alias = "gst_analytics_group_mtd_semantic_tag_has_prefix")]
224    pub fn semantic_tag_has_prefix(&self, prefix: &str) -> Result<bool, glib::BoolError> {
225        unsafe {
226            let mtd = ffi::GstAnalyticsMtd::unsafe_from(self);
227            Ok(from_glib(
228                ffi::gst_analytics_group_mtd_semantic_tag_has_prefix(
229                    &mtd as *const _ as *const ffi::GstAnalyticsGroupMtd,
230                    prefix.to_glib_none().0,
231                ),
232            ))
233        }
234    }
235
236    #[doc(alias = "gst_analytics_group_mtd_get_member_count")]
237    pub fn member_count(&self) -> usize {
238        unsafe {
239            let mtd = ffi::GstAnalyticsMtd::unsafe_from(self);
240            ffi::gst_analytics_group_mtd_get_member_count(
241                &mtd as *const _ as *const ffi::GstAnalyticsGroupMtd,
242            ) as usize
243        }
244    }
245
246    #[doc(alias = "gst_analytics_group_mtd_get_member")]
247    pub fn member(&self, index: usize) -> Option<AnalyticsMtdRef<'_, AnalyticsAnyMtd>> {
248        if index >= self.member_count() {
249            return None;
250        }
251
252        unsafe {
253            let mtd = ffi::GstAnalyticsMtd::unsafe_from(self);
254            let mut member = std::mem::MaybeUninit::uninit();
255            let ret = from_glib(ffi::gst_analytics_group_mtd_get_member(
256                &mtd as *const _ as *const ffi::GstAnalyticsGroupMtd,
257                index,
258                member.as_mut_ptr(),
259            ));
260
261            if ret {
262                let member = member.assume_init();
263                let id = ffi::gst_analytics_mtd_get_id(&member);
264                Some(AnalyticsMtdRef::from_meta(self.meta_ref(), id))
265            } else {
266                None
267            }
268        }
269    }
270
271    pub fn member_typed<T: AnalyticsMtd>(&self, index: usize) -> Option<AnalyticsMtdRef<'_, T>> {
272        self.member(index)
273            .and_then(|member| member.downcast::<T>().ok())
274    }
275
276    #[doc(alias = "gst_analytics_group_mtd_iterate")]
277    pub fn iter<T: AnalyticsMtd>(&self) -> AnalyticsGroupMtdIter<'_, T> {
278        AnalyticsGroupMtdIter::new(self)
279    }
280}
281
282#[must_use = "iterators are lazy and do nothing unless consumed"]
283pub struct AnalyticsGroupMtdIter<'a, T: AnalyticsMtd> {
284    group: &'a AnalyticsMtdRef<'a, AnalyticsGroupMtd>,
285    state: glib::ffi::gpointer,
286    phantom: PhantomData<T>,
287}
288
289impl<'a, T: AnalyticsMtd> AnalyticsGroupMtdIter<'a, T> {
290    fn new(group: &'a AnalyticsMtdRef<'a, AnalyticsGroupMtd>) -> Self {
291        skip_assert_initialized!();
292        AnalyticsGroupMtdIter {
293            group,
294            state: std::ptr::null_mut(),
295            phantom: PhantomData,
296        }
297    }
298}
299
300impl<'a, T: AnalyticsMtd + 'a> Iterator for AnalyticsGroupMtdIter<'a, T> {
301    type Item = AnalyticsMtdRef<'a, T>;
302
303    fn next(&mut self) -> Option<Self::Item> {
304        unsafe {
305            let mtd = ffi::GstAnalyticsMtd::unsafe_from(self.group);
306            let mut member = std::mem::MaybeUninit::uninit();
307            let ret = from_glib(ffi::gst_analytics_group_mtd_iterate(
308                &mtd as *const _ as *const ffi::GstAnalyticsGroupMtd,
309                &mut self.state,
310                T::mtd_type(),
311                member.as_mut_ptr(),
312            ));
313
314            if ret {
315                let member = member.assume_init();
316                let id = ffi::gst_analytics_mtd_get_id(&member);
317                Some(AnalyticsMtdRef::from_meta(self.group.meta_ref(), id))
318            } else {
319                None
320            }
321        }
322    }
323}
324
325unsafe impl AnalyticsMtd for AnalyticsGroupMtd {
326    #[doc(alias = "gst_analytics_group_mtd_get_mtd_type")]
327    fn mtd_type() -> ffi::GstAnalyticsMtdType {
328        unsafe { ffi::gst_analytics_group_mtd_get_mtd_type() }
329    }
330}
331
332impl AnalyticsMtdRefMut<'_, AnalyticsGroupMtd> {
333    #[doc(alias = "gst_analytics_group_mtd_add_member")]
334    pub fn add_member(&mut self, an_meta_id: u32) -> Result<(), glib::BoolError> {
335        let ret = unsafe {
336            let mut mtd = ffi::GstAnalyticsMtd::unsafe_from(self);
337            from_glib(ffi::gst_analytics_group_mtd_add_member(
338                &mut mtd as *mut _ as *mut ffi::GstAnalyticsGroupMtd,
339                an_meta_id,
340            ))
341        };
342
343        if ret {
344            Ok(())
345        } else {
346            Err(glib::bool_error!("Couldn't add group member"))
347        }
348    }
349
350    #[doc(alias = "gst_analytics_group_mtd_set_semantic_tag")]
351    pub fn set_semantic_tag(&mut self, tag: &str) -> Result<(), glib::BoolError> {
352        let ret = unsafe {
353            let mut mtd = ffi::GstAnalyticsMtd::unsafe_from(self);
354            from_glib(ffi::gst_analytics_group_mtd_set_semantic_tag(
355                &mut mtd as *mut _ as *mut ffi::GstAnalyticsGroupMtd,
356                tag.to_glib_none().0,
357            ))
358        };
359
360        if ret {
361            Ok(())
362        } else {
363            Err(glib::bool_error!("Couldn't set semantic tag"))
364        }
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use crate::*;
371
372    #[test]
373    fn group_members() {
374        gst::init().unwrap();
375
376        let type_name = AnalyticsGroupMtd::type_name();
377        assert_eq!(type_name, "grouping-mtd");
378
379        let mut buf = gst::Buffer::new();
380        let mut meta = AnalyticsRelationMeta::add(buf.make_mut());
381
382        let keypoint_id = {
383            let keypoint = meta
384                .add_keypoint_mtd_from_position(
385                    AnalyticsKeypointPosition {
386                        x: 1,
387                        y: 2,
388                        z: 0,
389                        dimension: AnalyticsKeypointDimensions::_2d,
390                    },
391                    AnalyticsKeypointVisibility::VISIBLE,
392                    0.5,
393                )
394                .unwrap();
395            keypoint.id()
396        };
397
398        let group = meta.add_group_mtd_with_size(1).unwrap();
399        let group_id = group.id();
400
401        let mut group_mut = meta.mtd_mut::<AnalyticsGroupMtd>(group_id).unwrap();
402        group_mut.set_semantic_tag("pose").unwrap();
403        group_mut.add_member(keypoint_id).unwrap();
404
405        let group = AnalyticsMtdRef::from(group_mut);
406        assert!(group.has_semantic_tag("pose").unwrap());
407        assert!(group.semantic_tag_has_prefix("po").unwrap());
408        assert_eq!(group.member_count(), 1);
409
410        let member = group.member_typed::<AnalyticsKeypointMtd>(0).unwrap();
411        let position = member.position().unwrap();
412        assert_eq!(position.x, 1);
413        assert_eq!(position.y, 2);
414    }
415
416    #[test]
417    fn keypoints_group() {
418        gst::init().unwrap();
419
420        let mut buf = gst::Buffer::new();
421        let mut meta = AnalyticsRelationMeta::add(buf.make_mut());
422
423        let positions = [
424            AnalyticsKeypointPosition {
425                x: 10,
426                y: 20,
427                z: 0,
428                dimension: AnalyticsKeypointDimensions::_2d,
429            },
430            AnalyticsKeypointPosition {
431                x: 30,
432                y: 40,
433                z: 0,
434                dimension: AnalyticsKeypointDimensions::_2d,
435            },
436        ];
437        let confidences = [0.9, 0.8];
438        let visibilities = [1, 0];
439
440        let group = meta
441            .add_keypoints_group_from_positions(
442                "pose",
443                &positions,
444                Some(&confidences),
445                Some(&visibilities),
446                &[],
447            )
448            .unwrap();
449
450        assert!(group.has_semantic_tag("pose").unwrap());
451        assert!(group.semantic_tag_has_prefix("po").unwrap());
452        assert_eq!(group.member_count(), 2);
453    }
454
455    #[test]
456    fn keypoints_group_rejects_mismatched_confidences() {
457        gst::init().unwrap();
458
459        let mut buf = gst::Buffer::new();
460        let mut meta = AnalyticsRelationMeta::add(buf.make_mut());
461
462        let positions = [
463            AnalyticsKeypointPosition {
464                x: 10,
465                y: 20,
466                z: 0,
467                dimension: AnalyticsKeypointDimensions::_2d,
468            },
469            AnalyticsKeypointPosition {
470                x: 30,
471                y: 40,
472                z: 0,
473                dimension: AnalyticsKeypointDimensions::_2d,
474            },
475        ];
476        let confidences = [0.9];
477
478        let result = meta.add_keypoints_group_from_positions(
479            "pose",
480            &positions,
481            Some(&confidences),
482            None,
483            &[],
484        );
485
486        assert!(result.is_err());
487    }
488
489    #[test]
490    fn keypoints_group_rejects_mismatched_visibilities() {
491        gst::init().unwrap();
492
493        let mut buf = gst::Buffer::new();
494        let mut meta = AnalyticsRelationMeta::add(buf.make_mut());
495
496        let positions = [
497            AnalyticsKeypointPosition {
498                x: 10,
499                y: 20,
500                z: 0,
501                dimension: AnalyticsKeypointDimensions::_2d,
502            },
503            AnalyticsKeypointPosition {
504                x: 30,
505                y: 40,
506                z: 0,
507                dimension: AnalyticsKeypointDimensions::_2d,
508            },
509        ];
510        let visibilities = [1];
511
512        let result = meta.add_keypoints_group_from_positions(
513            "pose",
514            &positions,
515            None,
516            Some(&visibilities),
517            &[],
518        );
519
520        assert!(result.is_err());
521    }
522}