1use std::iter::Step;
2use std::marker::PhantomData;
3use std::ops::{Bound, Range, RangeBounds};
4
5use smallvec::SmallVec;
6
7use crate::idx::Idx;
8use crate::vec::IndexVec;
9
10#[cfg(test)]
11mod tests;
12
13#[derive(Debug, Clone)]
19pub struct IntervalSet<I> {
20 map: SmallVec<[(u32, u32); 2]>,
22 domain: usize,
23 _data: PhantomData<I>,
24}
25
26#[inline]
27fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 {
28 match range.start_bound() {
29 Bound::Included(start) => start.index() as u32,
30 Bound::Excluded(start) => start.index() as u32 + 1,
31 Bound::Unbounded => 0,
32 }
33}
34
35#[inline]
36fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> {
37 let end = match range.end_bound() {
38 Bound::Included(end) => end.index() as u32,
39 Bound::Excluded(end) => end.index().checked_sub(1)? as u32,
40 Bound::Unbounded => domain.checked_sub(1)? as u32,
41 };
42 Some(end)
43}
44
45impl<I: Idx> IntervalSet<I> {
46 pub fn new(domain: usize) -> IntervalSet<I> {
47 IntervalSet { map: SmallVec::new(), domain, _data: PhantomData }
48 }
49
50 pub fn clear(&mut self) {
51 self.map.clear();
52 }
53
54 pub fn iter(&self) -> impl Iterator<Item = I>
55 where
56 I: Step,
57 {
58 self.iter_intervals().flatten()
59 }
60
61 pub fn iter_intervals(&self) -> impl Iterator<Item = std::ops::Range<I>>
63 where
64 I: Step,
65 {
66 self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1))
67 }
68
69 pub fn insert(&mut self, point: I) -> bool {
71 self.insert_range(point..=point)
72 }
73
74 pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool {
76 let start = inclusive_start(range.clone());
77 let Some(end) = inclusive_end(self.domain, range) else {
78 return false;
80 };
81 if start > end {
82 return false;
83 }
84
85 let next = self.map.partition_point(|r| r.0 <= end + 1);
91 let result = if let Some(right) = next.checked_sub(1) {
92 let (prev_start, prev_end) = self.map[right];
93 if prev_end + 1 >= start {
94 if start < prev_start {
97 let left = self.map.partition_point(|l| l.1 + 1 < start);
100 let min = std::cmp::min(self.map[left].0, start);
101 let max = std::cmp::max(prev_end, end);
102 self.map[right] = (min, max);
103 if left != right {
104 self.map.drain(left..right);
105 }
106 true
107 } else {
108 if end > prev_end {
115 self.map[right].1 = end;
116 true
117 } else {
118 false
119 }
120 }
121 } else {
122 self.map.insert(right + 1, (start, end));
124 true
125 }
126 } else {
127 if self.map.is_empty() {
128 self.map.push((start, end));
131 } else {
132 self.map.insert(next, (start, end));
133 }
134 true
135 };
136 debug_assert!(
137 self.check_invariants(),
138 "wrong intervals after insert {start:?}..={end:?} to {self:?}"
139 );
140 result
141 }
142
143 pub fn append(&mut self, point: I) {
146 let point = point.index() as u32;
147
148 if let Some((_, last_end)) = self.map.last_mut() {
149 assert!(*last_end <= point);
150 if point == *last_end {
151 } else if point == *last_end + 1 {
153 *last_end = point;
154 } else {
155 self.map.push((point, point));
156 }
157 } else {
158 self.map.push((point, point));
159 }
160
161 debug_assert!(
162 self.check_invariants(),
163 "wrong intervals after append {point:?} to {self:?}"
164 );
165 }
166
167 pub fn contains(&self, needle: I) -> bool {
168 let needle = needle.index() as u32;
169 let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else {
170 return false;
172 };
173 let (_, prev_end) = &self.map[last];
174 needle <= *prev_end
175 }
176
177 pub fn superset(&self, other: &IntervalSet<I>) -> bool
178 where
179 I: Step,
180 {
181 let mut sup_iter = self.iter_intervals();
182 let mut current = None;
183 let contains = |sup: Range<I>, sub: Range<I>, current: &mut Option<Range<I>>| {
184 if sup.end < sub.start {
185 None } else if sup.end >= sub.end && sup.start <= sub.start {
188 *current = Some(sup); Some(true)
190 } else {
191 Some(false)
192 }
193 };
194 other.iter_intervals().all(|sub| {
195 current
196 .take()
197 .and_then(|sup| contains(sup, sub.clone(), &mut current))
198 .or_else(|| sup_iter.find_map(|sup| contains(sup, sub.clone(), &mut current)))
199 .unwrap_or(false)
200 })
201 }
202
203 pub fn disjoint(&self, other: &IntervalSet<I>) -> bool
204 where
205 I: Step,
206 {
207 let helper = move || {
208 let mut self_iter = self.iter_intervals();
209 let mut other_iter = other.iter_intervals();
210
211 let mut self_current = self_iter.next()?;
212 let mut other_current = other_iter.next()?;
213
214 loop {
215 if self_current.end <= other_current.start {
216 self_current = self_iter.next()?;
217 continue;
218 }
219 if other_current.end <= self_current.start {
220 other_current = other_iter.next()?;
221 continue;
222 }
223 return Some(false);
224 }
225 };
226 helper().unwrap_or(true)
227 }
228
229 pub fn is_empty(&self) -> bool {
230 self.map.is_empty()
231 }
232
233 pub fn first_unset_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
235 let start = inclusive_start(range.clone());
236 let Some(end) = inclusive_end(self.domain, range) else {
237 return None;
239 };
240 if start > end {
241 return None;
242 }
243 let Some(last) = self.map.partition_point(|r| r.0 <= start).checked_sub(1) else {
244 return Some(I::new(start as usize));
246 };
247 let (_, prev_end) = self.map[last];
248 if start > prev_end {
249 Some(I::new(start as usize))
250 } else if prev_end < end {
251 Some(I::new(prev_end as usize + 1))
252 } else {
253 None
254 }
255 }
256
257 pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
259 let start = inclusive_start(range.clone());
260 let Some(end) = inclusive_end(self.domain, range) else {
261 return None;
263 };
264 if start > end {
265 return None;
266 }
267 let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else {
268 return None;
270 };
271 let (_, prev_end) = &self.map[last];
272 if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None }
273 }
274
275 pub fn insert_all(&mut self) {
276 self.clear();
277 if let Some(end) = self.domain.checked_sub(1) {
278 self.map.push((0, end.try_into().unwrap()));
279 }
280 debug_assert!(self.check_invariants());
281 }
282
283 pub fn union(&mut self, other: &IntervalSet<I>) -> bool
284 where
285 I: Step,
286 {
287 assert_eq!(self.domain, other.domain);
288 if self.map.len() < other.map.len() {
289 let backup = self.clone();
290 self.map.clone_from(&other.map);
291 return self.union(&backup);
292 }
293
294 let mut did_insert = false;
295 for range in other.iter_intervals() {
296 did_insert |= self.insert_range(range);
297 }
298 debug_assert!(self.check_invariants());
299 did_insert
300 }
301
302 fn check_invariants(&self) -> bool {
304 let mut current: Option<u32> = None;
305 for (start, end) in &self.map {
306 if start > end || current.is_some_and(|x| x + 1 >= *start) {
307 return false;
308 }
309 current = Some(*end);
310 }
311 current.is_none_or(|x| x < self.domain as u32)
312 }
313}
314
315#[derive(Clone)]
321pub struct SparseIntervalMatrix<R, C>
322where
323 R: Idx,
324 C: Idx,
325{
326 rows: IndexVec<R, IntervalSet<C>>,
327 column_size: usize,
328}
329
330impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> {
331 pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> {
332 SparseIntervalMatrix { rows: IndexVec::new(), column_size }
333 }
334
335 pub fn rows(&self) -> impl Iterator<Item = R> {
336 self.rows.indices()
337 }
338
339 pub fn row(&self, row: R) -> Option<&IntervalSet<C>> {
340 self.rows.get(row)
341 }
342
343 fn ensure_row(&mut self, row: R) -> &mut IntervalSet<C> {
344 self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size))
345 }
346
347 pub fn union_row(&mut self, row: R, from: &IntervalSet<C>) -> bool
348 where
349 C: Step,
350 {
351 self.ensure_row(row).union(from)
352 }
353
354 pub fn union_rows(&mut self, read: R, write: R) -> bool
355 where
356 C: Step,
357 {
358 if read == write || self.rows.get(read).is_none() {
359 return false;
360 }
361 self.ensure_row(write);
362 let (read_row, write_row) = self.rows.pick2_mut(read, write);
363 write_row.union(read_row)
364 }
365
366 pub fn insert_all_into_row(&mut self, row: R) {
367 self.ensure_row(row).insert_all();
368 }
369
370 pub fn insert_range(&mut self, row: R, range: impl RangeBounds<C> + Clone) {
371 self.ensure_row(row).insert_range(range);
372 }
373
374 pub fn insert(&mut self, row: R, point: C) -> bool {
375 self.ensure_row(row).insert(point)
376 }
377
378 pub fn append(&mut self, row: R, point: C) {
379 self.ensure_row(row).append(point)
380 }
381
382 pub fn contains(&self, row: R, point: C) -> bool {
383 self.row(row).is_some_and(|r| r.contains(point))
384 }
385}