rapx/analysis/senryx/
mod.rs

1#[allow(unused)]
2pub mod contracts;
3#[allow(unused)]
4pub mod dominated_graph;
5pub mod generic_check;
6pub mod inter_record;
7pub mod matcher;
8#[allow(unused)]
9pub mod visitor;
10#[allow(unused)]
11pub mod visitor_check;
12use dominated_graph::InterResultNode;
13use inter_record::InterAnalysisRecord;
14use rustc_data_structures::fx::FxHashMap;
15use rustc_hir::def_id::DefId;
16use rustc_middle::{
17    mir::{BasicBlock, Operand, TerminatorKind},
18    ty::{self, TyCtxt},
19};
20use rustc_span::Symbol;
21use std::collections::{HashMap, HashSet};
22use visitor::{BodyVisitor, CheckResult};
23
24use crate::{
25    analysis::{
26        core::alias_analysis::{default::AliasAnalyzer, AAResult, AliasAnalysis},
27        unsafety_isolation::{
28            draw_dot::render_dot_graphs,
29            hir_visitor::{ContainsUnsafe, RelatedFnCollector},
30            UnsafetyIsolationCheck,
31        },
32        utils::fn_info::*,
33        Analysis,
34    },
35    rap_info, rap_warn,
36};
37
38macro_rules! cond_print {
39    ($cond:expr, $($t:tt)*) => {if $cond {rap_warn!($($t)*)} else {rap_info!($($t)*)}};
40}
41
42pub enum CheckLevel {
43    High,
44    Medium,
45    Low,
46}
47
48pub struct SenryxCheck<'tcx> {
49    pub tcx: TyCtxt<'tcx>,
50    pub threshhold: usize,
51    pub global_recorder: HashMap<DefId, InterAnalysisRecord<'tcx>>,
52}
53
54impl<'tcx> SenryxCheck<'tcx> {
55    pub fn new(tcx: TyCtxt<'tcx>, threshhold: usize) -> Self {
56        Self {
57            tcx,
58            threshhold,
59            global_recorder: HashMap::new(),
60        }
61    }
62
63    pub fn start(&mut self, check_level: CheckLevel, is_verify: bool) {
64        let tcx = self.tcx;
65        let mut analyzer = AliasAnalyzer::new(self.tcx);
66        analyzer.run();
67        let fn_map = &analyzer.get_all_fn_alias();
68        let related_items = RelatedFnCollector::collect(tcx);
69        for vec in related_items.clone().values() {
70            for (body_id, _span) in vec {
71                let (function_unsafe, block_unsafe) =
72                    ContainsUnsafe::contains_unsafe(tcx, *body_id);
73                let def_id = tcx.hir_body_owner_def_id(*body_id).to_def_id();
74                let std_unsafe_callee = get_all_std_unsafe_callees(self.tcx, def_id);
75                if !Self::filter_by_check_level(tcx, &check_level, def_id) {
76                    continue;
77                }
78                if block_unsafe && is_verify && !std_unsafe_callee.is_empty() {
79                    self.check_soundness(def_id, fn_map);
80                }
81                if function_unsafe && !is_verify && !std_unsafe_callee.is_empty() {
82                    // self.annotate_safety(def_id);
83                    // let mutable_methods = get_all_mutable_methods(self.tcx, def_id);
84                    // println!("mutable_methods: {:?}", mutable_methods);
85                }
86            }
87        }
88    }
89
90    pub fn start_analyze_std_func(&mut self) {
91        let v_fn_def: Vec<_> = rustc_public::find_crates("alloc")
92            .iter()
93            .flat_map(|krate| krate.fn_defs())
94            .collect();
95        for fn_def in &v_fn_def {
96            let def_id = crate::def_id::to_internal(fn_def, self.tcx);
97            if is_verify_target_func(self.tcx, def_id) {
98                rap_info!(
99                    "Begin verification process for: {:?}",
100                    get_cleaned_def_path_name(self.tcx, def_id)
101                );
102                let check_results = self.body_visit_and_check(def_id, &FxHashMap::default());
103                if !check_results.is_empty() {
104                    Self::show_check_results(self.tcx, def_id, check_results);
105                }
106            }
107        }
108    }
109
110    pub fn generate_uig_by_def_id(&mut self) {
111        let all_std_fn_def = get_all_std_fns_by_rustc_public(self.tcx);
112        let symbol = Symbol::intern("Vec");
113        let vec_def_id = self.tcx.get_diagnostic_item(symbol).unwrap();
114        println!("vec_def_id {:?}", vec_def_id);
115        let mut uig_entrance = UnsafetyIsolationCheck::new(self.tcx);
116        for &def_id in &all_std_fn_def {
117            let adt_def = get_adt_def_id_by_adt_method(self.tcx, def_id);
118            if adt_def.is_some() && adt_def.unwrap() == vec_def_id {
119                println!("def_id {:?}", def_id);
120                uig_entrance.insert_uig(
121                    def_id,
122                    get_callees(self.tcx, def_id),
123                    get_cons(self.tcx, def_id),
124                );
125            }
126        }
127        let mut dot_strs = Vec::new();
128        for uig in &uig_entrance.uigs {
129            let dot_str = uig.generate_dot_str();
130            dot_strs.push(dot_str);
131        }
132        for uig in &uig_entrance.single {
133            let dot_str = uig.generate_dot_str();
134            dot_strs.push(dot_str);
135        }
136        render_dot_graphs(dot_strs);
137    }
138
139    pub fn start_analyze_std_func_chains(&mut self) {
140        let all_std_fn_def = get_all_std_fns_by_rustc_public(self.tcx);
141        let mut last_nodes = HashSet::new();
142        for &def_id in &all_std_fn_def {
143            if !check_visibility(self.tcx, def_id) {
144                continue;
145            }
146            let chains = get_all_std_unsafe_chains(self.tcx, def_id);
147            let valid_chains: Vec<Vec<String>> = chains
148                .into_iter()
149                .filter(|chain| {
150                    if chain.len() > 1 {
151                        return true;
152                    }
153                    if chain.len() == 1 {
154                        let is_unsafe = check_safety(self.tcx, def_id);
155                        return is_unsafe;
156                    }
157                    false
158                })
159                .collect();
160
161            let mut last = true;
162            for chain in &valid_chains {
163                if let Some(last_node) = chain.last() {
164                    if !last_node.contains("intrinsic") && !last_node.contains("aarch64") {
165                        last_nodes.insert(last_node.clone());
166                        last = false;
167                    }
168                }
169            }
170            if last {
171                continue;
172            }
173            // print_unsafe_chains(&valid_chains);
174        }
175        Self::print_last_nodes(&last_nodes);
176    }
177
178    pub fn print_last_nodes(last_nodes: &HashSet<String>) {
179        if last_nodes.is_empty() {
180            println!("No unsafe call chain last nodes found.");
181            return;
182        }
183
184        println!(
185            "Found {} unique unsafe call chain last nodes:",
186            last_nodes.len()
187        );
188        for (i, node) in last_nodes.iter().enumerate() {
189            println!("{}. {}", i + 1, node);
190        }
191    }
192
193    pub fn filter_by_check_level(
194        tcx: TyCtxt<'tcx>,
195        check_level: &CheckLevel,
196        def_id: DefId,
197    ) -> bool {
198        match *check_level {
199            CheckLevel::High => check_visibility(tcx, def_id),
200            _ => true,
201        }
202    }
203
204    pub fn check_soundness(&mut self, def_id: DefId, fn_map: &FxHashMap<DefId, AAResult>) {
205        let check_results = self.body_visit_and_check(def_id, fn_map);
206        let tcx = self.tcx;
207        if !check_results.is_empty() {
208            Self::show_check_results(tcx, def_id, check_results);
209        }
210    }
211
212    pub fn annotate_safety(&self, def_id: DefId) {
213        let annotation_results = self.get_annotation(def_id);
214        if annotation_results.is_empty() {
215            return;
216        }
217        Self::show_annotate_results(self.tcx, def_id, annotation_results);
218    }
219
220    pub fn body_visit_and_check(
221        &mut self,
222        def_id: DefId,
223        fn_map: &FxHashMap<DefId, AAResult>,
224    ) -> Vec<CheckResult> {
225        let mut body_visitor = BodyVisitor::new(self.tcx, def_id, self.global_recorder.clone(), 0);
226        let target_name = get_cleaned_def_path_name(self.tcx, def_id);
227        if !target_name.contains("into_raw_parts_with_alloc") {
228            return body_visitor.check_results;
229        }
230        rap_info!("Begin verification process for: {:?}", target_name);
231        if get_type(self.tcx, def_id) == 1 {
232            let func_cons = get_cons(self.tcx, def_id);
233            let mut base_inter_result = InterResultNode::new_default(get_adt_ty(self.tcx, def_id));
234            for func_con in func_cons {
235                let mut cons_body_visitor =
236                    BodyVisitor::new(self.tcx, func_con.0, self.global_recorder.clone(), 0);
237                let cons_fields_result = cons_body_visitor.path_forward_check(fn_map);
238                // cache and merge fields' states
239                let cons_name = get_cleaned_def_path_name(self.tcx, func_con.0);
240                println!(
241                    "cons {cons_name} state results {:?}",
242                    cons_fields_result.clone()
243                );
244                base_inter_result.merge(cons_fields_result);
245            }
246            // update method body's states by constructors' states
247            body_visitor.update_fields_states(base_inter_result);
248            // get mutable methods and TODO: update target method's states
249            let mutable_methods = get_all_mutable_methods(self.tcx, def_id);
250            for mm in mutable_methods {
251                println!("mut method {:?}", get_cleaned_def_path_name(self.tcx, mm.0));
252                // has_tainted_fields(self.tcx, mm.0, 1);
253            }
254            // analyze body's states
255            body_visitor.path_forward_check(fn_map);
256        } else {
257            body_visitor.path_forward_check(fn_map);
258        }
259        body_visitor.check_results
260    }
261
262    pub fn body_visit_and_check_uig(&self, def_id: DefId) {
263        let mut uig_checker = UnsafetyIsolationCheck::new(self.tcx);
264        let func_type = get_type(self.tcx, def_id);
265        if func_type == 1 && !self.get_annotation(def_id).is_empty() {
266            let func_cons = uig_checker.search_constructor(def_id);
267            for func_con in func_cons {
268                if check_safety(self.tcx, func_con) {
269                    Self::show_annotate_results(self.tcx, func_con, self.get_annotation(def_id));
270                    // uphold safety to unsafe constructor
271                }
272            }
273        }
274    }
275
276    pub fn get_annotation(&self, def_id: DefId) -> HashSet<String> {
277        let mut results = HashSet::new();
278        if !self.tcx.is_mir_available(def_id) {
279            return results;
280        }
281        let body = self.tcx.optimized_mir(def_id);
282        let basicblocks = &body.basic_blocks;
283        for i in 0..basicblocks.len() {
284            let iter = BasicBlock::from(i);
285            let terminator = basicblocks[iter].terminator.clone().unwrap();
286            if let TerminatorKind::Call {
287                ref func,
288                args: _,
289                destination: _,
290                target: _,
291                unwind: _,
292                call_source: _,
293                fn_span: _,
294            } = terminator.kind
295            {
296                match func {
297                    Operand::Constant(c) => {
298                        if let ty::FnDef(id, ..) = c.ty().kind() {
299                            if !get_sp(self.tcx, *id).is_empty() {
300                                results.extend(get_sp(self.tcx, *id));
301                            } else {
302                                results.extend(self.get_annotation(*id));
303                            }
304                        }
305                    }
306                    _ => {}
307                }
308            }
309        }
310        results
311    }
312
313    pub fn show_check_results(tcx: TyCtxt<'tcx>, def_id: DefId, check_results: Vec<CheckResult>) {
314        rap_info!(
315            "--------In safe function {:?}---------",
316            get_cleaned_def_path_name(tcx, def_id)
317        );
318        for check_result in &check_results {
319            cond_print!(
320                !check_result.failed_contracts.is_empty(),
321                "  Use unsafe api {:?}.",
322                check_result.func_name
323            );
324            for failed_contract in &check_result.failed_contracts {
325                cond_print!(
326                    true,
327                    "      Argument {}'s failed Sps: {:?}",
328                    failed_contract.0,
329                    failed_contract.1
330                );
331            }
332            for passed_contract in &check_result.passed_contracts {
333                cond_print!(
334                    false,
335                    "      Argument {}'s passed Sps: {:?}",
336                    passed_contract.0,
337                    passed_contract.1
338                );
339            }
340        }
341    }
342
343    pub fn show_annotate_results(
344        tcx: TyCtxt<'tcx>,
345        def_id: DefId,
346        annotation_results: HashSet<String>,
347    ) {
348        rap_info!(
349            "--------In unsafe function {:?}---------",
350            get_cleaned_def_path_name(tcx, def_id)
351        );
352        rap_warn!("Lack safety annotations: {:?}.", annotation_results);
353    }
354}