_estimator_html_repr.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. import html
  2. from contextlib import closing
  3. from inspect import isclass
  4. from io import StringIO
  5. from string import Template
  6. from .. import config_context
  7. class _IDCounter:
  8. """Generate sequential ids with a prefix."""
  9. def __init__(self, prefix):
  10. self.prefix = prefix
  11. self.count = 0
  12. def get_id(self):
  13. self.count += 1
  14. return f"{self.prefix}-{self.count}"
  15. _CONTAINER_ID_COUNTER = _IDCounter("sk-container-id")
  16. _ESTIMATOR_ID_COUNTER = _IDCounter("sk-estimator-id")
  17. class _VisualBlock:
  18. """HTML Representation of Estimator
  19. Parameters
  20. ----------
  21. kind : {'serial', 'parallel', 'single'}
  22. kind of HTML block
  23. estimators : list of estimators or `_VisualBlock`s or a single estimator
  24. If kind != 'single', then `estimators` is a list of
  25. estimators.
  26. If kind == 'single', then `estimators` is a single estimator.
  27. names : list of str, default=None
  28. If kind != 'single', then `names` corresponds to estimators.
  29. If kind == 'single', then `names` is a single string corresponding to
  30. the single estimator.
  31. name_details : list of str, str, or None, default=None
  32. If kind != 'single', then `name_details` corresponds to `names`.
  33. If kind == 'single', then `name_details` is a single string
  34. corresponding to the single estimator.
  35. dash_wrapped : bool, default=True
  36. If true, wrapped HTML element will be wrapped with a dashed border.
  37. Only active when kind != 'single'.
  38. """
  39. def __init__(
  40. self, kind, estimators, *, names=None, name_details=None, dash_wrapped=True
  41. ):
  42. self.kind = kind
  43. self.estimators = estimators
  44. self.dash_wrapped = dash_wrapped
  45. if self.kind in ("parallel", "serial"):
  46. if names is None:
  47. names = (None,) * len(estimators)
  48. if name_details is None:
  49. name_details = (None,) * len(estimators)
  50. self.names = names
  51. self.name_details = name_details
  52. def _sk_visual_block_(self):
  53. return self
  54. def _write_label_html(
  55. out,
  56. name,
  57. name_details,
  58. outer_class="sk-label-container",
  59. inner_class="sk-label",
  60. checked=False,
  61. ):
  62. """Write labeled html with or without a dropdown with named details"""
  63. out.write(f'<div class="{outer_class}"><div class="{inner_class} sk-toggleable">')
  64. name = html.escape(name)
  65. if name_details is not None:
  66. name_details = html.escape(str(name_details))
  67. label_class = "sk-toggleable__label sk-toggleable__label-arrow"
  68. checked_str = "checked" if checked else ""
  69. est_id = _ESTIMATOR_ID_COUNTER.get_id()
  70. out.write(
  71. '<input class="sk-toggleable__control sk-hidden--visually" '
  72. f'id="{est_id}" type="checkbox" {checked_str}>'
  73. f'<label for="{est_id}" class="{label_class}">{name}</label>'
  74. f'<div class="sk-toggleable__content"><pre>{name_details}'
  75. "</pre></div>"
  76. )
  77. else:
  78. out.write(f"<label>{name}</label>")
  79. out.write("</div></div>") # outer_class inner_class
  80. def _get_visual_block(estimator):
  81. """Generate information about how to display an estimator."""
  82. if hasattr(estimator, "_sk_visual_block_"):
  83. try:
  84. return estimator._sk_visual_block_()
  85. except Exception:
  86. return _VisualBlock(
  87. "single",
  88. estimator,
  89. names=estimator.__class__.__name__,
  90. name_details=str(estimator),
  91. )
  92. if isinstance(estimator, str):
  93. return _VisualBlock(
  94. "single", estimator, names=estimator, name_details=estimator
  95. )
  96. elif estimator is None:
  97. return _VisualBlock("single", estimator, names="None", name_details="None")
  98. # check if estimator looks like a meta estimator wraps estimators
  99. if hasattr(estimator, "get_params") and not isclass(estimator):
  100. estimators = [
  101. (key, est)
  102. for key, est in estimator.get_params(deep=False).items()
  103. if hasattr(est, "get_params") and hasattr(est, "fit") and not isclass(est)
  104. ]
  105. if estimators:
  106. return _VisualBlock(
  107. "parallel",
  108. [est for _, est in estimators],
  109. names=[f"{key}: {est.__class__.__name__}" for key, est in estimators],
  110. name_details=[str(est) for _, est in estimators],
  111. )
  112. return _VisualBlock(
  113. "single",
  114. estimator,
  115. names=estimator.__class__.__name__,
  116. name_details=str(estimator),
  117. )
  118. def _write_estimator_html(
  119. out, estimator, estimator_label, estimator_label_details, first_call=False
  120. ):
  121. """Write estimator to html in serial, parallel, or by itself (single)."""
  122. if first_call:
  123. est_block = _get_visual_block(estimator)
  124. else:
  125. with config_context(print_changed_only=True):
  126. est_block = _get_visual_block(estimator)
  127. if est_block.kind in ("serial", "parallel"):
  128. dashed_wrapped = first_call or est_block.dash_wrapped
  129. dash_cls = " sk-dashed-wrapped" if dashed_wrapped else ""
  130. out.write(f'<div class="sk-item{dash_cls}">')
  131. if estimator_label:
  132. _write_label_html(out, estimator_label, estimator_label_details)
  133. kind = est_block.kind
  134. out.write(f'<div class="sk-{kind}">')
  135. est_infos = zip(est_block.estimators, est_block.names, est_block.name_details)
  136. for est, name, name_details in est_infos:
  137. if kind == "serial":
  138. _write_estimator_html(out, est, name, name_details)
  139. else: # parallel
  140. out.write('<div class="sk-parallel-item">')
  141. # wrap element in a serial visualblock
  142. serial_block = _VisualBlock("serial", [est], dash_wrapped=False)
  143. _write_estimator_html(out, serial_block, name, name_details)
  144. out.write("</div>") # sk-parallel-item
  145. out.write("</div></div>")
  146. elif est_block.kind == "single":
  147. _write_label_html(
  148. out,
  149. est_block.names,
  150. est_block.name_details,
  151. outer_class="sk-item",
  152. inner_class="sk-estimator",
  153. checked=first_call,
  154. )
  155. _STYLE = """
  156. #$id {
  157. color: black;
  158. }
  159. #$id pre{
  160. padding: 0;
  161. }
  162. #$id div.sk-toggleable {
  163. background-color: white;
  164. }
  165. #$id label.sk-toggleable__label {
  166. cursor: pointer;
  167. display: block;
  168. width: 100%;
  169. margin-bottom: 0;
  170. padding: 0.3em;
  171. box-sizing: border-box;
  172. text-align: center;
  173. }
  174. #$id label.sk-toggleable__label-arrow:before {
  175. content: "▸";
  176. float: left;
  177. margin-right: 0.25em;
  178. color: #696969;
  179. }
  180. #$id label.sk-toggleable__label-arrow:hover:before {
  181. color: black;
  182. }
  183. #$id div.sk-estimator:hover label.sk-toggleable__label-arrow:before {
  184. color: black;
  185. }
  186. #$id div.sk-toggleable__content {
  187. max-height: 0;
  188. max-width: 0;
  189. overflow: hidden;
  190. text-align: left;
  191. background-color: #f0f8ff;
  192. }
  193. #$id div.sk-toggleable__content pre {
  194. margin: 0.2em;
  195. color: black;
  196. border-radius: 0.25em;
  197. background-color: #f0f8ff;
  198. }
  199. #$id input.sk-toggleable__control:checked~div.sk-toggleable__content {
  200. max-height: 200px;
  201. max-width: 100%;
  202. overflow: auto;
  203. }
  204. #$id input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {
  205. content: "▾";
  206. }
  207. #$id div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {
  208. background-color: #d4ebff;
  209. }
  210. #$id div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {
  211. background-color: #d4ebff;
  212. }
  213. #$id input.sk-hidden--visually {
  214. border: 0;
  215. clip: rect(1px 1px 1px 1px);
  216. clip: rect(1px, 1px, 1px, 1px);
  217. height: 1px;
  218. margin: -1px;
  219. overflow: hidden;
  220. padding: 0;
  221. position: absolute;
  222. width: 1px;
  223. }
  224. #$id div.sk-estimator {
  225. font-family: monospace;
  226. background-color: #f0f8ff;
  227. border: 1px dotted black;
  228. border-radius: 0.25em;
  229. box-sizing: border-box;
  230. margin-bottom: 0.5em;
  231. }
  232. #$id div.sk-estimator:hover {
  233. background-color: #d4ebff;
  234. }
  235. #$id div.sk-parallel-item::after {
  236. content: "";
  237. width: 100%;
  238. border-bottom: 1px solid gray;
  239. flex-grow: 1;
  240. }
  241. #$id div.sk-label:hover label.sk-toggleable__label {
  242. background-color: #d4ebff;
  243. }
  244. #$id div.sk-serial::before {
  245. content: "";
  246. position: absolute;
  247. border-left: 1px solid gray;
  248. box-sizing: border-box;
  249. top: 0;
  250. bottom: 0;
  251. left: 50%;
  252. z-index: 0;
  253. }
  254. #$id div.sk-serial {
  255. display: flex;
  256. flex-direction: column;
  257. align-items: center;
  258. background-color: white;
  259. padding-right: 0.2em;
  260. padding-left: 0.2em;
  261. position: relative;
  262. }
  263. #$id div.sk-item {
  264. position: relative;
  265. z-index: 1;
  266. }
  267. #$id div.sk-parallel {
  268. display: flex;
  269. align-items: stretch;
  270. justify-content: center;
  271. background-color: white;
  272. position: relative;
  273. }
  274. #$id div.sk-item::before, #$id div.sk-parallel-item::before {
  275. content: "";
  276. position: absolute;
  277. border-left: 1px solid gray;
  278. box-sizing: border-box;
  279. top: 0;
  280. bottom: 0;
  281. left: 50%;
  282. z-index: -1;
  283. }
  284. #$id div.sk-parallel-item {
  285. display: flex;
  286. flex-direction: column;
  287. z-index: 1;
  288. position: relative;
  289. background-color: white;
  290. }
  291. #$id div.sk-parallel-item:first-child::after {
  292. align-self: flex-end;
  293. width: 50%;
  294. }
  295. #$id div.sk-parallel-item:last-child::after {
  296. align-self: flex-start;
  297. width: 50%;
  298. }
  299. #$id div.sk-parallel-item:only-child::after {
  300. width: 0;
  301. }
  302. #$id div.sk-dashed-wrapped {
  303. border: 1px dashed gray;
  304. margin: 0 0.4em 0.5em 0.4em;
  305. box-sizing: border-box;
  306. padding-bottom: 0.4em;
  307. background-color: white;
  308. }
  309. #$id div.sk-label label {
  310. font-family: monospace;
  311. font-weight: bold;
  312. display: inline-block;
  313. line-height: 1.2em;
  314. }
  315. #$id div.sk-label-container {
  316. text-align: center;
  317. }
  318. #$id div.sk-container {
  319. /* jupyter's `normalize.less` sets `[hidden] { display: none; }`
  320. but bootstrap.min.css set `[hidden] { display: none !important; }`
  321. so we also need the `!important` here to be able to override the
  322. default hidden behavior on the sphinx rendered scikit-learn.org.
  323. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */
  324. display: inline-block !important;
  325. position: relative;
  326. }
  327. #$id div.sk-text-repr-fallback {
  328. display: none;
  329. }
  330. """.replace(" ", "").replace("\n", "") # noqa
  331. def estimator_html_repr(estimator):
  332. """Build a HTML representation of an estimator.
  333. Read more in the :ref:`User Guide <visualizing_composite_estimators>`.
  334. Parameters
  335. ----------
  336. estimator : estimator object
  337. The estimator to visualize.
  338. Returns
  339. -------
  340. html: str
  341. HTML representation of estimator.
  342. """
  343. with closing(StringIO()) as out:
  344. container_id = _CONTAINER_ID_COUNTER.get_id()
  345. style_template = Template(_STYLE)
  346. style_with_id = style_template.substitute(id=container_id)
  347. estimator_str = str(estimator)
  348. # The fallback message is shown by default and loading the CSS sets
  349. # div.sk-text-repr-fallback to display: none to hide the fallback message.
  350. #
  351. # If the notebook is trusted, the CSS is loaded which hides the fallback
  352. # message. If the notebook is not trusted, then the CSS is not loaded and the
  353. # fallback message is shown by default.
  354. #
  355. # The reverse logic applies to HTML repr div.sk-container.
  356. # div.sk-container is hidden by default and the loading the CSS displays it.
  357. fallback_msg = (
  358. "In a Jupyter environment, please rerun this cell to show the HTML"
  359. " representation or trust the notebook. <br />On GitHub, the"
  360. " HTML representation is unable to render, please try loading this page"
  361. " with nbviewer.org."
  362. )
  363. out.write(
  364. f"<style>{style_with_id}</style>"
  365. f'<div id="{container_id}" class="sk-top-container">'
  366. '<div class="sk-text-repr-fallback">'
  367. f"<pre>{html.escape(estimator_str)}</pre><b>{fallback_msg}</b>"
  368. "</div>"
  369. '<div class="sk-container" hidden>'
  370. )
  371. _write_estimator_html(
  372. out,
  373. estimator,
  374. estimator.__class__.__name__,
  375. estimator_str,
  376. first_call=True,
  377. )
  378. out.write("</div></div>")
  379. html_output = out.getvalue()
  380. return html_output