token_classification.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. import types
  2. import warnings
  3. from typing import List, Optional, Tuple, Union
  4. import numpy as np
  5. from ..models.bert.tokenization_bert import BasicTokenizer
  6. from ..utils import (
  7. ExplicitEnum,
  8. add_end_docstrings,
  9. is_tf_available,
  10. is_torch_available,
  11. )
  12. from .base import ArgumentHandler, ChunkPipeline, Dataset, build_pipeline_init_args
  13. if is_tf_available():
  14. import tensorflow as tf
  15. from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  16. if is_torch_available():
  17. import torch
  18. from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  19. class TokenClassificationArgumentHandler(ArgumentHandler):
  20. """
  21. Handles arguments for token classification.
  22. """
  23. def __call__(self, inputs: Union[str, List[str]], **kwargs):
  24. if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
  25. inputs = list(inputs)
  26. batch_size = len(inputs)
  27. elif isinstance(inputs, str):
  28. inputs = [inputs]
  29. batch_size = 1
  30. elif Dataset is not None and isinstance(inputs, Dataset) or isinstance(inputs, types.GeneratorType):
  31. return inputs, None
  32. else:
  33. raise ValueError("At least one input is required.")
  34. offset_mapping = kwargs.get("offset_mapping")
  35. if offset_mapping:
  36. if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
  37. offset_mapping = [offset_mapping]
  38. if len(offset_mapping) != batch_size:
  39. raise ValueError("offset_mapping should have the same batch size as the input")
  40. return inputs, offset_mapping
  41. class AggregationStrategy(ExplicitEnum):
  42. """All the valid aggregation strategies for TokenClassificationPipeline"""
  43. NONE = "none"
  44. SIMPLE = "simple"
  45. FIRST = "first"
  46. AVERAGE = "average"
  47. MAX = "max"
  48. @add_end_docstrings(
  49. build_pipeline_init_args(has_tokenizer=True),
  50. r"""
  51. ignore_labels (`List[str]`, defaults to `["O"]`):
  52. A list of labels to ignore.
  53. grouped_entities (`bool`, *optional*, defaults to `False`):
  54. DEPRECATED, use `aggregation_strategy` instead. Whether or not to group the tokens corresponding to the
  55. same entity together in the predictions or not.
  56. stride (`int`, *optional*):
  57. If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
  58. model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
  59. value of this argument defines the number of overlapping tokens between chunks. In other words, the model
  60. will shift forward by `tokenizer.model_max_length - stride` tokens each step.
  61. aggregation_strategy (`str`, *optional*, defaults to `"none"`):
  62. The strategy to fuse (or not) tokens based on the model prediction.
  63. - "none" : Will simply not do any aggregation and simply return raw results from the model
  64. - "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
  65. I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
  66. "entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
  67. different entities. On word based languages, we might end up splitting words undesirably : Imagine
  68. Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
  69. "NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
  70. that support that meaning, which is basically tokens separated by a space). These mitigations will
  71. only work on real words, "New york" might still be tagged with two different entities.
  72. - "first" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
  73. end up with different tags. Words will simply use the tag of the first token of the word when there
  74. is ambiguity.
  75. - "average" : (works only on word based models) Will use the `SIMPLE` strategy except that words,
  76. cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
  77. label is applied.
  78. - "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
  79. end up with different tags. Word entity will simply be the token with the maximum score.""",
  80. )
  81. class TokenClassificationPipeline(ChunkPipeline):
  82. """
  83. Named Entity Recognition pipeline using any `ModelForTokenClassification`. See the [named entity recognition
  84. examples](../task_summary#named-entity-recognition) for more information.
  85. Example:
  86. ```python
  87. >>> from transformers import pipeline
  88. >>> token_classifier = pipeline(model="Jean-Baptiste/camembert-ner", aggregation_strategy="simple")
  89. >>> sentence = "Je m'appelle jean-baptiste et je vis à montréal"
  90. >>> tokens = token_classifier(sentence)
  91. >>> tokens
  92. [{'entity_group': 'PER', 'score': 0.9931, 'word': 'jean-baptiste', 'start': 12, 'end': 26}, {'entity_group': 'LOC', 'score': 0.998, 'word': 'montréal', 'start': 38, 'end': 47}]
  93. >>> token = tokens[0]
  94. >>> # Start and end provide an easy way to highlight words in the original text.
  95. >>> sentence[token["start"] : token["end"]]
  96. ' jean-baptiste'
  97. >>> # Some models use the same idea to do part of speech.
  98. >>> syntaxer = pipeline(model="vblagoje/bert-english-uncased-finetuned-pos", aggregation_strategy="simple")
  99. >>> syntaxer("My name is Sarah and I live in London")
  100. [{'entity_group': 'PRON', 'score': 0.999, 'word': 'my', 'start': 0, 'end': 2}, {'entity_group': 'NOUN', 'score': 0.997, 'word': 'name', 'start': 3, 'end': 7}, {'entity_group': 'AUX', 'score': 0.994, 'word': 'is', 'start': 8, 'end': 10}, {'entity_group': 'PROPN', 'score': 0.999, 'word': 'sarah', 'start': 11, 'end': 16}, {'entity_group': 'CCONJ', 'score': 0.999, 'word': 'and', 'start': 17, 'end': 20}, {'entity_group': 'PRON', 'score': 0.999, 'word': 'i', 'start': 21, 'end': 22}, {'entity_group': 'VERB', 'score': 0.998, 'word': 'live', 'start': 23, 'end': 27}, {'entity_group': 'ADP', 'score': 0.999, 'word': 'in', 'start': 28, 'end': 30}, {'entity_group': 'PROPN', 'score': 0.999, 'word': 'london', 'start': 31, 'end': 37}]
  101. ```
  102. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  103. This token recognition pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  104. `"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location or miscellaneous).
  105. The models that this pipeline can use are models that have been fine-tuned on a token classification task. See the
  106. up-to-date list of available models on
  107. [huggingface.co/models](https://huggingface.co/models?filter=token-classification).
  108. """
  109. default_input_names = "sequences"
  110. def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
  111. super().__init__(*args, **kwargs)
  112. self.check_model_type(
  113. TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  114. if self.framework == "tf"
  115. else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  116. )
  117. self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
  118. self._args_parser = args_parser
  119. def _sanitize_parameters(
  120. self,
  121. ignore_labels=None,
  122. grouped_entities: Optional[bool] = None,
  123. ignore_subwords: Optional[bool] = None,
  124. aggregation_strategy: Optional[AggregationStrategy] = None,
  125. offset_mapping: Optional[List[Tuple[int, int]]] = None,
  126. stride: Optional[int] = None,
  127. ):
  128. preprocess_params = {}
  129. if offset_mapping is not None:
  130. preprocess_params["offset_mapping"] = offset_mapping
  131. postprocess_params = {}
  132. if grouped_entities is not None or ignore_subwords is not None:
  133. if grouped_entities and ignore_subwords:
  134. aggregation_strategy = AggregationStrategy.FIRST
  135. elif grouped_entities and not ignore_subwords:
  136. aggregation_strategy = AggregationStrategy.SIMPLE
  137. else:
  138. aggregation_strategy = AggregationStrategy.NONE
  139. if grouped_entities is not None:
  140. warnings.warn(
  141. "`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to"
  142. f' `aggregation_strategy="{aggregation_strategy}"` instead.'
  143. )
  144. if ignore_subwords is not None:
  145. warnings.warn(
  146. "`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to"
  147. f' `aggregation_strategy="{aggregation_strategy}"` instead.'
  148. )
  149. if aggregation_strategy is not None:
  150. if isinstance(aggregation_strategy, str):
  151. aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
  152. if (
  153. aggregation_strategy
  154. in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE}
  155. and not self.tokenizer.is_fast
  156. ):
  157. raise ValueError(
  158. "Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
  159. ' to `"simple"` or use a fast tokenizer.'
  160. )
  161. postprocess_params["aggregation_strategy"] = aggregation_strategy
  162. if ignore_labels is not None:
  163. postprocess_params["ignore_labels"] = ignore_labels
  164. if stride is not None:
  165. if stride >= self.tokenizer.model_max_length:
  166. raise ValueError(
  167. "`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
  168. )
  169. if aggregation_strategy == AggregationStrategy.NONE:
  170. raise ValueError(
  171. "`stride` was provided to process all the text but `aggregation_strategy="
  172. f'"{aggregation_strategy}"`, please select another one instead.'
  173. )
  174. else:
  175. if self.tokenizer.is_fast:
  176. tokenizer_params = {
  177. "return_overflowing_tokens": True,
  178. "padding": True,
  179. "stride": stride,
  180. }
  181. preprocess_params["tokenizer_params"] = tokenizer_params
  182. else:
  183. raise ValueError(
  184. "`stride` was provided to process all the text but you're using a slow tokenizer."
  185. " Please use a fast tokenizer."
  186. )
  187. return preprocess_params, {}, postprocess_params
  188. def __call__(self, inputs: Union[str, List[str]], **kwargs):
  189. """
  190. Classify each token of the text(s) given as inputs.
  191. Args:
  192. inputs (`str` or `List[str]`):
  193. One or several texts (or one list of texts) for token classification.
  194. Return:
  195. A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the
  196. corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy) with
  197. the following keys:
  198. - **word** (`str`) -- The token/word classified. This is obtained by decoding the selected tokens. If you
  199. want to have the exact string in the original sentence, use `start` and `end`.
  200. - **score** (`float`) -- The corresponding probability for `entity`.
  201. - **entity** (`str`) -- The entity predicted for that token/word (it is named *entity_group* when
  202. *aggregation_strategy* is not `"none"`.
  203. - **index** (`int`, only present when `aggregation_strategy="none"`) -- The index of the corresponding
  204. token in the sentence.
  205. - **start** (`int`, *optional*) -- The index of the start of the corresponding entity in the sentence. Only
  206. exists if the offsets are available within the tokenizer
  207. - **end** (`int`, *optional*) -- The index of the end of the corresponding entity in the sentence. Only
  208. exists if the offsets are available within the tokenizer
  209. """
  210. _inputs, offset_mapping = self._args_parser(inputs, **kwargs)
  211. if offset_mapping:
  212. kwargs["offset_mapping"] = offset_mapping
  213. return super().__call__(inputs, **kwargs)
  214. def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
  215. tokenizer_params = preprocess_params.pop("tokenizer_params", {})
  216. truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
  217. inputs = self.tokenizer(
  218. sentence,
  219. return_tensors=self.framework,
  220. truncation=truncation,
  221. return_special_tokens_mask=True,
  222. return_offsets_mapping=self.tokenizer.is_fast,
  223. **tokenizer_params,
  224. )
  225. inputs.pop("overflow_to_sample_mapping", None)
  226. num_chunks = len(inputs["input_ids"])
  227. for i in range(num_chunks):
  228. if self.framework == "tf":
  229. model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()}
  230. else:
  231. model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
  232. if offset_mapping is not None:
  233. model_inputs["offset_mapping"] = offset_mapping
  234. model_inputs["sentence"] = sentence if i == 0 else None
  235. model_inputs["is_last"] = i == num_chunks - 1
  236. yield model_inputs
  237. def _forward(self, model_inputs):
  238. # Forward
  239. special_tokens_mask = model_inputs.pop("special_tokens_mask")
  240. offset_mapping = model_inputs.pop("offset_mapping", None)
  241. sentence = model_inputs.pop("sentence")
  242. is_last = model_inputs.pop("is_last")
  243. if self.framework == "tf":
  244. logits = self.model(**model_inputs)[0]
  245. else:
  246. output = self.model(**model_inputs)
  247. logits = output["logits"] if isinstance(output, dict) else output[0]
  248. return {
  249. "logits": logits,
  250. "special_tokens_mask": special_tokens_mask,
  251. "offset_mapping": offset_mapping,
  252. "sentence": sentence,
  253. "is_last": is_last,
  254. **model_inputs,
  255. }
  256. def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
  257. if ignore_labels is None:
  258. ignore_labels = ["O"]
  259. all_entities = []
  260. for model_outputs in all_outputs:
  261. if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
  262. logits = model_outputs["logits"][0].to(torch.float32).numpy()
  263. else:
  264. logits = model_outputs["logits"][0].numpy()
  265. sentence = all_outputs[0]["sentence"]
  266. input_ids = model_outputs["input_ids"][0]
  267. offset_mapping = (
  268. model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
  269. )
  270. special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
  271. maxes = np.max(logits, axis=-1, keepdims=True)
  272. shifted_exp = np.exp(logits - maxes)
  273. scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
  274. if self.framework == "tf":
  275. input_ids = input_ids.numpy()
  276. offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None
  277. pre_entities = self.gather_pre_entities(
  278. sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
  279. )
  280. grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
  281. # Filter anything that is in self.ignore_labels
  282. entities = [
  283. entity
  284. for entity in grouped_entities
  285. if entity.get("entity", None) not in ignore_labels
  286. and entity.get("entity_group", None) not in ignore_labels
  287. ]
  288. all_entities.extend(entities)
  289. num_chunks = len(all_outputs)
  290. if num_chunks > 1:
  291. all_entities = self.aggregate_overlapping_entities(all_entities)
  292. return all_entities
  293. def aggregate_overlapping_entities(self, entities):
  294. if len(entities) == 0:
  295. return entities
  296. entities = sorted(entities, key=lambda x: x["start"])
  297. aggregated_entities = []
  298. previous_entity = entities[0]
  299. for entity in entities:
  300. if previous_entity["start"] <= entity["start"] < previous_entity["end"]:
  301. current_length = entity["end"] - entity["start"]
  302. previous_length = previous_entity["end"] - previous_entity["start"]
  303. if current_length > previous_length:
  304. previous_entity = entity
  305. elif current_length == previous_length and entity["score"] > previous_entity["score"]:
  306. previous_entity = entity
  307. else:
  308. aggregated_entities.append(previous_entity)
  309. previous_entity = entity
  310. aggregated_entities.append(previous_entity)
  311. return aggregated_entities
  312. def gather_pre_entities(
  313. self,
  314. sentence: str,
  315. input_ids: np.ndarray,
  316. scores: np.ndarray,
  317. offset_mapping: Optional[List[Tuple[int, int]]],
  318. special_tokens_mask: np.ndarray,
  319. aggregation_strategy: AggregationStrategy,
  320. ) -> List[dict]:
  321. """Fuse various numpy arrays into dicts with all the information needed for aggregation"""
  322. pre_entities = []
  323. for idx, token_scores in enumerate(scores):
  324. # Filter special_tokens
  325. if special_tokens_mask[idx]:
  326. continue
  327. word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
  328. if offset_mapping is not None:
  329. start_ind, end_ind = offset_mapping[idx]
  330. if not isinstance(start_ind, int):
  331. if self.framework == "pt":
  332. start_ind = start_ind.item()
  333. end_ind = end_ind.item()
  334. word_ref = sentence[start_ind:end_ind]
  335. if getattr(self.tokenizer, "_tokenizer", None) and getattr(
  336. self.tokenizer._tokenizer.model, "continuing_subword_prefix", None
  337. ):
  338. # This is a BPE, word aware tokenizer, there is a correct way
  339. # to fuse tokens
  340. is_subword = len(word) != len(word_ref)
  341. else:
  342. # This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately.
  343. if aggregation_strategy in {
  344. AggregationStrategy.FIRST,
  345. AggregationStrategy.AVERAGE,
  346. AggregationStrategy.MAX,
  347. }:
  348. warnings.warn(
  349. "Tokenizer does not support real words, using fallback heuristic",
  350. UserWarning,
  351. )
  352. is_subword = start_ind > 0 and " " not in sentence[start_ind - 1 : start_ind + 1]
  353. if int(input_ids[idx]) == self.tokenizer.unk_token_id:
  354. word = word_ref
  355. is_subword = False
  356. else:
  357. start_ind = None
  358. end_ind = None
  359. is_subword = False
  360. pre_entity = {
  361. "word": word,
  362. "scores": token_scores,
  363. "start": start_ind,
  364. "end": end_ind,
  365. "index": idx,
  366. "is_subword": is_subword,
  367. }
  368. pre_entities.append(pre_entity)
  369. return pre_entities
  370. def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
  371. if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}:
  372. entities = []
  373. for pre_entity in pre_entities:
  374. entity_idx = pre_entity["scores"].argmax()
  375. score = pre_entity["scores"][entity_idx]
  376. entity = {
  377. "entity": self.model.config.id2label[entity_idx],
  378. "score": score,
  379. "index": pre_entity["index"],
  380. "word": pre_entity["word"],
  381. "start": pre_entity["start"],
  382. "end": pre_entity["end"],
  383. }
  384. entities.append(entity)
  385. else:
  386. entities = self.aggregate_words(pre_entities, aggregation_strategy)
  387. if aggregation_strategy == AggregationStrategy.NONE:
  388. return entities
  389. return self.group_entities(entities)
  390. def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict:
  391. word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
  392. if aggregation_strategy == AggregationStrategy.FIRST:
  393. scores = entities[0]["scores"]
  394. idx = scores.argmax()
  395. score = scores[idx]
  396. entity = self.model.config.id2label[idx]
  397. elif aggregation_strategy == AggregationStrategy.MAX:
  398. max_entity = max(entities, key=lambda entity: entity["scores"].max())
  399. scores = max_entity["scores"]
  400. idx = scores.argmax()
  401. score = scores[idx]
  402. entity = self.model.config.id2label[idx]
  403. elif aggregation_strategy == AggregationStrategy.AVERAGE:
  404. scores = np.stack([entity["scores"] for entity in entities])
  405. average_scores = np.nanmean(scores, axis=0)
  406. entity_idx = average_scores.argmax()
  407. entity = self.model.config.id2label[entity_idx]
  408. score = average_scores[entity_idx]
  409. else:
  410. raise ValueError("Invalid aggregation_strategy")
  411. new_entity = {
  412. "entity": entity,
  413. "score": score,
  414. "word": word,
  415. "start": entities[0]["start"],
  416. "end": entities[-1]["end"],
  417. }
  418. return new_entity
  419. def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
  420. """
  421. Override tokens from a given word that disagree to force agreement on word boundaries.
  422. Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
  423. company| B-ENT I-ENT
  424. """
  425. if aggregation_strategy in {
  426. AggregationStrategy.NONE,
  427. AggregationStrategy.SIMPLE,
  428. }:
  429. raise ValueError("NONE and SIMPLE strategies are invalid for word aggregation")
  430. word_entities = []
  431. word_group = None
  432. for entity in entities:
  433. if word_group is None:
  434. word_group = [entity]
  435. elif entity["is_subword"]:
  436. word_group.append(entity)
  437. else:
  438. word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
  439. word_group = [entity]
  440. # Last item
  441. if word_group is not None:
  442. word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
  443. return word_entities
  444. def group_sub_entities(self, entities: List[dict]) -> dict:
  445. """
  446. Group together the adjacent tokens with the same entity predicted.
  447. Args:
  448. entities (`dict`): The entities predicted by the pipeline.
  449. """
  450. # Get the first entity in the entity group
  451. entity = entities[0]["entity"].split("-", 1)[-1]
  452. scores = np.nanmean([entity["score"] for entity in entities])
  453. tokens = [entity["word"] for entity in entities]
  454. entity_group = {
  455. "entity_group": entity,
  456. "score": np.mean(scores),
  457. "word": self.tokenizer.convert_tokens_to_string(tokens),
  458. "start": entities[0]["start"],
  459. "end": entities[-1]["end"],
  460. }
  461. return entity_group
  462. def get_tag(self, entity_name: str) -> Tuple[str, str]:
  463. if entity_name.startswith("B-"):
  464. bi = "B"
  465. tag = entity_name[2:]
  466. elif entity_name.startswith("I-"):
  467. bi = "I"
  468. tag = entity_name[2:]
  469. else:
  470. # It's not in B-, I- format
  471. # Default to I- for continuation.
  472. bi = "I"
  473. tag = entity_name
  474. return bi, tag
  475. def group_entities(self, entities: List[dict]) -> List[dict]:
  476. """
  477. Find and group together the adjacent tokens with the same entity predicted.
  478. Args:
  479. entities (`dict`): The entities predicted by the pipeline.
  480. """
  481. entity_groups = []
  482. entity_group_disagg = []
  483. for entity in entities:
  484. if not entity_group_disagg:
  485. entity_group_disagg.append(entity)
  486. continue
  487. # If the current entity is similar and adjacent to the previous entity,
  488. # append it to the disaggregated entity group
  489. # The split is meant to account for the "B" and "I" prefixes
  490. # Shouldn't merge if both entities are B-type
  491. bi, tag = self.get_tag(entity["entity"])
  492. last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"])
  493. if tag == last_tag and bi != "B":
  494. # Modify subword type to be previous_type
  495. entity_group_disagg.append(entity)
  496. else:
  497. # If the current entity is different from the previous entity
  498. # aggregate the disaggregated entity group
  499. entity_groups.append(self.group_sub_entities(entity_group_disagg))
  500. entity_group_disagg = [entity]
  501. if entity_group_disagg:
  502. # it's the last entity, add it to the entity groups
  503. entity_groups.append(self.group_sub_entities(entity_group_disagg))
  504. return entity_groups
  505. NerPipeline = TokenClassificationPipeline