rapx/analysis/core/range_analysis/domain/
SymbolicExpr.rs

1#![allow(unused_imports)]
2#![allow(unused_variables)]
3#![allow(dead_code)]
4#![allow(unused_assignments)]
5#![allow(unused_parens)]
6#![allow(non_snake_case)]
7use rust_intervals::NothingBetween;
8
9use crate::analysis::core::range_analysis::domain::ConstraintGraph::ConstraintGraph;
10use crate::analysis::core::range_analysis::domain::domain::{
11    ConstConvert, IntervalArithmetic, VarNode, VarNodes,
12};
13use crate::analysis::core::range_analysis::{Range, RangeType};
14use crate::{rap_debug, rap_trace};
15use num_traits::{Bounded, CheckedAdd, CheckedSub, One, ToPrimitive, Zero, ops};
16use rustc_abi::Size;
17use rustc_data_structures::fx::FxHashMap;
18use rustc_hir::def_id::DefId;
19use rustc_middle::mir::coverage::Op;
20use rustc_middle::mir::{
21    BasicBlock, BinOp, BorrowKind, CastKind, Const, Local, LocalDecl, Operand, Place, Rvalue,
22    Statement, StatementKind, Terminator, UnOp,
23};
24use rustc_middle::ty::{ScalarInt, Ty};
25use rustc_span::sym::no_default_passes;
26use std::cell::RefCell;
27use std::cmp::PartialEq;
28use std::collections::{HashMap, HashSet};
29use std::fmt::Debug;
30use std::hash::Hash;
31use std::ops::{Add, Mul, Sub};
32use std::rc::Rc;
33use std::{fmt, mem};
34#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum BoundMode {
36    Lower,
37    Upper,
38}
39
40impl BoundMode {
41    fn flip(self) -> Self {
42        match self {
43            BoundMode::Lower => BoundMode::Upper,
44            BoundMode::Upper => BoundMode::Lower,
45        }
46    }
47}
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum SymbExpr<'tcx> {
50    Constant(Const<'tcx>),
51
52    Place(&'tcx Place<'tcx>),
53
54    Binary(BinOp, Box<SymbExpr<'tcx>>, Box<SymbExpr<'tcx>>),
55
56    Unary(UnOp, Box<SymbExpr<'tcx>>),
57
58    Cast(CastKind, Box<SymbExpr<'tcx>>, Ty<'tcx>),
59
60    Unknown,
61}
62impl<'tcx> fmt::Display for SymbExpr<'tcx> {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        write!(f, "{:?}", self)
65    }
66}
67impl<'tcx> SymbExpr<'tcx> {
68    pub fn from_operand(op: &'tcx Operand<'tcx>, place_ctx: &Vec<&'tcx Place<'tcx>>) -> Self {
69        match op {
70            Operand::Copy(place) | Operand::Move(place) => {
71                let found_base = place_ctx
72                    .iter()
73                    .find(|&&p| p.local == place.local && p.projection.is_empty());
74
75                match found_base {
76                    Some(&base_place) => SymbExpr::Place(base_place),
77
78                    None => SymbExpr::Place(place),
79                }
80            }
81            Operand::Constant(c) => SymbExpr::Constant(c.const_),
82        }
83    }
84
85    pub fn from_rvalue(rvalue: &'tcx Rvalue<'tcx>, place_ctx: Vec<&'tcx Place<'tcx>>) -> Self {
86        match rvalue {
87            Rvalue::Use(op) => Self::from_operand(op, &place_ctx),
88            Rvalue::BinaryOp(bin_op, box (lhs, rhs)) => {
89                let left = Self::from_operand(lhs, &place_ctx);
90                let right = Self::from_operand(rhs, &place_ctx);
91
92                if matches!(left, SymbExpr::Unknown) || matches!(right, SymbExpr::Unknown) {
93                    return SymbExpr::Unknown;
94                }
95
96                SymbExpr::Binary(*bin_op, Box::new(left), Box::new(right))
97            }
98            Rvalue::UnaryOp(un_op, op) => {
99                let expr = Self::from_operand(op, &place_ctx);
100                if matches!(expr, SymbExpr::Unknown) {
101                    return SymbExpr::Unknown;
102                }
103                SymbExpr::Unary(*un_op, Box::new(expr))
104            }
105            Rvalue::Cast(kind, op, ty) => {
106                let expr = Self::from_operand(op, &place_ctx);
107                if matches!(expr, SymbExpr::Unknown) {
108                    return SymbExpr::Unknown;
109                }
110                SymbExpr::Cast(*kind, Box::new(expr), *ty)
111            }
112            Rvalue::Ref(..)
113            | Rvalue::ThreadLocalRef(..)
114            | Rvalue::Aggregate(..)
115            | Rvalue::Repeat(..)
116            | Rvalue::ShallowInitBox(..)
117            | Rvalue::NullaryOp(..)
118            | Rvalue::Discriminant(..)
119            | Rvalue::CopyForDeref(..) => SymbExpr::Unknown,
120            Rvalue::RawPtr(raw_ptr_kind, place) => todo!(),
121            Rvalue::WrapUnsafeBinder(operand, ty) => todo!(),
122        }
123    }
124
125    // pub fn eval<T: IntervalArithmetic + ConstConvert + Debug>(
126    //     &self,
127    //     vars: &VarNodes<'tcx, T>,
128    // ) -> Range<T> {
129    //     match self {
130    //         SymbExpr::Unknown => Range::new(T::min_value(), T::max_value(), RangeType::Regular),
131
132    //         SymbExpr::Constant(c) => {
133    //             if let Some(val) = T::from_const(c) {
134    //                 Range::new(val, val, RangeType::Regular)
135    //             } else {
136    //                 Range::new(T::min_value(), T::max_value(), RangeType::Regular)
137    //             }
138    //         }
139
140    //         SymbExpr::Place(place) => {
141    //             if let Some(node) = vars.get(place) {
142    //                 node.get_range().clone()
143    //             } else {
144    //                 Range::new(T::min_value(), T::max_value(), RangeType::Regular)
145    //             }
146    //         }
147
148    //         SymbExpr::Binary(op, lhs, rhs) => {
149    //             let l_range = lhs.eval(vars);
150    //             let r_range = rhs.eval(vars);
151
152    //             match op {
153    //                 BinOp::Add | BinOp::AddUnchecked | BinOp::AddWithOverflow => {
154    //                     l_range.add(&r_range)
155    //                 }
156    //                 BinOp::Sub | BinOp::SubUnchecked | BinOp::SubWithOverflow => {
157    //                     l_range.sub(&r_range)
158    //                 }
159    //                 BinOp::Mul | BinOp::MulUnchecked | BinOp::MulWithOverflow => {
160    //                     l_range.mul(&r_range)
161    //                 }
162
163    //                 _ => Range::new(T::min_value(), T::max_value(), RangeType::Regular),
164    //             }
165    //         }
166
167    //         SymbExpr::Unary(op, inner) => {
168    //             let _inner_range = inner.eval(vars);
169    //             match op {
170    //                 UnOp::Neg => Range::new(T::min_value(), T::max_value(), RangeType::Regular),
171    //                 UnOp::Not | UnOp::PtrMetadata => {
172    //                     Range::new(T::min_value(), T::max_value(), RangeType::Regular)
173    //                 }
174    //             }
175    //         }
176
177    //         SymbExpr::Cast(kind, inner, _target_ty) => {
178    //             let inner_range = inner.eval(vars);
179    //             match kind {
180    //                 CastKind::IntToInt => inner_range,
181
182    //                 _ => Range::new(T::min_value(), T::max_value(), RangeType::Regular),
183    //             }
184    //         }
185    //     }
186    // }
187    pub fn resolve_upper_bound<T: IntervalArithmetic + ConstConvert + Debug + Clone + PartialEq>(
188        &mut self,
189        vars: &VarNodes<'tcx, T>,
190    ) {
191        self.resolve_recursive(vars, 0, BoundMode::Upper);
192    }
193    pub fn resolve_lower_bound<T: IntervalArithmetic + ConstConvert + Debug + Clone + PartialEq>(
194        &mut self,
195        vars: &VarNodes<'tcx, T>,
196    ) {
197        self.resolve_recursive(vars, 0, BoundMode::Lower);
198    }
199
200    fn resolve_recursive<T: IntervalArithmetic + ConstConvert + Debug + Clone + PartialEq>(
201        &mut self,
202        vars: &VarNodes<'tcx, T>,
203        depth: usize,
204        mode: BoundMode,
205    ) {
206        const MAX_DEPTH: usize = 10;
207        if depth > MAX_DEPTH {
208            *self = SymbExpr::Unknown;
209            return;
210        }
211
212        match self {
213            SymbExpr::Binary(op, lhs, rhs) => {
214                lhs.resolve_recursive(vars, depth + 1, mode);
215
216                match op {
217                    BinOp::Add | BinOp::AddUnchecked | BinOp::AddWithOverflow => {
218                        rhs.resolve_recursive(vars, depth + 1, mode);
219                    }
220                    BinOp::Sub | BinOp::SubUnchecked | BinOp::SubWithOverflow => {
221                        rhs.resolve_recursive(vars, depth + 1, mode.flip());
222                    }
223                    _ => rhs.resolve_recursive(vars, depth + 1, mode),
224                }
225            }
226            SymbExpr::Unary(op, inner) => match op {
227                UnOp::Neg => {
228                    inner.resolve_recursive(vars, depth + 1, mode.flip());
229                }
230                _ => inner.resolve_recursive(vars, depth + 1, mode),
231            },
232            SymbExpr::Cast(_, inner, _) => {
233                inner.resolve_recursive(vars, depth + 1, mode);
234            }
235            _ => {}
236        }
237
238        // self.try_fold_constants::<T>();
239        rap_trace!("symexpr {}", self);
240        if let SymbExpr::Place(place) = self {
241            if let Some(node) = vars.get(place) {
242                if let IntervalType::Basic(basic) = &node.interval {
243                    rap_trace!("node {:?}", *node);
244
245                    let target_expr = if basic.lower == basic.upper {
246                        &basic.upper
247                    } else {
248                        match mode {
249                            BoundMode::Upper => &basic.upper,
250                            BoundMode::Lower => &basic.lower,
251                        }
252                    };
253
254                    match target_expr {
255                        SymbExpr::Unknown => *self = SymbExpr::Unknown,
256                        SymbExpr::Constant(c) => *self = SymbExpr::Constant(c.clone()),
257                        expr => {
258                            if let SymbExpr::Place(target_place) = expr {
259                                if target_place == place {
260                                    return;
261                                }
262                            }
263
264                            *self = expr.clone();
265                            self.resolve_recursive(vars, depth + 1, mode);
266                        }
267                    }
268                }
269            }
270        }
271    }
272    pub fn simplify(&mut self) {
273        match self {
274            SymbExpr::Binary(_, lhs, rhs) => {
275                lhs.simplify();
276                rhs.simplify();
277            }
278            SymbExpr::Unary(_, inner) => {
279                inner.simplify();
280            }
281            SymbExpr::Cast(_, inner, _) => {
282                inner.simplify();
283            }
284            _ => {}
285        }
286
287        if let SymbExpr::Binary(op, lhs, rhs) = self {
288            match op {
289                BinOp::Sub | BinOp::SubUnchecked | BinOp::SubWithOverflow => {
290                    if let SymbExpr::Binary(inner_op, inner_lhs, inner_rhs) = lhs.as_ref() {
291                        match inner_op {
292                            BinOp::Add | BinOp::AddUnchecked | BinOp::AddWithOverflow => {
293                                if inner_lhs == rhs {
294                                    *self = *inner_rhs.clone();
295                                } else if inner_rhs == rhs {
296                                    *self = *inner_lhs.clone();
297                                }
298                            }
299                            _ => {}
300                        }
301                    }
302                }
303                BinOp::Add | BinOp::AddUnchecked | BinOp::AddWithOverflow => {
304                    if let SymbExpr::Binary(inner_op, inner_lhs, inner_rhs) = lhs.as_ref() {
305                        match inner_op {
306                            BinOp::Sub | BinOp::SubUnchecked | BinOp::SubWithOverflow => {
307                                if inner_rhs == rhs {
308                                    *self = *inner_lhs.clone();
309                                }
310                            }
311                            _ => {}
312                        }
313                    }
314                }
315                _ => {}
316            }
317        }
318    }
319}
320#[derive(Debug, Clone)]
321pub enum IntervalType<'tcx, T: IntervalArithmetic + ConstConvert + Debug> {
322    Basic(BasicInterval<'tcx, T>),
323    Symb(SymbInterval<'tcx, T>),
324}
325
326impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> fmt::Display for IntervalType<'tcx, T> {
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        match self {
329            IntervalType::Basic(b) => write!(
330                f,
331                "BasicInterval: {:?} {:?} {:?} ",
332                b.get_range(),
333                b.lower,
334                b.upper
335            ),
336            IntervalType::Symb(b) => write!(
337                f,
338                "SymbInterval: {:?} {:?} {:?} ",
339                b.get_range(),
340                b.lower,
341                b.upper
342            ),
343        }
344    }
345}
346pub trait IntervalTypeTrait<'tcx, T: IntervalArithmetic + ConstConvert + Debug> {
347    fn get_range(&self) -> &Range<T>;
348    fn set_range(&mut self, new_range: Range<T>);
349    fn get_lower_expr(&self) -> &SymbExpr<'tcx>;
350    fn get_upper_expr(&self) -> &SymbExpr<'tcx>;
351}
352impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> IntervalTypeTrait<'tcx, T>
353    for IntervalType<'tcx, T>
354{
355    fn get_range(&self) -> &Range<T> {
356        match self {
357            IntervalType::Basic(b) => b.get_range(),
358            IntervalType::Symb(s) => s.get_range(),
359        }
360    }
361
362    fn set_range(&mut self, new_range: Range<T>) {
363        match self {
364            IntervalType::Basic(b) => b.set_range(new_range),
365            IntervalType::Symb(s) => s.set_range(new_range),
366        }
367    }
368    fn get_lower_expr(&self) -> &SymbExpr<'tcx> {
369        match self {
370            IntervalType::Basic(b) => b.get_lower_expr(),
371            IntervalType::Symb(s) => s.get_lower_expr(),
372        }
373    }
374
375    fn get_upper_expr(&self) -> &SymbExpr<'tcx> {
376        match self {
377            IntervalType::Basic(b) => b.get_upper_expr(),
378            IntervalType::Symb(s) => s.get_upper_expr(),
379        }
380    }
381}
382#[derive(Debug, Clone)]
383
384pub struct BasicInterval<'tcx, T: IntervalArithmetic + ConstConvert + Debug> {
385    pub range: Range<T>,
386    pub lower: SymbExpr<'tcx>,
387    pub upper: SymbExpr<'tcx>,
388}
389
390impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> BasicInterval<'tcx, T> {
391    pub fn new(range: Range<T>) -> Self {
392        Self {
393            range,
394            lower: SymbExpr::Unknown,
395            upper: SymbExpr::Unknown,
396        }
397    }
398    pub fn new_symb(range: Range<T>, lower: SymbExpr<'tcx>, upper: SymbExpr<'tcx>) -> Self {
399        Self {
400            range,
401            lower,
402            upper,
403        }
404    }
405    pub fn default() -> Self {
406        Self {
407            range: Range::default(T::min_value()),
408            lower: SymbExpr::Unknown,
409            upper: SymbExpr::Unknown,
410        }
411    }
412}
413
414impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> IntervalTypeTrait<'tcx, T>
415    for BasicInterval<'tcx, T>
416{
417    // fn get_value_id(&self) -> IntervalId {
418    //     IntervalId::BasicIntervalId
419    // }
420
421    fn get_range(&self) -> &Range<T> {
422        &self.range
423    }
424
425    fn set_range(&mut self, new_range: Range<T>) {
426        self.range = new_range;
427        if self.range.get_lower() > self.range.get_upper() {
428            self.range.set_empty();
429        }
430    }
431    fn get_lower_expr(&self) -> &SymbExpr<'tcx> {
432        &self.lower
433    }
434
435    fn get_upper_expr(&self) -> &SymbExpr<'tcx> {
436        &self.upper
437    }
438}
439
440#[derive(Debug, Clone)]
441
442pub struct SymbInterval<'tcx, T: IntervalArithmetic + ConstConvert + Debug> {
443    range: Range<T>,
444    symbound: &'tcx Place<'tcx>,
445    predicate: BinOp,
446    lower: SymbExpr<'tcx>,
447    upper: SymbExpr<'tcx>,
448}
449
450impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> SymbInterval<'tcx, T> {
451    pub fn new(range: Range<T>, symbound: &'tcx Place<'tcx>, predicate: BinOp) -> Self {
452        Self {
453            range,
454            symbound,
455            predicate,
456            lower: SymbExpr::Unknown,
457            upper: SymbExpr::Unknown,
458        }
459    }
460
461    // pub fn refine(&mut self, vars: &VarNodes<'tcx, T>) {
462    //     if let SymbExpr::Unknown = self.lower {
463    //     } else {
464    //         let low_range = self.lower.eval(vars);
465    //         if low_range.get_lower() > self.range.get_lower() {
466    //             let new_range = Range::new(
467    //                 low_range.get_lower(),
468    //                 self.range.get_upper(),
469    //                 RangeType::Regular,
470    //             );
471    //             self.range = new_range;
472    //         }
473    //     }
474
475    //     if let SymbExpr::Unknown = self.upper {
476    //         // Do nothing
477    //     } else {
478    //         let high_range = self.upper.eval(vars);
479    //         if high_range.get_upper() < self.range.get_upper() {
480    //             let new_range = Range::new(
481    //                 self.range.get_lower(),
482    //                 high_range.get_upper(),
483    //                 RangeType::Regular,
484    //             );
485    //             self.range = new_range;
486    //         }
487    //     }
488    // }
489
490    pub fn get_operation(&self) -> BinOp {
491        self.predicate
492    }
493
494    pub fn get_bound(&self) -> &'tcx Place<'tcx> {
495        self.symbound
496    }
497
498    pub fn sym_fix_intersects(
499        &self,
500        bound: &VarNode<'tcx, T>,
501        sink: &VarNode<'tcx, T>,
502    ) -> Range<T> {
503        let l = bound.get_range().get_lower().clone();
504        let u = bound.get_range().get_upper().clone();
505
506        let lower = sink.get_range().get_lower().clone();
507        let upper = sink.get_range().get_upper().clone();
508
509        match self.predicate {
510            BinOp::Eq => Range::new(l, u, RangeType::Regular),
511
512            BinOp::Le => Range::new(lower, u, RangeType::Regular),
513
514            BinOp::Lt => {
515                if u != T::max_value() {
516                    let u_minus_1 = u.checked_sub(&T::one()).unwrap_or(u);
517                    Range::new(lower, u_minus_1, RangeType::Regular)
518                } else {
519                    Range::new(lower, u, RangeType::Regular)
520                }
521            }
522
523            BinOp::Ge => Range::new(l, upper, RangeType::Regular),
524
525            BinOp::Gt => {
526                if l != T::min_value() {
527                    let l_plus_1 = l.checked_add(&T::one()).unwrap_or(l);
528                    Range::new(l_plus_1, upper, RangeType::Regular)
529                } else {
530                    Range::new(l, upper, RangeType::Regular)
531                }
532            }
533
534            BinOp::Ne => Range::new(T::min_value(), T::max_value(), RangeType::Regular),
535
536            _ => Range::new(T::min_value(), T::max_value(), RangeType::Regular),
537        }
538    }
539}
540
541impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> IntervalTypeTrait<'tcx, T>
542    for SymbInterval<'tcx, T>
543{
544    // fn get_value_id(&self) -> IntervalId {
545    //     IntervalId::SymbIntervalId
546    // }
547
548    fn get_range(&self) -> &Range<T> {
549        &self.range
550    }
551
552    fn set_range(&mut self, new_range: Range<T>) {
553        self.range = new_range;
554    }
555    fn get_lower_expr(&self) -> &SymbExpr<'tcx> {
556        &self.lower
557    }
558
559    fn get_upper_expr(&self) -> &SymbExpr<'tcx> {
560        &self.upper
561    }
562}