sparse_bitset.h 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. //===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines the SparseBitVector class. See the doxygen comment for
  10. // SparseBitVector for more details on the algorithm used.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #pragma once
  14. #include <c10/macros/Macros.h>
  15. #include <c10/util/llvmMathExtras.h>
  16. #include <array>
  17. #include <cassert>
  18. #include <climits>
  19. #include <iterator>
  20. #include <list>
  21. #include <ostream>
  22. namespace c10 {
  23. /// SparseBitVector is an implementation of a bitvector that is sparse by only
  24. /// storing the elements that have non-zero bits set. In order to make this
  25. /// fast for the most common cases, SparseBitVector is implemented as a linked
  26. /// list of SparseBitVectorElements. We maintain a pointer to the last
  27. /// SparseBitVectorElement accessed (in the form of a list iterator), in order
  28. /// to make multiple in-order test/set constant time after the first one is
  29. /// executed. Note that using vectors to store SparseBitVectorElement's does
  30. /// not work out very well because it causes insertion in the middle to take
  31. /// enormous amounts of time with a large amount of bits. Other structures that
  32. /// have better worst cases for insertion in the middle (various balanced trees,
  33. /// etc) do not perform as well in practice as a linked list with this iterator
  34. /// kept up to date. They are also significantly more memory intensive.
  35. template <unsigned ElementSize = 128>
  36. struct SparseBitVectorElement {
  37. public:
  38. using BitWord = unsigned long;
  39. using size_type = unsigned;
  40. enum {
  41. BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT,
  42. BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE,
  43. BITS_PER_ELEMENT = ElementSize
  44. };
  45. private:
  46. // Index of Element in terms of where first bit starts.
  47. unsigned ElementIndex;
  48. std::array<BitWord, BITWORDS_PER_ELEMENT> Bits{};
  49. SparseBitVectorElement() : ElementIndex(~0U) {}
  50. public:
  51. explicit SparseBitVectorElement(unsigned Idx) : ElementIndex(Idx) {}
  52. // Comparison.
  53. bool operator==(const SparseBitVectorElement& RHS) const {
  54. if (ElementIndex != RHS.ElementIndex)
  55. return false;
  56. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
  57. if (Bits[i] != RHS.Bits[i])
  58. return false;
  59. return true;
  60. }
  61. bool operator!=(const SparseBitVectorElement& RHS) const {
  62. return !(*this == RHS);
  63. }
  64. // Return the bits that make up word Idx in our element.
  65. BitWord word(unsigned Idx) const {
  66. assert(Idx < BITWORDS_PER_ELEMENT);
  67. return Bits[Idx];
  68. }
  69. unsigned index() const {
  70. return ElementIndex;
  71. }
  72. bool empty() const {
  73. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
  74. if (Bits[i])
  75. return false;
  76. return true;
  77. }
  78. void set(unsigned Idx) {
  79. Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE);
  80. }
  81. bool test_and_set(unsigned Idx) {
  82. bool old = test(Idx);
  83. if (!old) {
  84. set(Idx);
  85. return true;
  86. }
  87. return false;
  88. }
  89. void reset(unsigned Idx) {
  90. Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE));
  91. }
  92. bool test(unsigned Idx) const {
  93. return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE));
  94. }
  95. size_type count() const {
  96. unsigned NumBits = 0;
  97. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
  98. NumBits += llvm::countPopulation(Bits[i]);
  99. return NumBits;
  100. }
  101. /// find_first - Returns the index of the first set bit.
  102. int find_first() const {
  103. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
  104. if (Bits[i] != 0)
  105. return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
  106. throw std::runtime_error("Illegal empty element");
  107. }
  108. /// find_last - Returns the index of the last set bit.
  109. int find_last() const {
  110. for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) {
  111. unsigned Idx = BITWORDS_PER_ELEMENT - I - 1;
  112. if (Bits[Idx] != 0)
  113. return Idx * BITWORD_SIZE + BITWORD_SIZE -
  114. llvm::countLeadingZeros(Bits[Idx]);
  115. }
  116. throw std::runtime_error("Illegal empty element");
  117. }
  118. /// find_next - Returns the index of the next set bit starting from the
  119. /// "Curr" bit. Returns -1 if the next set bit is not found.
  120. int find_next(unsigned Curr) const {
  121. if (Curr >= BITS_PER_ELEMENT)
  122. return -1;
  123. unsigned WordPos = Curr / BITWORD_SIZE;
  124. unsigned BitPos = Curr % BITWORD_SIZE;
  125. BitWord Copy = Bits[WordPos];
  126. assert(
  127. WordPos <= BITWORDS_PER_ELEMENT && "Word Position outside of element");
  128. // Mask off previous bits.
  129. Copy &= ~0UL << BitPos;
  130. if (Copy != 0)
  131. return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy);
  132. // Check subsequent words.
  133. for (unsigned i = WordPos + 1; i < BITWORDS_PER_ELEMENT; ++i)
  134. if (Bits[i] != 0)
  135. return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
  136. return -1;
  137. }
  138. // Union this element with RHS and return true if this one changed.
  139. bool unionWith(const SparseBitVectorElement& RHS) {
  140. bool changed = false;
  141. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  142. BitWord old = changed ? 0 : Bits[i];
  143. Bits[i] |= RHS.Bits[i];
  144. if (!changed && old != Bits[i])
  145. changed = true;
  146. }
  147. return changed;
  148. }
  149. // Return true if we have any bits in common with RHS
  150. bool intersects(const SparseBitVectorElement& RHS) const {
  151. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  152. if (RHS.Bits[i] & Bits[i])
  153. return true;
  154. }
  155. return false;
  156. }
  157. // Intersect this Element with RHS and return true if this one changed.
  158. // BecameZero is set to true if this element became all-zero bits.
  159. bool intersectWith(const SparseBitVectorElement& RHS, bool& BecameZero) {
  160. bool changed = false;
  161. bool allzero = true;
  162. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  163. BitWord old = changed ? 0 : Bits[i];
  164. Bits[i] &= RHS.Bits[i];
  165. if (Bits[i] != 0)
  166. allzero = false;
  167. if (!changed && old != Bits[i])
  168. changed = true;
  169. }
  170. BecameZero = allzero;
  171. return changed;
  172. }
  173. // Intersect this Element with the complement of RHS and return true if this
  174. // one changed. BecameZero is set to true if this element became all-zero
  175. // bits.
  176. bool intersectWithComplement(
  177. const SparseBitVectorElement& RHS,
  178. bool& BecameZero) {
  179. bool changed = false;
  180. bool allzero = true;
  181. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  182. BitWord old = changed ? 0 : Bits[i];
  183. Bits[i] &= ~RHS.Bits[i];
  184. if (Bits[i] != 0)
  185. allzero = false;
  186. if (!changed && old != Bits[i])
  187. changed = true;
  188. }
  189. BecameZero = allzero;
  190. return changed;
  191. }
  192. // Three argument version of intersectWithComplement that intersects
  193. // RHS1 & ~RHS2 into this element
  194. void intersectWithComplement(
  195. const SparseBitVectorElement& RHS1,
  196. const SparseBitVectorElement& RHS2,
  197. bool& BecameZero) {
  198. bool allzero = true;
  199. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  200. Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i];
  201. if (Bits[i] != 0)
  202. allzero = false;
  203. }
  204. BecameZero = allzero;
  205. }
  206. };
  207. template <unsigned ElementSize = 128>
  208. class SparseBitVector {
  209. using ElementList = std::list<SparseBitVectorElement<ElementSize>>;
  210. using ElementListIter = typename ElementList::iterator;
  211. using ElementListConstIter = typename ElementList::const_iterator;
  212. enum { BITWORD_SIZE = SparseBitVectorElement<ElementSize>::BITWORD_SIZE };
  213. ElementList Elements;
  214. // Pointer to our current Element. This has no visible effect on the external
  215. // state of a SparseBitVector, it's just used to improve performance in the
  216. // common case of testing/modifying bits with similar indices.
  217. mutable ElementListIter CurrElementIter;
  218. // This is like std::lower_bound, except we do linear searching from the
  219. // current position.
  220. ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const {
  221. // We cache a non-const iterator so we're forced to resort to const_cast to
  222. // get the begin/end in the case where 'this' is const. To avoid duplication
  223. // of code with the only difference being whether the const cast is present
  224. // 'this' is always const in this particular function and we sort out the
  225. // difference in FindLowerBound and FindLowerBoundConst.
  226. ElementListIter Begin =
  227. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
  228. const_cast<SparseBitVector<ElementSize>*>(this)->Elements.begin();
  229. ElementListIter End =
  230. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
  231. const_cast<SparseBitVector<ElementSize>*>(this)->Elements.end();
  232. if (Elements.empty()) {
  233. CurrElementIter = Begin;
  234. return CurrElementIter;
  235. }
  236. // Make sure our current iterator is valid.
  237. if (CurrElementIter == End)
  238. --CurrElementIter;
  239. // Search from our current iterator, either backwards or forwards,
  240. // depending on what element we are looking for.
  241. ElementListIter ElementIter = CurrElementIter;
  242. if (CurrElementIter->index() == ElementIndex) {
  243. return ElementIter;
  244. } else if (CurrElementIter->index() > ElementIndex) {
  245. while (ElementIter != Begin && ElementIter->index() > ElementIndex)
  246. --ElementIter;
  247. } else {
  248. while (ElementIter != End && ElementIter->index() < ElementIndex)
  249. ++ElementIter;
  250. }
  251. CurrElementIter = ElementIter;
  252. return ElementIter;
  253. }
  254. ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const {
  255. return FindLowerBoundImpl(ElementIndex);
  256. }
  257. ElementListIter FindLowerBound(unsigned ElementIndex) {
  258. return FindLowerBoundImpl(ElementIndex);
  259. }
  260. // Iterator to walk set bits in the bitmap. This iterator is a lot uglier
  261. // than it would be, in order to be efficient.
  262. class SparseBitVectorIterator {
  263. private:
  264. bool AtEnd{false};
  265. const SparseBitVector<ElementSize>* BitVector = nullptr;
  266. // Current element inside of bitmap.
  267. ElementListConstIter Iter;
  268. // Current bit number inside of our bitmap.
  269. unsigned BitNumber{0};
  270. // Current word number inside of our element.
  271. unsigned WordNumber{0};
  272. // Current bits from the element.
  273. typename SparseBitVectorElement<ElementSize>::BitWord Bits{0};
  274. // Move our iterator to the first non-zero bit in the bitmap.
  275. void AdvanceToFirstNonZero() {
  276. if (AtEnd)
  277. return;
  278. if (BitVector->Elements.empty()) {
  279. AtEnd = true;
  280. return;
  281. }
  282. Iter = BitVector->Elements.begin();
  283. BitNumber = Iter->index() * ElementSize;
  284. unsigned BitPos = Iter->find_first();
  285. BitNumber += BitPos;
  286. WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
  287. Bits = Iter->word(WordNumber);
  288. Bits >>= BitPos % BITWORD_SIZE;
  289. }
  290. // Move our iterator to the next non-zero bit.
  291. void AdvanceToNextNonZero() {
  292. if (AtEnd)
  293. return;
  294. while (Bits && !(Bits & 1)) {
  295. Bits >>= 1;
  296. BitNumber += 1;
  297. }
  298. // See if we ran out of Bits in this word.
  299. if (!Bits) {
  300. int NextSetBitNumber = Iter->find_next(BitNumber % ElementSize);
  301. // If we ran out of set bits in this element, move to next element.
  302. if (NextSetBitNumber == -1 || (BitNumber % ElementSize == 0)) {
  303. ++Iter;
  304. WordNumber = 0;
  305. // We may run out of elements in the bitmap.
  306. if (Iter == BitVector->Elements.end()) {
  307. AtEnd = true;
  308. return;
  309. }
  310. // Set up for next non-zero word in bitmap.
  311. BitNumber = Iter->index() * ElementSize;
  312. NextSetBitNumber = Iter->find_first();
  313. BitNumber += NextSetBitNumber;
  314. WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
  315. Bits = Iter->word(WordNumber);
  316. Bits >>= NextSetBitNumber % BITWORD_SIZE;
  317. } else {
  318. WordNumber = (NextSetBitNumber % ElementSize) / BITWORD_SIZE;
  319. Bits = Iter->word(WordNumber);
  320. Bits >>= NextSetBitNumber % BITWORD_SIZE;
  321. BitNumber = Iter->index() * ElementSize;
  322. BitNumber += NextSetBitNumber;
  323. }
  324. }
  325. }
  326. public:
  327. SparseBitVectorIterator() = default;
  328. SparseBitVectorIterator(
  329. const SparseBitVector<ElementSize>* RHS,
  330. bool end = false)
  331. : AtEnd(end),
  332. BitVector(RHS),
  333. Iter(BitVector->Elements.begin()),
  334. WordNumber(~0) {
  335. AdvanceToFirstNonZero();
  336. }
  337. // Preincrement.
  338. inline SparseBitVectorIterator& operator++() {
  339. ++BitNumber;
  340. Bits >>= 1;
  341. AdvanceToNextNonZero();
  342. return *this;
  343. }
  344. // Postincrement.
  345. inline SparseBitVectorIterator operator++(int) {
  346. SparseBitVectorIterator tmp = *this;
  347. ++*this;
  348. return tmp;
  349. }
  350. // Return the current set bit number.
  351. unsigned operator*() const {
  352. return BitNumber;
  353. }
  354. bool operator==(const SparseBitVectorIterator& RHS) const {
  355. // If they are both at the end, ignore the rest of the fields.
  356. if (AtEnd && RHS.AtEnd)
  357. return true;
  358. // Otherwise they are the same if they have the same bit number and
  359. // bitmap.
  360. return AtEnd == RHS.AtEnd && RHS.BitNumber == BitNumber;
  361. }
  362. bool operator!=(const SparseBitVectorIterator& RHS) const {
  363. return !(*this == RHS);
  364. }
  365. };
  366. public:
  367. using iterator = SparseBitVectorIterator;
  368. SparseBitVector() : Elements(), CurrElementIter(Elements.begin()) {}
  369. SparseBitVector(const SparseBitVector& RHS)
  370. : Elements(RHS.Elements), CurrElementIter(Elements.begin()) {}
  371. SparseBitVector(SparseBitVector&& RHS) noexcept
  372. : Elements(std::move(RHS.Elements)), CurrElementIter(Elements.begin()) {}
  373. // Clear.
  374. void clear() {
  375. Elements.clear();
  376. }
  377. // Assignment
  378. SparseBitVector& operator=(const SparseBitVector& RHS) {
  379. if (this == &RHS)
  380. return *this;
  381. Elements = RHS.Elements;
  382. CurrElementIter = Elements.begin();
  383. return *this;
  384. }
  385. SparseBitVector& operator=(SparseBitVector&& RHS) noexcept {
  386. Elements = std::move(RHS.Elements);
  387. CurrElementIter = Elements.begin();
  388. return *this;
  389. }
  390. // Test, Reset, and Set a bit in the bitmap.
  391. bool test(unsigned Idx) const {
  392. if (Elements.empty())
  393. return false;
  394. unsigned ElementIndex = Idx / ElementSize;
  395. ElementListConstIter ElementIter = FindLowerBoundConst(ElementIndex);
  396. // If we can't find an element that is supposed to contain this bit, there
  397. // is nothing more to do.
  398. if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
  399. return false;
  400. return ElementIter->test(Idx % ElementSize);
  401. }
  402. void reset(unsigned Idx) {
  403. if (Elements.empty())
  404. return;
  405. unsigned ElementIndex = Idx / ElementSize;
  406. ElementListIter ElementIter = FindLowerBound(ElementIndex);
  407. // If we can't find an element that is supposed to contain this bit, there
  408. // is nothing more to do.
  409. if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
  410. return;
  411. ElementIter->reset(Idx % ElementSize);
  412. // When the element is zeroed out, delete it.
  413. if (ElementIter->empty()) {
  414. ++CurrElementIter;
  415. Elements.erase(ElementIter);
  416. }
  417. }
  418. void set(unsigned Idx) {
  419. unsigned ElementIndex = Idx / ElementSize;
  420. ElementListIter ElementIter;
  421. if (Elements.empty()) {
  422. ElementIter = Elements.emplace(Elements.end(), ElementIndex);
  423. } else {
  424. ElementIter = FindLowerBound(ElementIndex);
  425. if (ElementIter == Elements.end() ||
  426. ElementIter->index() != ElementIndex) {
  427. // We may have hit the beginning of our SparseBitVector, in which case,
  428. // we may need to insert right after this element, which requires moving
  429. // the current iterator forward one, because insert does insert before.
  430. if (ElementIter != Elements.end() &&
  431. ElementIter->index() < ElementIndex)
  432. ++ElementIter;
  433. ElementIter = Elements.emplace(ElementIter, ElementIndex);
  434. }
  435. }
  436. CurrElementIter = ElementIter;
  437. ElementIter->set(Idx % ElementSize);
  438. }
  439. bool test_and_set(unsigned Idx) {
  440. bool old = test(Idx);
  441. if (!old) {
  442. set(Idx);
  443. return true;
  444. }
  445. return false;
  446. }
  447. bool operator!=(const SparseBitVector& RHS) const {
  448. return !(*this == RHS);
  449. }
  450. bool operator==(const SparseBitVector& RHS) const {
  451. ElementListConstIter Iter1 = Elements.begin();
  452. ElementListConstIter Iter2 = RHS.Elements.begin();
  453. for (; Iter1 != Elements.end() && Iter2 != RHS.Elements.end();
  454. ++Iter1, ++Iter2) {
  455. if (*Iter1 != *Iter2)
  456. return false;
  457. }
  458. return Iter1 == Elements.end() && Iter2 == RHS.Elements.end();
  459. }
  460. // Union our bitmap with the RHS and return true if we changed.
  461. bool operator|=(const SparseBitVector& RHS) {
  462. if (this == &RHS)
  463. return false;
  464. if (empty()) {
  465. *this = RHS;
  466. return true;
  467. }
  468. bool changed = false;
  469. ElementListIter Iter1 = Elements.begin();
  470. ElementListConstIter Iter2 = RHS.Elements.begin();
  471. // If RHS is empty, we are done
  472. if (RHS.Elements.empty())
  473. return false;
  474. while (Iter2 != RHS.Elements.end()) {
  475. if (Iter1 == Elements.end() || Iter1->index() > Iter2->index()) {
  476. Elements.insert(Iter1, *Iter2);
  477. ++Iter2;
  478. changed = true;
  479. } else if (Iter1->index() == Iter2->index()) {
  480. changed |= Iter1->unionWith(*Iter2);
  481. ++Iter1;
  482. ++Iter2;
  483. } else {
  484. ++Iter1;
  485. }
  486. }
  487. CurrElementIter = Elements.begin();
  488. return changed;
  489. }
  490. // Intersect our bitmap with the RHS and return true if ours changed.
  491. bool operator-=(const SparseBitVector& RHS) {
  492. return intersectWithComplement(RHS);
  493. }
  494. // Intersect our bitmap with the RHS and return true if ours changed.
  495. bool operator&=(const SparseBitVector& RHS) {
  496. if (this == &RHS)
  497. return false;
  498. bool changed = false;
  499. ElementListIter Iter1 = Elements.begin();
  500. ElementListConstIter Iter2 = RHS.Elements.begin();
  501. // Check if both bitmaps are empty.
  502. if (Elements.empty() && RHS.Elements.empty())
  503. return false;
  504. // Loop through, intersecting as we go, erasing elements when necessary.
  505. while (Iter2 != RHS.Elements.end()) {
  506. if (Iter1 == Elements.end()) {
  507. CurrElementIter = Elements.begin();
  508. return changed;
  509. }
  510. if (Iter1->index() > Iter2->index()) {
  511. ++Iter2;
  512. } else if (Iter1->index() == Iter2->index()) {
  513. bool BecameZero = false;
  514. changed |= Iter1->intersectWith(*Iter2, BecameZero);
  515. if (BecameZero) {
  516. ElementListIter IterTmp = Iter1;
  517. ++Iter1;
  518. Elements.erase(IterTmp);
  519. } else {
  520. ++Iter1;
  521. }
  522. ++Iter2;
  523. } else {
  524. ElementListIter IterTmp = Iter1;
  525. ++Iter1;
  526. Elements.erase(IterTmp);
  527. changed = true;
  528. }
  529. }
  530. if (Iter1 != Elements.end()) {
  531. Elements.erase(Iter1, Elements.end());
  532. changed = true;
  533. }
  534. CurrElementIter = Elements.begin();
  535. return changed;
  536. }
  537. // Intersect our bitmap with the complement of the RHS and return true
  538. // if ours changed.
  539. bool intersectWithComplement(const SparseBitVector& RHS) {
  540. if (this == &RHS) {
  541. if (!empty()) {
  542. clear();
  543. return true;
  544. }
  545. return false;
  546. }
  547. bool changed = false;
  548. ElementListIter Iter1 = Elements.begin();
  549. ElementListConstIter Iter2 = RHS.Elements.begin();
  550. // If either our bitmap or RHS is empty, we are done
  551. if (Elements.empty() || RHS.Elements.empty())
  552. return false;
  553. // Loop through, intersecting as we go, erasing elements when necessary.
  554. while (Iter2 != RHS.Elements.end()) {
  555. if (Iter1 == Elements.end()) {
  556. CurrElementIter = Elements.begin();
  557. return changed;
  558. }
  559. if (Iter1->index() > Iter2->index()) {
  560. ++Iter2;
  561. } else if (Iter1->index() == Iter2->index()) {
  562. bool BecameZero = false;
  563. changed |= Iter1->intersectWithComplement(*Iter2, BecameZero);
  564. if (BecameZero) {
  565. ElementListIter IterTmp = Iter1;
  566. ++Iter1;
  567. Elements.erase(IterTmp);
  568. } else {
  569. ++Iter1;
  570. }
  571. ++Iter2;
  572. } else {
  573. ++Iter1;
  574. }
  575. }
  576. CurrElementIter = Elements.begin();
  577. return changed;
  578. }
  579. bool intersectWithComplement(const SparseBitVector<ElementSize>* RHS) const {
  580. return intersectWithComplement(*RHS);
  581. }
  582. // Three argument version of intersectWithComplement.
  583. // Result of RHS1 & ~RHS2 is stored into this bitmap.
  584. void intersectWithComplement(
  585. const SparseBitVector<ElementSize>& RHS1,
  586. const SparseBitVector<ElementSize>& RHS2) {
  587. if (this == &RHS1) {
  588. intersectWithComplement(RHS2);
  589. return;
  590. } else if (this == &RHS2) {
  591. SparseBitVector RHS2Copy(RHS2);
  592. intersectWithComplement(RHS1, RHS2Copy);
  593. return;
  594. }
  595. Elements.clear();
  596. CurrElementIter = Elements.begin();
  597. ElementListConstIter Iter1 = RHS1.Elements.begin();
  598. ElementListConstIter Iter2 = RHS2.Elements.begin();
  599. // If RHS1 is empty, we are done
  600. // If RHS2 is empty, we still have to copy RHS1
  601. if (RHS1.Elements.empty())
  602. return;
  603. // Loop through, intersecting as we go, erasing elements when necessary.
  604. while (Iter2 != RHS2.Elements.end()) {
  605. if (Iter1 == RHS1.Elements.end())
  606. return;
  607. if (Iter1->index() > Iter2->index()) {
  608. ++Iter2;
  609. } else if (Iter1->index() == Iter2->index()) {
  610. bool BecameZero = false;
  611. Elements.emplace_back(Iter1->index());
  612. Elements.back().intersectWithComplement(*Iter1, *Iter2, BecameZero);
  613. if (BecameZero)
  614. Elements.pop_back();
  615. ++Iter1;
  616. ++Iter2;
  617. } else {
  618. Elements.push_back(*Iter1++);
  619. }
  620. }
  621. // copy the remaining elements
  622. std::copy(Iter1, RHS1.Elements.end(), std::back_inserter(Elements));
  623. }
  624. void intersectWithComplement(
  625. const SparseBitVector<ElementSize>* RHS1,
  626. const SparseBitVector<ElementSize>* RHS2) {
  627. intersectWithComplement(*RHS1, *RHS2);
  628. }
  629. bool intersects(const SparseBitVector<ElementSize>* RHS) const {
  630. return intersects(*RHS);
  631. }
  632. // Return true if we share any bits in common with RHS
  633. bool intersects(const SparseBitVector<ElementSize>& RHS) const {
  634. ElementListConstIter Iter1 = Elements.begin();
  635. ElementListConstIter Iter2 = RHS.Elements.begin();
  636. // Check if both bitmaps are empty.
  637. if (Elements.empty() && RHS.Elements.empty())
  638. return false;
  639. // Loop through, intersecting stopping when we hit bits in common.
  640. while (Iter2 != RHS.Elements.end()) {
  641. if (Iter1 == Elements.end())
  642. return false;
  643. if (Iter1->index() > Iter2->index()) {
  644. ++Iter2;
  645. } else if (Iter1->index() == Iter2->index()) {
  646. if (Iter1->intersects(*Iter2))
  647. return true;
  648. ++Iter1;
  649. ++Iter2;
  650. } else {
  651. ++Iter1;
  652. }
  653. }
  654. return false;
  655. }
  656. // Return true iff all bits set in this SparseBitVector are
  657. // also set in RHS.
  658. bool contains(const SparseBitVector<ElementSize>& RHS) const {
  659. SparseBitVector<ElementSize> Result(*this);
  660. Result &= RHS;
  661. return (Result == RHS);
  662. }
  663. // Return the first set bit in the bitmap. Return -1 if no bits are set.
  664. int find_first() const {
  665. if (Elements.empty())
  666. return -1;
  667. const SparseBitVectorElement<ElementSize>& First = *(Elements.begin());
  668. return (First.index() * ElementSize) + First.find_first();
  669. }
  670. // Return the last set bit in the bitmap. Return -1 if no bits are set.
  671. int find_last() const {
  672. if (Elements.empty())
  673. return -1;
  674. const SparseBitVectorElement<ElementSize>& Last = *(Elements.rbegin());
  675. return (Last.index() * ElementSize) + Last.find_last();
  676. }
  677. // Return true if the SparseBitVector is empty
  678. bool empty() const {
  679. return Elements.empty();
  680. }
  681. unsigned count() const {
  682. unsigned BitCount = 0;
  683. for (ElementListConstIter Iter = Elements.begin(); Iter != Elements.end();
  684. ++Iter)
  685. BitCount += Iter->count();
  686. return BitCount;
  687. }
  688. iterator begin() const {
  689. return iterator(this);
  690. }
  691. iterator end() const {
  692. return iterator(this, true);
  693. }
  694. };
  695. // Convenience functions to allow Or and And without dereferencing in the user
  696. // code.
  697. template <unsigned ElementSize>
  698. inline bool operator|=(
  699. SparseBitVector<ElementSize>& LHS,
  700. const SparseBitVector<ElementSize>* RHS) {
  701. return LHS |= *RHS;
  702. }
  703. template <unsigned ElementSize>
  704. inline bool operator|=(
  705. SparseBitVector<ElementSize>* LHS,
  706. const SparseBitVector<ElementSize>& RHS) {
  707. return LHS->operator|=(RHS);
  708. }
  709. template <unsigned ElementSize>
  710. inline bool operator&=(
  711. SparseBitVector<ElementSize>* LHS,
  712. const SparseBitVector<ElementSize>& RHS) {
  713. return LHS->operator&=(RHS);
  714. }
  715. template <unsigned ElementSize>
  716. inline bool operator&=(
  717. SparseBitVector<ElementSize>& LHS,
  718. const SparseBitVector<ElementSize>* RHS) {
  719. return LHS &= *RHS;
  720. }
  721. // Convenience functions for infix union, intersection, difference operators.
  722. template <unsigned ElementSize>
  723. inline SparseBitVector<ElementSize> operator|(
  724. const SparseBitVector<ElementSize>& LHS,
  725. const SparseBitVector<ElementSize>& RHS) {
  726. SparseBitVector<ElementSize> Result(LHS);
  727. Result |= RHS;
  728. return Result;
  729. }
  730. template <unsigned ElementSize>
  731. inline SparseBitVector<ElementSize> operator&(
  732. const SparseBitVector<ElementSize>& LHS,
  733. const SparseBitVector<ElementSize>& RHS) {
  734. SparseBitVector<ElementSize> Result(LHS);
  735. Result &= RHS;
  736. return Result;
  737. }
  738. template <unsigned ElementSize>
  739. inline SparseBitVector<ElementSize> operator-(
  740. const SparseBitVector<ElementSize>& LHS,
  741. const SparseBitVector<ElementSize>& RHS) {
  742. SparseBitVector<ElementSize> Result;
  743. Result.intersectWithComplement(LHS, RHS);
  744. return Result;
  745. }
  746. template <unsigned ElementSize>
  747. std::ostream& operator<<(
  748. std::ostream& stream,
  749. const SparseBitVector<ElementSize>& vec) {
  750. bool first = true;
  751. stream << "{";
  752. for (auto el : vec) {
  753. if (first) {
  754. first = false;
  755. } else {
  756. stream << ", ";
  757. }
  758. stream << el;
  759. }
  760. stream << "}";
  761. return stream;
  762. }
  763. } // end namespace c10