staticfiles.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from __future__ import annotations
  2. import errno
  3. import importlib.util
  4. import os
  5. import stat
  6. import typing
  7. from email.utils import parsedate
  8. import anyio
  9. import anyio.to_thread
  10. from starlette._utils import get_route_path
  11. from starlette.datastructures import URL, Headers
  12. from starlette.exceptions import HTTPException
  13. from starlette.responses import FileResponse, RedirectResponse, Response
  14. from starlette.types import Receive, Scope, Send
  15. PathLike = typing.Union[str, "os.PathLike[str]"]
  16. class NotModifiedResponse(Response):
  17. NOT_MODIFIED_HEADERS = (
  18. "cache-control",
  19. "content-location",
  20. "date",
  21. "etag",
  22. "expires",
  23. "vary",
  24. )
  25. def __init__(self, headers: Headers):
  26. super().__init__(
  27. status_code=304,
  28. headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
  29. )
  30. class StaticFiles:
  31. def __init__(
  32. self,
  33. *,
  34. directory: PathLike | None = None,
  35. packages: list[str | tuple[str, str]] | None = None,
  36. html: bool = False,
  37. check_dir: bool = True,
  38. follow_symlink: bool = False,
  39. ) -> None:
  40. self.directory = directory
  41. self.packages = packages
  42. self.all_directories = self.get_directories(directory, packages)
  43. self.html = html
  44. self.config_checked = False
  45. self.follow_symlink = follow_symlink
  46. if check_dir and directory is not None and not os.path.isdir(directory):
  47. raise RuntimeError(f"Directory '{directory}' does not exist")
  48. def get_directories(
  49. self,
  50. directory: PathLike | None = None,
  51. packages: list[str | tuple[str, str]] | None = None,
  52. ) -> list[PathLike]:
  53. """
  54. Given `directory` and `packages` arguments, return a list of all the
  55. directories that should be used for serving static files from.
  56. """
  57. directories = []
  58. if directory is not None:
  59. directories.append(directory)
  60. for package in packages or []:
  61. if isinstance(package, tuple):
  62. package, statics_dir = package
  63. else:
  64. statics_dir = "statics"
  65. spec = importlib.util.find_spec(package)
  66. assert spec is not None, f"Package {package!r} could not be found."
  67. assert spec.origin is not None, f"Package {package!r} could not be found."
  68. package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
  69. assert os.path.isdir(
  70. package_directory
  71. ), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
  72. directories.append(package_directory)
  73. return directories
  74. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  75. """
  76. The ASGI entry point.
  77. """
  78. assert scope["type"] == "http"
  79. if not self.config_checked:
  80. await self.check_config()
  81. self.config_checked = True
  82. path = self.get_path(scope)
  83. response = await self.get_response(path, scope)
  84. await response(scope, receive, send)
  85. def get_path(self, scope: Scope) -> str:
  86. """
  87. Given the ASGI scope, return the `path` string to serve up,
  88. with OS specific path separators, and any '..', '.' components removed.
  89. """
  90. route_path = get_route_path(scope)
  91. return os.path.normpath(os.path.join(*route_path.split("/")))
  92. async def get_response(self, path: str, scope: Scope) -> Response:
  93. """
  94. Returns an HTTP response, given the incoming path, method and request headers.
  95. """
  96. if scope["method"] not in ("GET", "HEAD"):
  97. raise HTTPException(status_code=405)
  98. try:
  99. full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
  100. except PermissionError:
  101. raise HTTPException(status_code=401)
  102. except OSError as exc:
  103. # Filename is too long, so it can't be a valid static file.
  104. if exc.errno == errno.ENAMETOOLONG:
  105. raise HTTPException(status_code=404)
  106. raise exc
  107. if stat_result and stat.S_ISREG(stat_result.st_mode):
  108. # We have a static file to serve.
  109. return self.file_response(full_path, stat_result, scope)
  110. elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
  111. # We're in HTML mode, and have got a directory URL.
  112. # Check if we have 'index.html' file to serve.
  113. index_path = os.path.join(path, "index.html")
  114. full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
  115. if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
  116. if not scope["path"].endswith("/"):
  117. # Directory URLs should redirect to always end in "/".
  118. url = URL(scope=scope)
  119. url = url.replace(path=url.path + "/")
  120. return RedirectResponse(url=url)
  121. return self.file_response(full_path, stat_result, scope)
  122. if self.html:
  123. # Check for '404.html' if we're in HTML mode.
  124. full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
  125. if stat_result and stat.S_ISREG(stat_result.st_mode):
  126. return FileResponse(full_path, stat_result=stat_result, status_code=404)
  127. raise HTTPException(status_code=404)
  128. def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
  129. for directory in self.all_directories:
  130. joined_path = os.path.join(directory, path)
  131. if self.follow_symlink:
  132. full_path = os.path.abspath(joined_path)
  133. else:
  134. full_path = os.path.realpath(joined_path)
  135. directory = os.path.realpath(directory)
  136. if os.path.commonpath([full_path, directory]) != directory:
  137. # Don't allow misbehaving clients to break out of the static files
  138. # directory.
  139. continue
  140. try:
  141. return full_path, os.stat(full_path)
  142. except (FileNotFoundError, NotADirectoryError):
  143. continue
  144. return "", None
  145. def file_response(
  146. self,
  147. full_path: PathLike,
  148. stat_result: os.stat_result,
  149. scope: Scope,
  150. status_code: int = 200,
  151. ) -> Response:
  152. request_headers = Headers(scope=scope)
  153. response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
  154. if self.is_not_modified(response.headers, request_headers):
  155. return NotModifiedResponse(response.headers)
  156. return response
  157. async def check_config(self) -> None:
  158. """
  159. Perform a one-off configuration check that StaticFiles is actually
  160. pointed at a directory, so that we can raise loud errors rather than
  161. just returning 404 responses.
  162. """
  163. if self.directory is None:
  164. return
  165. try:
  166. stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
  167. except FileNotFoundError:
  168. raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
  169. if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
  170. raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
  171. def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
  172. """
  173. Given the request and response headers, return `True` if an HTTP
  174. "Not Modified" response could be returned instead.
  175. """
  176. try:
  177. if_none_match = request_headers["if-none-match"]
  178. etag = response_headers["etag"]
  179. if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]:
  180. return True
  181. except KeyError:
  182. pass
  183. try:
  184. if_modified_since = parsedate(request_headers["if-modified-since"])
  185. last_modified = parsedate(response_headers["last-modified"])
  186. if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
  187. return True
  188. except KeyError:
  189. pass
  190. return False