hp_naming.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import re
  16. class TrialShortNamer:
  17. PREFIX = "hp"
  18. DEFAULTS = {}
  19. NAMING_INFO = None
  20. @classmethod
  21. def set_defaults(cls, prefix, defaults):
  22. cls.PREFIX = prefix
  23. cls.DEFAULTS = defaults
  24. cls.build_naming_info()
  25. @staticmethod
  26. def shortname_for_word(info, word):
  27. if len(word) == 0:
  28. return ""
  29. short_word = None
  30. if any(char.isdigit() for char in word):
  31. raise Exception(f"Parameters should not contain numbers: '{word}' contains a number")
  32. if word in info["short_word"]:
  33. return info["short_word"][word]
  34. for prefix_len in range(1, len(word) + 1):
  35. prefix = word[:prefix_len]
  36. if prefix in info["reverse_short_word"]:
  37. continue
  38. else:
  39. short_word = prefix
  40. break
  41. if short_word is None:
  42. # Paranoid fallback
  43. def int_to_alphabetic(integer):
  44. s = ""
  45. while integer != 0:
  46. s = chr(ord("A") + integer % 10) + s
  47. integer //= 10
  48. return s
  49. i = 0
  50. while True:
  51. sword = word + "#" + int_to_alphabetic(i)
  52. if sword in info["reverse_short_word"]:
  53. continue
  54. else:
  55. short_word = sword
  56. break
  57. info["short_word"][word] = short_word
  58. info["reverse_short_word"][short_word] = word
  59. return short_word
  60. @staticmethod
  61. def shortname_for_key(info, param_name):
  62. words = param_name.split("_")
  63. shortname_parts = [TrialShortNamer.shortname_for_word(info, word) for word in words]
  64. # We try to create a separatorless short name, but if there is a collision we have to fallback
  65. # to a separated short name
  66. separators = ["", "_"]
  67. for separator in separators:
  68. shortname = separator.join(shortname_parts)
  69. if shortname not in info["reverse_short_param"]:
  70. info["short_param"][param_name] = shortname
  71. info["reverse_short_param"][shortname] = param_name
  72. return shortname
  73. return param_name
  74. @staticmethod
  75. def add_new_param_name(info, param_name):
  76. short_name = TrialShortNamer.shortname_for_key(info, param_name)
  77. info["short_param"][param_name] = short_name
  78. info["reverse_short_param"][short_name] = param_name
  79. @classmethod
  80. def build_naming_info(cls):
  81. if cls.NAMING_INFO is not None:
  82. return
  83. info = {
  84. "short_word": {},
  85. "reverse_short_word": {},
  86. "short_param": {},
  87. "reverse_short_param": {},
  88. }
  89. field_keys = list(cls.DEFAULTS.keys())
  90. for k in field_keys:
  91. cls.add_new_param_name(info, k)
  92. cls.NAMING_INFO = info
  93. @classmethod
  94. def shortname(cls, params):
  95. cls.build_naming_info()
  96. assert cls.PREFIX is not None
  97. name = [copy.copy(cls.PREFIX)]
  98. for k, v in params.items():
  99. if k not in cls.DEFAULTS:
  100. raise Exception(f"You should provide a default value for the param name {k} with value {v}")
  101. if v == cls.DEFAULTS[k]:
  102. # The default value is not added to the name
  103. continue
  104. key = cls.NAMING_INFO["short_param"][k]
  105. if isinstance(v, bool):
  106. v = 1 if v else 0
  107. sep = "" if isinstance(v, (int, float)) else "-"
  108. e = f"{key}{sep}{v}"
  109. name.append(e)
  110. return "_".join(name)
  111. @classmethod
  112. def parse_repr(cls, repr):
  113. repr = repr[len(cls.PREFIX) + 1 :]
  114. if repr == "":
  115. values = []
  116. else:
  117. values = repr.split("_")
  118. parameters = {}
  119. for value in values:
  120. if "-" in value:
  121. p_k, p_v = value.split("-")
  122. else:
  123. p_k = re.sub("[0-9.]", "", value)
  124. p_v = float(re.sub("[^0-9.]", "", value))
  125. key = cls.NAMING_INFO["reverse_short_param"][p_k]
  126. parameters[key] = p_v
  127. for k in cls.DEFAULTS:
  128. if k not in parameters:
  129. parameters[k] = cls.DEFAULTS[k]
  130. return parameters