gstreamer_analytics/model_info.rs
1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use crate::ffi;
4use crate::*;
5use glib::translate::*;
6use std::ptr;
7
8use crate::{ModelInfo, ModelInfoTensorDirection, TensorDataType};
9
10impl ModelInfo {
11 /// Find the name of a tensor in the modelinfo that matches the given criteria.
12 ///
13 /// The function performs the following checks in order:
14 /// 1. If `in_tensor_name` is provided and exists in modelinfo, validate it matches
15 /// 2. Search by index for the specified direction and validate
16 /// 3. Search by dimensions and data type
17 /// ## `dir`
18 /// The tensor direction (input or output)
19 /// ## `index`
20 /// The tensor index within the specified direction
21 /// ## `in_tensor_name`
22 /// An optional tensor name hint to check first
23 /// ## `data_type`
24 /// The tensor data type to match
25 /// ## `dims`
26 /// The dimension sizes. Use -1 for dynamic dimensions.
27 ///
28 /// # Returns
29 ///
30 /// The tensor name if found, or [`None`] otherwise.
31 /// The caller must free this with `g_free()` when done.
32 #[doc(alias = "gst_analytics_modelinfo_find_tensor_name")]
33 pub fn find_tensor_name(
34 &self,
35 dir: ModelInfoTensorDirection,
36 index: usize,
37 in_tensor_name: Option<&str>,
38 data_type: TensorDataType,
39 dims: &[usize],
40 ) -> Option<glib::GString> {
41 unsafe {
42 from_glib_full(ffi::gst_analytics_modelinfo_find_tensor_name(
43 self.to_glib_none().0,
44 dir.into_glib(),
45 index,
46 in_tensor_name.to_glib_none().0,
47 data_type.into_glib(),
48 dims.len(),
49 dims.as_ptr(),
50 ))
51 }
52 }
53
54 /// Calculate normalization scales and offsets to transform input data to the target range.
55 ///
56 /// This function calculates transformation parameters to convert from the actual input data range
57 /// [input_min, input_max] to the target range expected by the model [target_min, target_max]:
58 /// `normalized_value[i] = input[i] * output_scale[i] + output_offset[i]`
59 ///
60 /// The target ranges are read from the modelinfo `ranges` field: Semicolon-separated list of
61 /// comma-separated pairs (min,max) for per-channel target ranges
62 /// (e.g., "0.0,255.0;-1.0,1.0;0.0,1.0" for RGB channels with different target ranges).
63 ///
64 /// Common input ranges:
65 /// - [0.0, 255.0]: 8-bit unsigned (uint8)
66 /// - [-128.0, 127.0]: 8-bit signed (int8)
67 /// - [0.0, 65535.0]: 16-bit unsigned (uint16)
68 /// - [-32768.0, 32767.0]: 16-bit signed (int16)
69 /// - [0.0, 1.0]: Normalized float
70 /// - [-1.0, 1.0]: Normalized signed float
71 ///
72 /// The number of input ranges (`num_input_ranges`) must equal the number of target ranges
73 /// in the modelinfo. The function will return FALSE if they don't match.
74 ///
75 /// The caller must free `output_scales` and `output_offsets` with `g_free()` when done.
76 /// ## `tensor_name`
77 /// The name of the tensor
78 /// ## `input_mins`
79 /// The minimum values of the actual input data for each channel
80 /// ## `input_maxs`
81 /// The maximum values of the actual input data for each channel
82 ///
83 /// # Returns
84 ///
85 /// [`true`] on success, [`false`] on error, if ranges field is not found, or if `num_input_ranges`
86 /// doesn't match the number of target ranges in the modelinfo
87 ///
88 /// ## `output_scales`
89 /// The scale values for normalization
90 ///
91 /// ## `output_offsets`
92 /// The offset values for normalization
93 #[doc(alias = "gst_analytics_modelinfo_get_input_scales_offsets")]
94 #[doc(alias = "get_input_scales_offsets")]
95 pub fn input_scales_offsets(
96 &self,
97 tensor_name: &str,
98 input_mins: &[f64],
99 input_maxs: &[f64],
100 ) -> Option<(glib::Slice<f64>, glib::Slice<f64>)> {
101 unsafe {
102 assert_eq!(input_mins.len(), input_maxs.len());
103
104 let mut num_output_ranges = 0;
105 let mut output_scales = ptr::null_mut();
106 let mut output_offsets = ptr::null_mut();
107 let res = from_glib(ffi::gst_analytics_modelinfo_get_input_scales_offsets(
108 self.to_glib_none().0,
109 tensor_name.to_glib_none().0,
110 input_mins.len(),
111 input_mins.as_ptr(),
112 input_maxs.as_ptr(),
113 &mut num_output_ranges,
114 &mut output_scales,
115 &mut output_offsets,
116 ));
117 if res {
118 Some((
119 glib::Slice::from_glib_full_num(output_scales, num_output_ranges),
120 glib::Slice::from_glib_full_num(output_offsets, num_output_ranges),
121 ))
122 } else {
123 None
124 }
125 }
126 }
127
128 /// Retrieve all target ranges (min/max pairs) expected by the model for a given tensor.
129 ///
130 /// This function retrieves all target ranges from the `ranges` field in the modelinfo.
131 /// Each range represents the expected input range for a channel or dimension that the
132 /// model requires.
133 ///
134 /// The function reads from the `ranges` field: Semicolon-separated list of
135 /// comma-separated pairs (min,max) for per-channel target ranges
136 /// (e.g., "0.0,1.0;-1.0,1.0;0.0,1.0" for RGB channels with different normalization targets).
137 ///
138 /// The caller must free `mins` and `maxs` with `g_free()` when done.
139 /// ## `tensor_name`
140 /// The name of the tensor
141 ///
142 /// # Returns
143 ///
144 /// [`true`] if range information was found and valid, [`false`] otherwise
145 ///
146 /// ## `mins`
147 /// The minimum values for each target range
148 ///
149 /// ## `maxs`
150 /// The maximum values for each target range
151 #[doc(alias = "gst_analytics_modelinfo_get_target_ranges")]
152 #[doc(alias = "get_target_ranges")]
153 pub fn target_ranges(&self, tensor_name: &str) -> Option<(glib::Slice<f64>, glib::Slice<f64>)> {
154 unsafe {
155 let mut num_ranges = 0;
156 let mut mins = ptr::null_mut();
157 let mut maxs = ptr::null_mut();
158 let res = from_glib(ffi::gst_analytics_modelinfo_get_target_ranges(
159 mut_override(self.to_glib_none().0),
160 tensor_name.to_glib_none().0,
161 &mut num_ranges,
162 &mut mins,
163 &mut maxs,
164 ));
165 if res {
166 Some((
167 glib::Slice::from_glib_full_num(mins, num_ranges),
168 glib::Slice::from_glib_full_num(maxs, num_ranges),
169 ))
170 } else {
171 None
172 }
173 }
174 }
175}