rustc_index/
interval.rs

1use std::iter::Step;
2use std::marker::PhantomData;
3use std::ops::{Bound, Range, RangeBounds};
4
5use smallvec::SmallVec;
6
7use crate::idx::Idx;
8use crate::vec::IndexVec;
9
10#[cfg(test)]
11mod tests;
12
13/// Stores a set of intervals on the indices.
14///
15/// The elements in `map` are sorted and non-adjacent, which means
16/// the second value of the previous element is *greater* than the
17/// first value of the following element.
18#[derive(Debug, Clone)]
19pub struct IntervalSet<I> {
20    // Start, end (both inclusive)
21    map: SmallVec<[(u32, u32); 2]>,
22    domain: usize,
23    _data: PhantomData<I>,
24}
25
26#[inline]
27fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 {
28    match range.start_bound() {
29        Bound::Included(start) => start.index() as u32,
30        Bound::Excluded(start) => start.index() as u32 + 1,
31        Bound::Unbounded => 0,
32    }
33}
34
35#[inline]
36fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> {
37    let end = match range.end_bound() {
38        Bound::Included(end) => end.index() as u32,
39        Bound::Excluded(end) => end.index().checked_sub(1)? as u32,
40        Bound::Unbounded => domain.checked_sub(1)? as u32,
41    };
42    Some(end)
43}
44
45impl<I: Idx> IntervalSet<I> {
46    pub fn new(domain: usize) -> IntervalSet<I> {
47        IntervalSet { map: SmallVec::new(), domain, _data: PhantomData }
48    }
49
50    pub fn clear(&mut self) {
51        self.map.clear();
52    }
53
54    pub fn iter(&self) -> impl Iterator<Item = I>
55    where
56        I: Step,
57    {
58        self.iter_intervals().flatten()
59    }
60
61    /// Iterates through intervals stored in the set, in order.
62    pub fn iter_intervals(&self) -> impl Iterator<Item = std::ops::Range<I>>
63    where
64        I: Step,
65    {
66        self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1))
67    }
68
69    /// Returns true if we increased the number of elements present.
70    pub fn insert(&mut self, point: I) -> bool {
71        self.insert_range(point..=point)
72    }
73
74    /// Returns true if we increased the number of elements present.
75    pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool {
76        let start = inclusive_start(range.clone());
77        let Some(end) = inclusive_end(self.domain, range) else {
78            // empty range
79            return false;
80        };
81        if start > end {
82            return false;
83        }
84
85        // This condition looks a bit weird, but actually makes sense.
86        //
87        // if r.0 == end + 1, then we're actually adjacent, so we want to
88        // continue to the next range. We're looking here for the first
89        // range which starts *non-adjacently* to our end.
90        let next = self.map.partition_point(|r| r.0 <= end + 1);
91        let result = if let Some(right) = next.checked_sub(1) {
92            let (prev_start, prev_end) = self.map[right];
93            if prev_end + 1 >= start {
94                // If the start for the inserted range is adjacent to the
95                // end of the previous, we can extend the previous range.
96                if start < prev_start {
97                    // The first range which ends *non-adjacently* to our start.
98                    // And we can ensure that left <= right.
99                    let left = self.map.partition_point(|l| l.1 + 1 < start);
100                    let min = std::cmp::min(self.map[left].0, start);
101                    let max = std::cmp::max(prev_end, end);
102                    self.map[right] = (min, max);
103                    if left != right {
104                        self.map.drain(left..right);
105                    }
106                    true
107                } else {
108                    // We overlap with the previous range, increase it to
109                    // include us.
110                    //
111                    // Make sure we're actually going to *increase* it though --
112                    // it may be that end is just inside the previously existing
113                    // set.
114                    if end > prev_end {
115                        self.map[right].1 = end;
116                        true
117                    } else {
118                        false
119                    }
120                }
121            } else {
122                // Otherwise, we don't overlap, so just insert
123                self.map.insert(right + 1, (start, end));
124                true
125            }
126        } else {
127            if self.map.is_empty() {
128                // Quite common in practice, and expensive to call memcpy
129                // with length zero.
130                self.map.push((start, end));
131            } else {
132                self.map.insert(next, (start, end));
133            }
134            true
135        };
136        debug_assert!(
137            self.check_invariants(),
138            "wrong intervals after insert {start:?}..={end:?} to {self:?}"
139        );
140        result
141    }
142
143    /// Specialized version of `insert` when we know that the inserted point is *after* any
144    /// contained.
145    pub fn append(&mut self, point: I) {
146        let point = point.index() as u32;
147
148        if let Some((_, last_end)) = self.map.last_mut() {
149            assert!(*last_end <= point);
150            if point == *last_end {
151                // The point is already in the set.
152            } else if point == *last_end + 1 {
153                *last_end = point;
154            } else {
155                self.map.push((point, point));
156            }
157        } else {
158            self.map.push((point, point));
159        }
160
161        debug_assert!(
162            self.check_invariants(),
163            "wrong intervals after append {point:?} to {self:?}"
164        );
165    }
166
167    pub fn contains(&self, needle: I) -> bool {
168        let needle = needle.index() as u32;
169        let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else {
170            // All ranges in the map start after the new range's end
171            return false;
172        };
173        let (_, prev_end) = &self.map[last];
174        needle <= *prev_end
175    }
176
177    pub fn superset(&self, other: &IntervalSet<I>) -> bool
178    where
179        I: Step,
180    {
181        let mut sup_iter = self.iter_intervals();
182        let mut current = None;
183        let contains = |sup: Range<I>, sub: Range<I>, current: &mut Option<Range<I>>| {
184            if sup.end < sub.start {
185                // if `sup.end == sub.start`, the next sup doesn't contain `sub.start`
186                None // continue to the next sup
187            } else if sup.end >= sub.end && sup.start <= sub.start {
188                *current = Some(sup); // save the current sup
189                Some(true)
190            } else {
191                Some(false)
192            }
193        };
194        other.iter_intervals().all(|sub| {
195            current
196                .take()
197                .and_then(|sup| contains(sup, sub.clone(), &mut current))
198                .or_else(|| sup_iter.find_map(|sup| contains(sup, sub.clone(), &mut current)))
199                .unwrap_or(false)
200        })
201    }
202
203    pub fn disjoint(&self, other: &IntervalSet<I>) -> bool
204    where
205        I: Step,
206    {
207        let helper = move || {
208            let mut self_iter = self.iter_intervals();
209            let mut other_iter = other.iter_intervals();
210
211            let mut self_current = self_iter.next()?;
212            let mut other_current = other_iter.next()?;
213
214            loop {
215                if self_current.end <= other_current.start {
216                    self_current = self_iter.next()?;
217                    continue;
218                }
219                if other_current.end <= self_current.start {
220                    other_current = other_iter.next()?;
221                    continue;
222                }
223                return Some(false);
224            }
225        };
226        helper().unwrap_or(true)
227    }
228
229    pub fn is_empty(&self) -> bool {
230        self.map.is_empty()
231    }
232
233    /// Equivalent to `range.iter().find(|i| !self.contains(i))`.
234    pub fn first_unset_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
235        let start = inclusive_start(range.clone());
236        let Some(end) = inclusive_end(self.domain, range) else {
237            // empty range
238            return None;
239        };
240        if start > end {
241            return None;
242        }
243        let Some(last) = self.map.partition_point(|r| r.0 <= start).checked_sub(1) else {
244            // All ranges in the map start after the new range's end
245            return Some(I::new(start as usize));
246        };
247        let (_, prev_end) = self.map[last];
248        if start > prev_end {
249            Some(I::new(start as usize))
250        } else if prev_end < end {
251            Some(I::new(prev_end as usize + 1))
252        } else {
253            None
254        }
255    }
256
257    /// Returns the maximum (last) element present in the set from `range`.
258    pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
259        let start = inclusive_start(range.clone());
260        let Some(end) = inclusive_end(self.domain, range) else {
261            // empty range
262            return None;
263        };
264        if start > end {
265            return None;
266        }
267        let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else {
268            // All ranges in the map start after the new range's end
269            return None;
270        };
271        let (_, prev_end) = &self.map[last];
272        if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None }
273    }
274
275    pub fn insert_all(&mut self) {
276        self.clear();
277        if let Some(end) = self.domain.checked_sub(1) {
278            self.map.push((0, end.try_into().unwrap()));
279        }
280        debug_assert!(self.check_invariants());
281    }
282
283    pub fn union(&mut self, other: &IntervalSet<I>) -> bool
284    where
285        I: Step,
286    {
287        assert_eq!(self.domain, other.domain);
288        if self.map.len() < other.map.len() {
289            let backup = self.clone();
290            self.map.clone_from(&other.map);
291            return self.union(&backup);
292        }
293
294        let mut did_insert = false;
295        for range in other.iter_intervals() {
296            did_insert |= self.insert_range(range);
297        }
298        debug_assert!(self.check_invariants());
299        did_insert
300    }
301
302    // Check the intervals are valid, sorted and non-adjacent
303    fn check_invariants(&self) -> bool {
304        let mut current: Option<u32> = None;
305        for (start, end) in &self.map {
306            if start > end || current.is_some_and(|x| x + 1 >= *start) {
307                return false;
308            }
309            current = Some(*end);
310        }
311        current.is_none_or(|x| x < self.domain as u32)
312    }
313}
314
315/// This data structure optimizes for cases where the stored bits in each row
316/// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast
317/// to BitMatrix and SparseBitMatrix which are optimized for
318/// "random"/non-contiguous bits and cheap(er) point queries at the expense of
319/// memory usage.
320#[derive(Clone)]
321pub struct SparseIntervalMatrix<R, C>
322where
323    R: Idx,
324    C: Idx,
325{
326    rows: IndexVec<R, IntervalSet<C>>,
327    column_size: usize,
328}
329
330impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> {
331    pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> {
332        SparseIntervalMatrix { rows: IndexVec::new(), column_size }
333    }
334
335    pub fn rows(&self) -> impl Iterator<Item = R> {
336        self.rows.indices()
337    }
338
339    pub fn row(&self, row: R) -> Option<&IntervalSet<C>> {
340        self.rows.get(row)
341    }
342
343    fn ensure_row(&mut self, row: R) -> &mut IntervalSet<C> {
344        self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size))
345    }
346
347    pub fn union_row(&mut self, row: R, from: &IntervalSet<C>) -> bool
348    where
349        C: Step,
350    {
351        self.ensure_row(row).union(from)
352    }
353
354    pub fn union_rows(&mut self, read: R, write: R) -> bool
355    where
356        C: Step,
357    {
358        if read == write || self.rows.get(read).is_none() {
359            return false;
360        }
361        self.ensure_row(write);
362        let (read_row, write_row) = self.rows.pick2_mut(read, write);
363        write_row.union(read_row)
364    }
365
366    pub fn insert_all_into_row(&mut self, row: R) {
367        self.ensure_row(row).insert_all();
368    }
369
370    pub fn insert_range(&mut self, row: R, range: impl RangeBounds<C> + Clone) {
371        self.ensure_row(row).insert_range(range);
372    }
373
374    pub fn insert(&mut self, row: R, point: C) -> bool {
375        self.ensure_row(row).insert(point)
376    }
377
378    pub fn append(&mut self, row: R, point: C) {
379        self.ensure_row(row).append(point)
380    }
381
382    pub fn contains(&self, row: R, point: C) -> bool {
383        self.row(row).is_some_and(|r| r.contains(point))
384    }
385}