caching.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966
  1. from __future__ import annotations
  2. import collections
  3. import functools
  4. import logging
  5. import math
  6. import os
  7. import threading
  8. import warnings
  9. from concurrent.futures import Future, ThreadPoolExecutor
  10. from itertools import groupby
  11. from operator import itemgetter
  12. from typing import (
  13. TYPE_CHECKING,
  14. Any,
  15. Callable,
  16. ClassVar,
  17. Generic,
  18. NamedTuple,
  19. Optional,
  20. OrderedDict,
  21. TypeVar,
  22. )
  23. if TYPE_CHECKING:
  24. import mmap
  25. from typing_extensions import ParamSpec
  26. P = ParamSpec("P")
  27. else:
  28. P = TypeVar("P")
  29. T = TypeVar("T")
  30. logger = logging.getLogger("fsspec")
  31. Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes
  32. class BaseCache:
  33. """Pass-though cache: doesn't keep anything, calls every time
  34. Acts as base class for other cachers
  35. Parameters
  36. ----------
  37. blocksize: int
  38. How far to read ahead in numbers of bytes
  39. fetcher: func
  40. Function of the form f(start, end) which gets bytes from remote as
  41. specified
  42. size: int
  43. How big this file is
  44. """
  45. name: ClassVar[str] = "none"
  46. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  47. self.blocksize = blocksize
  48. self.nblocks = 0
  49. self.fetcher = fetcher
  50. self.size = size
  51. self.hit_count = 0
  52. self.miss_count = 0
  53. # the bytes that we actually requested
  54. self.total_requested_bytes = 0
  55. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  56. if start is None:
  57. start = 0
  58. if stop is None:
  59. stop = self.size
  60. if start >= self.size or start >= stop:
  61. return b""
  62. return self.fetcher(start, stop)
  63. def _reset_stats(self) -> None:
  64. """Reset hit and miss counts for a more ganular report e.g. by file."""
  65. self.hit_count = 0
  66. self.miss_count = 0
  67. self.total_requested_bytes = 0
  68. def _log_stats(self) -> str:
  69. """Return a formatted string of the cache statistics."""
  70. if self.hit_count == 0 and self.miss_count == 0:
  71. # a cache that does nothing, this is for logs only
  72. return ""
  73. return f" , {self.name}: {self.hit_count} hits, {self.miss_count} misses, {self.total_requested_bytes} total requested bytes"
  74. def __repr__(self) -> str:
  75. # TODO: use rich for better formatting
  76. return f"""
  77. <{self.__class__.__name__}:
  78. block size : {self.blocksize}
  79. block count : {self.nblocks}
  80. file size : {self.size}
  81. cache hits : {self.hit_count}
  82. cache misses: {self.miss_count}
  83. total requested bytes: {self.total_requested_bytes}>
  84. """
  85. class MMapCache(BaseCache):
  86. """memory-mapped sparse file cache
  87. Opens temporary file, which is filled blocks-wise when data is requested.
  88. Ensure there is enough disc space in the temporary location.
  89. This cache method might only work on posix
  90. """
  91. name = "mmap"
  92. def __init__(
  93. self,
  94. blocksize: int,
  95. fetcher: Fetcher,
  96. size: int,
  97. location: str | None = None,
  98. blocks: set[int] | None = None,
  99. ) -> None:
  100. super().__init__(blocksize, fetcher, size)
  101. self.blocks = set() if blocks is None else blocks
  102. self.location = location
  103. self.cache = self._makefile()
  104. def _makefile(self) -> mmap.mmap | bytearray:
  105. import mmap
  106. import tempfile
  107. if self.size == 0:
  108. return bytearray()
  109. # posix version
  110. if self.location is None or not os.path.exists(self.location):
  111. if self.location is None:
  112. fd = tempfile.TemporaryFile()
  113. self.blocks = set()
  114. else:
  115. fd = open(self.location, "wb+")
  116. fd.seek(self.size - 1)
  117. fd.write(b"1")
  118. fd.flush()
  119. else:
  120. fd = open(self.location, "r+b")
  121. return mmap.mmap(fd.fileno(), self.size)
  122. def _fetch(self, start: int | None, end: int | None) -> bytes:
  123. logger.debug(f"MMap cache fetching {start}-{end}")
  124. if start is None:
  125. start = 0
  126. if end is None:
  127. end = self.size
  128. if start >= self.size or start >= end:
  129. return b""
  130. start_block = start // self.blocksize
  131. end_block = end // self.blocksize
  132. block_range = range(start_block, end_block + 1)
  133. # Determine which blocks need to be fetched. This sequence is sorted by construction.
  134. need = (i for i in block_range if i not in self.blocks)
  135. # Count the number of blocks already cached
  136. self.hit_count += sum(1 for i in block_range if i in self.blocks)
  137. # Consolidate needed blocks.
  138. # Algorithm adapted from Python 2.x itertools documentation.
  139. # We are grouping an enumerated sequence of blocks. By comparing when the difference
  140. # between an ascending range (provided by enumerate) and the needed block numbers
  141. # we can detect when the block number skips values. The key computes this difference.
  142. # Whenever the difference changes, we know that we have previously cached block(s),
  143. # and a new group is started. In other words, this algorithm neatly groups
  144. # runs of consecutive block numbers so they can be fetched together.
  145. for _, _blocks in groupby(enumerate(need), key=lambda x: x[0] - x[1]):
  146. # Extract the blocks from the enumerated sequence
  147. _blocks = tuple(map(itemgetter(1), _blocks))
  148. # Compute start of first block
  149. sstart = _blocks[0] * self.blocksize
  150. # Compute the end of the last block. Last block may not be full size.
  151. send = min(_blocks[-1] * self.blocksize + self.blocksize, self.size)
  152. # Fetch bytes (could be multiple consecutive blocks)
  153. self.total_requested_bytes += send - sstart
  154. logger.debug(
  155. f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})"
  156. )
  157. self.cache[sstart:send] = self.fetcher(sstart, send)
  158. # Update set of cached blocks
  159. self.blocks.update(_blocks)
  160. # Update cache statistics with number of blocks we had to cache
  161. self.miss_count += len(_blocks)
  162. return self.cache[start:end]
  163. def __getstate__(self) -> dict[str, Any]:
  164. state = self.__dict__.copy()
  165. # Remove the unpicklable entries.
  166. del state["cache"]
  167. return state
  168. def __setstate__(self, state: dict[str, Any]) -> None:
  169. # Restore instance attributes
  170. self.__dict__.update(state)
  171. self.cache = self._makefile()
  172. class ReadAheadCache(BaseCache):
  173. """Cache which reads only when we get beyond a block of data
  174. This is a much simpler version of BytesCache, and does not attempt to
  175. fill holes in the cache or keep fragments alive. It is best suited to
  176. many small reads in a sequential order (e.g., reading lines from a file).
  177. """
  178. name = "readahead"
  179. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  180. super().__init__(blocksize, fetcher, size)
  181. self.cache = b""
  182. self.start = 0
  183. self.end = 0
  184. def _fetch(self, start: int | None, end: int | None) -> bytes:
  185. if start is None:
  186. start = 0
  187. if end is None or end > self.size:
  188. end = self.size
  189. if start >= self.size or start >= end:
  190. return b""
  191. l = end - start
  192. if start >= self.start and end <= self.end:
  193. # cache hit
  194. self.hit_count += 1
  195. return self.cache[start - self.start : end - self.start]
  196. elif self.start <= start < self.end:
  197. # partial hit
  198. self.miss_count += 1
  199. part = self.cache[start - self.start :]
  200. l -= len(part)
  201. start = self.end
  202. else:
  203. # miss
  204. self.miss_count += 1
  205. part = b""
  206. end = min(self.size, end + self.blocksize)
  207. self.total_requested_bytes += end - start
  208. self.cache = self.fetcher(start, end) # new block replaces old
  209. self.start = start
  210. self.end = self.start + len(self.cache)
  211. return part + self.cache[:l]
  212. class FirstChunkCache(BaseCache):
  213. """Caches the first block of a file only
  214. This may be useful for file types where the metadata is stored in the header,
  215. but is randomly accessed.
  216. """
  217. name = "first"
  218. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  219. if blocksize > size:
  220. # this will buffer the whole thing
  221. blocksize = size
  222. super().__init__(blocksize, fetcher, size)
  223. self.cache: bytes | None = None
  224. def _fetch(self, start: int | None, end: int | None) -> bytes:
  225. start = start or 0
  226. if start > self.size:
  227. logger.debug("FirstChunkCache: requested start > file size")
  228. return b""
  229. end = min(end, self.size)
  230. if start < self.blocksize:
  231. if self.cache is None:
  232. self.miss_count += 1
  233. if end > self.blocksize:
  234. self.total_requested_bytes += end
  235. data = self.fetcher(0, end)
  236. self.cache = data[: self.blocksize]
  237. return data[start:]
  238. self.cache = self.fetcher(0, self.blocksize)
  239. self.total_requested_bytes += self.blocksize
  240. part = self.cache[start:end]
  241. if end > self.blocksize:
  242. self.total_requested_bytes += end - self.blocksize
  243. part += self.fetcher(self.blocksize, end)
  244. self.hit_count += 1
  245. return part
  246. else:
  247. self.miss_count += 1
  248. self.total_requested_bytes += end - start
  249. return self.fetcher(start, end)
  250. class BlockCache(BaseCache):
  251. """
  252. Cache holding memory as a set of blocks.
  253. Requests are only ever made ``blocksize`` at a time, and are
  254. stored in an LRU cache. The least recently accessed block is
  255. discarded when more than ``maxblocks`` are stored.
  256. Parameters
  257. ----------
  258. blocksize : int
  259. The number of bytes to store in each block.
  260. Requests are only ever made for ``blocksize``, so this
  261. should balance the overhead of making a request against
  262. the granularity of the blocks.
  263. fetcher : Callable
  264. size : int
  265. The total size of the file being cached.
  266. maxblocks : int
  267. The maximum number of blocks to cache for. The maximum memory
  268. use for this cache is then ``blocksize * maxblocks``.
  269. """
  270. name = "blockcache"
  271. def __init__(
  272. self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32
  273. ) -> None:
  274. super().__init__(blocksize, fetcher, size)
  275. self.nblocks = math.ceil(size / blocksize)
  276. self.maxblocks = maxblocks
  277. self._fetch_block_cached = functools.lru_cache(maxblocks)(self._fetch_block)
  278. def cache_info(self):
  279. """
  280. The statistics on the block cache.
  281. Returns
  282. -------
  283. NamedTuple
  284. Returned directly from the LRU Cache used internally.
  285. """
  286. return self._fetch_block_cached.cache_info()
  287. def __getstate__(self) -> dict[str, Any]:
  288. state = self.__dict__
  289. del state["_fetch_block_cached"]
  290. return state
  291. def __setstate__(self, state: dict[str, Any]) -> None:
  292. self.__dict__.update(state)
  293. self._fetch_block_cached = functools.lru_cache(state["maxblocks"])(
  294. self._fetch_block
  295. )
  296. def _fetch(self, start: int | None, end: int | None) -> bytes:
  297. if start is None:
  298. start = 0
  299. if end is None:
  300. end = self.size
  301. if start >= self.size or start >= end:
  302. return b""
  303. # byte position -> block numbers
  304. start_block_number = start // self.blocksize
  305. end_block_number = end // self.blocksize
  306. # these are cached, so safe to do multiple calls for the same start and end.
  307. for block_number in range(start_block_number, end_block_number + 1):
  308. self._fetch_block_cached(block_number)
  309. return self._read_cache(
  310. start,
  311. end,
  312. start_block_number=start_block_number,
  313. end_block_number=end_block_number,
  314. )
  315. def _fetch_block(self, block_number: int) -> bytes:
  316. """
  317. Fetch the block of data for `block_number`.
  318. """
  319. if block_number > self.nblocks:
  320. raise ValueError(
  321. f"'block_number={block_number}' is greater than "
  322. f"the number of blocks ({self.nblocks})"
  323. )
  324. start = block_number * self.blocksize
  325. end = start + self.blocksize
  326. self.total_requested_bytes += end - start
  327. self.miss_count += 1
  328. logger.info("BlockCache fetching block %d", block_number)
  329. block_contents = super()._fetch(start, end)
  330. return block_contents
  331. def _read_cache(
  332. self, start: int, end: int, start_block_number: int, end_block_number: int
  333. ) -> bytes:
  334. """
  335. Read from our block cache.
  336. Parameters
  337. ----------
  338. start, end : int
  339. The start and end byte positions.
  340. start_block_number, end_block_number : int
  341. The start and end block numbers.
  342. """
  343. start_pos = start % self.blocksize
  344. end_pos = end % self.blocksize
  345. self.hit_count += 1
  346. if start_block_number == end_block_number:
  347. block: bytes = self._fetch_block_cached(start_block_number)
  348. return block[start_pos:end_pos]
  349. else:
  350. # read from the initial
  351. out = [self._fetch_block_cached(start_block_number)[start_pos:]]
  352. # intermediate blocks
  353. # Note: it'd be nice to combine these into one big request. However
  354. # that doesn't play nicely with our LRU cache.
  355. out.extend(
  356. map(
  357. self._fetch_block_cached,
  358. range(start_block_number + 1, end_block_number),
  359. )
  360. )
  361. # final block
  362. out.append(self._fetch_block_cached(end_block_number)[:end_pos])
  363. return b"".join(out)
  364. class BytesCache(BaseCache):
  365. """Cache which holds data in a in-memory bytes object
  366. Implements read-ahead by the block size, for semi-random reads progressing
  367. through the file.
  368. Parameters
  369. ----------
  370. trim: bool
  371. As we read more data, whether to discard the start of the buffer when
  372. we are more than a blocksize ahead of it.
  373. """
  374. name: ClassVar[str] = "bytes"
  375. def __init__(
  376. self, blocksize: int, fetcher: Fetcher, size: int, trim: bool = True
  377. ) -> None:
  378. super().__init__(blocksize, fetcher, size)
  379. self.cache = b""
  380. self.start: int | None = None
  381. self.end: int | None = None
  382. self.trim = trim
  383. def _fetch(self, start: int | None, end: int | None) -> bytes:
  384. # TODO: only set start/end after fetch, in case it fails?
  385. # is this where retry logic might go?
  386. if start is None:
  387. start = 0
  388. if end is None:
  389. end = self.size
  390. if start >= self.size or start >= end:
  391. return b""
  392. if (
  393. self.start is not None
  394. and start >= self.start
  395. and self.end is not None
  396. and end < self.end
  397. ):
  398. # cache hit: we have all the required data
  399. offset = start - self.start
  400. self.hit_count += 1
  401. return self.cache[offset : offset + end - start]
  402. if self.blocksize:
  403. bend = min(self.size, end + self.blocksize)
  404. else:
  405. bend = end
  406. if bend == start or start > self.size:
  407. return b""
  408. if (self.start is None or start < self.start) and (
  409. self.end is None or end > self.end
  410. ):
  411. # First read, or extending both before and after
  412. self.total_requested_bytes += bend - start
  413. self.miss_count += 1
  414. self.cache = self.fetcher(start, bend)
  415. self.start = start
  416. else:
  417. assert self.start is not None
  418. assert self.end is not None
  419. self.miss_count += 1
  420. if start < self.start:
  421. if self.end is None or self.end - end > self.blocksize:
  422. self.total_requested_bytes += bend - start
  423. self.cache = self.fetcher(start, bend)
  424. self.start = start
  425. else:
  426. self.total_requested_bytes += self.start - start
  427. new = self.fetcher(start, self.start)
  428. self.start = start
  429. self.cache = new + self.cache
  430. elif self.end is not None and bend > self.end:
  431. if self.end > self.size:
  432. pass
  433. elif end - self.end > self.blocksize:
  434. self.total_requested_bytes += bend - start
  435. self.cache = self.fetcher(start, bend)
  436. self.start = start
  437. else:
  438. self.total_requested_bytes += bend - self.end
  439. new = self.fetcher(self.end, bend)
  440. self.cache = self.cache + new
  441. self.end = self.start + len(self.cache)
  442. offset = start - self.start
  443. out = self.cache[offset : offset + end - start]
  444. if self.trim:
  445. num = (self.end - self.start) // (self.blocksize + 1)
  446. if num > 1:
  447. self.start += self.blocksize * num
  448. self.cache = self.cache[self.blocksize * num :]
  449. return out
  450. def __len__(self) -> int:
  451. return len(self.cache)
  452. class AllBytes(BaseCache):
  453. """Cache entire contents of the file"""
  454. name: ClassVar[str] = "all"
  455. def __init__(
  456. self,
  457. blocksize: int | None = None,
  458. fetcher: Fetcher | None = None,
  459. size: int | None = None,
  460. data: bytes | None = None,
  461. ) -> None:
  462. super().__init__(blocksize, fetcher, size) # type: ignore[arg-type]
  463. if data is None:
  464. self.miss_count += 1
  465. self.total_requested_bytes += self.size
  466. data = self.fetcher(0, self.size)
  467. self.data = data
  468. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  469. self.hit_count += 1
  470. return self.data[start:stop]
  471. class KnownPartsOfAFile(BaseCache):
  472. """
  473. Cache holding known file parts.
  474. Parameters
  475. ----------
  476. blocksize: int
  477. How far to read ahead in numbers of bytes
  478. fetcher: func
  479. Function of the form f(start, end) which gets bytes from remote as
  480. specified
  481. size: int
  482. How big this file is
  483. data: dict
  484. A dictionary mapping explicit `(start, stop)` file-offset tuples
  485. with known bytes.
  486. strict: bool, default True
  487. Whether to fetch reads that go beyond a known byte-range boundary.
  488. If `False`, any read that ends outside a known part will be zero
  489. padded. Note that zero padding will not be used for reads that
  490. begin outside a known byte-range.
  491. """
  492. name: ClassVar[str] = "parts"
  493. def __init__(
  494. self,
  495. blocksize: int,
  496. fetcher: Fetcher,
  497. size: int,
  498. data: Optional[dict[tuple[int, int], bytes]] = None,
  499. strict: bool = True,
  500. **_: Any,
  501. ):
  502. super().__init__(blocksize, fetcher, size)
  503. self.strict = strict
  504. # simple consolidation of contiguous blocks
  505. if data:
  506. old_offsets = sorted(data.keys())
  507. offsets = [old_offsets[0]]
  508. blocks = [data.pop(old_offsets[0])]
  509. for start, stop in old_offsets[1:]:
  510. start0, stop0 = offsets[-1]
  511. if start == stop0:
  512. offsets[-1] = (start0, stop)
  513. blocks[-1] += data.pop((start, stop))
  514. else:
  515. offsets.append((start, stop))
  516. blocks.append(data.pop((start, stop)))
  517. self.data = dict(zip(offsets, blocks))
  518. else:
  519. self.data = {}
  520. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  521. if start is None:
  522. start = 0
  523. if stop is None:
  524. stop = self.size
  525. out = b""
  526. for (loc0, loc1), data in self.data.items():
  527. # If self.strict=False, use zero-padded data
  528. # for reads beyond the end of a "known" buffer
  529. if loc0 <= start < loc1:
  530. off = start - loc0
  531. out = data[off : off + stop - start]
  532. if not self.strict or loc0 <= stop <= loc1:
  533. # The request is within a known range, or
  534. # it begins within a known range, and we
  535. # are allowed to pad reads beyond the
  536. # buffer with zero
  537. out += b"\x00" * (stop - start - len(out))
  538. self.hit_count += 1
  539. return out
  540. else:
  541. # The request ends outside a known range,
  542. # and we are being "strict" about reads
  543. # beyond the buffer
  544. start = loc1
  545. break
  546. # We only get here if there is a request outside the
  547. # known parts of the file. In an ideal world, this
  548. # should never happen
  549. if self.fetcher is None:
  550. # We cannot fetch the data, so raise an error
  551. raise ValueError(f"Read is outside the known file parts: {(start, stop)}. ")
  552. # We can fetch the data, but should warn the user
  553. # that this may be slow
  554. warnings.warn(
  555. f"Read is outside the known file parts: {(start, stop)}. "
  556. f"IO/caching performance may be poor!"
  557. )
  558. logger.debug(f"KnownPartsOfAFile cache fetching {start}-{stop}")
  559. self.total_requested_bytes += stop - start
  560. self.miss_count += 1
  561. return out + super()._fetch(start, stop)
  562. class UpdatableLRU(Generic[P, T]):
  563. """
  564. Custom implementation of LRU cache that allows updating keys
  565. Used by BackgroudBlockCache
  566. """
  567. class CacheInfo(NamedTuple):
  568. hits: int
  569. misses: int
  570. maxsize: int
  571. currsize: int
  572. def __init__(self, func: Callable[P, T], max_size: int = 128) -> None:
  573. self._cache: OrderedDict[Any, T] = collections.OrderedDict()
  574. self._func = func
  575. self._max_size = max_size
  576. self._hits = 0
  577. self._misses = 0
  578. self._lock = threading.Lock()
  579. def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
  580. if kwargs:
  581. raise TypeError(f"Got unexpected keyword argument {kwargs.keys()}")
  582. with self._lock:
  583. if args in self._cache:
  584. self._cache.move_to_end(args)
  585. self._hits += 1
  586. return self._cache[args]
  587. result = self._func(*args, **kwargs)
  588. with self._lock:
  589. self._cache[args] = result
  590. self._misses += 1
  591. if len(self._cache) > self._max_size:
  592. self._cache.popitem(last=False)
  593. return result
  594. def is_key_cached(self, *args: Any) -> bool:
  595. with self._lock:
  596. return args in self._cache
  597. def add_key(self, result: T, *args: Any) -> None:
  598. with self._lock:
  599. self._cache[args] = result
  600. if len(self._cache) > self._max_size:
  601. self._cache.popitem(last=False)
  602. def cache_info(self) -> UpdatableLRU.CacheInfo:
  603. with self._lock:
  604. return self.CacheInfo(
  605. maxsize=self._max_size,
  606. currsize=len(self._cache),
  607. hits=self._hits,
  608. misses=self._misses,
  609. )
  610. class BackgroundBlockCache(BaseCache):
  611. """
  612. Cache holding memory as a set of blocks with pre-loading of
  613. the next block in the background.
  614. Requests are only ever made ``blocksize`` at a time, and are
  615. stored in an LRU cache. The least recently accessed block is
  616. discarded when more than ``maxblocks`` are stored. If the
  617. next block is not in cache, it is loaded in a separate thread
  618. in non-blocking way.
  619. Parameters
  620. ----------
  621. blocksize : int
  622. The number of bytes to store in each block.
  623. Requests are only ever made for ``blocksize``, so this
  624. should balance the overhead of making a request against
  625. the granularity of the blocks.
  626. fetcher : Callable
  627. size : int
  628. The total size of the file being cached.
  629. maxblocks : int
  630. The maximum number of blocks to cache for. The maximum memory
  631. use for this cache is then ``blocksize * maxblocks``.
  632. """
  633. name: ClassVar[str] = "background"
  634. def __init__(
  635. self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32
  636. ) -> None:
  637. super().__init__(blocksize, fetcher, size)
  638. self.nblocks = math.ceil(size / blocksize)
  639. self.maxblocks = maxblocks
  640. self._fetch_block_cached = UpdatableLRU(self._fetch_block, maxblocks)
  641. self._thread_executor = ThreadPoolExecutor(max_workers=1)
  642. self._fetch_future_block_number: int | None = None
  643. self._fetch_future: Future[bytes] | None = None
  644. self._fetch_future_lock = threading.Lock()
  645. def cache_info(self) -> UpdatableLRU.CacheInfo:
  646. """
  647. The statistics on the block cache.
  648. Returns
  649. -------
  650. NamedTuple
  651. Returned directly from the LRU Cache used internally.
  652. """
  653. return self._fetch_block_cached.cache_info()
  654. def __getstate__(self) -> dict[str, Any]:
  655. state = self.__dict__
  656. del state["_fetch_block_cached"]
  657. del state["_thread_executor"]
  658. del state["_fetch_future_block_number"]
  659. del state["_fetch_future"]
  660. del state["_fetch_future_lock"]
  661. return state
  662. def __setstate__(self, state) -> None:
  663. self.__dict__.update(state)
  664. self._fetch_block_cached = UpdatableLRU(self._fetch_block, state["maxblocks"])
  665. self._thread_executor = ThreadPoolExecutor(max_workers=1)
  666. self._fetch_future_block_number = None
  667. self._fetch_future = None
  668. self._fetch_future_lock = threading.Lock()
  669. def _fetch(self, start: int | None, end: int | None) -> bytes:
  670. if start is None:
  671. start = 0
  672. if end is None:
  673. end = self.size
  674. if start >= self.size or start >= end:
  675. return b""
  676. # byte position -> block numbers
  677. start_block_number = start // self.blocksize
  678. end_block_number = end // self.blocksize
  679. fetch_future_block_number = None
  680. fetch_future = None
  681. with self._fetch_future_lock:
  682. # Background thread is running. Check we we can or must join it.
  683. if self._fetch_future is not None:
  684. assert self._fetch_future_block_number is not None
  685. if self._fetch_future.done():
  686. logger.info("BlockCache joined background fetch without waiting.")
  687. self._fetch_block_cached.add_key(
  688. self._fetch_future.result(), self._fetch_future_block_number
  689. )
  690. # Cleanup the fetch variables. Done with fetching the block.
  691. self._fetch_future_block_number = None
  692. self._fetch_future = None
  693. else:
  694. # Must join if we need the block for the current fetch
  695. must_join = bool(
  696. start_block_number
  697. <= self._fetch_future_block_number
  698. <= end_block_number
  699. )
  700. if must_join:
  701. # Copy to the local variables to release lock
  702. # before waiting for result
  703. fetch_future_block_number = self._fetch_future_block_number
  704. fetch_future = self._fetch_future
  705. # Cleanup the fetch variables. Have a local copy.
  706. self._fetch_future_block_number = None
  707. self._fetch_future = None
  708. # Need to wait for the future for the current read
  709. if fetch_future is not None:
  710. logger.info("BlockCache waiting for background fetch.")
  711. # Wait until result and put it in cache
  712. self._fetch_block_cached.add_key(
  713. fetch_future.result(), fetch_future_block_number
  714. )
  715. # these are cached, so safe to do multiple calls for the same start and end.
  716. for block_number in range(start_block_number, end_block_number + 1):
  717. self._fetch_block_cached(block_number)
  718. # fetch next block in the background if nothing is running in the background,
  719. # the block is within file and it is not already cached
  720. end_block_plus_1 = end_block_number + 1
  721. with self._fetch_future_lock:
  722. if (
  723. self._fetch_future is None
  724. and end_block_plus_1 <= self.nblocks
  725. and not self._fetch_block_cached.is_key_cached(end_block_plus_1)
  726. ):
  727. self._fetch_future_block_number = end_block_plus_1
  728. self._fetch_future = self._thread_executor.submit(
  729. self._fetch_block, end_block_plus_1, "async"
  730. )
  731. return self._read_cache(
  732. start,
  733. end,
  734. start_block_number=start_block_number,
  735. end_block_number=end_block_number,
  736. )
  737. def _fetch_block(self, block_number: int, log_info: str = "sync") -> bytes:
  738. """
  739. Fetch the block of data for `block_number`.
  740. """
  741. if block_number > self.nblocks:
  742. raise ValueError(
  743. f"'block_number={block_number}' is greater than "
  744. f"the number of blocks ({self.nblocks})"
  745. )
  746. start = block_number * self.blocksize
  747. end = start + self.blocksize
  748. logger.info("BlockCache fetching block (%s) %d", log_info, block_number)
  749. self.total_requested_bytes += end - start
  750. self.miss_count += 1
  751. block_contents = super()._fetch(start, end)
  752. return block_contents
  753. def _read_cache(
  754. self, start: int, end: int, start_block_number: int, end_block_number: int
  755. ) -> bytes:
  756. """
  757. Read from our block cache.
  758. Parameters
  759. ----------
  760. start, end : int
  761. The start and end byte positions.
  762. start_block_number, end_block_number : int
  763. The start and end block numbers.
  764. """
  765. start_pos = start % self.blocksize
  766. end_pos = end % self.blocksize
  767. # kind of pointless to count this as a hit, but it is
  768. self.hit_count += 1
  769. if start_block_number == end_block_number:
  770. block = self._fetch_block_cached(start_block_number)
  771. return block[start_pos:end_pos]
  772. else:
  773. # read from the initial
  774. out = [self._fetch_block_cached(start_block_number)[start_pos:]]
  775. # intermediate blocks
  776. # Note: it'd be nice to combine these into one big request. However
  777. # that doesn't play nicely with our LRU cache.
  778. out.extend(
  779. map(
  780. self._fetch_block_cached,
  781. range(start_block_number + 1, end_block_number),
  782. )
  783. )
  784. # final block
  785. out.append(self._fetch_block_cached(end_block_number)[:end_pos])
  786. return b"".join(out)
  787. caches: dict[str | None, type[BaseCache]] = {
  788. # one custom case
  789. None: BaseCache,
  790. }
  791. def register_cache(cls: type[BaseCache], clobber: bool = False) -> None:
  792. """'Register' cache implementation.
  793. Parameters
  794. ----------
  795. clobber: bool, optional
  796. If set to True (default is False) - allow to overwrite existing
  797. entry.
  798. Raises
  799. ------
  800. ValueError
  801. """
  802. name = cls.name
  803. if not clobber and name in caches:
  804. raise ValueError(f"Cache with name {name!r} is already known: {caches[name]}")
  805. caches[name] = cls
  806. for c in (
  807. BaseCache,
  808. MMapCache,
  809. BytesCache,
  810. ReadAheadCache,
  811. BlockCache,
  812. FirstChunkCache,
  813. AllBytes,
  814. KnownPartsOfAFile,
  815. BackgroundBlockCache,
  816. ):
  817. register_cache(c)