_trace.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from __future__ import annotations
  2. import inspect
  3. import logging
  4. import types
  5. import typing
  6. from ._models import Request
  7. class Trace:
  8. def __init__(
  9. self,
  10. name: str,
  11. logger: logging.Logger,
  12. request: Request | None = None,
  13. kwargs: dict[str, typing.Any] | None = None,
  14. ) -> None:
  15. self.name = name
  16. self.logger = logger
  17. self.trace_extension = (
  18. None if request is None else request.extensions.get("trace")
  19. )
  20. self.debug = self.logger.isEnabledFor(logging.DEBUG)
  21. self.kwargs = kwargs or {}
  22. self.return_value: typing.Any = None
  23. self.should_trace = self.debug or self.trace_extension is not None
  24. self.prefix = self.logger.name.split(".")[-1]
  25. def trace(self, name: str, info: dict[str, typing.Any]) -> None:
  26. if self.trace_extension is not None:
  27. prefix_and_name = f"{self.prefix}.{name}"
  28. ret = self.trace_extension(prefix_and_name, info)
  29. if inspect.iscoroutine(ret): # pragma: no cover
  30. raise TypeError(
  31. "If you are using a synchronous interface, "
  32. "the callback of the `trace` extension should "
  33. "be a normal function instead of an asynchronous function."
  34. )
  35. if self.debug:
  36. if not info or "return_value" in info and info["return_value"] is None:
  37. message = name
  38. else:
  39. args = " ".join([f"{key}={value!r}" for key, value in info.items()])
  40. message = f"{name} {args}"
  41. self.logger.debug(message)
  42. def __enter__(self) -> Trace:
  43. if self.should_trace:
  44. info = self.kwargs
  45. self.trace(f"{self.name}.started", info)
  46. return self
  47. def __exit__(
  48. self,
  49. exc_type: type[BaseException] | None = None,
  50. exc_value: BaseException | None = None,
  51. traceback: types.TracebackType | None = None,
  52. ) -> None:
  53. if self.should_trace:
  54. if exc_value is None:
  55. info = {"return_value": self.return_value}
  56. self.trace(f"{self.name}.complete", info)
  57. else:
  58. info = {"exception": exc_value}
  59. self.trace(f"{self.name}.failed", info)
  60. async def atrace(self, name: str, info: dict[str, typing.Any]) -> None:
  61. if self.trace_extension is not None:
  62. prefix_and_name = f"{self.prefix}.{name}"
  63. coro = self.trace_extension(prefix_and_name, info)
  64. if not inspect.iscoroutine(coro): # pragma: no cover
  65. raise TypeError(
  66. "If you're using an asynchronous interface, "
  67. "the callback of the `trace` extension should "
  68. "be an asynchronous function rather than a normal function."
  69. )
  70. await coro
  71. if self.debug:
  72. if not info or "return_value" in info and info["return_value"] is None:
  73. message = name
  74. else:
  75. args = " ".join([f"{key}={value!r}" for key, value in info.items()])
  76. message = f"{name} {args}"
  77. self.logger.debug(message)
  78. async def __aenter__(self) -> Trace:
  79. if self.should_trace:
  80. info = self.kwargs
  81. await self.atrace(f"{self.name}.started", info)
  82. return self
  83. async def __aexit__(
  84. self,
  85. exc_type: type[BaseException] | None = None,
  86. exc_value: BaseException | None = None,
  87. traceback: types.TracebackType | None = None,
  88. ) -> None:
  89. if self.should_trace:
  90. if exc_value is None:
  91. info = {"return_value": self.return_value}
  92. await self.atrace(f"{self.name}.complete", info)
  93. else:
  94. info = {"exception": exc_value}
  95. await self.atrace(f"{self.name}.failed", info)