rapx/analysis/opt/checking/bounds_checking/
bounds_len.rs

1use std::collections::HashSet;
2
3use once_cell::sync::OnceCell;
4
5use rustc_ast::BinOpKind;
6use rustc_hir::{intravisit, Expr, ExprKind};
7use rustc_middle::mir::Local;
8use rustc_middle::ty::TyCtxt;
9use rustc_span::Span;
10
11use crate::analysis::core::dataflow::graph::{
12    AggKind, DFSStatus, Direction, EdgeOp, Graph, GraphNode, NodeOp,
13};
14use crate::analysis::utils::def_path::DefPath;
15use crate::utils::log::{
16    relative_pos_range, span_to_filename, span_to_line_number, span_to_source_code,
17};
18use annotate_snippets::{Level, Renderer, Snippet};
19
20use super::super::super::NO_STD;
21
22static DEFPATHS: OnceCell<DefPaths> = OnceCell::new();
23
24struct DefPaths {
25    ops_range: DefPath,
26    vec_len: DefPath,
27    slice_len: DefPath,
28    ops_index: DefPath,
29    ops_index_mut: DefPath,
30}
31
32impl DefPaths {
33    pub fn new(tcx: &TyCtxt<'_>) -> Self {
34        let no_std = NO_STD.lock().unwrap();
35        if *no_std {
36            Self {
37                ops_range: DefPath::new("core::ops::Range", tcx),
38                vec_len: DefPath::new("alloc::vec::Vec::len", tcx),
39                slice_len: DefPath::new("core::slice::len", tcx),
40                ops_index: DefPath::new("core::ops::Index::index", tcx),
41                ops_index_mut: DefPath::new("core::ops::IndexMut::index_mut", tcx),
42            }
43        } else {
44            Self {
45                ops_range: DefPath::new("std::ops::Range", tcx),
46                vec_len: DefPath::new("std::vec::Vec::len", tcx),
47                slice_len: DefPath::new("slice::len", tcx),
48                ops_index: DefPath::new("std::ops::Index::index", tcx),
49                ops_index_mut: DefPath::new("std::ops::IndexMut::index_mut", tcx),
50            }
51        }
52    }
53}
54
55use crate::analysis::opt::OptCheck;
56
57pub struct BoundsLenCheck {
58    pub record: Vec<(Local, Vec<Local>)>,
59}
60
61struct IfFinder {
62    record: Vec<(Span, Vec<Span>)>,
63}
64struct LtFinder {
65    record: Vec<Span>,
66}
67struct IndexFinder {
68    record: Vec<Span>,
69}
70
71impl intravisit::Visitor<'_> for LtFinder {
72    fn visit_expr(&mut self, ex: &Expr) {
73        if let ExprKind::Binary(op, ..) = ex.kind {
74            if op.node == BinOpKind::Lt {
75                self.record.push(ex.span);
76            }
77        }
78        intravisit::walk_expr(self, ex);
79    }
80}
81
82impl<'tcx> intravisit::Visitor<'tcx> for IfFinder {
83    fn visit_expr(&mut self, ex: &'tcx Expr<'tcx>) {
84        if let ExprKind::If(cond, e1, _) = ex.kind {
85            let mut lt_finder = LtFinder { record: vec![] };
86            intravisit::walk_expr(&mut lt_finder, cond);
87            if !lt_finder.record.is_empty() {
88                let mut index_finder = IndexFinder { record: vec![] };
89                intravisit::walk_expr(&mut index_finder, e1);
90                if !index_finder.record.is_empty() {
91                    self.record.push((lt_finder.record[0], index_finder.record));
92                }
93            }
94        }
95        intravisit::walk_expr(self, ex);
96    }
97}
98
99impl<'tcx> intravisit::Visitor<'tcx> for IndexFinder {
100    fn visit_expr(&mut self, ex: &'tcx Expr<'tcx>) {
101        if let ExprKind::Index(_, ex2, _) = ex.kind {
102            self.record.push(ex2.span);
103        }
104        intravisit::walk_expr(self, ex);
105    }
106}
107
108impl OptCheck for BoundsLenCheck {
109    fn new() -> Self {
110        Self { record: vec![] }
111    }
112
113    fn check(&mut self, graph: &Graph, tcx: &TyCtxt) {
114        let _ = &DEFPATHS.get_or_init(|| DefPaths::new(tcx));
115        for (node_idx, node) in graph.nodes.iter_enumerated() {
116            if let Some(upperbound_node_idx) = extract_upperbound_node_if_ops_range(graph, node) {
117                if let Some(vec_len_node_idx) = find_upside_len_node(graph, upperbound_node_idx) {
118                    let maybe_vec_node_idx = graph.get_upside_idx(vec_len_node_idx, 0).unwrap();
119                    let maybe_vec_node_idxs =
120                        graph.collect_equivalent_locals(maybe_vec_node_idx, true);
121                    let mut index_record = vec![];
122                    for index_node_idx in find_downside_index_node(graph, node_idx).into_iter() {
123                        let maybe_vec_node_idx = graph.get_upside_idx(index_node_idx, 0).unwrap();
124                        if maybe_vec_node_idxs.contains(&maybe_vec_node_idx) {
125                            index_record.push(index_node_idx);
126                        }
127                    }
128                    if !index_record.is_empty() {
129                        self.record.push((upperbound_node_idx, index_record));
130                    }
131                }
132            }
133        }
134        let def_id = graph.def_id;
135        let body = tcx.hir_body_owned_by(def_id.as_local().unwrap());
136        let mut if_finder = IfFinder { record: vec![] };
137        intravisit::walk_body(&mut if_finder, body);
138        for (cond, slice_index_record) in if_finder.record.iter() {
139            if let Some((node_idx, node)) = graph.query_node_by_span(*cond, true) {
140                let left_arm = graph.edges[node.in_edges[0]].src;
141                let right_arm = graph.edges[node.in_edges[1]].src;
142                if find_upside_len_node(graph, right_arm).is_some() {
143                    let index_set = graph.collect_ancestor_locals(left_arm, true);
144                    let len_set = graph.collect_ancestor_locals(right_arm, true);
145                    let mut slice_node_indice = vec![];
146                    for slice_index_idx in slice_index_record {
147                        if let Some((index_node_idx, _)) =
148                            graph.query_node_by_span(*slice_index_idx, true)
149                        {
150                            let index_ancestors =
151                                graph.collect_ancestor_locals(index_node_idx, true);
152                            let indexed_node_idx =
153                                find_indexed_node_from_index(graph, index_node_idx);
154                            if let Some(indexed_node_idx) = indexed_node_idx {
155                                let indexed_ancestors =
156                                    graph.collect_ancestor_locals(indexed_node_idx, true);
157                                // Warning: We only checks index without checking the indexed value
158                                if index_ancestors.intersection(&index_set).next().is_some()
159                                    && indexed_ancestors.intersection(&len_set).next().is_some()
160                                {
161                                    slice_node_indice.push(index_node_idx);
162                                }
163                            }
164                        }
165                    }
166                    self.record.push((node_idx, slice_node_indice));
167                }
168            }
169        }
170    }
171
172    fn report(&self, graph: &Graph) {
173        for (upperbound_node_idx, index_record) in self.record.iter() {
174            report_upperbound_bug(graph, *upperbound_node_idx, index_record);
175        }
176    }
177
178    fn cnt(&self) -> usize {
179        self.record.iter().map(|(_, spans)| spans.len()).sum()
180    }
181}
182
183fn find_indexed_node_from_index(graph: &Graph, index_node_idx: Local) -> Option<Local> {
184    let def_paths = &DEFPATHS.get().unwrap();
185    let index_node = &graph.nodes[index_node_idx];
186    for edge_idx in index_node.out_edges.iter() {
187        let dst_node_idx = graph.edges[*edge_idx].dst;
188        let dst_node = &graph.nodes[dst_node_idx];
189        for op in dst_node.ops.iter() {
190            if let NodeOp::Call(def_id) = op {
191                if *def_id == def_paths.ops_index.last_def_id()
192                    || *def_id == def_paths.ops_index_mut.last_def_id()
193                {
194                    let index_operator_node =
195                        &graph.nodes[graph.edges[index_node.out_edges[0]].dst];
196
197                    return Some(graph.edges[index_operator_node.in_edges[0]].src);
198                }
199            }
200            if graph.is_marker(dst_node_idx) {
201                for edge_idx_ in dst_node.in_edges.iter() {
202                    let edge = &graph.edges[*edge_idx_];
203                    if let EdgeOp::Index = edge.op {
204                        return Some(edge.src);
205                    }
206                }
207            }
208        }
209    }
210    None
211}
212
213fn extract_upperbound_node_if_ops_range(graph: &Graph, node: &GraphNode) -> Option<Local> {
214    let def_paths = &DEFPATHS.get().unwrap();
215    let target_def_id = def_paths.ops_range.last_def_id();
216    for op in node.ops.iter() {
217        if let NodeOp::Aggregate(AggKind::Adt(def_id)) = op {
218            if *def_id == target_def_id {
219                let upperbound_edge = &graph.edges[node.in_edges[1]]; // the second field
220                return Some(upperbound_edge.src);
221            }
222        }
223    }
224    None
225}
226
227fn find_upside_len_node(graph: &Graph, node_idx: Local) -> Option<Local> {
228    let mut len_node_idx = None;
229    let def_paths = &DEFPATHS.get().unwrap();
230    // Warning: may traverse all upside nodes and the new result will overwrite on the previous result
231    let mut node_operator = |graph: &Graph, idx: Local| -> DFSStatus {
232        let node = &graph.nodes[idx];
233        for op in node.ops.iter() {
234            if let NodeOp::Call(def_id) = op {
235                if *def_id == def_paths.vec_len.last_def_id()
236                    || *def_id == def_paths.slice_len.last_def_id()
237                {
238                    len_node_idx = Some(idx);
239                    return DFSStatus::Stop;
240                }
241            }
242        }
243        DFSStatus::Continue
244    };
245    let mut seen = HashSet::new();
246    graph.dfs(
247        node_idx,
248        Direction::Upside,
249        &mut node_operator,
250        &mut Graph::equivalent_edge_validator,
251        false,
252        &mut seen,
253    );
254    len_node_idx
255}
256
257fn find_downside_index_node(graph: &Graph, node_idx: Local) -> Vec<Local> {
258    let mut index_node_idxs: Vec<Local> = vec![];
259    let def_paths = &DEFPATHS.get().unwrap();
260    // Warning: traverse all downside nodes
261    let mut node_operator = |graph: &Graph, idx: Local| -> DFSStatus {
262        let node = &graph.nodes[idx];
263        for op in node.ops.iter() {
264            if let NodeOp::Call(def_id) = op {
265                if *def_id == def_paths.ops_index.last_def_id()
266                    || *def_id == def_paths.ops_index_mut.last_def_id()
267                {
268                    index_node_idxs.push(idx);
269                    break;
270                }
271            }
272        }
273        DFSStatus::Continue
274    };
275    let mut seen = HashSet::new();
276    graph.dfs(
277        node_idx,
278        Direction::Downside,
279        &mut node_operator,
280        &mut Graph::always_true_edge_validator,
281        true,
282        &mut seen,
283    );
284    index_node_idxs
285}
286
287fn report_upperbound_bug(graph: &Graph, upperbound_node_idx: Local, index_record: &Vec<Local>) {
288    let upperbound_span = graph.nodes[upperbound_node_idx].span;
289    let code_source = span_to_source_code(graph.span);
290    let filename = span_to_filename(upperbound_span);
291    let mut snippet = Snippet::source(&code_source)
292        .line_start(span_to_line_number(graph.span))
293        .origin(&filename)
294        .fold(true)
295        .annotation(
296            Level::Info
297                .span(relative_pos_range(graph.span, upperbound_span))
298                .label("Index is upperbounded."),
299        );
300    for node_idx in index_record {
301        let index_span = graph.nodes[*node_idx].span;
302        snippet = snippet.annotation(
303            Level::Error
304                .span(relative_pos_range(graph.span, index_span))
305                .label("Checked here."),
306        );
307    }
308    let message = Level::Warning
309        .title("Unnecessary bounds checkings detected")
310        .snippet(snippet)
311        .footer(Level::Help.title("Use unsafe APIs instead."));
312    let renderer = Renderer::styled();
313    println!("{}", renderer.render(message));
314}