rapx/analysis/opt/data_collection/reallocation/
flatten_collect.rs1use once_cell::sync::OnceCell;
2
3use rustc_middle::ty::TyCtxt;
4
5use crate::{
6 analysis::{
7 core::dataflow::{graph::Graph, *},
8 opt::OptCheck,
9 utils::def_path::DefPath,
10 },
11 utils::log::{relative_pos_range, span_to_filename, span_to_line_number, span_to_source_code},
12};
13use annotate_snippets::{Level, Renderer, Snippet};
14use rustc_span::Span;
15
16static DEFPATHS: OnceCell<DefPaths> = OnceCell::new();
17
18struct DefPaths {
19 flat_map: DefPath,
20 flatten: DefPath,
21 collect: DefPath,
22}
23
24impl DefPaths {
25 fn new(tcx: &TyCtxt<'_>) -> Self {
26 Self {
27 flat_map: DefPath::new("std::iter::Iterator::flat_map", tcx),
28 flatten: DefPath::new("std::iter::Iterator::flatten", tcx),
29 collect: DefPath::new("std::iter::Iterator::collect", tcx),
30 }
31 }
32}
33
34pub struct FlattenCollectCheck {
35 record: Vec<Span>,
36}
37
38fn is_flatten_node(node: &GraphNode) -> bool {
39 let def_paths = &DEFPATHS.get().unwrap();
40 for op in node.ops.iter() {
41 if let NodeOp::Call(def_id) = op {
42 if *def_id == def_paths.flat_map.last_def_id()
43 || *def_id == def_paths.flatten.last_def_id()
44 {
45 return true;
46 }
47 }
48 }
49 false
50}
51
52fn is_collect_node(node: &GraphNode) -> bool {
53 let def_paths = &DEFPATHS.get().unwrap();
54 for op in node.ops.iter() {
55 if let NodeOp::Call(def_id) = op {
56 if *def_id == def_paths.collect.last_def_id() {
57 return true;
58 }
59 }
60 }
61 false
62}
63
64impl OptCheck for FlattenCollectCheck {
65 fn new() -> Self {
66 Self { record: Vec::new() }
67 }
68
69 fn check(&mut self, graph: &Graph, tcx: &TyCtxt) {
70 let _ = &DEFPATHS.get_or_init(|| DefPaths::new(tcx));
71 for node in graph.nodes.iter() {
72 if is_flatten_node(node) {
73 for edge_idx in node.out_edges.iter() {
74 let dst_idx = graph.edges[*edge_idx].dst;
75 let dst_node = &graph.nodes[dst_idx];
76 if is_collect_node(dst_node) {
77 self.record.push(dst_node.span);
78 }
79 }
80 }
81 }
82 }
83
84 fn report(&self, graph: &Graph) {
85 for span in self.record.iter() {
86 report_flatten_collect(graph, *span);
87 }
88 }
89
90 fn cnt(&self) -> usize {
91 self.record.len()
92 }
93}
94
95fn report_flatten_collect(graph: &Graph, span: Span) {
96 let code_source = span_to_source_code(graph.span);
97 let filename = span_to_filename(span);
98 let snippet: Snippet<'_> = Snippet::source(&code_source)
99 .line_start(span_to_line_number(graph.span))
100 .origin(&filename)
101 .fold(true)
102 .annotation(
103 Level::Error
104 .span(relative_pos_range(graph.span, span))
105 .label("Flatten then collect."),
106 );
107
108 let message = Level::Error
109 .title("Data collection inefficiency detected")
110 .snippet(snippet)
111 .footer(Level::Help.title("Use extend manually."));
112 let renderer = Renderer::styled();
113 println!("{}", renderer.render(message));
114}