Skip to main content

renderbox_sdk/
expr.rs

1use std::fmt;
2use std::ops;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum Var {
6    PTS,
7    N,
8    T,
9    W,
10    H,
11    SW,
12    SH,
13    A,
14}
15
16impl fmt::Display for Var {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        let s = match self {
19            Var::PTS => "PTS",
20            Var::N => "N",
21            Var::T => "T",
22            Var::W => "W",
23            Var::H => "H",
24            Var::SW => "SW",
25            Var::SH => "SH",
26            Var::A => "A",
27        };
28        f.write_str(s)
29    }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum Op {
34    Add,
35    Sub,
36    Mul,
37    Div,
38    Gt,
39    Gte,
40    Lt,
41    Lte,
42    Eq,
43}
44
45impl fmt::Display for Op {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        let s = match self {
48            Op::Add => "+",
49            Op::Sub => "-",
50            Op::Mul => "*",
51            Op::Div => "/",
52            Op::Gt => "gt",
53            Op::Gte => "gte",
54            Op::Lt => "lt",
55            Op::Lte => "lte",
56            Op::Eq => "eq",
57        };
58        f.write_str(s)
59    }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum UnaryOp {
64    Sin,
65    Cos,
66    Tan,
67    Abs,
68    Sqrt,
69    Ceil,
70    Floor,
71    Round,
72    Log,
73    Exp,
74    Not,
75}
76
77impl fmt::Display for UnaryOp {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        let s = match self {
80            UnaryOp::Sin => "sin",
81            UnaryOp::Cos => "cos",
82            UnaryOp::Tan => "tan",
83            UnaryOp::Abs => "abs",
84            UnaryOp::Sqrt => "sqrt",
85            UnaryOp::Ceil => "ceil",
86            UnaryOp::Floor => "floor",
87            UnaryOp::Round => "round",
88            UnaryOp::Log => "log",
89            UnaryOp::Exp => "exp",
90            UnaryOp::Not => "not",
91        };
92        f.write_str(s)
93    }
94}
95
96#[derive(Debug, Clone)]
97pub enum Expr {
98    Lit(f64),
99    Var(Var),
100    BinOp {
101        op: Op,
102        left: Box<Expr>,
103        right: Box<Expr>,
104    },
105    UnaryOp {
106        op: UnaryOp,
107        arg: Box<Expr>,
108    },
109    If {
110        cond: Box<Expr>,
111        then: Box<Expr>,
112        else_: Box<Expr>,
113    },
114    Fn {
115        name: String,
116        args: Vec<Expr>,
117    },
118}
119
120fn format_number(n: f64) -> String {
121    if n.fract() == 0.0 && n.abs() < 1e15 {
122        format!("{}", n as i64)
123    } else {
124        format!("{}", n)
125    }
126}
127
128impl Expr {
129    pub fn lit(n: f64) -> Expr {
130        Expr::Lit(n)
131    }
132
133    pub fn var(v: Var) -> Expr {
134        Expr::Var(v)
135    }
136
137    pub fn gt(self, other: impl Into<Expr>) -> Expr {
138        Expr::BinOp {
139            op: Op::Gt,
140            left: Box::new(self),
141            right: Box::new(other.into()),
142        }
143    }
144
145    pub fn gte(self, other: impl Into<Expr>) -> Expr {
146        Expr::BinOp {
147            op: Op::Gte,
148            left: Box::new(self),
149            right: Box::new(other.into()),
150        }
151    }
152
153    pub fn lt(self, other: impl Into<Expr>) -> Expr {
154        Expr::BinOp {
155            op: Op::Lt,
156            left: Box::new(self),
157            right: Box::new(other.into()),
158        }
159    }
160
161    pub fn lte(self, other: impl Into<Expr>) -> Expr {
162        Expr::BinOp {
163            op: Op::Lte,
164            left: Box::new(self),
165            right: Box::new(other.into()),
166        }
167    }
168
169    pub fn eq(self, other: impl Into<Expr>) -> Expr {
170        Expr::BinOp {
171            op: Op::Eq,
172            left: Box::new(self),
173            right: Box::new(other.into()),
174        }
175    }
176
177    pub fn if_then_else(self, then: impl Into<Expr>, else_: impl Into<Expr>) -> Expr {
178        Expr::If {
179            cond: Box::new(self),
180            then: Box::new(then.into()),
181            else_: Box::new(else_.into()),
182        }
183    }
184
185    pub fn sin(self) -> Expr {
186        Expr::UnaryOp {
187            op: UnaryOp::Sin,
188            arg: Box::new(self),
189        }
190    }
191
192    pub fn cos(self) -> Expr {
193        Expr::UnaryOp {
194            op: UnaryOp::Cos,
195            arg: Box::new(self),
196        }
197    }
198
199    pub fn tan(self) -> Expr {
200        Expr::UnaryOp {
201            op: UnaryOp::Tan,
202            arg: Box::new(self),
203        }
204    }
205
206    pub fn abs(self) -> Expr {
207        Expr::UnaryOp {
208            op: UnaryOp::Abs,
209            arg: Box::new(self),
210        }
211    }
212
213    pub fn sqrt(self) -> Expr {
214        Expr::UnaryOp {
215            op: UnaryOp::Sqrt,
216            arg: Box::new(self),
217        }
218    }
219
220    pub fn ceil(self) -> Expr {
221        Expr::UnaryOp {
222            op: UnaryOp::Ceil,
223            arg: Box::new(self),
224        }
225    }
226
227    pub fn floor(self) -> Expr {
228        Expr::UnaryOp {
229            op: UnaryOp::Floor,
230            arg: Box::new(self),
231        }
232    }
233
234    pub fn round(self) -> Expr {
235        Expr::UnaryOp {
236            op: UnaryOp::Round,
237            arg: Box::new(self),
238        }
239    }
240
241    pub fn log(self) -> Expr {
242        Expr::UnaryOp {
243            op: UnaryOp::Log,
244            arg: Box::new(self),
245        }
246    }
247
248    pub fn exp(self) -> Expr {
249        Expr::UnaryOp {
250            op: UnaryOp::Exp,
251            arg: Box::new(self),
252        }
253    }
254
255    pub fn not(self) -> Expr {
256        Expr::UnaryOp {
257            op: UnaryOp::Not,
258            arg: Box::new(self),
259        }
260    }
261
262    pub fn call(name: impl Into<String>, args: Vec<Expr>) -> Expr {
263        Expr::Fn {
264            name: name.into(),
265            args,
266        }
267    }
268
269    pub fn compile(&self) -> String {
270        match self {
271            Expr::Lit(n) => format_number(*n),
272            Expr::Var(v) => v.to_string(),
273            Expr::BinOp { op, left, right } => match op {
274                Op::Gt | Op::Gte | Op::Lt | Op::Lte | Op::Eq => {
275                    format!("{}({},{})", op, left.compile(), right.compile())
276                }
277                _ => {
278                    format!("({}{}{})", left.compile(), op, right.compile())
279                }
280            },
281            Expr::UnaryOp { op, arg } => {
282                format!("{}({})", op, arg.compile())
283            }
284            Expr::If { cond, then, else_ } => {
285                format!("if({},{},{})", cond.compile(), then.compile(), else_.compile())
286            }
287            Expr::Fn { name, args } => {
288                let args_str: Vec<String> = args.iter().map(|a| a.compile()).collect();
289                format!("{}({})", name, args_str.join(","))
290            }
291        }
292    }
293}
294
295impl From<f64> for Expr {
296    fn from(n: f64) -> Self {
297        Expr::Lit(n)
298    }
299}
300
301impl From<i32> for Expr {
302    fn from(n: i32) -> Self {
303        Expr::Lit(n as f64)
304    }
305}
306
307impl From<u32> for Expr {
308    fn from(n: u32) -> Self {
309        Expr::Lit(n as f64)
310    }
311}
312
313impl ops::Add for Expr {
314    type Output = Expr;
315    fn add(self, rhs: Expr) -> Expr {
316        Expr::BinOp {
317            op: Op::Add,
318            left: Box::new(self),
319            right: Box::new(rhs),
320        }
321    }
322}
323
324impl ops::Add<f64> for Expr {
325    type Output = Expr;
326    fn add(self, rhs: f64) -> Expr {
327        self + Expr::Lit(rhs)
328    }
329}
330
331impl ops::Add<Expr> for f64 {
332    type Output = Expr;
333    fn add(self, rhs: Expr) -> Expr {
334        Expr::Lit(self) + rhs
335    }
336}
337
338impl ops::Sub for Expr {
339    type Output = Expr;
340    fn sub(self, rhs: Expr) -> Expr {
341        Expr::BinOp {
342            op: Op::Sub,
343            left: Box::new(self),
344            right: Box::new(rhs),
345        }
346    }
347}
348
349impl ops::Sub<f64> for Expr {
350    type Output = Expr;
351    fn sub(self, rhs: f64) -> Expr {
352        self - Expr::Lit(rhs)
353    }
354}
355
356impl ops::Sub<Expr> for f64 {
357    type Output = Expr;
358    fn sub(self, rhs: Expr) -> Expr {
359        Expr::Lit(self) - rhs
360    }
361}
362
363impl ops::Mul for Expr {
364    type Output = Expr;
365    fn mul(self, rhs: Expr) -> Expr {
366        Expr::BinOp {
367            op: Op::Mul,
368            left: Box::new(self),
369            right: Box::new(rhs),
370        }
371    }
372}
373
374impl ops::Mul<f64> for Expr {
375    type Output = Expr;
376    fn mul(self, rhs: f64) -> Expr {
377        self * Expr::Lit(rhs)
378    }
379}
380
381impl ops::Mul<Expr> for f64 {
382    type Output = Expr;
383    fn mul(self, rhs: Expr) -> Expr {
384        Expr::Lit(self) * rhs
385    }
386}
387
388impl ops::Div for Expr {
389    type Output = Expr;
390    fn div(self, rhs: Expr) -> Expr {
391        Expr::BinOp {
392            op: Op::Div,
393            left: Box::new(self),
394            right: Box::new(rhs),
395        }
396    }
397}
398
399impl ops::Div<f64> for Expr {
400    type Output = Expr;
401    fn div(self, rhs: f64) -> Expr {
402        self / Expr::Lit(rhs)
403    }
404}
405
406impl ops::Div<Expr> for f64 {
407    type Output = Expr;
408    fn div(self, rhs: Expr) -> Expr {
409        Expr::Lit(self) / rhs
410    }
411}
412
413#[derive(Debug, Clone)]
414pub enum NumberOrExpr {
415    Number(f64),
416    Expr(Expr),
417}
418
419impl NumberOrExpr {
420    pub fn compile(&self) -> String {
421        match self {
422            NumberOrExpr::Number(n) => format_number(*n),
423            NumberOrExpr::Expr(e) => e.compile(),
424        }
425    }
426}
427
428impl From<f64> for NumberOrExpr {
429    fn from(n: f64) -> Self {
430        NumberOrExpr::Number(n)
431    }
432}
433
434impl From<i32> for NumberOrExpr {
435    fn from(n: i32) -> Self {
436        NumberOrExpr::Number(n as f64)
437    }
438}
439
440impl From<u32> for NumberOrExpr {
441    fn from(n: u32) -> Self {
442        NumberOrExpr::Number(n as f64)
443    }
444}
445
446impl From<Expr> for NumberOrExpr {
447    fn from(e: Expr) -> Self {
448        NumberOrExpr::Expr(e)
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn literal_integer() {
458        assert_eq!(Expr::lit(42.0).compile(), "42");
459    }
460
461    #[test]
462    fn literal_float() {
463        assert_eq!(Expr::lit(1.5).compile(), "1.5");
464    }
465
466    #[test]
467    fn variable() {
468        assert_eq!(Expr::var(Var::PTS).compile(), "PTS");
469        assert_eq!(Expr::var(Var::T).compile(), "T");
470        assert_eq!(Expr::var(Var::N).compile(), "N");
471        assert_eq!(Expr::var(Var::W).compile(), "W");
472        assert_eq!(Expr::var(Var::H).compile(), "H");
473        assert_eq!(Expr::var(Var::SW).compile(), "SW");
474        assert_eq!(Expr::var(Var::SH).compile(), "SH");
475        assert_eq!(Expr::var(Var::A).compile(), "A");
476    }
477
478    #[test]
479    fn binop_mul() {
480        let e = Expr::var(Var::PTS) * 2.0;
481        assert_eq!(e.compile(), "(PTS*2)");
482    }
483
484    #[test]
485    fn binop_add() {
486        let e = Expr::var(Var::T) + 1.0;
487        assert_eq!(e.compile(), "(T+1)");
488    }
489
490    #[test]
491    fn binop_sub() {
492        let e = Expr::var(Var::N) - 10.0;
493        assert_eq!(e.compile(), "(N-10)");
494    }
495
496    #[test]
497    fn binop_div() {
498        let e = Expr::var(Var::W) / 2.0;
499        assert_eq!(e.compile(), "(W/2)");
500    }
501
502    #[test]
503    fn nested_expression() {
504        let e = Expr::var(Var::PTS) * 2.0 + 1.0;
505        assert_eq!(e.compile(), "((PTS*2)+1)");
506    }
507
508    #[test]
509    fn reverse_add() {
510        let e = 1.0 + Expr::var(Var::T);
511        assert_eq!(e.compile(), "(1+T)");
512    }
513
514    #[test]
515    fn reverse_mul() {
516        let e = 2.0 * Expr::var(Var::T);
517        assert_eq!(e.compile(), "(2*T)");
518    }
519
520    #[test]
521    fn reverse_sub() {
522        let e = 10.0 - Expr::var(Var::N);
523        assert_eq!(e.compile(), "(10-N)");
524    }
525
526    #[test]
527    fn reverse_div() {
528        let e = 1.0 / Expr::var(Var::T);
529        assert_eq!(e.compile(), "(1/T)");
530    }
531
532    #[test]
533    fn comparison_gt() {
534        let e = Expr::var(Var::N).gt(100.0);
535        assert_eq!(e.compile(), "gt(N,100)");
536    }
537
538    #[test]
539    fn comparison_gte() {
540        let e = Expr::var(Var::T).gte(5.0);
541        assert_eq!(e.compile(), "gte(T,5)");
542    }
543
544    #[test]
545    fn comparison_lt() {
546        let e = Expr::var(Var::W).lt(1920.0);
547        assert_eq!(e.compile(), "lt(W,1920)");
548    }
549
550    #[test]
551    fn comparison_lte() {
552        let e = Expr::var(Var::H).lte(1080.0);
553        assert_eq!(e.compile(), "lte(H,1080)");
554    }
555
556    #[test]
557    fn comparison_eq() {
558        let e = Expr::var(Var::N).eq(0.0);
559        assert_eq!(e.compile(), "eq(N,0)");
560    }
561
562    #[test]
563    fn if_then_else() {
564        let cond = Expr::var(Var::N).gt(100.0);
565        let e = cond.if_then_else(1.0, 0.0);
566        assert_eq!(e.compile(), "if(gt(N,100),1,0)");
567    }
568
569    #[test]
570    fn nested_if() {
571        let e = Expr::var(Var::T)
572            .gt(5.0)
573            .if_then_else(
574                Expr::var(Var::PTS) * 2.0,
575                Expr::var(Var::PTS),
576            );
577        assert_eq!(e.compile(), "if(gt(T,5),(PTS*2),PTS)");
578    }
579
580    #[test]
581    fn unary_sin() {
582        let e = Expr::var(Var::T).sin();
583        assert_eq!(e.compile(), "sin(T)");
584    }
585
586    #[test]
587    fn unary_cos() {
588        let e = (Expr::var(Var::T) * 2.0).cos();
589        assert_eq!(e.compile(), "cos((T*2))");
590    }
591
592    #[test]
593    fn unary_abs() {
594        let e = (Expr::var(Var::N) - 50.0).abs();
595        assert_eq!(e.compile(), "abs((N-50))");
596    }
597
598    #[test]
599    fn fn_between() {
600        let e = Expr::call("between", vec![
601            Expr::var(Var::T),
602            Expr::lit(5.0),
603            Expr::lit(10.0),
604        ]);
605        assert_eq!(e.compile(), "between(T,5,10)");
606    }
607
608    #[test]
609    fn fn_clip() {
610        let e = Expr::call("clip", vec![
611            Expr::var(Var::N),
612            Expr::lit(0.0),
613            Expr::lit(255.0),
614        ]);
615        assert_eq!(e.compile(), "clip(N,0,255)");
616    }
617
618    #[test]
619    fn complex_expression() {
620        // if(gt(T,5), sin(T*3.14)*W/2, 0)
621        let e = Expr::var(Var::T)
622            .gt(5.0)
623            .if_then_else(
624                (Expr::var(Var::T) * 3.14).sin() * Expr::var(Var::W) / 2.0,
625                0.0,
626            );
627        assert_eq!(e.compile(), "if(gt(T,5),((sin((T*3.14))*W)/2),0)");
628    }
629
630    #[test]
631    fn number_or_expr_from_f64() {
632        let n: NumberOrExpr = 42.0.into();
633        assert_eq!(n.compile(), "42");
634    }
635
636    #[test]
637    fn number_or_expr_from_i32() {
638        let n: NumberOrExpr = 42i32.into();
639        assert_eq!(n.compile(), "42");
640    }
641
642    #[test]
643    fn number_or_expr_from_u32() {
644        let n: NumberOrExpr = 42u32.into();
645        assert_eq!(n.compile(), "42");
646    }
647
648    #[test]
649    fn number_or_expr_from_expr() {
650        let n: NumberOrExpr = Expr::var(Var::PTS).into();
651        assert_eq!(n.compile(), "PTS");
652    }
653
654    #[test]
655    fn number_or_expr_float() {
656        let n: NumberOrExpr = 3.14.into();
657        assert_eq!(n.compile(), "3.14");
658    }
659
660    #[test]
661    fn expr_expr_add() {
662        let e = Expr::var(Var::W) + Expr::var(Var::H);
663        assert_eq!(e.compile(), "(W+H)");
664    }
665
666    #[test]
667    fn var_display() {
668        assert_eq!(format!("{}", Var::PTS), "PTS");
669        assert_eq!(format!("{}", Var::SW), "SW");
670    }
671
672    #[test]
673    fn negative_literal() {
674        assert_eq!(Expr::lit(-1.0).compile(), "-1");
675    }
676
677    #[test]
678    fn zero_literal() {
679        assert_eq!(Expr::lit(0.0).compile(), "0");
680    }
681
682    #[test]
683    fn from_i32_into_expr() {
684        let e: Expr = 42i32.into();
685        assert_eq!(e.compile(), "42");
686    }
687
688    #[test]
689    fn from_u32_into_expr() {
690        let e: Expr = 10u32.into();
691        assert_eq!(e.compile(), "10");
692    }
693
694    #[test]
695    fn unary_chain() {
696        let e = Expr::var(Var::T).sin().abs();
697        assert_eq!(e.compile(), "abs(sin(T))");
698    }
699}