1use std::fmt;
2use std::ops;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum Var {
6 PTS,
7 N,
8 T,
9 W,
10 H,
11 SW,
12 SH,
13 A,
14}
15
16impl fmt::Display for Var {
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 let s = match self {
19 Var::PTS => "PTS",
20 Var::N => "N",
21 Var::T => "T",
22 Var::W => "W",
23 Var::H => "H",
24 Var::SW => "SW",
25 Var::SH => "SH",
26 Var::A => "A",
27 };
28 f.write_str(s)
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum Op {
34 Add,
35 Sub,
36 Mul,
37 Div,
38 Gt,
39 Gte,
40 Lt,
41 Lte,
42 Eq,
43}
44
45impl fmt::Display for Op {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 let s = match self {
48 Op::Add => "+",
49 Op::Sub => "-",
50 Op::Mul => "*",
51 Op::Div => "/",
52 Op::Gt => "gt",
53 Op::Gte => "gte",
54 Op::Lt => "lt",
55 Op::Lte => "lte",
56 Op::Eq => "eq",
57 };
58 f.write_str(s)
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum UnaryOp {
64 Sin,
65 Cos,
66 Tan,
67 Abs,
68 Sqrt,
69 Ceil,
70 Floor,
71 Round,
72 Log,
73 Exp,
74 Not,
75}
76
77impl fmt::Display for UnaryOp {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 let s = match self {
80 UnaryOp::Sin => "sin",
81 UnaryOp::Cos => "cos",
82 UnaryOp::Tan => "tan",
83 UnaryOp::Abs => "abs",
84 UnaryOp::Sqrt => "sqrt",
85 UnaryOp::Ceil => "ceil",
86 UnaryOp::Floor => "floor",
87 UnaryOp::Round => "round",
88 UnaryOp::Log => "log",
89 UnaryOp::Exp => "exp",
90 UnaryOp::Not => "not",
91 };
92 f.write_str(s)
93 }
94}
95
96#[derive(Debug, Clone)]
97pub enum Expr {
98 Lit(f64),
99 Var(Var),
100 BinOp {
101 op: Op,
102 left: Box<Expr>,
103 right: Box<Expr>,
104 },
105 UnaryOp {
106 op: UnaryOp,
107 arg: Box<Expr>,
108 },
109 If {
110 cond: Box<Expr>,
111 then: Box<Expr>,
112 else_: Box<Expr>,
113 },
114 Fn {
115 name: String,
116 args: Vec<Expr>,
117 },
118}
119
120fn format_number(n: f64) -> String {
121 if n.fract() == 0.0 && n.abs() < 1e15 {
122 format!("{}", n as i64)
123 } else {
124 format!("{}", n)
125 }
126}
127
128impl Expr {
129 pub fn lit(n: f64) -> Expr {
130 Expr::Lit(n)
131 }
132
133 pub fn var(v: Var) -> Expr {
134 Expr::Var(v)
135 }
136
137 pub fn gt(self, other: impl Into<Expr>) -> Expr {
138 Expr::BinOp {
139 op: Op::Gt,
140 left: Box::new(self),
141 right: Box::new(other.into()),
142 }
143 }
144
145 pub fn gte(self, other: impl Into<Expr>) -> Expr {
146 Expr::BinOp {
147 op: Op::Gte,
148 left: Box::new(self),
149 right: Box::new(other.into()),
150 }
151 }
152
153 pub fn lt(self, other: impl Into<Expr>) -> Expr {
154 Expr::BinOp {
155 op: Op::Lt,
156 left: Box::new(self),
157 right: Box::new(other.into()),
158 }
159 }
160
161 pub fn lte(self, other: impl Into<Expr>) -> Expr {
162 Expr::BinOp {
163 op: Op::Lte,
164 left: Box::new(self),
165 right: Box::new(other.into()),
166 }
167 }
168
169 pub fn eq(self, other: impl Into<Expr>) -> Expr {
170 Expr::BinOp {
171 op: Op::Eq,
172 left: Box::new(self),
173 right: Box::new(other.into()),
174 }
175 }
176
177 pub fn if_then_else(self, then: impl Into<Expr>, else_: impl Into<Expr>) -> Expr {
178 Expr::If {
179 cond: Box::new(self),
180 then: Box::new(then.into()),
181 else_: Box::new(else_.into()),
182 }
183 }
184
185 pub fn sin(self) -> Expr {
186 Expr::UnaryOp {
187 op: UnaryOp::Sin,
188 arg: Box::new(self),
189 }
190 }
191
192 pub fn cos(self) -> Expr {
193 Expr::UnaryOp {
194 op: UnaryOp::Cos,
195 arg: Box::new(self),
196 }
197 }
198
199 pub fn tan(self) -> Expr {
200 Expr::UnaryOp {
201 op: UnaryOp::Tan,
202 arg: Box::new(self),
203 }
204 }
205
206 pub fn abs(self) -> Expr {
207 Expr::UnaryOp {
208 op: UnaryOp::Abs,
209 arg: Box::new(self),
210 }
211 }
212
213 pub fn sqrt(self) -> Expr {
214 Expr::UnaryOp {
215 op: UnaryOp::Sqrt,
216 arg: Box::new(self),
217 }
218 }
219
220 pub fn ceil(self) -> Expr {
221 Expr::UnaryOp {
222 op: UnaryOp::Ceil,
223 arg: Box::new(self),
224 }
225 }
226
227 pub fn floor(self) -> Expr {
228 Expr::UnaryOp {
229 op: UnaryOp::Floor,
230 arg: Box::new(self),
231 }
232 }
233
234 pub fn round(self) -> Expr {
235 Expr::UnaryOp {
236 op: UnaryOp::Round,
237 arg: Box::new(self),
238 }
239 }
240
241 pub fn log(self) -> Expr {
242 Expr::UnaryOp {
243 op: UnaryOp::Log,
244 arg: Box::new(self),
245 }
246 }
247
248 pub fn exp(self) -> Expr {
249 Expr::UnaryOp {
250 op: UnaryOp::Exp,
251 arg: Box::new(self),
252 }
253 }
254
255 pub fn not(self) -> Expr {
256 Expr::UnaryOp {
257 op: UnaryOp::Not,
258 arg: Box::new(self),
259 }
260 }
261
262 pub fn call(name: impl Into<String>, args: Vec<Expr>) -> Expr {
263 Expr::Fn {
264 name: name.into(),
265 args,
266 }
267 }
268
269 pub fn compile(&self) -> String {
270 match self {
271 Expr::Lit(n) => format_number(*n),
272 Expr::Var(v) => v.to_string(),
273 Expr::BinOp { op, left, right } => match op {
274 Op::Gt | Op::Gte | Op::Lt | Op::Lte | Op::Eq => {
275 format!("{}({},{})", op, left.compile(), right.compile())
276 }
277 _ => {
278 format!("({}{}{})", left.compile(), op, right.compile())
279 }
280 },
281 Expr::UnaryOp { op, arg } => {
282 format!("{}({})", op, arg.compile())
283 }
284 Expr::If { cond, then, else_ } => {
285 format!("if({},{},{})", cond.compile(), then.compile(), else_.compile())
286 }
287 Expr::Fn { name, args } => {
288 let args_str: Vec<String> = args.iter().map(|a| a.compile()).collect();
289 format!("{}({})", name, args_str.join(","))
290 }
291 }
292 }
293}
294
295impl From<f64> for Expr {
296 fn from(n: f64) -> Self {
297 Expr::Lit(n)
298 }
299}
300
301impl From<i32> for Expr {
302 fn from(n: i32) -> Self {
303 Expr::Lit(n as f64)
304 }
305}
306
307impl From<u32> for Expr {
308 fn from(n: u32) -> Self {
309 Expr::Lit(n as f64)
310 }
311}
312
313impl ops::Add for Expr {
314 type Output = Expr;
315 fn add(self, rhs: Expr) -> Expr {
316 Expr::BinOp {
317 op: Op::Add,
318 left: Box::new(self),
319 right: Box::new(rhs),
320 }
321 }
322}
323
324impl ops::Add<f64> for Expr {
325 type Output = Expr;
326 fn add(self, rhs: f64) -> Expr {
327 self + Expr::Lit(rhs)
328 }
329}
330
331impl ops::Add<Expr> for f64 {
332 type Output = Expr;
333 fn add(self, rhs: Expr) -> Expr {
334 Expr::Lit(self) + rhs
335 }
336}
337
338impl ops::Sub for Expr {
339 type Output = Expr;
340 fn sub(self, rhs: Expr) -> Expr {
341 Expr::BinOp {
342 op: Op::Sub,
343 left: Box::new(self),
344 right: Box::new(rhs),
345 }
346 }
347}
348
349impl ops::Sub<f64> for Expr {
350 type Output = Expr;
351 fn sub(self, rhs: f64) -> Expr {
352 self - Expr::Lit(rhs)
353 }
354}
355
356impl ops::Sub<Expr> for f64 {
357 type Output = Expr;
358 fn sub(self, rhs: Expr) -> Expr {
359 Expr::Lit(self) - rhs
360 }
361}
362
363impl ops::Mul for Expr {
364 type Output = Expr;
365 fn mul(self, rhs: Expr) -> Expr {
366 Expr::BinOp {
367 op: Op::Mul,
368 left: Box::new(self),
369 right: Box::new(rhs),
370 }
371 }
372}
373
374impl ops::Mul<f64> for Expr {
375 type Output = Expr;
376 fn mul(self, rhs: f64) -> Expr {
377 self * Expr::Lit(rhs)
378 }
379}
380
381impl ops::Mul<Expr> for f64 {
382 type Output = Expr;
383 fn mul(self, rhs: Expr) -> Expr {
384 Expr::Lit(self) * rhs
385 }
386}
387
388impl ops::Div for Expr {
389 type Output = Expr;
390 fn div(self, rhs: Expr) -> Expr {
391 Expr::BinOp {
392 op: Op::Div,
393 left: Box::new(self),
394 right: Box::new(rhs),
395 }
396 }
397}
398
399impl ops::Div<f64> for Expr {
400 type Output = Expr;
401 fn div(self, rhs: f64) -> Expr {
402 self / Expr::Lit(rhs)
403 }
404}
405
406impl ops::Div<Expr> for f64 {
407 type Output = Expr;
408 fn div(self, rhs: Expr) -> Expr {
409 Expr::Lit(self) / rhs
410 }
411}
412
413#[derive(Debug, Clone)]
414pub enum NumberOrExpr {
415 Number(f64),
416 Expr(Expr),
417}
418
419impl NumberOrExpr {
420 pub fn compile(&self) -> String {
421 match self {
422 NumberOrExpr::Number(n) => format_number(*n),
423 NumberOrExpr::Expr(e) => e.compile(),
424 }
425 }
426}
427
428impl From<f64> for NumberOrExpr {
429 fn from(n: f64) -> Self {
430 NumberOrExpr::Number(n)
431 }
432}
433
434impl From<i32> for NumberOrExpr {
435 fn from(n: i32) -> Self {
436 NumberOrExpr::Number(n as f64)
437 }
438}
439
440impl From<u32> for NumberOrExpr {
441 fn from(n: u32) -> Self {
442 NumberOrExpr::Number(n as f64)
443 }
444}
445
446impl From<Expr> for NumberOrExpr {
447 fn from(e: Expr) -> Self {
448 NumberOrExpr::Expr(e)
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[test]
457 fn literal_integer() {
458 assert_eq!(Expr::lit(42.0).compile(), "42");
459 }
460
461 #[test]
462 fn literal_float() {
463 assert_eq!(Expr::lit(1.5).compile(), "1.5");
464 }
465
466 #[test]
467 fn variable() {
468 assert_eq!(Expr::var(Var::PTS).compile(), "PTS");
469 assert_eq!(Expr::var(Var::T).compile(), "T");
470 assert_eq!(Expr::var(Var::N).compile(), "N");
471 assert_eq!(Expr::var(Var::W).compile(), "W");
472 assert_eq!(Expr::var(Var::H).compile(), "H");
473 assert_eq!(Expr::var(Var::SW).compile(), "SW");
474 assert_eq!(Expr::var(Var::SH).compile(), "SH");
475 assert_eq!(Expr::var(Var::A).compile(), "A");
476 }
477
478 #[test]
479 fn binop_mul() {
480 let e = Expr::var(Var::PTS) * 2.0;
481 assert_eq!(e.compile(), "(PTS*2)");
482 }
483
484 #[test]
485 fn binop_add() {
486 let e = Expr::var(Var::T) + 1.0;
487 assert_eq!(e.compile(), "(T+1)");
488 }
489
490 #[test]
491 fn binop_sub() {
492 let e = Expr::var(Var::N) - 10.0;
493 assert_eq!(e.compile(), "(N-10)");
494 }
495
496 #[test]
497 fn binop_div() {
498 let e = Expr::var(Var::W) / 2.0;
499 assert_eq!(e.compile(), "(W/2)");
500 }
501
502 #[test]
503 fn nested_expression() {
504 let e = Expr::var(Var::PTS) * 2.0 + 1.0;
505 assert_eq!(e.compile(), "((PTS*2)+1)");
506 }
507
508 #[test]
509 fn reverse_add() {
510 let e = 1.0 + Expr::var(Var::T);
511 assert_eq!(e.compile(), "(1+T)");
512 }
513
514 #[test]
515 fn reverse_mul() {
516 let e = 2.0 * Expr::var(Var::T);
517 assert_eq!(e.compile(), "(2*T)");
518 }
519
520 #[test]
521 fn reverse_sub() {
522 let e = 10.0 - Expr::var(Var::N);
523 assert_eq!(e.compile(), "(10-N)");
524 }
525
526 #[test]
527 fn reverse_div() {
528 let e = 1.0 / Expr::var(Var::T);
529 assert_eq!(e.compile(), "(1/T)");
530 }
531
532 #[test]
533 fn comparison_gt() {
534 let e = Expr::var(Var::N).gt(100.0);
535 assert_eq!(e.compile(), "gt(N,100)");
536 }
537
538 #[test]
539 fn comparison_gte() {
540 let e = Expr::var(Var::T).gte(5.0);
541 assert_eq!(e.compile(), "gte(T,5)");
542 }
543
544 #[test]
545 fn comparison_lt() {
546 let e = Expr::var(Var::W).lt(1920.0);
547 assert_eq!(e.compile(), "lt(W,1920)");
548 }
549
550 #[test]
551 fn comparison_lte() {
552 let e = Expr::var(Var::H).lte(1080.0);
553 assert_eq!(e.compile(), "lte(H,1080)");
554 }
555
556 #[test]
557 fn comparison_eq() {
558 let e = Expr::var(Var::N).eq(0.0);
559 assert_eq!(e.compile(), "eq(N,0)");
560 }
561
562 #[test]
563 fn if_then_else() {
564 let cond = Expr::var(Var::N).gt(100.0);
565 let e = cond.if_then_else(1.0, 0.0);
566 assert_eq!(e.compile(), "if(gt(N,100),1,0)");
567 }
568
569 #[test]
570 fn nested_if() {
571 let e = Expr::var(Var::T)
572 .gt(5.0)
573 .if_then_else(
574 Expr::var(Var::PTS) * 2.0,
575 Expr::var(Var::PTS),
576 );
577 assert_eq!(e.compile(), "if(gt(T,5),(PTS*2),PTS)");
578 }
579
580 #[test]
581 fn unary_sin() {
582 let e = Expr::var(Var::T).sin();
583 assert_eq!(e.compile(), "sin(T)");
584 }
585
586 #[test]
587 fn unary_cos() {
588 let e = (Expr::var(Var::T) * 2.0).cos();
589 assert_eq!(e.compile(), "cos((T*2))");
590 }
591
592 #[test]
593 fn unary_abs() {
594 let e = (Expr::var(Var::N) - 50.0).abs();
595 assert_eq!(e.compile(), "abs((N-50))");
596 }
597
598 #[test]
599 fn fn_between() {
600 let e = Expr::call("between", vec![
601 Expr::var(Var::T),
602 Expr::lit(5.0),
603 Expr::lit(10.0),
604 ]);
605 assert_eq!(e.compile(), "between(T,5,10)");
606 }
607
608 #[test]
609 fn fn_clip() {
610 let e = Expr::call("clip", vec![
611 Expr::var(Var::N),
612 Expr::lit(0.0),
613 Expr::lit(255.0),
614 ]);
615 assert_eq!(e.compile(), "clip(N,0,255)");
616 }
617
618 #[test]
619 fn complex_expression() {
620 let e = Expr::var(Var::T)
622 .gt(5.0)
623 .if_then_else(
624 (Expr::var(Var::T) * 3.14).sin() * Expr::var(Var::W) / 2.0,
625 0.0,
626 );
627 assert_eq!(e.compile(), "if(gt(T,5),((sin((T*3.14))*W)/2),0)");
628 }
629
630 #[test]
631 fn number_or_expr_from_f64() {
632 let n: NumberOrExpr = 42.0.into();
633 assert_eq!(n.compile(), "42");
634 }
635
636 #[test]
637 fn number_or_expr_from_i32() {
638 let n: NumberOrExpr = 42i32.into();
639 assert_eq!(n.compile(), "42");
640 }
641
642 #[test]
643 fn number_or_expr_from_u32() {
644 let n: NumberOrExpr = 42u32.into();
645 assert_eq!(n.compile(), "42");
646 }
647
648 #[test]
649 fn number_or_expr_from_expr() {
650 let n: NumberOrExpr = Expr::var(Var::PTS).into();
651 assert_eq!(n.compile(), "PTS");
652 }
653
654 #[test]
655 fn number_or_expr_float() {
656 let n: NumberOrExpr = 3.14.into();
657 assert_eq!(n.compile(), "3.14");
658 }
659
660 #[test]
661 fn expr_expr_add() {
662 let e = Expr::var(Var::W) + Expr::var(Var::H);
663 assert_eq!(e.compile(), "(W+H)");
664 }
665
666 #[test]
667 fn var_display() {
668 assert_eq!(format!("{}", Var::PTS), "PTS");
669 assert_eq!(format!("{}", Var::SW), "SW");
670 }
671
672 #[test]
673 fn negative_literal() {
674 assert_eq!(Expr::lit(-1.0).compile(), "-1");
675 }
676
677 #[test]
678 fn zero_literal() {
679 assert_eq!(Expr::lit(0.0).compile(), "0");
680 }
681
682 #[test]
683 fn from_i32_into_expr() {
684 let e: Expr = 42i32.into();
685 assert_eq!(e.compile(), "42");
686 }
687
688 #[test]
689 fn from_u32_into_expr() {
690 let e: Expr = 10u32.into();
691 assert_eq!(e.compile(), "10");
692 }
693
694 #[test]
695 fn unary_chain() {
696 let e = Expr::var(Var::T).sin().abs();
697 assert_eq!(e.compile(), "abs(sin(T))");
698 }
699}