rapx/analysis/core/api_dependency/graph/
resolve.rs

1use super::dep_edge::DepEdge;
2use super::dep_node::{desc_str, DepNode};
3use super::transform::TransformKind;
4use super::ty_wrapper::TyWrapper;
5use super::Config;
6use crate::analysis::core::api_dependency::mono::Mono;
7use crate::analysis::core::api_dependency::utils::{is_fuzzable_ty, ty_complexity};
8use crate::analysis::core::api_dependency::visitor::FnVisitor;
9use crate::analysis::core::api_dependency::ApiDependencyGraph;
10use crate::analysis::core::api_dependency::{mono, utils};
11use crate::utils::fs::rap_create_file;
12use crate::{rap_debug, rap_info, rap_trace};
13use petgraph::dot;
14use petgraph::graph::NodeIndex;
15use petgraph::visit::{NodeIndexable, Visitable};
16use petgraph::Direction::{self, Incoming};
17use petgraph::Graph;
18use rand::Rng;
19use rustc_hir::def_id::DefId;
20use rustc_middle::ty::{self, GenericArgsRef, TraitRef, Ty, TyCtxt};
21use rustc_span::sym::require;
22use std::collections::HashMap;
23use std::collections::HashSet;
24use std::collections::VecDeque;
25use std::hash::Hash;
26use std::io::Write;
27use std::path::Path;
28use std::time;
29
30fn add_return_type_if_reachable<'tcx>(
31    fn_did: DefId,
32    args: &[ty::GenericArg<'tcx>],
33    reachable_tys: &HashSet<TyWrapper<'tcx>>,
34    new_tys: &mut HashSet<Ty<'tcx>>,
35    tcx: TyCtxt<'tcx>,
36) -> bool {
37    let fn_sig = utils::fn_sig_with_generic_args(fn_did, args, tcx);
38    let inputs = fn_sig.inputs();
39    for input_ty in inputs {
40        if !is_fuzzable_ty(*input_ty, tcx) && !reachable_tys.contains(&TyWrapper::from(*input_ty)) {
41            return false;
42        }
43    }
44    let output_ty = fn_sig.output();
45    if !output_ty.is_unit() {
46        new_tys.insert(output_ty);
47    }
48    true
49}
50
51#[derive(Clone)]
52struct TypeCandidates<'tcx> {
53    tcx: TyCtxt<'tcx>,
54    candidates: HashSet<TyWrapper<'tcx>>,
55    max_complexity: usize,
56}
57
58impl<'tcx> TypeCandidates<'tcx> {
59    pub fn new(tcx: TyCtxt<'tcx>, max_complexity: usize) -> Self {
60        TypeCandidates {
61            tcx,
62            candidates: HashSet::new(),
63            max_complexity,
64        }
65    }
66
67    pub fn insert(&mut self, ty: Ty<'tcx>) -> bool {
68        if ty_complexity(ty) <= self.max_complexity {
69            self.candidates.insert(ty.into())
70        } else {
71            false
72        }
73    }
74
75    pub fn insert_all(&mut self, ty: Ty<'tcx>) -> bool {
76        let complexity = ty_complexity(ty);
77        self.insert_all_with_complexity(ty, complexity)
78    }
79
80    pub fn insert_all_with_complexity(&mut self, ty: Ty<'tcx>, current_cmplx: usize) -> bool {
81        if current_cmplx > self.max_complexity {
82            return false;
83        }
84
85        // add T
86        let mut changed = self.candidates.insert(ty.into());
87
88        // add &T
89        changed |= self.insert_all_with_complexity(
90            Ty::new_ref(
91                self.tcx,
92                self.tcx.lifetimes.re_erased,
93                ty,
94                ty::Mutability::Not,
95            ),
96            current_cmplx + 1,
97        );
98
99        // add &mut T
100        changed |= self.insert_all_with_complexity(
101            Ty::new_ref(
102                self.tcx,
103                self.tcx.lifetimes.re_erased,
104                ty,
105                ty::Mutability::Mut,
106            ),
107            current_cmplx + 1,
108        );
109
110        // add &[T]
111        changed |= self.insert_all_with_complexity(
112            Ty::new_ref(
113                self.tcx,
114                self.tcx.lifetimes.re_erased,
115                Ty::new_slice(self.tcx, ty),
116                ty::Mutability::Not,
117            ),
118            current_cmplx + 2,
119        );
120
121        // add &mut [T]
122        changed |= self.insert_all_with_complexity(
123            Ty::new_ref(
124                self.tcx,
125                self.tcx.lifetimes.re_erased,
126                Ty::new_slice(self.tcx, ty),
127                ty::Mutability::Mut,
128            ),
129            current_cmplx + 2,
130        );
131
132        changed
133    }
134
135    pub fn add_prelude_tys(&mut self) {
136        let tcx = self.tcx;
137        let prelude_tys = [
138            tcx.types.bool,
139            tcx.types.char,
140            tcx.types.f32,
141            tcx.types.f64,
142            tcx.types.i8,
143            tcx.types.i16,
144            tcx.types.i32,
145            tcx.types.i64,
146            tcx.types.isize,
147            tcx.types.u8,
148            tcx.types.u16,
149            tcx.types.u32,
150            tcx.types.u64,
151            tcx.types.usize,
152            Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, tcx.types.str_),
153        ];
154        prelude_tys.into_iter().for_each(|ty| {
155            self.insert_all(ty);
156        });
157    }
158
159    pub fn candidates(&self) -> &HashSet<TyWrapper<'tcx>> {
160        &self.candidates
161    }
162}
163
164pub fn partion_generic_api<'tcx>(
165    all_apis: &HashSet<DefId>,
166    tcx: TyCtxt<'tcx>,
167) -> (HashSet<DefId>, HashSet<DefId>) {
168    let mut generic_api = HashSet::new();
169    let mut non_generic_api = HashSet::new();
170    for api_id in all_apis.iter() {
171        if tcx.generics_of(*api_id).requires_monomorphization(tcx) {
172            generic_api.insert(*api_id);
173        } else {
174            non_generic_api.insert(*api_id);
175        }
176    }
177    (non_generic_api, generic_api)
178}
179
180impl<'tcx> ApiDependencyGraph<'tcx> {
181    pub fn resolve_generic_api(&mut self) {
182        rap_info!("start resolving generic APIs");
183        let generic_map = self.search_reachable_apis();
184        self.prune_by_similarity(generic_map);
185    }
186
187    pub fn search_reachable_apis(&mut self) -> HashMap<DefId, HashSet<Mono<'tcx>>> {
188        let tcx = self.tcx;
189        let max_ty_complexity = 6;
190        let mut type_candidates = TypeCandidates::new(self.tcx, max_ty_complexity);
191
192        type_candidates.add_prelude_tys();
193
194        // let mut num_reachable = 0;
195        let mut generic_map: HashMap<DefId, HashSet<Mono>> = HashMap::new();
196
197        // initialize unreachable non generic API
198        let (mut unreachable_non_generic_api, generic_apis) =
199            partion_generic_api(&self.all_apis, tcx);
200
201        rap_debug!("[resolve_generic] non_generic_api = {unreachable_non_generic_api:?}");
202        rap_debug!("[resolve_generic] generic_api = {generic_apis:?}");
203
204        let mut num_iter = 0;
205        let max_iteration = 10;
206
207        loop {
208            num_iter += 1;
209            let all_reachable_tys = type_candidates.candidates();
210            rap_info!(
211                "start iter #{num_iter}, # of reachble types = {}",
212                all_reachable_tys.len()
213            );
214
215            // dump all reachable types to files, each line output a type
216            let mut file = rap_create_file(Path::new("reachable_types.txt"), "create file fail");
217            for ty in all_reachable_tys.iter() {
218                writeln!(file, "{}", ty.ty()).unwrap();
219            }
220
221            let mut current_tys = HashSet::new();
222            // check whether there is any non-generic api that is reachable
223            // if the api is reachable, add output type to reachble_tys,
224            // and remove fn_did from the set.
225            unreachable_non_generic_api.retain(|fn_did| {
226                !add_return_type_if_reachable(
227                    *fn_did,
228                    ty::GenericArgs::identity_for_item(tcx, *fn_did),
229                    all_reachable_tys,
230                    &mut current_tys,
231                    tcx,
232                )
233            });
234
235            // check each generic API for new monomorphic API
236            for fn_did in generic_apis.iter() {
237                let mono_set = mono::resolve_mono_apis(*fn_did, &all_reachable_tys, tcx);
238                rap_debug!(
239                    "[search_reachable_apis] {} -> {:?}",
240                    tcx.def_path_str(*fn_did),
241                    mono_set
242                );
243                for mono in mono_set.monos {
244                    let fn_sig = utils::fn_sig_with_generic_args(*fn_did, &mono.value, tcx);
245                    let output_ty = fn_sig.output();
246                    if generic_map.entry(*fn_did).or_default().insert(mono) {
247                        if !output_ty.is_unit() && ty_complexity(output_ty) <= max_ty_complexity {
248                            current_tys.insert(output_ty);
249                        }
250                    }
251                }
252            }
253
254            let mut changed = false;
255            for ty in current_tys {
256                changed = changed | type_candidates.insert_all(ty);
257            }
258
259            if !changed {
260                rap_info!("Terminate. Reachable types unchange in this iteration.");
261                break;
262            }
263            if num_iter >= max_iteration {
264                rap_info!("Terminate. Max iteration reached.");
265                break;
266            }
267        }
268
269        let mono_cnt = generic_map.values().fold(0, |acc, monos| acc + monos.len());
270
271        rap_debug!(
272            "# of reachable types: {}",
273            type_candidates.candidates().len()
274        );
275        rap_debug!("# of mono APIs: {}", mono_cnt);
276
277        generic_map
278    }
279
280    pub fn prune_by_similarity(&mut self, generic_map: HashMap<DefId, HashSet<Mono<'tcx>>>) {
281        let mut rng = rand::rng();
282        let mut reserved_map: HashMap<DefId, Vec<(GenericArgsRef<'tcx>, bool)>> = HashMap::new();
283
284        // transform into reserved map
285        for (fn_did, mono_set) in generic_map {
286            let entry = reserved_map.entry(fn_did).or_default();
287            mono_set.into_iter().for_each(|mono| {
288                let args = self.tcx.mk_args(&mono.value);
289                self.add_api(fn_did, args);
290                entry.push((args, false));
291            });
292        }
293        // add transform edges
294        self.update_transform_edges();
295
296        self.dump_to_dot(Path::new("api_graph_unpruned.dot"), self.tcx);
297        let (estimate, total) = self.estimate_coverage_distinct();
298        rap_info!(
299            "estimate API coverage before pruning: {:.2} ({}/{})",
300            estimate as f64 / total as f64,
301            estimate,
302            total
303        );
304
305        let mut visited = vec![false; self.graph.node_count()];
306        let mut reserved = vec![false; self.graph.node_count()];
307
308        // initialize reserved
309        // all non-generic API should be reserved
310        for idx in self.graph.node_indices() {
311            if let DepNode::Api(fn_did, _) = self.graph[idx] {
312                if !utils::fn_requires_monomorphization(fn_did, self.tcx) {
313                    reserved[idx.index()] = true;
314                }
315            }
316        }
317
318        // add all monomorphic APIs to API Graph, but select minimal set cover to be reserved
319        for (fn_did, monos) in &mut reserved_map {
320            select_minimal_set_cover(self.tcx, *fn_did, monos, &mut rng);
321            for (args, r) in monos {
322                if *r {
323                    let idx = self.get_index(DepNode::Api(*fn_did, args)).unwrap();
324                    reserved[idx.index()] = true;
325                }
326            }
327        }
328
329        // traverse from start node, if a node can achieve a reserved node,
330        // this node should be reserved as well
331        for node in self.graph.node_indices() {
332            if !visited[node.index()] && self.is_start_node_index(node) {
333                rap_trace!("start propagate from {:?}", self.graph[node]);
334                self.propagate_reserved(node, &mut visited, &mut reserved);
335            }
336        }
337
338        for node in self.graph.node_indices() {
339            if !visited[node.index()] {
340                rap_trace!("{:?} is unvisited", self.graph[node]);
341                self.propagate_reserved(node, &mut visited, &mut reserved);
342            }
343        }
344
345        let mut count = 0;
346        for idx in (0..self.graph.node_count()).rev() {
347            if !reserved[idx] {
348                self.graph
349                    .remove_node(NodeIndex::new(idx))
350                    .expect("remove should not fail");
351                count += 1;
352            }
353        }
354        self.recache();
355        rap_info!("remove {} nodes by pruning", count);
356        let (estimate, total) = self.estimate_coverage_distinct();
357        rap_info!(
358            "estimate API coverage after pruning: {:.2} ({}/{})",
359            estimate as f64 / total as f64,
360            estimate,
361            total
362        );
363    }
364
365    fn recache(&mut self) {
366        self.node_indices.clear();
367        self.ty_nodes.clear();
368        self.api_nodes.clear();
369        for idx in self.graph.node_indices() {
370            let node = &self.graph[idx];
371            self.node_indices.insert(node.clone(), idx);
372            match node {
373                DepNode::Api(..) => self.api_nodes.push(idx),
374                DepNode::Ty(..) => self.ty_nodes.push(idx),
375                _ => {}
376            }
377        }
378    }
379
380    pub fn propagate_reserved(
381        &self,
382        node: NodeIndex,
383        visited: &mut [bool],
384        reserved: &mut [bool],
385    ) -> bool {
386        visited[node.index()] = true;
387
388        match self.graph[node] {
389            DepNode::Api(..) => {
390                for neighbor in self.graph.neighbors(node) {
391                    if !visited[neighbor.index()] {
392                        reserved[node.index()] |=
393                            self.propagate_reserved(neighbor, visited, reserved);
394                    }
395                }
396            }
397            DepNode::Ty(..) => {
398                for neighbor in self.graph.neighbors(node) {
399                    if !visited[neighbor.index()] {
400                        self.propagate_reserved(neighbor, visited, reserved);
401                    }
402                    reserved[node.index()] |= reserved[neighbor.index()]
403                }
404            }
405        }
406
407        if reserved[node.index()] {
408            rap_trace!(
409                "[propagate_reserved] reserve: {:?}",
410                self.graph.node_weight(node).unwrap()
411            );
412        }
413        reserved[node.index()]
414    }
415}
416
417fn select_minimal_set_cover<'tcx, 'a>(
418    tcx: TyCtxt<'tcx>,
419    fn_did: DefId,
420    monos: &'a mut Vec<(ty::GenericArgsRef<'tcx>, bool)>,
421    rng: &mut impl Rng,
422) {
423    rap_debug!("select minimal set for: {}", tcx.def_path_str(fn_did));
424    let mut impl_vec = Vec::new();
425    for (args, _) in monos.iter() {
426        let impls = mono::get_impls(tcx, fn_did, args);
427        impl_vec.push(impls);
428    }
429
430    let mut selected_cnt = 0;
431    let mut complete = HashSet::new();
432    loop {
433        let mut current_max = 0;
434        let mut idx = 0;
435        for i in 0..impl_vec.len() {
436            let impls = &impl_vec[i];
437            let size = impls.iter().fold(0, |cnt, did| {
438                if !complete.contains(did) {
439                    cnt + 1
440                } else {
441                    cnt
442                }
443            });
444            if size > current_max {
445                current_max = size;
446                idx = i;
447            }
448        }
449        if current_max == 0 {
450            break;
451        }
452        selected_cnt += 1;
453        monos[idx].1 = true;
454        rap_debug!("select: {:?}", monos[idx].0);
455        impl_vec[idx].iter().for_each(|did| {
456            complete.insert(*did);
457        });
458    }
459
460    if selected_cnt == 0 {
461        let idx = rng.random_range(0..impl_vec.len());
462        rap_debug!("random select: {:?}", monos[idx].0);
463        monos[idx].1 = true;
464    }
465}