rapx/analysis/opt/data_collection/reallocation/
flatten_collect.rs

1use once_cell::sync::OnceCell;
2
3use rustc_middle::ty::TyCtxt;
4
5use crate::{
6    analysis::{
7        core::dataflow::{graph::Graph, *},
8        opt::OptCheck,
9        utils::def_path::DefPath,
10    },
11    utils::log::{relative_pos_range, span_to_filename, span_to_line_number, span_to_source_code},
12};
13use annotate_snippets::{Level, Renderer, Snippet};
14use rustc_span::Span;
15
16static DEFPATHS: OnceCell<DefPaths> = OnceCell::new();
17
18struct DefPaths {
19    flat_map: DefPath,
20    flatten: DefPath,
21    collect: DefPath,
22}
23
24impl DefPaths {
25    fn new(tcx: &TyCtxt<'_>) -> Self {
26        Self {
27            flat_map: DefPath::new("std::iter::Iterator::flat_map", tcx),
28            flatten: DefPath::new("std::iter::Iterator::flatten", tcx),
29            collect: DefPath::new("std::iter::Iterator::collect", tcx),
30        }
31    }
32}
33
34pub struct FlattenCollectCheck {
35    record: Vec<Span>,
36}
37
38fn is_flatten_node(node: &GraphNode) -> bool {
39    let def_paths = &DEFPATHS.get().unwrap();
40    for op in node.ops.iter() {
41        if let NodeOp::Call(def_id) = op {
42            if *def_id == def_paths.flat_map.last_def_id()
43                || *def_id == def_paths.flatten.last_def_id()
44            {
45                return true;
46            }
47        }
48    }
49    false
50}
51
52fn is_collect_node(node: &GraphNode) -> bool {
53    let def_paths = &DEFPATHS.get().unwrap();
54    for op in node.ops.iter() {
55        if let NodeOp::Call(def_id) = op {
56            if *def_id == def_paths.collect.last_def_id() {
57                return true;
58            }
59        }
60    }
61    false
62}
63
64impl OptCheck for FlattenCollectCheck {
65    fn new() -> Self {
66        Self { record: Vec::new() }
67    }
68
69    fn check(&mut self, graph: &Graph, tcx: &TyCtxt) {
70        let _ = &DEFPATHS.get_or_init(|| DefPaths::new(tcx));
71        for node in graph.nodes.iter() {
72            if is_flatten_node(node) {
73                for edge_idx in node.out_edges.iter() {
74                    let dst_idx = graph.edges[*edge_idx].dst;
75                    let dst_node = &graph.nodes[dst_idx];
76                    if is_collect_node(dst_node) {
77                        self.record.push(dst_node.span);
78                    }
79                }
80            }
81        }
82    }
83
84    fn report(&self, graph: &Graph) {
85        for span in self.record.iter() {
86            report_flatten_collect(graph, *span);
87        }
88    }
89
90    fn cnt(&self) -> usize {
91        self.record.len()
92    }
93}
94
95fn report_flatten_collect(graph: &Graph, span: Span) {
96    let code_source = span_to_source_code(graph.span);
97    let filename = span_to_filename(span);
98    let snippet: Snippet<'_> = Snippet::source(&code_source)
99        .line_start(span_to_line_number(graph.span))
100        .origin(&filename)
101        .fold(true)
102        .annotation(
103            Level::Error
104                .span(relative_pos_range(graph.span, span))
105                .label("Flatten then collect."),
106        );
107
108    let message = Level::Error
109        .title("Data collection inefficiency detected")
110        .snippet(snippet)
111        .footer(Level::Help.title("Use extend manually."));
112    let renderer = Renderer::styled();
113    println!("{}", renderer.render(message));
114}