Skip to main content

renderbox_sdk/ops/
ai.rs

1use std::collections::BTreeMap;
2
3use renderbox_dsl::{OpNode, ParamValue};
4
5use crate::ops::params;
6use crate::sorts::{Audio, Density, Depth, Detection, Pose, Scene, Segmentation, Sort, Text, Transcript, Video};
7use crate::graph::Stream;
8
9// ---------------------------------------------------------------------------
10// Helpers
11// ---------------------------------------------------------------------------
12
13fn add_fanout<In: Sort, Out: Sort>(
14    input: &Stream<In>,
15    op: &str,
16    params: BTreeMap<String, ParamValue>,
17) -> Stream<Out> {
18    let node_id = input.graph.borrow_mut().add_node(OpNode::Fanout {
19        op: op.into(),
20        params,
21        input: input.node_id,
22    });
23    Stream::new(node_id, input.graph.clone())
24}
25
26fn add_transform<S: Sort>(
27    input: Stream<S>,
28    op: &str,
29    params: BTreeMap<String, ParamValue>,
30) -> Stream<S> {
31    let node_id = input.graph.borrow_mut().add_node(OpNode::Transform {
32        op: op.into(),
33        params,
34        input: input.node_id,
35    });
36    Stream::new(node_id, input.graph)
37}
38
39fn add_combine<A: Sort, B: Sort, Out: Sort>(
40    a: Stream<A>,
41    b: &Stream<B>,
42    op: &str,
43    params: BTreeMap<String, ParamValue>,
44) -> Stream<Out> {
45    assert!(
46        std::rc::Rc::ptr_eq(&a.graph, &b.graph),
47        "streams must belong to the same graph"
48    );
49    let node_id = a.graph.borrow_mut().add_node(OpNode::Combine2 {
50        op: op.into(),
51        a: a.node_id,
52        b: b.node_id,
53        params,
54    });
55    Stream::new(node_id, a.graph)
56}
57
58// ---------------------------------------------------------------------------
59// Detection
60// ---------------------------------------------------------------------------
61
62pub fn detect(v: &Stream<Video>, model: &str) -> Stream<Detection> {
63    add_fanout(v, "detect", params! { "model" => model })
64}
65
66#[derive(Debug, Clone, Copy)]
67pub enum TrackAlgorithm {
68    ByteTrack,
69    Sort,
70    DeepSort,
71}
72
73impl TrackAlgorithm {
74    fn as_str(&self) -> &'static str {
75        match self {
76            TrackAlgorithm::ByteTrack => "bytetrack",
77            TrackAlgorithm::Sort => "sort",
78            TrackAlgorithm::DeepSort => "deepsort",
79        }
80    }
81}
82
83pub fn track(d: Stream<Detection>, algo: TrackAlgorithm) -> Stream<Detection> {
84    add_transform(d, "track", params! { "algorithm" => algo.as_str() })
85}
86
87#[derive(Debug, Clone)]
88pub enum RedactMode {
89    Blur { radius: u32 },
90    Pixelate { size: u32 },
91    Black,
92}
93
94impl RedactMode {
95    fn to_params(&self) -> BTreeMap<String, ParamValue> {
96        match self {
97            RedactMode::Blur { radius } => {
98                params! { "mode" => "blur", "radius" => *radius }
99            }
100            RedactMode::Pixelate { size } => {
101                params! { "mode" => "pixelate", "radius" => *size }
102            }
103            RedactMode::Black => {
104                params! { "mode" => "black" }
105            }
106        }
107    }
108}
109
110pub fn redact(
111    d: Stream<Detection>,
112    mode: RedactMode,
113) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
114    move |input| add_combine(input, &d, "redact", mode.to_params())
115}
116
117pub fn annotate(d: Stream<Detection>) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
118    move |input| add_combine(input, &d, "annotate", params! {})
119}
120
121pub fn filter_detections(
122    d: Stream<Detection>,
123    classes: &str,
124) -> Stream<Detection> {
125    let classes = classes.to_string();
126    add_transform(d, "filter_detections", params! { "classes" => classes.as_str() })
127}
128
129// ---------------------------------------------------------------------------
130// Segmentation
131// ---------------------------------------------------------------------------
132
133pub fn segment(v: &Stream<Video>, model: &str) -> Stream<Segmentation> {
134    add_fanout(v, "segment", params! { "model" => model })
135}
136
137pub fn mask_composite(
138    seg: Stream<Segmentation>,
139) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
140    move |input| add_combine(input, &seg, "mask_composite", params! {})
141}
142
143pub fn mask_blur(
144    seg: Stream<Segmentation>,
145    radius: f64,
146) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
147    move |input| add_combine(input, &seg, "mask_blur", params! { "radius" => radius })
148}
149
150// ---------------------------------------------------------------------------
151// Pose
152// ---------------------------------------------------------------------------
153
154pub fn estimate_pose(v: &Stream<Video>, model: &str) -> Stream<Pose> {
155    add_fanout(v, "estimate_pose", params! { "model" => model })
156}
157
158pub fn draw_skeleton(p: Stream<Pose>) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
159    move |input| add_combine(input, &p, "draw_skeleton", params! {})
160}
161
162// ---------------------------------------------------------------------------
163// Depth
164// ---------------------------------------------------------------------------
165
166pub fn estimate_depth(v: &Stream<Video>, model: &str) -> Stream<Depth> {
167    add_fanout(v, "estimate_depth", params! { "model" => model })
168}
169
170pub fn depth_blur(
171    d: Stream<Depth>,
172    sigma: f64,
173) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
174    move |input| add_combine(input, &d, "depth_blur", params! { "sigma" => sigma })
175}
176
177pub fn depth_fog(
178    d: Stream<Depth>,
179    density: f64,
180) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
181    move |input| add_combine(input, &d, "depth_fog", params! { "density" => density })
182}
183
184// ---------------------------------------------------------------------------
185// Transcription
186// ---------------------------------------------------------------------------
187
188pub fn transcribe(a: &Stream<Audio>, model: &str) -> Stream<Transcript> {
189    add_fanout(a, "transcribe", params! { "model" => model })
190}
191
192pub fn diarize(t: Stream<Transcript>) -> Stream<Transcript> {
193    add_transform(t, "diarize", params! {})
194}
195
196pub fn burn_subtitles(t: Stream<Transcript>) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
197    move |input| add_combine(input, &t, "burn_subtitles", params! {})
198}
199
200// ---------------------------------------------------------------------------
201// OCR
202// ---------------------------------------------------------------------------
203
204pub fn recognize_text(v: &Stream<Video>, model: &str) -> Stream<Text> {
205    add_fanout(v, "recognize_text", params! { "model" => model })
206}
207
208pub fn redact_text(t: Stream<Text>) -> impl FnOnce(Stream<Video>) -> Stream<Video> {
209    move |input| add_combine(input, &t, "redact_text", params! {})
210}
211
212// ---------------------------------------------------------------------------
213// Scene / Density
214// ---------------------------------------------------------------------------
215
216pub fn classify_scene(v: &Stream<Video>, model: &str) -> Stream<Scene> {
217    add_fanout(v, "classify_scene", params! { "model" => model })
218}
219
220pub fn estimate_density(v: &Stream<Video>, model: &str) -> Stream<Density> {
221    add_fanout(v, "estimate_density", params! { "model" => model })
222}
223
224// ---------------------------------------------------------------------------
225// Tests
226// ---------------------------------------------------------------------------
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::graph::input;
232
233    #[test]
234    fn detect_creates_fanout_node() {
235        let (v, _a) = input("test.mp4");
236        let d = detect(&v, "yolov8");
237        let graph = d.graph.borrow();
238        match graph.arena.get(d.node_id) {
239            OpNode::Fanout { op, params, .. } => {
240                assert_eq!(op, "detect");
241                assert_eq!(params["model"], ParamValue::String("yolov8".into()));
242            }
243            other => panic!("expected Fanout, got {:?}", other),
244        }
245    }
246
247    #[test]
248    fn track_creates_transform_node() {
249        let (v, _a) = input("test.mp4");
250        let d = detect(&v, "yolov8");
251        let d = track(d, TrackAlgorithm::ByteTrack);
252        let graph = d.graph.borrow();
253        match graph.arena.get(d.node_id) {
254            OpNode::Transform { op, params, .. } => {
255                assert_eq!(op, "track");
256                assert_eq!(
257                    params["algorithm"],
258                    ParamValue::String("bytetrack".into())
259                );
260            }
261            other => panic!("expected Transform, got {:?}", other),
262        }
263    }
264
265    #[test]
266    fn redact_creates_combine2_node() {
267        let (v, _a) = input("test.mp4");
268        let d = detect(&v, "yolov8");
269        let v2 = v.pipe(redact(d, RedactMode::Blur { radius: 20 }));
270        let graph = v2.graph.borrow();
271        match graph.arena.get(v2.node_id) {
272            OpNode::Combine2 { op, params, .. } => {
273                assert_eq!(op, "redact");
274                assert_eq!(params["mode"], ParamValue::String("blur".into()));
275                assert_eq!(params["radius"], ParamValue::Number(20.0));
276            }
277            other => panic!("expected Combine2, got {:?}", other),
278        }
279    }
280
281    #[test]
282    fn detection_pipeline_original_video_preserved() {
283        let (v, _a) = input("test.mp4");
284        let original_node = v.node_id;
285        let d = detect(&v, "yolov8");
286        // v is still valid and points to the original input
287        assert_eq!(v.node_id, original_node);
288        // d points to the fanout
289        assert_ne!(d.node_id, v.node_id);
290    }
291
292    #[test]
293    fn full_detection_pipeline() {
294        let (v, a) = input("interview.mp4");
295        let faces = detect(&v, "yolov8-face");
296        let faces = track(faces, TrackAlgorithm::ByteTrack);
297        let v = v.pipe(redact(faces, RedactMode::Blur { radius: 25 }));
298        let v = v.pipe(crate::ops::video::scale(1920, 1080));
299        let bytes = crate::graph::output("out.mp4", v, a).build().unwrap();
300        assert!(!bytes.is_empty());
301    }
302
303    #[test]
304    fn depth_pipeline() {
305        let (v, a) = input("test.mp4");
306        let d = estimate_depth(&v, "midas-v2");
307        let v = v.pipe(depth_blur(d, 5.0));
308        let bytes = crate::graph::output("out.mp4", v, a).build().unwrap();
309        assert!(!bytes.is_empty());
310    }
311
312    #[test]
313    fn transcription_pipeline() {
314        let (v, a) = input("test.mp4");
315        let t = transcribe(&a, "whisper-base");
316        let v = v.pipe(burn_subtitles(t));
317        let bytes = crate::graph::output("out.mp4", v, a).build().unwrap();
318        assert!(!bytes.is_empty());
319    }
320}