use crate::{
analysis::utils::fn_info::{display_hashmap, get_pointee, is_ptr, is_ref},
rap_warn,
};
use rustc_hir::def_id::DefId;
use rustc_middle::mir::Local;
use rustc_middle::ty::{Ty, TyCtxt};
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct States {
pub nonnull: bool,
pub allocator_consistency: bool,
}
impl States {
pub fn new() -> Self {
Self {
nonnull: true,
allocator_consistency: true,
}
}
pub fn new_unknown() -> Self {
Self {
nonnull: false,
allocator_consistency: false,
}
}
}
#[derive(Debug, Clone)]
pub struct VariableNode<'tcx> {
pub id: usize,
pub alias_set: HashSet<usize>,
points_to: Option<usize>,
pointed_by: HashSet<usize>,
pub field: HashMap<usize, usize>,
pub ty: Option<Ty<'tcx>>,
pub is_dropped: bool,
pub states: States,
}
impl<'tcx> VariableNode<'tcx> {
pub fn new(
id: usize,
points_to: Option<usize>,
pointed_by: HashSet<usize>,
ty: Option<Ty<'tcx>>,
states: States,
) -> Self {
VariableNode {
id,
alias_set: HashSet::from([id]),
points_to,
pointed_by,
field: HashMap::new(),
ty,
is_dropped: false,
states,
}
}
pub fn new_default(id: usize, ty: Option<Ty<'tcx>>) -> Self {
VariableNode {
id,
alias_set: HashSet::from([id]),
points_to: None,
pointed_by: HashSet::new(),
field: HashMap::new(),
ty,
is_dropped: false,
states: States::new(),
}
}
pub fn new_with_states(id: usize, ty: Option<Ty<'tcx>>, states: States) -> Self {
VariableNode {
id,
alias_set: HashSet::from([id]),
points_to: None,
pointed_by: HashSet::new(),
field: HashMap::new(),
ty,
is_dropped: false,
states,
}
}
}
pub struct DominatedGraph<'tcx> {
pub tcx: TyCtxt<'tcx>,
pub def_id: DefId,
pub local_len: usize,
pub variables: HashMap<usize, VariableNode<'tcx>>,
}
impl<'tcx> DominatedGraph<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, def_id: DefId) -> Self {
let body = tcx.optimized_mir(def_id);
let locals = body.local_decls.clone();
let fn_sig = tcx.fn_sig(def_id).skip_binder();
let param_len = fn_sig.inputs().skip_binder().len();
let mut var_map: HashMap<usize, VariableNode<'_>> = HashMap::new();
let mut obj_cnt = 0;
for (idx, local) in locals.iter().enumerate() {
let local_ty = local.ty;
var_map.insert(idx, VariableNode::new_default(idx, Some(local_ty)));
}
Self {
tcx,
def_id,
local_len: locals.len(),
variables: var_map,
}
}
pub fn init_arg(&mut self) {
let body = self.tcx.optimized_mir(self.def_id);
let locals = body.local_decls.clone();
let fn_sig = self.tcx.fn_sig(self.def_id).skip_binder();
let param_len = fn_sig.inputs().skip_binder().len();
for idx in 1..param_len + 1 {
let local_ty = locals[Local::from(idx)].ty;
self.generate_ptr_with_obj_node(local_ty, idx);
}
}
pub fn generate_ptr_with_obj_node(&mut self, local_ty: Ty<'tcx>, idx: usize) {
let new_id = self.generate_node_id();
if is_ptr(local_ty) {
self.get_var_node_mut(idx).unwrap().points_to = Some(new_id);
self.insert_node(
new_id,
Some(get_pointee(local_ty)),
idx,
None,
States::new_unknown(),
);
} else if is_ref(local_ty) {
self.get_var_node_mut(idx).unwrap().points_to = Some(new_id);
self.insert_node(
new_id,
Some(get_pointee(local_ty)),
idx,
None,
States::new(),
);
}
}
pub fn get_local_ty_by_place(&self, arg: usize) -> Option<Ty<'tcx>> {
let body = self.tcx.optimized_mir(self.def_id);
let locals = body.local_decls.clone();
if arg < locals.len() {
return Some(locals[Local::from(arg)].ty);
} else {
return self.get_var_node(arg).unwrap().ty;
}
}
pub fn get_obj_ty_through_chain(&self, arg: usize) -> Option<Ty<'tcx>> {
let var = self.get_var_node(arg).unwrap();
if let Some(pointed_idx) = var.points_to {
let pointed_var = self.get_var_node(pointed_idx).unwrap();
return pointed_var.ty;
} else {
return var.ty;
}
}
pub fn get_point_to_id(&self, arg: usize) -> usize {
let var = self.get_var_node(arg).unwrap();
if let Some(pointed_idx) = var.points_to {
return pointed_idx;
} else {
return arg;
}
}
pub fn is_local(&self, node_id: usize) -> bool {
self.local_len > node_id
}
}
impl<'tcx> DominatedGraph<'tcx> {
pub fn generate_node_id(&self) -> usize {
if self.variables.len() == 0 || *self.variables.keys().max().unwrap() < self.local_len {
return self.local_len;
}
return *self.variables.keys().max().unwrap() + 1;
}
pub fn get_field_node_id(
&mut self,
local: usize,
field_idx: usize,
ty: Option<Ty<'tcx>>,
) -> usize {
let node = self.get_var_node(local).unwrap();
if let Some(alias_local) = node.field.get(&field_idx) {
return *alias_local;
} else {
self.insert_field_node(local, field_idx, ty)
}
}
pub fn insert_field_node(
&mut self,
local: usize,
field_idx: usize,
ty: Option<Ty<'tcx>>,
) -> usize {
let new_id = self.generate_node_id();
self.variables
.insert(new_id, VariableNode::new_default(new_id, ty));
let mut_node = self.get_var_node_mut(local).unwrap();
mut_node.field.insert(field_idx, new_id);
return new_id;
}
pub fn find_var_id_with_fields_seq(&mut self, local: usize, fields: Vec<usize>) -> usize {
let mut cur = local;
for field in fields {
cur = self.get_field_node_id(cur, field, None);
}
return cur;
}
pub fn point(&mut self, lv: usize, rv: usize) {
let rv_node = self.get_var_node_mut(rv).unwrap();
rv_node.pointed_by.insert(lv);
let lv_node = self.get_var_node_mut(lv).unwrap();
let ori_to = lv_node.points_to.clone();
lv_node.points_to = Some(rv);
if let Some(to) = ori_to {
let ori_to_node = self.get_var_node_mut(to).unwrap();
ori_to_node.pointed_by.remove(&lv);
}
}
pub fn get_var_nod_id(&self, local_id: usize) -> usize {
self.get_var_node(local_id).unwrap().id
}
pub fn get_var_node(&self, local_id: usize) -> Option<&VariableNode<'tcx>> {
for (_idx, var_node) in &self.variables {
if var_node.alias_set.contains(&local_id) {
return Some(var_node);
}
}
rap_warn!("def id:{:?}, local_id: {local_id}", self.def_id);
display_hashmap(&self.variables, 1);
None
}
pub fn get_var_node_mut(&mut self, local_id: usize) -> Option<&mut VariableNode<'tcx>> {
let va = self.variables.clone();
for (_idx, var_node) in &mut self.variables {
if var_node.alias_set.contains(&local_id) {
return Some(var_node);
}
}
rap_warn!("def id:{:?}, local_id: {local_id}", self.def_id);
display_hashmap(&va, 1);
None
}
pub fn merge(&mut self, lv: usize, rv: usize) {
let lv_node = self.get_var_node_mut(lv).unwrap().clone();
if lv_node.alias_set.contains(&rv) {
return;
}
for lv_pointed_by in lv_node.pointed_by.clone() {
self.point(lv_pointed_by, rv);
}
let lv_node = self.get_var_node_mut(lv).unwrap();
lv_node.alias_set.remove(&lv);
let rv_node = self.get_var_node_mut(rv).unwrap();
rv_node.alias_set.insert(lv);
}
pub fn copy_node(&mut self, lv: usize, rv: usize) {
let rv_node = self.get_var_node_mut(rv).unwrap().clone();
let lv_node = self.get_var_node_mut(lv).unwrap();
lv_node.states = rv_node.states;
lv_node.is_dropped = rv_node.is_dropped;
if let Some(to) = &rv_node.points_to {
self.point(lv, *to);
}
}
pub fn break_node_connection(&mut self, lv: usize, rv: usize) {
let rv_node = self.get_var_node_mut(rv).unwrap();
rv_node.pointed_by.remove(&lv);
let lv_node = self.get_var_node_mut(lv).unwrap();
lv_node.points_to = None;
}
pub fn insert_node(
&mut self,
dv: usize,
ty: Option<Ty<'tcx>>,
parent_id: usize,
child_id: Option<usize>,
state: States,
) {
self.variables.insert(
dv,
VariableNode::new(dv, child_id, HashSet::from([parent_id]), ty, state),
);
}
pub fn delete_node(&mut self, idx: usize) {
let node = self.get_var_node(idx).unwrap().clone();
for pre_idx in &node.pointed_by.clone() {
let pre_node = self.get_var_node_mut(*pre_idx).unwrap();
pre_node.points_to = None;
}
if let Some(to) = &node.points_to.clone() {
let next_node = self.get_var_node_mut(*to).unwrap();
next_node.pointed_by.remove(&idx);
}
self.variables.remove(&idx);
}
pub fn set_drop(&mut self, idx: usize) -> bool {
if let Some(ori_node) = self.get_var_node_mut(idx) {
if ori_node.is_dropped == true {
return false;
}
ori_node.is_dropped = true;
}
return true;
}
pub fn print_graph(&self) {
let mut visited = HashSet::new();
let mut subgraphs = Vec::new();
for &node_id in self.variables.keys() {
if !visited.contains(&node_id) {
let mut queue = VecDeque::new();
let mut subgraph = Vec::new();
queue.push_back(node_id);
visited.insert(node_id);
while let Some(current_id) = queue.pop_front() {
subgraph.push(current_id);
if let Some(node) = self.get_var_node(current_id) {
if let Some(next_id) = node.points_to {
if !visited.contains(&next_id) {
visited.insert(next_id);
queue.push_back(next_id);
}
}
for &pointer_id in &node.pointed_by {
if !visited.contains(&pointer_id) {
visited.insert(pointer_id);
queue.push_back(pointer_id);
}
}
}
}
subgraphs.push(subgraph);
}
}
for (i, mut subgraph) in subgraphs.into_iter().enumerate() {
subgraph.sort_unstable();
println!("Connected Subgraph {}: {:?}", i + 1, subgraph);
for node_id in subgraph {
if let Some(node) = self.get_var_node(node_id) {
println!(" Node {} → {:?}", node_id, node.points_to);
}
}
println!();
}
}
}