1use std::collections::BTreeMap;
2
3use renderbox_dsl::{OpNode, ParamValue};
4
5use crate::ops::params;
6use crate::sorts::{Audio, Density, Depth, Detection, Pose, Scene, Segmentation, Sort, Text, Transcript, Video};
7use crate::graph::Stream;
8
9fn add_fanout<In: Sort, Out: Sort>(
14 input: &Stream<In>,
15 op: &str,
16 params: BTreeMap<String, ParamValue>,
17) -> Stream<Out> {
18 let node_id = input.graph.borrow_mut().add_node(OpNode::Fanout {
19 op: op.into(),
20 params,
21 input: input.node_id,
22 });
23 Stream::new(node_id, input.graph.clone())
24}
25
26fn add_transform<S: Sort>(
27 input: Stream<S>,
28 op: &str,
29 params: BTreeMap<String, ParamValue>,
30) -> Stream<S> {
31 let node_id = input.graph.borrow_mut().add_node(OpNode::Transform {
32 op: op.into(),
33 params,
34 input: input.node_id,
35 });
36 Stream::new(node_id, input.graph)
37}
38
39fn add_combine<A: Sort, B: Sort, Out: Sort>(
40 a: Stream<A>,
41 b: &Stream<B>,
42 op: &str,
43 params: BTreeMap<String, ParamValue>,
44) -> Stream<Out> {
45 assert!(
46 std::rc::Rc::ptr_eq(&a.graph, &b.graph),
47 "streams must belong to the same graph"
48 );
49 let node_id = a.graph.borrow_mut().add_node(OpNode::Combine2 {
50 op: op.into(),
51 a: a.node_id,
52 b: b.node_id,
53 params,
54 });
55 Stream::new(node_id, a.graph)
56}
57
58pub fn detect(v: &Stream<Video>, model: &str) -> Stream<Detection> {
63 add_fanout(v, "detect", params! { "model" => model })
64}
65
66#[derive(Debug, Clone, Copy)]
67pub enum TrackAlgorithm {
68 ByteTrack,
69 Sort,
70 DeepSort,
71}
72
73impl TrackAlgorithm {
74 fn as_str(&self) -> &'static str {
75 match self {
76 TrackAlgorithm::ByteTrack => "bytetrack",
77 TrackAlgorithm::Sort => "sort",
78 TrackAlgorithm::DeepSort => "deepsort",
79 }
80 }
81}
82
83pub fn track(d: Stream<Detection>, algo: TrackAlgorithm) -> Stream<Detection> {
84 add_transform(d, "track", params! { "algorithm" => algo.as_str() })
85}
86
87#[derive(Debug, Clone)]
88pub enum RedactMode {
89 Blur { radius: u32 },
90 Pixelate { size: u32 },
91 Black,
92}
93
94impl RedactMode {
95 fn to_params(&self) -> BTreeMap<String, ParamValue> {
96 match self {
97 RedactMode::Blur { radius } => {
98 params! { "mode" => "blur", "radius" => *radius }
99 }
100 RedactMode::Pixelate { size } => {
101 params! { "mode" => "pixelate", "radius" => *size }
102 }
103 RedactMode::Black => {
104 params! { "mode" => "black" }
105 }
106 }
107 }
108}
109
110pub fn redact(
111 d: Stream<Detection>,
112 mode: RedactMode,
113) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
114 move |input| add_combine(input, &d, "redact", mode.to_params())
115}
116
117pub fn annotate(d: Stream<Detection>) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
118 move |input| add_combine(input, &d, "annotate", params! {})
119}
120
121pub fn filter_detections(
122 d: Stream<Detection>,
123 classes: &str,
124) -> Stream<Detection> {
125 let classes = classes.to_string();
126 add_transform(d, "filter_detections", params! { "classes" => classes.as_str() })
127}
128
129pub fn segment(v: &Stream<Video>, model: &str) -> Stream<Segmentation> {
134 add_fanout(v, "segment", params! { "model" => model })
135}
136
137pub fn mask_composite(
138 seg: Stream<Segmentation>,
139) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
140 move |input| add_combine(input, &seg, "mask_composite", params! {})
141}
142
143pub fn mask_blur(
144 seg: Stream<Segmentation>,
145 radius: f64,
146) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
147 move |input| add_combine(input, &seg, "mask_blur", params! { "radius" => radius })
148}
149
150pub fn estimate_pose(v: &Stream<Video>, model: &str) -> Stream<Pose> {
155 add_fanout(v, "estimate_pose", params! { "model" => model })
156}
157
158pub fn draw_skeleton(p: Stream<Pose>) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
159 move |input| add_combine(input, &p, "draw_skeleton", params! {})
160}
161
162pub fn estimate_depth(v: &Stream<Video>, model: &str) -> Stream<Depth> {
167 add_fanout(v, "estimate_depth", params! { "model" => model })
168}
169
170pub fn depth_blur(
171 d: Stream<Depth>,
172 sigma: f64,
173) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
174 move |input| add_combine(input, &d, "depth_blur", params! { "sigma" => sigma })
175}
176
177pub fn depth_fog(
178 d: Stream<Depth>,
179 density: f64,
180) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
181 move |input| add_combine(input, &d, "depth_fog", params! { "density" => density })
182}
183
184pub fn transcribe(a: &Stream<Audio>, model: &str) -> Stream<Transcript> {
189 add_fanout(a, "transcribe", params! { "model" => model })
190}
191
192pub fn diarize(t: Stream<Transcript>) -> Stream<Transcript> {
193 add_transform(t, "diarize", params! {})
194}
195
196pub fn burn_subtitles(t: Stream<Transcript>) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
197 move |input| add_combine(input, &t, "burn_subtitles", params! {})
198}
199
200pub fn recognize_text(v: &Stream<Video>, model: &str) -> Stream<Text> {
205 add_fanout(v, "recognize_text", params! { "model" => model })
206}
207
208pub fn redact_text(t: Stream<Text>) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
209 move |input| add_combine(input, &t, "redact_text", params! {})
210}
211
212pub fn classify_scene(v: &Stream<Video>, model: &str) -> Stream<Scene> {
217 add_fanout(v, "classify_scene", params! { "model" => model })
218}
219
220pub fn estimate_density(v: &Stream<Video>, model: &str) -> Stream<Density> {
221 add_fanout(v, "estimate_density", params! { "model" => model })
222}
223
224#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::graph::input;
232
233 #[test]
234 fn detect_creates_fanout_node() {
235 let (v, _a) = input("test.mp4");
236 let d = detect(&v, "yolov8");
237 let graph = d.graph.borrow();
238 match graph.arena.get(d.node_id) {
239 OpNode::Fanout { op, params, .. } => {
240 assert_eq!(op, "detect");
241 assert_eq!(params["model"], ParamValue::String("yolov8".into()));
242 }
243 other => panic!("expected Fanout, got {:?}", other),
244 }
245 }
246
247 #[test]
248 fn track_creates_transform_node() {
249 let (v, _a) = input("test.mp4");
250 let d = detect(&v, "yolov8");
251 let d = track(d, TrackAlgorithm::ByteTrack);
252 let graph = d.graph.borrow();
253 match graph.arena.get(d.node_id) {
254 OpNode::Transform { op, params, .. } => {
255 assert_eq!(op, "track");
256 assert_eq!(
257 params["algorithm"],
258 ParamValue::String("bytetrack".into())
259 );
260 }
261 other => panic!("expected Transform, got {:?}", other),
262 }
263 }
264
265 #[test]
266 fn redact_creates_combine2_node() {
267 let (v, _a) = input("test.mp4");
268 let d = detect(&v, "yolov8");
269 let v2 = v.pipe(redact(d, RedactMode::Blur { radius: 20 }));
270 let graph = v2.graph.borrow();
271 match graph.arena.get(v2.node_id) {
272 OpNode::Combine2 { op, params, .. } => {
273 assert_eq!(op, "redact");
274 assert_eq!(params["mode"], ParamValue::String("blur".into()));
275 assert_eq!(params["radius"], ParamValue::Number(20.0));
276 }
277 other => panic!("expected Combine2, got {:?}", other),
278 }
279 }
280
281 #[test]
282 fn detection_pipeline_original_video_preserved() {
283 let (v, _a) = input("test.mp4");
284 let original_node = v.node_id;
285 let d = detect(&v, "yolov8");
286 assert_eq!(v.node_id, original_node);
288 assert_ne!(d.node_id, v.node_id);
290 }
291
292 #[test]
293 fn full_detection_pipeline() {
294 let (v, a) = input("interview.mp4");
295 let faces = detect(&v, "yolov8-face");
296 let faces = track(faces, TrackAlgorithm::ByteTrack);
297 let v = v.pipe(redact(faces, RedactMode::Blur { radius: 25 }));
298 let v = v.pipe(crate::ops::video::scale(1920, 1080));
299 let bytes = crate::graph::output("out.mp4", v, a).build().unwrap();
300 assert!(!bytes.is_empty());
301 }
302
303 #[test]
304 fn depth_pipeline() {
305 let (v, a) = input("test.mp4");
306 let d = estimate_depth(&v, "midas-v2");
307 let v = v.pipe(depth_blur(d, 5.0));
308 let bytes = crate::graph::output("out.mp4", v, a).build().unwrap();
309 assert!(!bytes.is_empty());
310 }
311
312 #[test]
313 fn transcription_pipeline() {
314 let (v, a) = input("test.mp4");
315 let t = transcribe(&a, "whisper-base");
316 let v = v.pipe(burn_subtitles(t));
317 let bytes = crate::graph::output("out.mp4", v, a).build().unwrap();
318 assert!(!bytes.is_empty());
319 }
320}