rapx/analysis/senryx/
mod.rs

1#[allow(unused)]
2pub mod contracts;
3#[allow(unused)]
4pub mod dominated_graph;
5pub mod generic_check;
6// pub mod inter_record;
7pub mod matcher;
8pub mod symbolic_analysis;
9#[allow(unused)]
10pub mod visitor;
11#[allow(unused)]
12pub mod visitor_check;
13use dominated_graph::InterResultNode;
14use rustc_data_structures::fx::FxHashMap;
15use rustc_hir::{Safety, def_id::DefId};
16use rustc_middle::{
17    mir::{BasicBlock, Operand, TerminatorKind},
18    ty::{self, TyCtxt},
19};
20use std::collections::HashSet;
21use visitor::{BodyVisitor, CheckResult};
22
23use crate::analysis::{
24    Analysis,
25    core::alias_analysis::{AAResult, AliasAnalysis, default::AliasAnalyzer},
26    upg::{fn_collector::FnCollector, hir_visitor::ContainsUnsafe},
27    utils::fn_info::*,
28};
29
30macro_rules! cond_print {
31    ($cond:expr, $($t:tt)*) => {if $cond {rap_warn!($($t)*)} else {rap_info!($($t)*)}};
32}
33
34pub enum CheckLevel {
35    High,
36    Medium,
37    Low,
38}
39
40pub struct SenryxCheck<'tcx> {
41    pub tcx: TyCtxt<'tcx>,
42    pub threshhold: usize,
43}
44
45impl<'tcx> SenryxCheck<'tcx> {
46    /// Create a new SenryxCheck instance.
47    ///
48    /// Parameters:
49    /// - `tcx`: compiler TyCtxt for querying types/definitions.
50    /// - `threshhold`: a numeric threshold used by checks.
51    pub fn new(tcx: TyCtxt<'tcx>, threshhold: usize) -> Self {
52        Self { tcx, threshhold }
53    }
54
55    /// Start the checking pass over the collected functions.
56    ///
57    /// - `check_level` controls filtering of which functions to analyze.
58    /// - `is_verify` toggles verification mode (vs. annotation mode).
59    pub fn start(&mut self, check_level: CheckLevel, is_verify: bool) {
60        let tcx = self.tcx;
61        // Build alias information for all functions first.
62        let mut analyzer = AliasAnalyzer::new(self.tcx);
63        analyzer.run(); // populate alias results
64        let fn_map = &analyzer.get_all_fn_alias();
65
66        // Collect functions of interest (e.g. from UPG/collector)
67        let related_items = FnCollector::collect(tcx);
68        for vec in related_items.clone().values() {
69            for (body_id, _span) in vec {
70                // Check whether the function/block contains unsafe code
71                let (function_unsafe, block_unsafe) =
72                    ContainsUnsafe::contains_unsafe(tcx, *body_id);
73
74                let def_id = tcx.hir_body_owner_def_id(*body_id).to_def_id();
75
76                // Gather std unsafe callees used by this function
77                let std_unsafe_callee = get_all_std_unsafe_callees(self.tcx, def_id);
78
79                // Apply filtering by configured check level
80                if !Self::filter_by_check_level(tcx, &check_level, def_id) {
81                    continue;
82                }
83
84                // If the body-level contains unsafe ops and we are verifying, run soundness checks
85                if block_unsafe && is_verify && !std_unsafe_callee.is_empty() {
86                    self.check_soundness(def_id, fn_map);
87                }
88
89                // In non-verify mode we might annotate or produce diagnostics (disabled here)
90                if function_unsafe && !is_verify && !std_unsafe_callee.is_empty() {
91                    // annotation or other non-verification actions can be placed here
92                }
93            }
94        }
95    }
96
97    /// Iterate standard library `alloc` functions and run verification for those
98    /// that match the verification target predicate.
99    pub fn start_analyze_std_func(&mut self) {
100        // Gather function definitions from the `alloc` crate
101        let v_fn_def: Vec<_> = rustc_public::find_crates("alloc")
102            .iter()
103            .flat_map(|krate| krate.fn_defs())
104            .collect();
105        for fn_def in &v_fn_def {
106            let def_id = crate::def_id::to_internal(fn_def, self.tcx);
107            if is_verify_target_func(self.tcx, def_id) {
108                rap_info!(
109                    "Begin verification process for: {:?}",
110                    get_cleaned_def_path_name(self.tcx, def_id)
111                );
112
113                // Run main body visitor/check for this def_id
114                let check_results = self.body_visit_and_check(def_id, &FxHashMap::default());
115                if !check_results.is_empty() {
116                    Self::show_check_results(self.tcx, def_id, check_results);
117                }
118            }
119        }
120    }
121
122    /// Analyze unsafe call chains across standard library functions and print
123    /// the last non-intrinsic nodes for manual inspection.
124    pub fn start_analyze_std_func_chains(&mut self) {
125        let all_std_fn_def = get_all_std_fns_by_rustc_public(self.tcx);
126        let mut last_nodes = HashSet::new();
127        for &def_id in &all_std_fn_def {
128            // Skip non-public functions based on visibility filter
129            if !check_visibility(self.tcx, def_id) {
130                continue;
131            }
132
133            // Get unsafe call chains for the function
134            let chains = get_all_std_unsafe_chains(self.tcx, def_id);
135
136            // Filter out trivial chains unless the function is explicitly unsafe
137            let valid_chains: Vec<Vec<String>> = chains
138                .into_iter()
139                .filter(|chain| {
140                    if chain.len() > 1 {
141                        return true;
142                    }
143                    if chain.len() == 1 {
144                        if check_safety(self.tcx, def_id) == Safety::Unsafe {
145                            return true;
146                        }
147                    }
148                    false
149                })
150                .collect();
151
152            // Collect last nodes that are relevant for further inspection
153            let mut last = true;
154            for chain in &valid_chains {
155                if let Some(last_node) = chain.last() {
156                    if !last_node.contains("intrinsic") && !last_node.contains("aarch64") {
157                        last_nodes.insert(last_node.clone());
158                        last = false;
159                    }
160                }
161            }
162            if last {
163                continue;
164            }
165        }
166        Self::print_last_nodes(&last_nodes);
167    }
168
169    /// Pretty-print a set of last nodes discovered in unsafe call chains.
170    pub fn print_last_nodes(last_nodes: &HashSet<String>) {
171        if last_nodes.is_empty() {
172            println!("No unsafe call chain last nodes found.");
173            return;
174        }
175
176        println!(
177            "Found {} unique unsafe call chain last nodes:",
178            last_nodes.len()
179        );
180        for (i, node) in last_nodes.iter().enumerate() {
181            println!("{}. {}", i + 1, node);
182        }
183    }
184
185    /// Filter functions by configured check level.
186    /// - High: only publicly visible functions are considered.
187    /// - Medium/Low: accept all functions.
188    pub fn filter_by_check_level(
189        tcx: TyCtxt<'tcx>,
190        check_level: &CheckLevel,
191        def_id: DefId,
192    ) -> bool {
193        match *check_level {
194            CheckLevel::High => check_visibility(tcx, def_id),
195            _ => true,
196        }
197    }
198
199    /// Run soundness checks on a single function identified by `def_id` using
200    /// the provided alias analysis map `fn_map`.
201    pub fn check_soundness(&mut self, def_id: DefId, fn_map: &FxHashMap<DefId, AAResult>) {
202        let check_results = self.body_visit_and_check(def_id, fn_map);
203        let tcx = self.tcx;
204        if !check_results.is_empty() {
205            // Display aggregated results for this function
206            Self::show_check_results(tcx, def_id, check_results);
207        }
208    }
209
210    /// Collect safety annotations for `def_id` and display them if present.
211    pub fn annotate_safety(&self, def_id: DefId) {
212        let annotation_results = self.get_annotation(def_id);
213        if annotation_results.is_empty() {
214            return;
215        }
216        Self::show_annotate_results(self.tcx, def_id, annotation_results);
217    }
218
219    /// Visit the function body and run path-sensitive checks, returning
220    /// a list of `CheckResult`s summarizing passed/failed contracts.
221    ///
222    /// If the function is a method, constructor results are merged into the
223    /// method's initial state before analyzing the method body.
224    pub fn body_visit_and_check(
225        &mut self,
226        def_id: DefId,
227        fn_map: &FxHashMap<DefId, AAResult>,
228    ) -> Vec<CheckResult> {
229        // Create a body visitor for the target function
230        let mut body_visitor = BodyVisitor::new(self.tcx, def_id, 0);
231        let target_name = get_cleaned_def_path_name(self.tcx, def_id);
232        rap_info!("Begin verification process for: {:?}", target_name);
233
234        // If this is a method, gather constructor-derived state first
235        if get_type(self.tcx, def_id) == FnKind::Method {
236            let cons = get_cons(self.tcx, def_id);
237            // Start with a default inter-result node for ADT fields
238            let mut base_inter_result = InterResultNode::new_default(get_adt_ty(self.tcx, def_id));
239            for con in cons {
240                let mut cons_body_visitor = BodyVisitor::new(self.tcx, con, 0);
241                // Analyze constructor and merge its field states
242                let cons_fields_result = cons_body_visitor.path_forward_check(fn_map);
243                // cache and merge fields' states
244                let cons_name = get_cleaned_def_path_name(self.tcx, con);
245                println!(
246                    "cons {cons_name} state results {:?}",
247                    cons_fields_result.clone()
248                );
249                base_inter_result.merge(cons_fields_result);
250            }
251
252            // Seed the method visitor with constructor-derived field states
253            body_visitor.update_fields_states(base_inter_result);
254
255            // Optionally inspect mutable methods - diagnostic only
256            let mutable_methods = get_all_mutable_methods(self.tcx, def_id);
257            for mm in mutable_methods {
258                println!("mut method {:?}", get_cleaned_def_path_name(self.tcx, mm.0));
259            }
260
261            // Analyze the method body
262            body_visitor.path_forward_check(fn_map);
263        } else {
264            // Non-method functions: just analyze body directly
265            body_visitor.path_forward_check(fn_map);
266        }
267        body_visitor.check_results
268    }
269
270    /// Variant of `body_visit_and_check` used for UI-guided annotation flows.
271    pub fn body_visit_and_check_uig(&self, def_id: DefId) {
272        let func_type = get_type(self.tcx, def_id);
273        if func_type == FnKind::Method && !self.get_annotation(def_id).is_empty() {
274            let func_cons = search_constructor(self.tcx, def_id);
275            for func_con in func_cons {
276                if check_safety(self.tcx, func_con) == Safety::Unsafe {
277                    // Display annotations for unsafe constructors
278                    Self::show_annotate_results(self.tcx, func_con, self.get_annotation(def_id));
279                }
280            }
281        }
282    }
283
284    /// Collect annotation strings for a function by scanning calls in MIR.
285    /// For each call, if the callee has a safety annotation it is aggregated; otherwise
286    /// the callee's annotations (recursively) are collected.
287    pub fn get_annotation(&self, def_id: DefId) -> HashSet<String> {
288        let mut results = HashSet::new();
289        if !self.tcx.is_mir_available(def_id) {
290            return results;
291        }
292        let body = self.tcx.optimized_mir(def_id);
293        let basicblocks = &body.basic_blocks;
294        for i in 0..basicblocks.len() {
295            let iter = BasicBlock::from(i);
296            let terminator = basicblocks[iter].terminator.clone().unwrap();
297            if let TerminatorKind::Call {
298                ref func,
299                args: _,
300                destination: _,
301                target: _,
302                unwind: _,
303                call_source: _,
304                fn_span: _,
305            } = terminator.kind
306            {
307                match func {
308                    Operand::Constant(c) => {
309                        if let ty::FnDef(id, ..) = c.ty().kind() {
310                            // If the callee has direct annotations, extend results.
311                            if !get_sp(self.tcx, *id).is_empty() {
312                                results.extend(get_sp(self.tcx, *id));
313                            } else {
314                                // Otherwise, recurse into callee's annotations.
315                                results.extend(self.get_annotation(*id));
316                            }
317                        }
318                    }
319                    _ => {}
320                }
321            }
322        }
323        results
324    }
325
326    /// Pretty-print aggregated check results for a function.
327    /// Shows succeeded and failed contracts grouped across all arguments.
328    pub fn show_check_results(tcx: TyCtxt<'tcx>, def_id: DefId, check_results: Vec<CheckResult>) {
329        rap_info!(
330            "--------In safe function {:?}---------",
331            get_cleaned_def_path_name(tcx, def_id)
332        );
333        for check_result in &check_results {
334            // Aggregate all failed contracts from all arguments
335            let mut all_failed = HashSet::new();
336            for set in check_result.failed_contracts.values() {
337                for sp in set {
338                    all_failed.insert(sp);
339                }
340            }
341
342            // Aggregate all passed contracts from all arguments
343            let mut all_passed = HashSet::new();
344            for set in check_result.passed_contracts.values() {
345                for sp in set {
346                    all_passed.insert(sp);
347                }
348            }
349
350            // Print the API name with conditional coloring
351            cond_print!(
352                !all_failed.is_empty(),
353                "  Use unsafe api {:?}.",
354                check_result.func_name
355            );
356
357            // Print aggregated Failed set
358            if !all_failed.is_empty() {
359                let mut failed_sorted: Vec<&String> = all_failed.into_iter().collect();
360                failed_sorted.sort();
361                cond_print!(true, "      Failed: {:?}", failed_sorted);
362            }
363
364            // Print aggregated Passed set
365            if !all_passed.is_empty() {
366                let mut passed_sorted: Vec<&String> = all_passed.into_iter().collect();
367                passed_sorted.sort();
368                cond_print!(false, "      Passed: {:?}", passed_sorted);
369            }
370        }
371    }
372
373    /// Show annotation results for unsafe functions (diagnostic output).
374    pub fn show_annotate_results(
375        tcx: TyCtxt<'tcx>,
376        def_id: DefId,
377        annotation_results: HashSet<String>,
378    ) {
379        rap_info!(
380            "--------In unsafe function {:?}---------",
381            get_cleaned_def_path_name(tcx, def_id)
382        );
383        rap_warn!("Lack safety annotations: {:?}.", annotation_results);
384    }
385}