rapx/analysis/core/api_dependency/
visitor.rs

1use std::io::Write;
2
3use super::extract::extract_constraints;
4use crate::{
5    analysis::core::api_dependency::{ApiDependencyGraph, Edge, Node},
6    rap_debug,
7};
8use rustc_hir::{
9    def_id::{DefId, LocalDefId},
10    intravisit::{FnKind, Visitor},
11    BodyId, FnDecl,
12};
13
14use rustc_middle::ty::{self, TyCtxt};
15use rustc_span::Span;
16
17pub struct FnVisitor<'tcx, 'a> {
18    fn_cnt: usize,
19    tcx: TyCtxt<'tcx>,
20    funcs: Vec<DefId>,
21    current_fn_did: Option<DefId>,
22    graph: &'a mut ApiDependencyGraph<'tcx>,
23}
24
25impl<'tcx, 'a> FnVisitor<'tcx, 'a> {
26    pub fn new(tcx: TyCtxt<'tcx>, graph: &'a mut ApiDependencyGraph<'tcx>) -> FnVisitor<'tcx, 'a> {
27        let fn_cnt = 0;
28        let funcs = Vec::new();
29        FnVisitor {
30            fn_cnt,
31            tcx,
32            graph,
33            funcs,
34            current_fn_did: None,
35        }
36    }
37    pub fn fn_cnt(&self) -> usize {
38        self.fn_cnt
39    }
40    pub fn write_funcs<T: Write>(&self, f: &mut T) {
41        for id in &self.funcs {
42            write!(f, "{}\n", self.tcx.def_path_str(id)).expect("fail when write funcs");
43        }
44    }
45}
46
47fn get_bound_var_attr(var: ty::BoundVariableKind) -> (String, bool) {
48    let name: String;
49    let is_lifetime;
50    match var {
51        ty::BoundVariableKind::Ty(bound_ty_kind) => {
52            is_lifetime = false;
53            name = match bound_ty_kind {
54                ty::BoundTyKind::Param(_, sym) => sym.to_string(),
55                _ => "anon".to_string(),
56            }
57        }
58        ty::BoundVariableKind::Region(bound_region_kind) => {
59            is_lifetime = true;
60            name = match bound_region_kind {
61                ty::BoundRegionKind::Named(_, name) => name.to_string(),
62                _ => "anon".to_string(),
63            }
64        }
65        ty::BoundVariableKind::Const => {
66            is_lifetime = false;
67            name = "anon const".to_string();
68        }
69    }
70    (name, is_lifetime)
71}
72
73impl<'tcx, 'a> Visitor<'tcx> for FnVisitor<'tcx, 'a> {
74    fn visit_fn<'v>(
75        &mut self,
76        _fk: FnKind<'v>,
77        _fd: &'v FnDecl<'v>,
78        _b: BodyId,
79        _span: Span,
80        id: LocalDefId,
81    ) -> Self::Result {
82        let fn_def_id = id.to_def_id();
83        self.fn_cnt += 1;
84        self.funcs.push(fn_def_id);
85        let api_node = self.graph.get_node(Node::api(id));
86
87        let early_fn_sig = self.tcx.fn_sig(fn_def_id);
88        let binder_fn_sig = early_fn_sig.instantiate_identity();
89        let fn_sig = self
90            .tcx
91            .liberate_late_bound_regions(fn_def_id, binder_fn_sig);
92        rap_debug!("visit {}", fn_sig);
93
94        // add generic param def to graph
95        // NOTE: generics_of query only return early bound generics
96        let generics = self.tcx.generics_of(fn_def_id);
97        let early_generic_count = generics.count();
98        rap_debug!("early bound generic count = {}", early_generic_count);
99        for i in 0..early_generic_count {
100            let generic_param_def = generics.param_at(i, self.tcx);
101            rap_debug!("early bound generic#{i}: {:?}", generic_param_def);
102            let node_index = self.graph.get_node(Node::generic_param_def(
103                fn_def_id,
104                i,
105                generic_param_def.name,
106                !generic_param_def.kind.is_ty_or_const(),
107            ));
108            self.graph
109                .add_edge_once(api_node, node_index, Edge::fn2generic());
110        }
111
112        // add late bound generic
113        rap_debug!(
114            "late bound generic count = {}",
115            binder_fn_sig.bound_vars().len()
116        );
117        for (i, var) in binder_fn_sig.bound_vars().iter().enumerate() {
118            rap_debug!("bound var#{i}: {var:?}");
119            let (name, is_lifetime) = get_bound_var_attr(var);
120            let node_index = self.graph.get_node(Node::generic_param_def(
121                fn_def_id,
122                early_generic_count + i,
123                name,
124                is_lifetime,
125            ));
126            self.graph
127                .add_edge_once(api_node, node_index, Edge::fn2generic());
128        }
129
130        extract_constraints(fn_def_id, self.tcx);
131
132        // add inputs/output to graph, and compute constraints based on subtyping
133        for (no, input_ty) in fn_sig.inputs().iter().enumerate() {
134            // let free_input_ty = input_ty.fold_with(folder)
135            let input_node = self.graph.get_node(Node::ty(*input_ty));
136            self.graph.add_edge(input_node, api_node, Edge::arg(no));
137        }
138
139        let output_ty = fn_sig.output();
140        let output_node = self.graph.get_node(Node::ty(output_ty));
141        self.graph.add_edge(api_node, output_node, Edge::ret());
142        rap_debug!("exit visit_fn");
143    }
144}