trustedhost.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from __future__ import annotations
  2. import typing
  3. from starlette.datastructures import URL, Headers
  4. from starlette.responses import PlainTextResponse, RedirectResponse, Response
  5. from starlette.types import ASGIApp, Receive, Scope, Send
  6. ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
  7. class TrustedHostMiddleware:
  8. def __init__(
  9. self,
  10. app: ASGIApp,
  11. allowed_hosts: typing.Sequence[str] | None = None,
  12. www_redirect: bool = True,
  13. ) -> None:
  14. if allowed_hosts is None:
  15. allowed_hosts = ["*"]
  16. for pattern in allowed_hosts:
  17. assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
  18. if pattern.startswith("*") and pattern != "*":
  19. assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
  20. self.app = app
  21. self.allowed_hosts = list(allowed_hosts)
  22. self.allow_any = "*" in allowed_hosts
  23. self.www_redirect = www_redirect
  24. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  25. if self.allow_any or scope["type"] not in (
  26. "http",
  27. "websocket",
  28. ): # pragma: no cover
  29. await self.app(scope, receive, send)
  30. return
  31. headers = Headers(scope=scope)
  32. host = headers.get("host", "").split(":")[0]
  33. is_valid_host = False
  34. found_www_redirect = False
  35. for pattern in self.allowed_hosts:
  36. if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
  37. is_valid_host = True
  38. break
  39. elif "www." + host == pattern:
  40. found_www_redirect = True
  41. if is_valid_host:
  42. await self.app(scope, receive, send)
  43. else:
  44. response: Response
  45. if found_www_redirect and self.www_redirect:
  46. url = URL(scope=scope)
  47. redirect_url = url.replace(netloc="www." + url.netloc)
  48. response = RedirectResponse(url=str(redirect_url))
  49. else:
  50. response = PlainTextResponse("Invalid host header", status_code=400)
  51. await response(scope, receive, send)