1#![allow(unused_imports)]
2#![allow(unused_variables)]
3#![allow(dead_code)]
4#![allow(unused_assignments)]
5#![allow(unused_parens)]
6#![allow(non_snake_case)]
7use rust_intervals::NothingBetween;
8
9use crate::analysis::core::range_analysis::domain::ConstraintGraph::ConstraintGraph;
10use crate::analysis::core::range_analysis::domain::domain::{
11 ConstConvert, IntervalArithmetic, VarNode, VarNodes,
12};
13use crate::analysis::core::range_analysis::{Range, RangeType};
14use crate::{rap_debug, rap_trace};
15use num_traits::{Bounded, CheckedAdd, CheckedSub, One, ToPrimitive, Zero, ops};
16use rustc_abi::Size;
17use rustc_data_structures::fx::FxHashMap;
18use rustc_hir::def_id::DefId;
19use rustc_middle::mir::coverage::Op;
20use rustc_middle::mir::{
21 BasicBlock, BinOp, BorrowKind, CastKind, Const, Local, LocalDecl, Operand, Place, Rvalue,
22 Statement, StatementKind, Terminator, UnOp,
23};
24use rustc_middle::ty::{ScalarInt, Ty};
25use rustc_span::sym::no_default_passes;
26use std::cell::RefCell;
27use std::cmp::PartialEq;
28use std::collections::{HashMap, HashSet};
29use std::fmt::Debug;
30use std::hash::Hash;
31use std::ops::{Add, Mul, Sub};
32use std::rc::Rc;
33use std::{fmt, mem};
34#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum BoundMode {
36 Lower,
37 Upper,
38}
39
40impl BoundMode {
41 fn flip(self) -> Self {
42 match self {
43 BoundMode::Lower => BoundMode::Upper,
44 BoundMode::Upper => BoundMode::Lower,
45 }
46 }
47}
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum SymbExpr<'tcx> {
50 Constant(Const<'tcx>),
51
52 Place(&'tcx Place<'tcx>),
53
54 Binary(BinOp, Box<SymbExpr<'tcx>>, Box<SymbExpr<'tcx>>),
55
56 Unary(UnOp, Box<SymbExpr<'tcx>>),
57
58 Cast(CastKind, Box<SymbExpr<'tcx>>, Ty<'tcx>),
59
60 Unknown,
61}
62impl<'tcx> fmt::Display for SymbExpr<'tcx> {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 write!(f, "{:?}", self)
65 }
66}
67impl<'tcx> SymbExpr<'tcx> {
68 pub fn from_operand(op: &'tcx Operand<'tcx>, place_ctx: &Vec<&'tcx Place<'tcx>>) -> Self {
69 match op {
70 Operand::Copy(place) | Operand::Move(place) => {
71 let found_base = place_ctx
72 .iter()
73 .find(|&&p| p.local == place.local && p.projection.is_empty());
74
75 match found_base {
76 Some(&base_place) => SymbExpr::Place(base_place),
77
78 None => SymbExpr::Place(place),
79 }
80 }
81 Operand::Constant(c) => SymbExpr::Constant(c.const_),
82 }
83 }
84
85 pub fn from_rvalue(rvalue: &'tcx Rvalue<'tcx>, place_ctx: Vec<&'tcx Place<'tcx>>) -> Self {
86 match rvalue {
87 Rvalue::Use(op) => Self::from_operand(op, &place_ctx),
88 Rvalue::BinaryOp(bin_op, box (lhs, rhs)) => {
89 let left = Self::from_operand(lhs, &place_ctx);
90 let right = Self::from_operand(rhs, &place_ctx);
91
92 if matches!(left, SymbExpr::Unknown) || matches!(right, SymbExpr::Unknown) {
93 return SymbExpr::Unknown;
94 }
95
96 SymbExpr::Binary(*bin_op, Box::new(left), Box::new(right))
97 }
98 Rvalue::UnaryOp(un_op, op) => {
99 let expr = Self::from_operand(op, &place_ctx);
100 if matches!(expr, SymbExpr::Unknown) {
101 return SymbExpr::Unknown;
102 }
103 SymbExpr::Unary(*un_op, Box::new(expr))
104 }
105 Rvalue::Cast(kind, op, ty) => {
106 let expr = Self::from_operand(op, &place_ctx);
107 if matches!(expr, SymbExpr::Unknown) {
108 return SymbExpr::Unknown;
109 }
110 SymbExpr::Cast(*kind, Box::new(expr), *ty)
111 }
112 Rvalue::Ref(..)
113 | Rvalue::ThreadLocalRef(..)
114 | Rvalue::Aggregate(..)
115 | Rvalue::Repeat(..)
116 | Rvalue::ShallowInitBox(..)
117 | Rvalue::NullaryOp(..)
118 | Rvalue::Discriminant(..)
119 | Rvalue::CopyForDeref(..) => SymbExpr::Unknown,
120 Rvalue::RawPtr(raw_ptr_kind, place) => todo!(),
121 Rvalue::WrapUnsafeBinder(operand, ty) => todo!(),
122 }
123 }
124
125 pub fn resolve_upper_bound<T: IntervalArithmetic + ConstConvert + Debug + Clone + PartialEq>(
188 &mut self,
189 vars: &VarNodes<'tcx, T>,
190 ) {
191 self.resolve_recursive(vars, 0, BoundMode::Upper);
192 }
193 pub fn resolve_lower_bound<T: IntervalArithmetic + ConstConvert + Debug + Clone + PartialEq>(
194 &mut self,
195 vars: &VarNodes<'tcx, T>,
196 ) {
197 self.resolve_recursive(vars, 0, BoundMode::Lower);
198 }
199
200 fn resolve_recursive<T: IntervalArithmetic + ConstConvert + Debug + Clone + PartialEq>(
201 &mut self,
202 vars: &VarNodes<'tcx, T>,
203 depth: usize,
204 mode: BoundMode,
205 ) {
206 const MAX_DEPTH: usize = 10;
207 if depth > MAX_DEPTH {
208 *self = SymbExpr::Unknown;
209 return;
210 }
211
212 match self {
213 SymbExpr::Binary(op, lhs, rhs) => {
214 lhs.resolve_recursive(vars, depth + 1, mode);
215
216 match op {
217 BinOp::Add | BinOp::AddUnchecked | BinOp::AddWithOverflow => {
218 rhs.resolve_recursive(vars, depth + 1, mode);
219 }
220 BinOp::Sub | BinOp::SubUnchecked | BinOp::SubWithOverflow => {
221 rhs.resolve_recursive(vars, depth + 1, mode.flip());
222 }
223 _ => rhs.resolve_recursive(vars, depth + 1, mode),
224 }
225 }
226 SymbExpr::Unary(op, inner) => match op {
227 UnOp::Neg => {
228 inner.resolve_recursive(vars, depth + 1, mode.flip());
229 }
230 _ => inner.resolve_recursive(vars, depth + 1, mode),
231 },
232 SymbExpr::Cast(_, inner, _) => {
233 inner.resolve_recursive(vars, depth + 1, mode);
234 }
235 _ => {}
236 }
237
238 rap_trace!("symexpr {}", self);
240 if let SymbExpr::Place(place) = self {
241 if let Some(node) = vars.get(place) {
242 if let IntervalType::Basic(basic) = &node.interval {
243 rap_trace!("node {:?}", *node);
244
245 let target_expr = if basic.lower == basic.upper {
246 &basic.upper
247 } else {
248 match mode {
249 BoundMode::Upper => &basic.upper,
250 BoundMode::Lower => &basic.lower,
251 }
252 };
253
254 match target_expr {
255 SymbExpr::Unknown => *self = SymbExpr::Unknown,
256 SymbExpr::Constant(c) => *self = SymbExpr::Constant(c.clone()),
257 expr => {
258 if let SymbExpr::Place(target_place) = expr {
259 if target_place == place {
260 return;
261 }
262 }
263
264 *self = expr.clone();
265 self.resolve_recursive(vars, depth + 1, mode);
266 }
267 }
268 }
269 }
270 }
271 }
272 pub fn simplify(&mut self) {
273 match self {
274 SymbExpr::Binary(_, lhs, rhs) => {
275 lhs.simplify();
276 rhs.simplify();
277 }
278 SymbExpr::Unary(_, inner) => {
279 inner.simplify();
280 }
281 SymbExpr::Cast(_, inner, _) => {
282 inner.simplify();
283 }
284 _ => {}
285 }
286
287 if let SymbExpr::Binary(op, lhs, rhs) = self {
288 match op {
289 BinOp::Sub | BinOp::SubUnchecked | BinOp::SubWithOverflow => {
290 if let SymbExpr::Binary(inner_op, inner_lhs, inner_rhs) = lhs.as_ref() {
291 match inner_op {
292 BinOp::Add | BinOp::AddUnchecked | BinOp::AddWithOverflow => {
293 if inner_lhs == rhs {
294 *self = *inner_rhs.clone();
295 } else if inner_rhs == rhs {
296 *self = *inner_lhs.clone();
297 }
298 }
299 _ => {}
300 }
301 }
302 }
303 BinOp::Add | BinOp::AddUnchecked | BinOp::AddWithOverflow => {
304 if let SymbExpr::Binary(inner_op, inner_lhs, inner_rhs) = lhs.as_ref() {
305 match inner_op {
306 BinOp::Sub | BinOp::SubUnchecked | BinOp::SubWithOverflow => {
307 if inner_rhs == rhs {
308 *self = *inner_lhs.clone();
309 }
310 }
311 _ => {}
312 }
313 }
314 }
315 _ => {}
316 }
317 }
318 }
319}
320#[derive(Debug, Clone)]
321pub enum IntervalType<'tcx, T: IntervalArithmetic + ConstConvert + Debug> {
322 Basic(BasicInterval<'tcx, T>),
323 Symb(SymbInterval<'tcx, T>),
324}
325
326impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> fmt::Display for IntervalType<'tcx, T> {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 match self {
329 IntervalType::Basic(b) => write!(
330 f,
331 "BasicInterval: {:?} {:?} {:?} ",
332 b.get_range(),
333 b.lower,
334 b.upper
335 ),
336 IntervalType::Symb(b) => write!(
337 f,
338 "SymbInterval: {:?} {:?} {:?} ",
339 b.get_range(),
340 b.lower,
341 b.upper
342 ),
343 }
344 }
345}
346pub trait IntervalTypeTrait<'tcx, T: IntervalArithmetic + ConstConvert + Debug> {
347 fn get_range(&self) -> &Range<T>;
348 fn set_range(&mut self, new_range: Range<T>);
349 fn get_lower_expr(&self) -> &SymbExpr<'tcx>;
350 fn get_upper_expr(&self) -> &SymbExpr<'tcx>;
351}
352impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> IntervalTypeTrait<'tcx, T>
353 for IntervalType<'tcx, T>
354{
355 fn get_range(&self) -> &Range<T> {
356 match self {
357 IntervalType::Basic(b) => b.get_range(),
358 IntervalType::Symb(s) => s.get_range(),
359 }
360 }
361
362 fn set_range(&mut self, new_range: Range<T>) {
363 match self {
364 IntervalType::Basic(b) => b.set_range(new_range),
365 IntervalType::Symb(s) => s.set_range(new_range),
366 }
367 }
368 fn get_lower_expr(&self) -> &SymbExpr<'tcx> {
369 match self {
370 IntervalType::Basic(b) => b.get_lower_expr(),
371 IntervalType::Symb(s) => s.get_lower_expr(),
372 }
373 }
374
375 fn get_upper_expr(&self) -> &SymbExpr<'tcx> {
376 match self {
377 IntervalType::Basic(b) => b.get_upper_expr(),
378 IntervalType::Symb(s) => s.get_upper_expr(),
379 }
380 }
381}
382#[derive(Debug, Clone)]
383
384pub struct BasicInterval<'tcx, T: IntervalArithmetic + ConstConvert + Debug> {
385 pub range: Range<T>,
386 pub lower: SymbExpr<'tcx>,
387 pub upper: SymbExpr<'tcx>,
388}
389
390impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> BasicInterval<'tcx, T> {
391 pub fn new(range: Range<T>) -> Self {
392 Self {
393 range,
394 lower: SymbExpr::Unknown,
395 upper: SymbExpr::Unknown,
396 }
397 }
398 pub fn new_symb(range: Range<T>, lower: SymbExpr<'tcx>, upper: SymbExpr<'tcx>) -> Self {
399 Self {
400 range,
401 lower,
402 upper,
403 }
404 }
405 pub fn default() -> Self {
406 Self {
407 range: Range::default(T::min_value()),
408 lower: SymbExpr::Unknown,
409 upper: SymbExpr::Unknown,
410 }
411 }
412}
413
414impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> IntervalTypeTrait<'tcx, T>
415 for BasicInterval<'tcx, T>
416{
417 fn get_range(&self) -> &Range<T> {
422 &self.range
423 }
424
425 fn set_range(&mut self, new_range: Range<T>) {
426 self.range = new_range;
427 if self.range.get_lower() > self.range.get_upper() {
428 self.range.set_empty();
429 }
430 }
431 fn get_lower_expr(&self) -> &SymbExpr<'tcx> {
432 &self.lower
433 }
434
435 fn get_upper_expr(&self) -> &SymbExpr<'tcx> {
436 &self.upper
437 }
438}
439
440#[derive(Debug, Clone)]
441
442pub struct SymbInterval<'tcx, T: IntervalArithmetic + ConstConvert + Debug> {
443 range: Range<T>,
444 symbound: &'tcx Place<'tcx>,
445 predicate: BinOp,
446 lower: SymbExpr<'tcx>,
447 upper: SymbExpr<'tcx>,
448}
449
450impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> SymbInterval<'tcx, T> {
451 pub fn new(range: Range<T>, symbound: &'tcx Place<'tcx>, predicate: BinOp) -> Self {
452 Self {
453 range,
454 symbound,
455 predicate,
456 lower: SymbExpr::Unknown,
457 upper: SymbExpr::Unknown,
458 }
459 }
460
461 pub fn get_operation(&self) -> BinOp {
491 self.predicate
492 }
493
494 pub fn get_bound(&self) -> &'tcx Place<'tcx> {
495 self.symbound
496 }
497
498 pub fn sym_fix_intersects(
499 &self,
500 bound: &VarNode<'tcx, T>,
501 sink: &VarNode<'tcx, T>,
502 ) -> Range<T> {
503 let l = bound.get_range().get_lower().clone();
504 let u = bound.get_range().get_upper().clone();
505
506 let lower = sink.get_range().get_lower().clone();
507 let upper = sink.get_range().get_upper().clone();
508
509 match self.predicate {
510 BinOp::Eq => Range::new(l, u, RangeType::Regular),
511
512 BinOp::Le => Range::new(lower, u, RangeType::Regular),
513
514 BinOp::Lt => {
515 if u != T::max_value() {
516 let u_minus_1 = u.checked_sub(&T::one()).unwrap_or(u);
517 Range::new(lower, u_minus_1, RangeType::Regular)
518 } else {
519 Range::new(lower, u, RangeType::Regular)
520 }
521 }
522
523 BinOp::Ge => Range::new(l, upper, RangeType::Regular),
524
525 BinOp::Gt => {
526 if l != T::min_value() {
527 let l_plus_1 = l.checked_add(&T::one()).unwrap_or(l);
528 Range::new(l_plus_1, upper, RangeType::Regular)
529 } else {
530 Range::new(l, upper, RangeType::Regular)
531 }
532 }
533
534 BinOp::Ne => Range::new(T::min_value(), T::max_value(), RangeType::Regular),
535
536 _ => Range::new(T::min_value(), T::max_value(), RangeType::Regular),
537 }
538 }
539}
540
541impl<'tcx, T: IntervalArithmetic + ConstConvert + Debug> IntervalTypeTrait<'tcx, T>
542 for SymbInterval<'tcx, T>
543{
544 fn get_range(&self) -> &Range<T> {
549 &self.range
550 }
551
552 fn set_range(&mut self, new_range: Range<T>) {
553 self.range = new_range;
554 }
555 fn get_lower_expr(&self) -> &SymbExpr<'tcx> {
556 &self.lower
557 }
558
559 fn get_upper_expr(&self) -> &SymbExpr<'tcx> {
560 &self.upper
561 }
562}