rapx/analysis/unsafety_isolation/
render_module_dot.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt::Write;
3
4use crate::analysis::unsafety_isolation::UnsafetyIsolationCheck;
5use crate::analysis::unsafety_isolation::draw_dot::render_dot_graphs;
6use crate::analysis::unsafety_isolation::generate_dot::{NodeType, UigEdge, UigNode, UigUnit};
7use crate::analysis::utils::fn_info::{check_safety, get_type};
8use rustc_hir::def_id::DefId;
9use rustc_middle::ty::TyCtxt;
10
11impl<'tcx> UnsafetyIsolationCheck<'tcx> {
12    /// Main function to aggregate data and render DOT graphs per module.
13    pub fn render_module_dot(&self) {
14        let mut modules_data: HashMap<String, ModuleGraphData> = HashMap::new();
15
16        let mut collect_unit = |unit: &UigUnit| {
17            let caller_id = unit.caller.0;
18            let module_name = self.get_module_name(caller_id);
19
20            let module_data = modules_data
21                .entry(module_name)
22                .or_insert_with(ModuleGraphData::new);
23
24            module_data.add_node(self.tcx, unit.caller);
25
26            // Edge from associated item (constructor) to the method.
27            for cons in &unit.caller_cons {
28                module_data.add_node(self.tcx, *cons);
29                module_data.add_edge(cons.0, unit.caller.0, UigEdge::ConsToMethod);
30            }
31
32            // Edge for mutable access to the caller.
33            for mut_method_id in &unit.mut_methods {
34                let node_type = get_type(self.tcx, *mut_method_id);
35                let is_unsafe = check_safety(self.tcx, *mut_method_id);
36                let node = (*mut_method_id, is_unsafe, node_type);
37
38                module_data.add_node(self.tcx, node);
39                module_data.add_edge(*mut_method_id, unit.caller.0, UigEdge::MutToCaller);
40            }
41
42            // Edge representing a call from caller to callee.
43            for (callee, callee_cons_vec) in &unit.callee_cons_pair {
44                module_data.add_node(self.tcx, *callee);
45                module_data.add_edge(unit.caller.0, callee.0, UigEdge::CallerToCallee);
46
47                for callee_cons in callee_cons_vec {
48                    module_data.add_node(self.tcx, *callee_cons);
49                    module_data.add_edge(callee_cons.0, callee.0, UigEdge::ConsToMethod);
50                }
51            }
52        };
53
54        // Aggregate all Units
55        for uig in &self.uigs {
56            collect_unit(uig);
57        }
58        for uig in &self.single {
59            collect_unit(uig);
60        }
61
62        // Generate string of dot
63        let mut final_dots = Vec::new();
64        for (mod_name, data) in modules_data {
65            let dot = data.generate_dot_string(&mod_name);
66            final_dots.push((mod_name, dot));
67        }
68
69        render_dot_graphs(final_dots);
70    }
71
72    /// get module of DefId
73    fn get_module_name(&self, def_id: DefId) -> String {
74        let tcx = self.tcx;
75        let parent_mod = tcx.parent_module_from_def_id(def_id.expect_local());
76        let mod_def_id = parent_mod.to_def_id();
77
78        let path = tcx.def_path_str(mod_def_id);
79        if path.is_empty() {
80            "root_module".to_string()
81        } else {
82            path
83        }
84    }
85}
86
87/// Holds graph data for a single module before DOT generation.
88struct ModuleGraphData {
89    // Nodes grouped by their associated struct/type name.
90    struct_clusters: HashMap<String, HashSet<NodeType>>,
91    // Edges between DefIds with their type.
92    edges: HashSet<(DefId, DefId, UigEdge)>,
93    // Pre-generated DOT attribute strings for each node (DefId).
94    node_styles: HashMap<DefId, String>,
95}
96
97impl ModuleGraphData {
98    fn new() -> Self {
99        Self {
100            struct_clusters: HashMap::new(),
101            edges: HashSet::new(),
102            node_styles: HashMap::new(),
103        }
104    }
105
106    fn add_node(&mut self, tcx: TyCtxt<'_>, node: NodeType) {
107        let (def_id, _, _) = node;
108        let struct_name = self.get_struct_group_name(tcx, def_id);
109        self.struct_clusters
110            .entry(struct_name)
111            .or_default()
112            .insert(node);
113
114        if !self.node_styles.contains_key(&def_id) {
115            let uig_node = UigUnit::get_node_ty(node);
116            let attr = self.node_to_dot_attr(tcx, &uig_node);
117            self.node_styles.insert(def_id, attr);
118        }
119    }
120
121    fn add_edge(&mut self, from: DefId, to: DefId, edge_type: UigEdge) {
122        if from == to {
123            return;
124        }
125        self.edges.insert((from, to, edge_type));
126    }
127
128    fn get_struct_group_name(&self, tcx: TyCtxt<'_>, def_id: DefId) -> String {
129        if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
130            if let Some(impl_id) = assoc_item.impl_container(tcx) {
131                let ty = tcx.type_of(impl_id).skip_binder();
132                let raw_name = ty.to_string();
133                let clean_name = raw_name
134                    .split('<')
135                    .next()
136                    .unwrap_or(&raw_name)
137                    .trim()
138                    .to_string();
139                return clean_name;
140            }
141        }
142        "Free_Functions".to_string()
143    }
144
145    fn node_to_dot_attr(&self, _tcx: TyCtxt<'_>, node: &UigNode) -> String {
146        match node {
147            UigNode::Safe(def_id, shape) => {
148                format!("label=\"{:?}\", color=black, shape={:?}", def_id, shape)
149            }
150            UigNode::Unsafe(def_id, shape) => {
151                let label = format!("{:?}", def_id);
152                format!("label=\"{}\", shape={:?}, color=red", label, shape)
153            }
154            _ => "label=\"Unknown\"".to_string(),
155        }
156    }
157
158    fn generate_dot_string(&self, module_name: &str) -> String {
159        let mut dot = String::new();
160        let graph_id = module_name
161            .replace("::", "_")
162            .replace(|c: char| !c.is_alphanumeric(), "_");
163
164        writeln!(dot, "digraph {} {{", graph_id).unwrap();
165        writeln!(dot, "    compound=true;").unwrap();
166        writeln!(dot, "    rankdir=LR;").unwrap();
167
168        for (struct_name, nodes) in &self.struct_clusters {
169            let cluster_id = format!(
170                "cluster_{}",
171                struct_name.replace(|c: char| !c.is_alphanumeric(), "_")
172            );
173
174            writeln!(dot, "    subgraph {} {{", cluster_id).unwrap();
175            writeln!(dot, "        label=\"{}\";", struct_name).unwrap();
176            writeln!(dot, "        style=dashed;").unwrap();
177            writeln!(dot, "        color=gray;").unwrap();
178
179            for node in nodes {
180                let def_id = node.0;
181                let node_id =
182                    format!("n_{:?}", def_id).replace(|c: char| !c.is_alphanumeric(), "_");
183
184                if let Some(attr) = self.node_styles.get(&def_id) {
185                    writeln!(dot, "        {} [{}];", node_id, attr).unwrap();
186                }
187            }
188            writeln!(dot, "    }}").unwrap();
189        }
190
191        for (from, to, edge_type) in &self.edges {
192            let from_id = format!("n_{:?}", from).replace(|c: char| !c.is_alphanumeric(), "_");
193            let to_id = format!("n_{:?}", to).replace(|c: char| !c.is_alphanumeric(), "_");
194
195            let attr = match edge_type {
196                UigEdge::CallerToCallee => "color=black, style=solid",
197                UigEdge::ConsToMethod => "color=black, style=dotted",
198                UigEdge::MutToCaller => "color=blue, style=dashed",
199            };
200
201            writeln!(dot, "    {} -> {} [{}];", from_id, to_id, attr).unwrap();
202        }
203
204        writeln!(dot, "}}").unwrap();
205        dot
206    }
207}