configuration_superpoint.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List
  15. from ...configuration_utils import PretrainedConfig
  16. from ...utils import logging
  17. logger = logging.get_logger(__name__)
  18. class SuperPointConfig(PretrainedConfig):
  19. r"""
  20. This is the configuration class to store the configuration of a [`SuperPointForKeypointDetection`]. It is used to instantiate a
  21. SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a
  22. configuration with the defaults will yield a similar configuration to that of the SuperPoint
  23. [magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture.
  24. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  25. documentation from [`PretrainedConfig`] for more information.
  26. Args:
  27. encoder_hidden_sizes (`List`, *optional*, defaults to `[64, 64, 128, 128]`):
  28. The number of channels in each convolutional layer in the encoder.
  29. decoder_hidden_size (`int`, *optional*, defaults to 256): The hidden size of the decoder.
  30. keypoint_decoder_dim (`int`, *optional*, defaults to 65): The output dimension of the keypoint decoder.
  31. descriptor_decoder_dim (`int`, *optional*, defaults to 256): The output dimension of the descriptor decoder.
  32. keypoint_threshold (`float`, *optional*, defaults to 0.005):
  33. The threshold to use for extracting keypoints.
  34. max_keypoints (`int`, *optional*, defaults to -1):
  35. The maximum number of keypoints to extract. If `-1`, will extract all keypoints.
  36. nms_radius (`int`, *optional*, defaults to 4):
  37. The radius for non-maximum suppression.
  38. border_removal_distance (`int`, *optional*, defaults to 4):
  39. The distance from the border to remove keypoints.
  40. initializer_range (`float`, *optional*, defaults to 0.02):
  41. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  42. Example:
  43. ```python
  44. >>> from transformers import SuperPointConfig, SuperPointForKeypointDetection
  45. >>> # Initializing a SuperPoint superpoint style configuration
  46. >>> configuration = SuperPointConfig()
  47. >>> # Initializing a model from the superpoint style configuration
  48. >>> model = SuperPointForKeypointDetection(configuration)
  49. >>> # Accessing the model configuration
  50. >>> configuration = model.config
  51. ```"""
  52. model_type = "superpoint"
  53. def __init__(
  54. self,
  55. encoder_hidden_sizes: List[int] = [64, 64, 128, 128],
  56. decoder_hidden_size: int = 256,
  57. keypoint_decoder_dim: int = 65,
  58. descriptor_decoder_dim: int = 256,
  59. keypoint_threshold: float = 0.005,
  60. max_keypoints: int = -1,
  61. nms_radius: int = 4,
  62. border_removal_distance: int = 4,
  63. initializer_range=0.02,
  64. **kwargs,
  65. ):
  66. self.encoder_hidden_sizes = encoder_hidden_sizes
  67. self.decoder_hidden_size = decoder_hidden_size
  68. self.keypoint_decoder_dim = keypoint_decoder_dim
  69. self.descriptor_decoder_dim = descriptor_decoder_dim
  70. self.keypoint_threshold = keypoint_threshold
  71. self.max_keypoints = max_keypoints
  72. self.nms_radius = nms_radius
  73. self.border_removal_distance = border_removal_distance
  74. self.initializer_range = initializer_range
  75. super().__init__(**kwargs)