rapx/analysis/unsafety_isolation/
mod.rs

1pub mod draw_dot;
2pub mod generate_dot;
3pub mod hir_visitor;
4pub mod isolation_graph;
5pub mod render_module_dot;
6pub mod std_unsafety_isolation;
7
8use crate::analysis::unsafety_isolation::draw_dot::render_dot_graphs;
9// use crate::analysis::unsafety_isolation::draw_dot::render_dot_graphs;
10use crate::analysis::unsafety_isolation::generate_dot::UigUnit;
11use crate::analysis::unsafety_isolation::hir_visitor::{ContainsUnsafe, RelatedFnCollector};
12use crate::analysis::unsafety_isolation::isolation_graph::*;
13use crate::analysis::utils::fn_info::*;
14use rustc_hir::def_id::DefId;
15use rustc_middle::{
16    mir::{Operand, TerminatorKind},
17    ty,
18    ty::TyCtxt,
19};
20use std::collections::VecDeque;
21
22#[derive(PartialEq)]
23pub enum UigInstruction {
24    Audit,
25    StdAudit,
26}
27
28pub struct UnsafetyIsolationCheck<'tcx> {
29    pub tcx: TyCtxt<'tcx>,
30    pub nodes: Vec<IsolationGraphNode>,
31    pub related_func_def_id: Vec<DefId>,
32    pub uigs: Vec<UigUnit>,
33    pub single: Vec<UigUnit>,
34}
35
36impl<'tcx> UnsafetyIsolationCheck<'tcx> {
37    pub fn new(tcx: TyCtxt<'tcx>) -> Self {
38        Self {
39            tcx,
40            nodes: Vec::new(),
41            related_func_def_id: Vec::new(),
42            uigs: Vec::new(),
43            single: Vec::new(),
44        }
45    }
46
47    pub fn start(&mut self, ins: UigInstruction) {
48        if ins == UigInstruction::StdAudit {
49            self.audit_std_unsafe();
50            return;
51        }
52        let related_items = RelatedFnCollector::collect(self.tcx);
53        for vec in related_items.values() {
54            for (body_id, _span) in vec {
55                let (function_unsafe, _block_unsafe) =
56                    ContainsUnsafe::contains_unsafe(self.tcx, *body_id);
57                let def_id = self.tcx.hir_body_owner_def_id(*body_id).to_def_id();
58                if function_unsafe {
59                    self.insert_uig(
60                        def_id,
61                        get_callees(self.tcx, def_id),
62                        get_cons(self.tcx, def_id),
63                    );
64                }
65            }
66        }
67        // self.render_dot();
68        self.render_module_dot();
69        // let file_name = format!("re.dot");
70        // let mut file = std::fs::File::create(&file_name).expect("Unable to create file");
71        // file.write_all(dot.as_bytes())
72        //     .expect("Unable to write data");
73
74        // std::process::Command::new("sfdp")
75        //     .args(["-Tsvg", &file_name, "-o", &format!("UPG.png")])
76        //     .output()
77        //     .expect("Failed to execute Graphviz dot command");
78    }
79
80    pub fn render_dot(&mut self) {
81        let mut dot_strs = Vec::new();
82        for uig in &self.uigs {
83            let dot_str = uig.generate_dot_str();
84            let caller_name = get_cleaned_def_path_name(self.tcx, uig.caller.0);
85            dot_strs.push((caller_name, dot_str));
86        }
87        for uig in &self.single {
88            let dot_str = uig.generate_dot_str();
89            let caller_name = get_cleaned_def_path_name(self.tcx, uig.caller.0);
90            dot_strs.push((caller_name, dot_str));
91        }
92        render_dot_graphs(dot_strs);
93    }
94
95    pub fn filter_and_extend_unsafe(&mut self) {
96        let related_items = RelatedFnCollector::collect(self.tcx);
97        let mut queue = VecDeque::new();
98        let mut visited = std::collections::HashSet::new();
99
100        //'related_items' is used for recording whether this api is in crate or not
101        //then init the queue, including all unsafe func and interior unsafe func
102        for vec in related_items.values() {
103            for (body_id, _) in vec {
104                let (function_unsafe, block_unsafe) =
105                    ContainsUnsafe::contains_unsafe(self.tcx, *body_id);
106                let body_did = self.tcx.hir_body_owner_def_id(*body_id).to_def_id();
107                if function_unsafe || block_unsafe {
108                    let node_type = get_type(self.tcx, body_did);
109                    let name = self.get_name(body_did);
110                    let mut new_node =
111                        IsolationGraphNode::new(body_did, node_type, name, function_unsafe, true);
112                    if node_type == 1 {
113                        new_node.constructors = self.search_constructor(body_did);
114                    }
115                    self.nodes.push(new_node);
116                    self.related_func_def_id.push(body_did);
117                    if visited.insert(body_did) {
118                        queue.push_back(body_did);
119                    }
120                }
121            }
122        }
123
124        // BFS handling the queue
125        while let Some(body_did) = queue.pop_front() {
126            if !self.is_crate_api_node(body_did) {
127                continue;
128            }
129            // get all unsafe callees in current crate api and insert to queue
130            let callees = self.visit_node_callees(body_did);
131            for &callee_id in &callees {
132                if visited.insert(callee_id) {
133                    queue.push_back(callee_id);
134                }
135            }
136        }
137    }
138
139    pub fn check_if_node_exists(&self, body_did: DefId) -> bool {
140        if let Some(_node) = self.nodes.iter().find(|n| n.node_id == body_did) {
141            return true;
142        }
143        false
144    }
145
146    pub fn get_name(&self, body_did: DefId) -> String {
147        let tcx = self.tcx;
148        let mut name = String::new();
149        if let Some(assoc_item) = tcx.opt_associated_item(body_did) {
150            if let Some(impl_id) = assoc_item.impl_container(tcx) {
151                // get struct name
152                let ty = tcx.type_of(impl_id).skip_binder();
153                let type_name = ty.to_string();
154                let type_name = type_name.split('<').next().unwrap_or("").trim();
155                // get method name
156                let method_name = tcx.def_path(body_did).to_string_no_crate_verbose();
157                let method_name = method_name.split("::").last().unwrap_or("");
158                name = format!("{}.{}", type_name, method_name);
159            }
160        } else {
161            let verbose_name = tcx.def_path(body_did).to_string_no_crate_verbose();
162            name = verbose_name.split("::").last().unwrap_or("").to_string();
163        }
164        name
165    }
166
167    pub fn search_constructor(&mut self, def_id: DefId) -> Vec<DefId> {
168        let tcx = self.tcx;
169        let mut constructors = Vec::new();
170        if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
171            if let Some(impl_id) = assoc_item.impl_container(tcx) {
172                // get struct ty
173                let ty = tcx.type_of(impl_id).skip_binder();
174                if let Some(adt_def) = ty.ty_adt_def() {
175                    let adt_def_id = adt_def.did();
176                    let impl_vec = get_impls_for_struct(self.tcx, adt_def_id);
177                    for impl_id in impl_vec {
178                        let associated_items = tcx.associated_items(impl_id);
179                        for item in associated_items.in_definition_order() {
180                            if let ty::AssocKind::Fn {
181                                name: _,
182                                has_self: _,
183                            } = item.kind
184                            {
185                                let item_def_id = item.def_id;
186                                if get_type(self.tcx, item_def_id) == 0 {
187                                    constructors.push(item_def_id);
188                                    self.check_and_insert_node(item_def_id);
189                                    self.set_method_for_constructor(item_def_id, def_id);
190                                }
191                            }
192                        }
193                    }
194                }
195            }
196        }
197        constructors
198    }
199
200    pub fn get_cons_counts(&self, def_id: DefId) -> Vec<DefId> {
201        let tcx = self.tcx;
202        let mut constructors = Vec::new();
203        let mut methods = Vec::new();
204        if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
205            if let Some(impl_id) = assoc_item.impl_container(tcx) {
206                // get struct ty
207                let ty = tcx.type_of(impl_id).skip_binder();
208                if let Some(adt_def) = ty.ty_adt_def() {
209                    let adt_def_id = adt_def.did();
210                    let impl_vec = get_impls_for_struct(self.tcx, adt_def_id);
211                    for impl_id in impl_vec {
212                        let associated_items = tcx.associated_items(impl_id);
213                        for item in associated_items.in_definition_order() {
214                            if let ty::AssocKind::Fn {
215                                name: _,
216                                has_self: _,
217                            } = item.kind
218                            {
219                                let item_def_id = item.def_id;
220                                if get_type(self.tcx, item_def_id) == 0 {
221                                    constructors.push(item_def_id);
222                                } else if get_type(self.tcx, item_def_id) == 1 {
223                                    methods.push(item_def_id);
224                                }
225                            }
226                        }
227                    }
228                }
229                print!("struct:{:?}", ty);
230            }
231        }
232        println!("--------methods:{:?}", methods.len());
233        constructors
234    }
235
236    // visit the func body and record all its unsafe callees and modify visited_tag
237    pub fn visit_node_callees(&mut self, def_id: DefId) -> Vec<DefId> {
238        let mut callees = Vec::new();
239        let tcx = self.tcx;
240        if tcx.is_mir_available(def_id) {
241            let body = tcx.optimized_mir(def_id);
242            for bb in body.basic_blocks.iter() {
243                if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
244                    if let Operand::Constant(func_constant) = func {
245                        if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
246                            if check_safety(self.tcx, *callee_def_id) {
247                                if !callees.contains(callee_def_id) {
248                                    callees.push(*callee_def_id);
249                                    if !self.check_if_node_exists(*callee_def_id) {
250                                        self.check_and_insert_node(*callee_def_id);
251                                        self.set_caller_for_callee(def_id, *callee_def_id);
252                                    }
253                                }
254                            }
255                        }
256                    }
257                }
258            }
259        }
260        if let Some(node) = self.nodes.iter_mut().find(|n| n.node_id == def_id) {
261            node.callees = callees.clone();
262            node.visited_tag = true;
263        }
264        callees
265    }
266
267    pub fn is_crate_api_node(&self, body_did: DefId) -> bool {
268        self.related_func_def_id.contains(&body_did)
269    }
270
271    pub fn check_and_insert_node(&mut self, body_did: DefId) {
272        if self.check_if_node_exists(body_did) {
273            return;
274        }
275        let node_type = get_type(self.tcx, body_did);
276        let name = self.get_name(body_did);
277        let is_crate_api = self.is_crate_api_node(body_did);
278        let node_safety = check_safety(self.tcx, body_did);
279        let mut new_node =
280            IsolationGraphNode::new(body_did, node_type, name, node_safety, is_crate_api);
281        if node_type == 1 {
282            new_node.constructors = self.search_constructor(body_did);
283        }
284        new_node.visited_tag = false;
285        self.nodes.push(new_node);
286    }
287
288    pub fn set_method_for_constructor(&mut self, constructor_did: DefId, method_did: DefId) {
289        if let Some(node) = self
290            .nodes
291            .iter_mut()
292            .find(|node| node.node_id == constructor_did)
293        {
294            if !node.methods.contains(&method_did) {
295                node.methods.push(method_did);
296            }
297        }
298    }
299
300    pub fn set_caller_for_callee(&mut self, caller_did: DefId, callee_did: DefId) {
301        if let Some(node) = self
302            .nodes
303            .iter_mut()
304            .find(|node| node.node_id == callee_did)
305        {
306            if !node.callers.contains(&caller_did) {
307                node.callers.push(caller_did);
308            }
309        }
310    }
311
312    pub fn show_nodes(&self) {
313        for node in &self.nodes {
314            println!(
315                "name:{:?},safety:{:?},calles:{:?}",
316                node.node_name, node.node_unsafety, node.callees
317            );
318        }
319    }
320}