rapx/analysis/opt/checking/bounds_checking/
bounds_len.rs

1use std::collections::HashSet;
2
3use once_cell::sync::OnceCell;
4
5use rustc_ast::BinOpKind;
6use rustc_hir::{intravisit, Expr, ExprKind};
7use rustc_middle::{mir::Local, ty::TyCtxt};
8use rustc_span::Span;
9
10use crate::{
11    analysis::{
12        core::dataflow::{graph::*, *},
13        utils::def_path::DefPath,
14    },
15    utils::log::{relative_pos_range, span_to_filename, span_to_line_number, span_to_source_code},
16};
17use annotate_snippets::{Level, Renderer, Snippet};
18
19use super::super::super::NO_STD;
20
21static DEFPATHS: OnceCell<DefPaths> = OnceCell::new();
22
23struct DefPaths {
24    ops_range: DefPath,
25    vec_len: DefPath,
26    slice_len: DefPath,
27    ops_index: DefPath,
28    ops_index_mut: DefPath,
29}
30
31impl DefPaths {
32    pub fn new(tcx: &TyCtxt<'_>) -> Self {
33        let no_std = NO_STD.lock().unwrap();
34        if *no_std {
35            Self {
36                ops_range: DefPath::new("core::ops::Range", tcx),
37                vec_len: DefPath::new("alloc::vec::Vec::len", tcx),
38                slice_len: DefPath::new("core::slice::len", tcx),
39                ops_index: DefPath::new("core::ops::Index::index", tcx),
40                ops_index_mut: DefPath::new("core::ops::IndexMut::index_mut", tcx),
41            }
42        } else {
43            Self {
44                ops_range: DefPath::new("std::ops::Range", tcx),
45                vec_len: DefPath::new("std::vec::Vec::len", tcx),
46                slice_len: DefPath::new("slice::len", tcx),
47                ops_index: DefPath::new("std::ops::Index::index", tcx),
48                ops_index_mut: DefPath::new("std::ops::IndexMut::index_mut", tcx),
49            }
50        }
51    }
52}
53
54use crate::analysis::opt::OptCheck;
55
56pub struct BoundsLenCheck {
57    pub record: Vec<(Local, Vec<Local>)>,
58}
59
60struct IfFinder {
61    record: Vec<(Span, Vec<Span>)>,
62}
63struct LtFinder {
64    record: Vec<Span>,
65}
66struct IndexFinder {
67    record: Vec<Span>,
68}
69
70impl intravisit::Visitor<'_> for LtFinder {
71    fn visit_expr(&mut self, ex: &Expr) {
72        if let ExprKind::Binary(op, ..) = ex.kind {
73            if op.node == BinOpKind::Lt {
74                self.record.push(ex.span);
75            }
76        }
77        intravisit::walk_expr(self, ex);
78    }
79}
80
81impl<'tcx> intravisit::Visitor<'tcx> for IfFinder {
82    fn visit_expr(&mut self, ex: &'tcx Expr<'tcx>) {
83        if let ExprKind::If(cond, e1, _) = ex.kind {
84            let mut lt_finder = LtFinder { record: vec![] };
85            intravisit::walk_expr(&mut lt_finder, cond);
86            if !lt_finder.record.is_empty() {
87                let mut index_finder = IndexFinder { record: vec![] };
88                intravisit::walk_expr(&mut index_finder, e1);
89                if !index_finder.record.is_empty() {
90                    self.record.push((lt_finder.record[0], index_finder.record));
91                }
92            }
93        }
94        intravisit::walk_expr(self, ex);
95    }
96}
97
98impl<'tcx> intravisit::Visitor<'tcx> for IndexFinder {
99    fn visit_expr(&mut self, ex: &'tcx Expr<'tcx>) {
100        if let ExprKind::Index(_, ex2, _) = ex.kind {
101            self.record.push(ex2.span);
102        }
103        intravisit::walk_expr(self, ex);
104    }
105}
106
107impl OptCheck for BoundsLenCheck {
108    fn new() -> Self {
109        Self { record: vec![] }
110    }
111
112    fn check(&mut self, graph: &Graph, tcx: &TyCtxt) {
113        let _ = &DEFPATHS.get_or_init(|| DefPaths::new(tcx));
114        for (node_idx, node) in graph.nodes.iter_enumerated() {
115            if let Some(upperbound_node_idx) = extract_upperbound_node_if_ops_range(graph, node) {
116                if let Some(vec_len_node_idx) = find_upside_len_node(graph, upperbound_node_idx) {
117                    let maybe_vec_node_idx = graph.get_upside_idx(vec_len_node_idx, 0).unwrap();
118                    let maybe_vec_node_idxs =
119                        graph.collect_equivalent_locals(maybe_vec_node_idx, true);
120                    let mut index_record = vec![];
121                    for index_node_idx in find_downside_index_node(graph, node_idx).into_iter() {
122                        let maybe_vec_node_idx = graph.get_upside_idx(index_node_idx, 0).unwrap();
123                        if maybe_vec_node_idxs.contains(&maybe_vec_node_idx) {
124                            index_record.push(index_node_idx);
125                        }
126                    }
127                    if !index_record.is_empty() {
128                        self.record.push((upperbound_node_idx, index_record));
129                    }
130                }
131            }
132        }
133        let def_id = graph.def_id;
134        let body = tcx.hir_body_owned_by(def_id.as_local().unwrap());
135        let mut if_finder = IfFinder { record: vec![] };
136        intravisit::walk_body(&mut if_finder, body);
137        for (cond, slice_index_record) in if_finder.record.iter() {
138            if let Some((node_idx, node)) = graph.query_node_by_span(*cond, true) {
139                let left_arm = graph.edges[node.in_edges[0]].src;
140                let right_arm = graph.edges[node.in_edges[1]].src;
141                if find_upside_len_node(graph, right_arm).is_some() {
142                    let index_set = graph.collect_ancestor_locals(left_arm, true);
143                    let len_set = graph.collect_ancestor_locals(right_arm, true);
144                    let mut slice_node_indice = vec![];
145                    for slice_index_idx in slice_index_record {
146                        if let Some((index_node_idx, _)) =
147                            graph.query_node_by_span(*slice_index_idx, true)
148                        {
149                            let index_ancestors =
150                                graph.collect_ancestor_locals(index_node_idx, true);
151                            let indexed_node_idx =
152                                find_indexed_node_from_index(graph, index_node_idx);
153                            if let Some(indexed_node_idx) = indexed_node_idx {
154                                let indexed_ancestors =
155                                    graph.collect_ancestor_locals(indexed_node_idx, true);
156                                // Warning: We only checks index without checking the indexed value
157                                if index_ancestors.intersection(&index_set).next().is_some()
158                                    && indexed_ancestors.intersection(&len_set).next().is_some()
159                                {
160                                    slice_node_indice.push(index_node_idx);
161                                }
162                            }
163                        }
164                    }
165                    self.record.push((node_idx, slice_node_indice));
166                }
167            }
168        }
169    }
170
171    fn report(&self, graph: &Graph) {
172        for (upperbound_node_idx, index_record) in self.record.iter() {
173            report_upperbound_bug(graph, *upperbound_node_idx, index_record);
174        }
175    }
176
177    fn cnt(&self) -> usize {
178        self.record.iter().map(|(_, spans)| spans.len()).sum()
179    }
180}
181
182fn find_indexed_node_from_index(graph: &Graph, index_node_idx: Local) -> Option<Local> {
183    let def_paths = &DEFPATHS.get().unwrap();
184    let index_node = &graph.nodes[index_node_idx];
185    for edge_idx in index_node.out_edges.iter() {
186        let dst_node_idx = graph.edges[*edge_idx].dst;
187        let dst_node = &graph.nodes[dst_node_idx];
188        for op in dst_node.ops.iter() {
189            if let NodeOp::Call(def_id) = op {
190                if *def_id == def_paths.ops_index.last_def_id()
191                    || *def_id == def_paths.ops_index_mut.last_def_id()
192                {
193                    let index_operator_node =
194                        &graph.nodes[graph.edges[index_node.out_edges[0]].dst];
195
196                    return Some(graph.edges[index_operator_node.in_edges[0]].src);
197                }
198            }
199            if graph.is_marker(dst_node_idx) {
200                for edge_idx_ in dst_node.in_edges.iter() {
201                    let edge = &graph.edges[*edge_idx_];
202                    if let EdgeOp::Index = edge.op {
203                        return Some(edge.src);
204                    }
205                }
206            }
207        }
208    }
209    None
210}
211
212fn extract_upperbound_node_if_ops_range(graph: &Graph, node: &GraphNode) -> Option<Local> {
213    let def_paths = &DEFPATHS.get().unwrap();
214    let target_def_id = def_paths.ops_range.last_def_id();
215    for op in node.ops.iter() {
216        if let NodeOp::Aggregate(AggKind::Adt(def_id)) = op {
217            if *def_id == target_def_id {
218                let upperbound_edge = &graph.edges[node.in_edges[1]]; // the second field
219                return Some(upperbound_edge.src);
220            }
221        }
222    }
223    None
224}
225
226fn find_upside_len_node(graph: &Graph, node_idx: Local) -> Option<Local> {
227    let mut len_node_idx = None;
228    let def_paths = &DEFPATHS.get().unwrap();
229    // Warning: may traverse all upside nodes and the new result will overwrite on the previous result
230    let mut node_operator = |graph: &Graph, idx: Local| -> DFSStatus {
231        let node = &graph.nodes[idx];
232        for op in node.ops.iter() {
233            if let NodeOp::Call(def_id) = op {
234                if *def_id == def_paths.vec_len.last_def_id()
235                    || *def_id == def_paths.slice_len.last_def_id()
236                {
237                    len_node_idx = Some(idx);
238                    return DFSStatus::Stop;
239                }
240            }
241        }
242        DFSStatus::Continue
243    };
244    let mut seen = HashSet::new();
245    graph.dfs(
246        node_idx,
247        Direction::Upside,
248        &mut node_operator,
249        &mut Graph::equivalent_edge_validator,
250        false,
251        &mut seen,
252    );
253    len_node_idx
254}
255
256fn find_downside_index_node(graph: &Graph, node_idx: Local) -> Vec<Local> {
257    let mut index_node_idxs: Vec<Local> = vec![];
258    let def_paths = &DEFPATHS.get().unwrap();
259    // Warning: traverse all downside nodes
260    let mut node_operator = |graph: &Graph, idx: Local| -> DFSStatus {
261        let node = &graph.nodes[idx];
262        for op in node.ops.iter() {
263            if let NodeOp::Call(def_id) = op {
264                if *def_id == def_paths.ops_index.last_def_id()
265                    || *def_id == def_paths.ops_index_mut.last_def_id()
266                {
267                    index_node_idxs.push(idx);
268                    break;
269                }
270            }
271        }
272        DFSStatus::Continue
273    };
274    let mut seen = HashSet::new();
275    graph.dfs(
276        node_idx,
277        Direction::Downside,
278        &mut node_operator,
279        &mut Graph::always_true_edge_validator,
280        true,
281        &mut seen,
282    );
283    index_node_idxs
284}
285
286fn report_upperbound_bug(graph: &Graph, upperbound_node_idx: Local, index_record: &Vec<Local>) {
287    let upperbound_span = graph.nodes[upperbound_node_idx].span;
288    let code_source = span_to_source_code(graph.span);
289    let filename = span_to_filename(upperbound_span);
290    let mut snippet = Snippet::source(&code_source)
291        .line_start(span_to_line_number(graph.span))
292        .origin(&filename)
293        .fold(true)
294        .annotation(
295            Level::Info
296                .span(relative_pos_range(graph.span, upperbound_span))
297                .label("Index is upperbounded."),
298        );
299    for node_idx in index_record {
300        let index_span = graph.nodes[*node_idx].span;
301        snippet = snippet.annotation(
302            Level::Error
303                .span(relative_pos_range(graph.span, index_span))
304                .label("Checked here."),
305        );
306    }
307    let message = Level::Warning
308        .title("Unnecessary bounds checkings detected")
309        .snippet(snippet)
310        .footer(Level::Help.title("Use unsafe APIs instead."));
311    let renderer = Renderer::styled();
312    println!("{}", renderer.render(message));
313}