1use std::cell::Cell;
2use std::collections::HashSet;
3
4use rustc_hir::def_id::DefId;
5use rustc_index::IndexVec;
6use rustc_middle::{
7 mir::{
8 AggregateKind, BorrowKind, Const, Local, Operand, Place, PlaceElem, Rvalue, Statement,
9 StatementKind, Terminator, TerminatorKind,
10 },
11 ty::TyKind,
12};
13use rustc_span::{DUMMY_SP, Span};
14
15use crate::{analysis::core::dataflow::*, utils::log::relative_pos_range};
16
17impl GraphNode {
18 pub fn new() -> Self {
19 Self {
20 ops: vec![NodeOp::Nop],
21 span: DUMMY_SP,
22 seq: 0,
23 out_edges: vec![],
24 in_edges: vec![],
25 }
26 }
27}
28
29#[derive(Clone)]
30pub struct Graph {
31 pub def_id: DefId,
32 pub span: Span,
33 pub argc: usize,
34 pub nodes: GraphNodes, pub edges: GraphEdges,
36 pub n_locals: usize,
37 pub closures: HashSet<DefId>,
38}
39
40impl From<Graph> for DataFlowGraph {
41 fn from(graph: Graph) -> Self {
42 let param_ret_deps = graph.param_return_deps();
43 DataFlowGraph {
44 nodes: graph.nodes,
45 edges: graph.edges,
46 param_ret_deps: param_ret_deps,
47 }
48 }
49}
50
51impl Graph {
52 pub fn new(def_id: DefId, span: Span, argc: usize, n_locals: usize) -> Self {
53 Self {
54 def_id,
55 span,
56 argc,
57 nodes: GraphNodes::from_elem_n(GraphNode::new(), n_locals),
58 edges: GraphEdges::new(),
59 n_locals,
60 closures: HashSet::new(),
61 }
62 }
63
64 pub fn add_node_edge(&mut self, src: Local, dst: Local, op: EdgeOp) -> EdgeIdx {
66 let seq = self.nodes[dst].seq;
67 let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
68 self.nodes[dst].in_edges.push(edge_idx);
69 self.nodes[src].out_edges.push(edge_idx);
70 edge_idx
71 }
72
73 pub fn add_const_edge(
75 &mut self,
76 src_desc: String,
77 src_ty: String,
78 dst: Local,
79 op: EdgeOp,
80 ) -> EdgeIdx {
81 let seq = self.nodes[dst].seq;
82 let mut const_node = GraphNode::new();
83 const_node.ops[0] = NodeOp::Const(src_desc, src_ty);
84 let src = self.nodes.push(const_node);
85 let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
86 self.nodes[dst].in_edges.push(edge_idx);
87 edge_idx
88 }
89
90 pub fn add_operand(&mut self, operand: &Operand, dst: Local) {
91 match operand {
92 Operand::Copy(place) => {
93 let src = self.parse_place(place);
94 self.add_node_edge(src, dst, EdgeOp::Copy);
95 }
96 Operand::Move(place) => {
97 let src = self.parse_place(place);
98 self.add_node_edge(src, dst, EdgeOp::Move);
99 }
100 Operand::Constant(boxed_const_op) => {
101 let src_desc = boxed_const_op.const_.to_string();
102 let src_ty = match boxed_const_op.const_ {
103 Const::Val(_, ty) => ty.to_string(),
104 Const::Unevaluated(_, ty) => ty.to_string(),
105 Const::Ty(ty, _) => ty.to_string(),
106 };
107 self.add_const_edge(src_desc, src_ty, dst, EdgeOp::Const);
108 }
109 }
110 }
111
112 pub fn parse_place(&mut self, place: &Place) -> Local {
113 fn parse_one_step(graph: &mut Graph, src: Local, place_elem: PlaceElem) -> Local {
114 let dst = graph.nodes.push(GraphNode::new());
115 match place_elem {
116 PlaceElem::Deref => {
117 graph.add_node_edge(src, dst, EdgeOp::Deref);
118 }
119 PlaceElem::Field(field_idx, _) => {
120 graph.add_node_edge(src, dst, EdgeOp::Field(field_idx.as_usize()));
121 }
122 PlaceElem::Downcast(symbol, _) => {
123 graph.add_node_edge(src, dst, EdgeOp::Downcast(symbol.unwrap().to_string()));
124 }
125 PlaceElem::Index(idx) => {
126 graph.add_node_edge(src, dst, EdgeOp::Index);
127 graph.add_node_edge(idx, dst, EdgeOp::Nop);
128 }
129 PlaceElem::ConstantIndex { .. } => {
130 graph.add_node_edge(src, dst, EdgeOp::ConstIndex);
131 }
132 PlaceElem::Subslice { .. } => {
133 graph.add_node_edge(src, dst, EdgeOp::SubSlice);
134 }
135 _ => {
136 println!("{:?}", place_elem);
137 todo!()
138 }
139 }
140 dst
141 }
142 let mut ret = place.local;
143 for place_elem in place.projection {
144 ret = parse_one_step(self, ret, place_elem);
146 }
147 ret
148 }
149
150 pub fn add_statm_to_graph(&mut self, statement: &Statement) {
151 if let StatementKind::Assign(boxed_statm) = &statement.kind {
152 let place = boxed_statm.0;
153 let dst = self.parse_place(&place);
154 self.nodes[dst].span = statement.source_info.span;
155 let rvalue = &boxed_statm.1;
156 let seq = self.nodes[dst].seq;
157 if seq == self.nodes[dst].ops.len() {
158 self.nodes[dst].ops.push(NodeOp::Nop);
160 }
161 match rvalue {
162 Rvalue::Use(op) => {
163 self.add_operand(op, dst);
164 self.nodes[dst].ops[seq] = NodeOp::Use;
165 }
166 Rvalue::Repeat(op, _) => {
167 self.add_operand(op, dst);
168 self.nodes[dst].ops[seq] = NodeOp::Repeat;
169 }
170 Rvalue::Ref(_, borrow_kind, place) => {
171 let op = match borrow_kind {
172 BorrowKind::Shared => EdgeOp::Immut,
173 BorrowKind::Mut { .. } => EdgeOp::Mut,
174 BorrowKind::Fake(_) => EdgeOp::Nop, };
176 let src = self.parse_place(place);
177 self.add_node_edge(src, dst, op);
178 self.nodes[dst].ops[seq] = NodeOp::Ref;
179 }
180 Rvalue::Cast(_cast_kind, operand, _) => {
181 self.add_operand(operand, dst);
182 self.nodes[dst].ops[seq] = NodeOp::Cast;
183 }
184 Rvalue::BinaryOp(_, operands) => {
185 self.add_operand(&operands.0, dst);
186 self.add_operand(&operands.1, dst);
187 self.nodes[dst].ops[seq] = NodeOp::CheckedBinaryOp;
188 }
189 Rvalue::Aggregate(boxed_kind, operands) => {
190 for operand in operands.iter() {
191 self.add_operand(operand, dst);
192 }
193 match **boxed_kind {
194 AggregateKind::Array(_) => {
195 self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Array)
196 }
197 AggregateKind::Tuple => {
198 self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Tuple)
199 }
200 AggregateKind::Adt(def_id, ..) => {
201 self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Adt(def_id))
202 }
203 AggregateKind::Closure(def_id, ..) => {
204 self.closures.insert(def_id);
205 self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Closure(def_id))
206 }
207 AggregateKind::Coroutine(def_id, ..) => {
208 self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Coroutine(def_id))
209 }
210 AggregateKind::RawPtr(_, _mutability) => {
211 self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::RawPtr)
212 }
214 _ => {
215 println!("{:?}", boxed_kind);
216 todo!()
217 }
218 }
219 }
220 Rvalue::UnaryOp(_, operand) => {
221 self.add_operand(operand, dst);
222 self.nodes[dst].ops[seq] = NodeOp::UnaryOp;
223 }
224 Rvalue::NullaryOp(_) => {
225 self.nodes[dst].ops[seq] = NodeOp::NullaryOp;
226 }
227 Rvalue::ThreadLocalRef(_) => {
228 }
230 Rvalue::Discriminant(place) => {
231 let src = self.parse_place(place);
232 self.add_node_edge(src, dst, EdgeOp::Nop);
233 self.nodes[dst].ops[seq] = NodeOp::Discriminant;
234 }
235 Rvalue::ShallowInitBox(operand, _) => {
236 self.add_operand(operand, dst);
237 self.nodes[dst].ops[seq] = NodeOp::ShallowInitBox;
238 }
239 Rvalue::CopyForDeref(place) => {
240 let src = self.parse_place(place);
241 self.add_node_edge(src, dst, EdgeOp::Nop);
242 self.nodes[dst].ops[seq] = NodeOp::CopyForDeref;
243 }
244 Rvalue::RawPtr(_, place) => {
245 let src = self.parse_place(place);
246 self.add_node_edge(src, dst, EdgeOp::Nop); self.nodes[dst].ops[seq] = NodeOp::RawPtr;
248 }
249 _ => todo!(),
250 };
251 self.nodes[dst].seq = seq + 1;
252 }
253 }
254
255 pub fn add_terminator_to_graph(&mut self, terminator: &Terminator) {
256 if let TerminatorKind::Call {
257 func,
258 args,
259 destination,
260 ..
261 } = &terminator.kind
262 {
263 let dst = destination.local;
264 let seq = self.nodes[dst].seq;
265 if seq == self.nodes[dst].ops.len() {
266 self.nodes[dst].ops.push(NodeOp::Nop);
267 }
268 match func {
269 Operand::Constant(boxed_cnst) => {
270 if let Const::Val(_, ty) = boxed_cnst.const_ {
271 if let TyKind::FnDef(def_id, _) = ty.kind() {
272 for op in args.iter() {
273 self.add_operand(&op.node, dst);
275 }
276 self.nodes[dst].ops[seq] = NodeOp::Call(*def_id);
277 }
278 }
279 }
280 Operand::Move(_) => {
281 self.add_operand(func, dst); for op in args.iter() {
283 self.add_operand(&op.node, dst);
285 }
286 self.nodes[dst].ops[seq] = NodeOp::CallOperand;
287 }
288 _ => {
289 println!("{:?}", func);
290 todo!();
291 }
292 }
293 self.nodes[dst].span = terminator.source_info.span;
294 self.nodes[dst].seq = seq + 1;
295 }
296 }
297
298 pub fn collect_equivalent_locals(&self, local: Local, strict: bool) -> HashSet<Local> {
302 let mut set = HashSet::new();
303 let root = Cell::new(local);
304 let reduce_func = if strict {
305 DFSStatus::and
306 } else {
307 DFSStatus::or
308 };
309 let mut find_root_operator = |graph: &Graph, idx: Local| -> DFSStatus {
310 let node = &graph.nodes[idx];
311 node.ops
312 .iter()
313 .map(|op| {
314 match op {
315 NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
316 root.set(idx);
318 DFSStatus::Continue
319 }
320 NodeOp::Call(_) => {
321 root.set(idx);
324 DFSStatus::Stop
325 }
326 _ => DFSStatus::Stop,
327 }
328 })
329 .reduce(reduce_func)
330 .unwrap()
331 };
332 let mut find_equivalent_operator = |graph: &Graph, idx: Local| -> DFSStatus {
333 let node = &graph.nodes[idx];
334 if set.contains(&idx) {
335 return DFSStatus::Stop;
336 }
337 node.ops
338 .iter()
339 .map(|op| match op {
340 NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
341 set.insert(idx);
342 DFSStatus::Continue
343 }
344 NodeOp::Call(_) => {
345 if idx == root.get() {
346 set.insert(idx);
347 DFSStatus::Continue
348 } else {
349 DFSStatus::Stop
351 }
352 }
353 _ => DFSStatus::Stop,
354 })
355 .reduce(reduce_func)
356 .unwrap()
357 };
358 let mut seen = HashSet::new();
360 self.dfs(
361 local,
362 Direction::Upside,
363 &mut find_root_operator,
364 &mut Self::equivalent_edge_validator,
365 true,
366 &mut seen,
367 );
368 seen.clear();
369 self.dfs(
370 root.get(),
371 Direction::Downside,
372 &mut find_equivalent_operator,
373 &mut Self::equivalent_edge_validator,
374 true,
375 &mut seen,
376 );
377 set
378 }
379
380 pub fn collect_ancestor_locals(&self, local: Local, self_included: bool) -> HashSet<Local> {
381 let mut ret = HashSet::new();
382 let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
383 ret.insert(idx);
384 DFSStatus::Continue
385 };
386 let mut seen = HashSet::new();
387 self.dfs(
388 local,
389 Direction::Upside,
390 &mut node_operator,
391 &mut Graph::always_true_edge_validator,
392 true,
393 &mut seen,
394 );
395 if !self_included {
396 ret.remove(&local);
397 }
398 ret
399 }
400
401 pub fn collect_descending_locals(&self, local: Local, self_included: bool) -> HashSet<Local> {
402 let mut ret = HashSet::new();
403 let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
404 ret.insert(idx);
405 DFSStatus::Continue
406 };
407 let mut seen = HashSet::new();
408 self.dfs(
409 local,
410 Direction::Downside,
411 &mut node_operator,
412 &mut Graph::always_true_edge_validator,
413 true,
414 &mut seen,
415 );
416 if !self_included {
417 ret.remove(&local);
418 }
419 ret
420 }
421
422 pub fn get_field_sequence(&self, local: Local) -> Option<(Local, Vec<usize>)> {
423 let mut fields = vec![];
424 let var = Cell::new(local);
425 let mut node_operator = |graph: &Graph, idx: Local| -> DFSStatus {
426 if graph.is_marker(idx) {
427 DFSStatus::Continue
428 } else {
429 var.set(idx);
430 DFSStatus::Stop
431 }
432 };
433 let mut edge_validator = |graph: &Graph, idx: EdgeIdx| -> DFSStatus {
434 if let EdgeOp::Field(field) = graph.edges[idx].op {
435 fields.insert(0, field);
436 DFSStatus::Continue
437 } else {
438 DFSStatus::Stop
439 }
440 };
441 let mut seen = HashSet::new();
442 self.dfs(
443 local,
444 Direction::Upside,
445 &mut node_operator,
446 &mut edge_validator,
447 false,
448 &mut seen,
449 );
450 if fields.is_empty() {
451 None
452 } else {
453 Some((var.get(), fields))
454 }
455 }
456
457 pub fn is_connected(&self, idx_1: Local, idx_2: Local) -> bool {
458 let target = idx_2;
459 let find = Cell::new(false);
460 let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
461 find.set(idx == target);
462 if find.get() {
463 DFSStatus::Stop
464 } else {
465 DFSStatus::Continue
467 }
468 };
469 let mut seen = HashSet::new();
470 self.dfs(
471 idx_1,
472 Direction::Downside,
473 &mut node_operator,
474 &mut Self::always_true_edge_validator,
475 false,
476 &mut seen,
477 );
478 seen.clear();
479 if !find.get() {
480 self.dfs(
481 idx_1,
482 Direction::Upside,
483 &mut node_operator,
484 &mut Self::always_true_edge_validator,
485 false,
486 &mut seen,
487 );
488 }
489 find.get()
490 }
491
492 pub fn param_return_deps(&self) -> IndexVec<Local, bool> {
494 let _0 = Local::from_usize(0);
495 let deps = (0..self.argc + 1) .map(|i| {
497 let _i = Local::from_usize(i);
498 self.is_connected(_i, _0)
499 })
500 .collect();
501 deps
502 }
503
504 pub fn dfs<F, G>(
510 &self,
511 now: Local,
512 direction: Direction,
513 node_operator: &mut F,
514 edge_validator: &mut G,
515 traverse_all: bool,
516 seen: &mut HashSet<Local>,
517 ) -> (DFSStatus, bool)
518 where
519 F: FnMut(&Graph, Local) -> DFSStatus,
520 G: FnMut(&Graph, EdgeIdx) -> DFSStatus,
521 {
522 if seen.contains(&now) {
523 return (DFSStatus::Stop, false);
524 }
525 seen.insert(now);
526 macro_rules! traverse {
527 ($edges: ident, $field: ident) => {
528 for edge_idx in self.nodes[now].$edges.iter() {
529 let edge = &self.edges[*edge_idx];
530 if matches!(edge_validator(self, *edge_idx), DFSStatus::Continue) {
531 let (dfs_status, result) = self.dfs(
532 edge.$field,
533 direction,
534 node_operator,
535 edge_validator,
536 traverse_all,
537 seen,
538 );
539
540 if matches!(dfs_status, DFSStatus::Stop) && result && !traverse_all {
541 return (DFSStatus::Stop, true);
542 }
543 }
544 }
545 };
546 }
547
548 if matches!(node_operator(self, now), DFSStatus::Continue) {
549 match direction {
550 Direction::Upside => {
551 traverse!(in_edges, src);
552 }
553 Direction::Downside => {
554 traverse!(out_edges, dst);
555 }
556 Direction::Both => {
557 traverse!(in_edges, src);
558 traverse!(out_edges, dst);
559 }
560 };
561 (DFSStatus::Continue, false)
562 } else {
563 (DFSStatus::Stop, true)
564 }
565 }
566
567 pub fn get_upside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
568 if let Some(edge_idx) = self.nodes[node_idx].in_edges.get(order) {
569 Some(self.edges[*edge_idx].src)
570 } else {
571 None
572 }
573 }
574
575 pub fn get_downside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
576 if let Some(edge_idx) = self.nodes[node_idx].out_edges.get(order) {
577 Some(self.edges[*edge_idx].dst)
578 } else {
579 None
580 }
581 }
582
583 pub fn query_node_by_span(&self, span: Span, strict: bool) -> Option<(Local, &GraphNode)> {
585 for (node_idx, node) in self.nodes.iter_enumerated() {
586 if strict {
587 if node.span == span {
588 return Some((node_idx, node));
589 }
590 } else {
591 if !relative_pos_range(node.span, span).eq(0..0)
592 && (node.span.lo() == span.lo() || node.span.hi() == span.hi())
593 {
594 return Some((node_idx, node));
595 }
596 }
597 }
598 None
599 }
600
601 pub fn is_marker(&self, idx: Local) -> bool {
602 idx >= Local::from_usize(self.n_locals)
603 }
604}
605
606impl Graph {
607 pub fn equivalent_edge_validator(graph: &Graph, idx: EdgeIdx) -> DFSStatus {
608 match graph.edges[idx].op {
609 EdgeOp::Copy | EdgeOp::Move | EdgeOp::Mut | EdgeOp::Immut | EdgeOp::Deref => {
610 DFSStatus::Continue
611 }
612 EdgeOp::Nop
613 | EdgeOp::Const
614 | EdgeOp::Downcast(_)
615 | EdgeOp::Field(_)
616 | EdgeOp::Index
617 | EdgeOp::ConstIndex
618 | EdgeOp::SubSlice => DFSStatus::Stop,
619 }
620 }
621
622 pub fn always_true_edge_validator(_: &Graph, _: EdgeIdx) -> DFSStatus {
623 DFSStatus::Continue
624 }
625}
626
627#[derive(Clone, Copy)]
628pub enum Direction {
629 Upside,
630 Downside,
631 Both,
632}
633
634pub enum DFSStatus {
635 Continue, Stop, }
638
639impl DFSStatus {
640 pub fn and(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
641 if matches!(s1, DFSStatus::Stop) || matches!(s2, DFSStatus::Stop) {
642 DFSStatus::Stop
643 } else {
644 DFSStatus::Continue
645 }
646 }
647
648 pub fn or(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
649 if matches!(s1, DFSStatus::Continue) || matches!(s2, DFSStatus::Continue) {
650 DFSStatus::Continue
651 } else {
652 DFSStatus::Stop
653 }
654 }
655}