Skip to main content

gstreamer_analytics/
model_info.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use crate::ffi;
4use crate::*;
5use glib::translate::*;
6use std::ptr;
7
8use crate::{ModelInfo, ModelInfoTensorDirection, TensorDataType};
9
10impl ModelInfo {
11    /// Find the name of a tensor in the modelinfo that matches the given criteria.
12    ///
13    /// The function performs the following checks in order:
14    /// 1. If `in_tensor_name` is provided and exists in modelinfo, validate it matches
15    /// 2. Search by index for the specified direction and validate
16    /// 3. Search by dimensions and data type
17    /// ## `dir`
18    /// The tensor direction (input or output)
19    /// ## `index`
20    /// The tensor index within the specified direction
21    /// ## `in_tensor_name`
22    /// An optional tensor name hint to check first
23    /// ## `data_type`
24    /// The tensor data type to match
25    /// ## `dims`
26    /// The dimension sizes. Use -1 for dynamic dimensions.
27    ///
28    /// # Returns
29    ///
30    /// The tensor name if found, or [`None`] otherwise.
31    ///  The caller must free this with `g_free()` when done.
32    #[doc(alias = "gst_analytics_modelinfo_find_tensor_name")]
33    pub fn find_tensor_name(
34        &self,
35        dir: ModelInfoTensorDirection,
36        index: usize,
37        in_tensor_name: Option<&str>,
38        data_type: TensorDataType,
39        dims: &[usize],
40    ) -> Option<glib::GString> {
41        unsafe {
42            from_glib_full(ffi::gst_analytics_modelinfo_find_tensor_name(
43                self.to_glib_none().0,
44                dir.into_glib(),
45                index,
46                in_tensor_name.to_glib_none().0,
47                data_type.into_glib(),
48                dims.len(),
49                dims.as_ptr(),
50            ))
51        }
52    }
53
54    /// Calculate normalization scales and offsets to transform input data to the target range.
55    ///
56    /// This function calculates transformation parameters to convert from the actual input data range
57    /// [input_min, input_max] to the target range expected by the model [target_min, target_max]:
58    ///  `normalized_value[i] = input[i] * output_scale[i] + output_offset[i]`
59    ///
60    /// The target ranges are read from the modelinfo `ranges` field: Semicolon-separated list of
61    /// comma-separated pairs (min,max) for per-channel target ranges
62    /// (e.g., "0.0,255.0;-1.0,1.0;0.0,1.0" for RGB channels with different target ranges).
63    ///
64    /// Common input ranges:
65    /// - [0.0, 255.0]: 8-bit unsigned (uint8)
66    /// - [-128.0, 127.0]: 8-bit signed (int8)
67    /// - [0.0, 65535.0]: 16-bit unsigned (uint16)
68    /// - [-32768.0, 32767.0]: 16-bit signed (int16)
69    /// - [0.0, 1.0]: Normalized float
70    /// - [-1.0, 1.0]: Normalized signed float
71    ///
72    /// The number of input ranges (`num_input_ranges`) must equal the number of target ranges
73    /// in the modelinfo. The function will return FALSE if they don't match.
74    ///
75    /// The caller must free `output_scales` and `output_offsets` with `g_free()` when done.
76    /// ## `tensor_name`
77    /// The name of the tensor
78    /// ## `input_mins`
79    /// The minimum values of the actual input data for each channel
80    /// ## `input_maxs`
81    /// The maximum values of the actual input data for each channel
82    ///
83    /// # Returns
84    ///
85    /// [`true`] on success, [`false`] on error, if ranges field is not found, or if `num_input_ranges`
86    ///  doesn't match the number of target ranges in the modelinfo
87    ///
88    /// ## `output_scales`
89    /// The scale values for normalization
90    ///
91    /// ## `output_offsets`
92    /// The offset values for normalization
93    #[doc(alias = "gst_analytics_modelinfo_get_input_scales_offsets")]
94    #[doc(alias = "get_input_scales_offsets")]
95    pub fn input_scales_offsets(
96        &self,
97        tensor_name: &str,
98        input_mins: &[f64],
99        input_maxs: &[f64],
100    ) -> Option<(glib::Slice<f64>, glib::Slice<f64>)> {
101        unsafe {
102            assert_eq!(input_mins.len(), input_maxs.len());
103
104            let mut num_output_ranges = 0;
105            let mut output_scales = ptr::null_mut();
106            let mut output_offsets = ptr::null_mut();
107            let res = from_glib(ffi::gst_analytics_modelinfo_get_input_scales_offsets(
108                self.to_glib_none().0,
109                tensor_name.to_glib_none().0,
110                input_mins.len(),
111                input_mins.as_ptr(),
112                input_maxs.as_ptr(),
113                &mut num_output_ranges,
114                &mut output_scales,
115                &mut output_offsets,
116            ));
117            if res {
118                Some((
119                    glib::Slice::from_glib_full_num(output_scales, num_output_ranges),
120                    glib::Slice::from_glib_full_num(output_offsets, num_output_ranges),
121                ))
122            } else {
123                None
124            }
125        }
126    }
127
128    /// Retrieve all target ranges (min/max pairs) expected by the model for a given tensor.
129    ///
130    /// This function retrieves all target ranges from the `ranges` field in the modelinfo.
131    /// Each range represents the expected input range for a channel or dimension that the
132    /// model requires.
133    ///
134    /// The function reads from the `ranges` field: Semicolon-separated list of
135    /// comma-separated pairs (min,max) for per-channel target ranges
136    /// (e.g., "0.0,1.0;-1.0,1.0;0.0,1.0" for RGB channels with different normalization targets).
137    ///
138    /// The caller must free `mins` and `maxs` with `g_free()` when done.
139    /// ## `tensor_name`
140    /// The name of the tensor
141    ///
142    /// # Returns
143    ///
144    /// [`true`] if range information was found and valid, [`false`] otherwise
145    ///
146    /// ## `mins`
147    /// The minimum values for each target range
148    ///
149    /// ## `maxs`
150    /// The maximum values for each target range
151    #[doc(alias = "gst_analytics_modelinfo_get_target_ranges")]
152    #[doc(alias = "get_target_ranges")]
153    pub fn target_ranges(&self, tensor_name: &str) -> Option<(glib::Slice<f64>, glib::Slice<f64>)> {
154        unsafe {
155            let mut num_ranges = 0;
156            let mut mins = ptr::null_mut();
157            let mut maxs = ptr::null_mut();
158            let res = from_glib(ffi::gst_analytics_modelinfo_get_target_ranges(
159                mut_override(self.to_glib_none().0),
160                tensor_name.to_glib_none().0,
161                &mut num_ranges,
162                &mut mins,
163                &mut maxs,
164            ));
165            if res {
166                Some((
167                    glib::Slice::from_glib_full_num(mins, num_ranges),
168                    glib::Slice::from_glib_full_num(maxs, num_ranges),
169                ))
170            } else {
171                None
172            }
173        }
174    }
175}