code_template.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #pragma once
  2. #include <c10/util/irange.h>
  3. #include <sstream>
  4. #include <string>
  5. #include <unordered_map>
  6. #include <vector>
  7. namespace at::jit {
  8. // A template environment is a mapping from template variable names, e.g.,
  9. // identifier (corresponding to $identifier) to their expansions.
  10. //
  11. // This template environment supports storing strings, numbers and lists
  12. // of strings, and can be chained together (so that lookup proceeds in
  13. // in the top level environment, and then recurses into a parent
  14. // environment if the key is not found.)
  15. struct TemplateEnv {
  16. TemplateEnv() = default;
  17. TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
  18. using string_list = std::vector<std::string>;
  19. // Add a string 'v' to the map at key 'k'.
  20. void s(const std::string& k, const std::string& v) {
  21. strings_[k] = v;
  22. lists_.erase(k);
  23. }
  24. // Add a number 'v' to the map at key 'k'
  25. template <typename T>
  26. void d(const std::string& k, const T& v) {
  27. strings_[k] = c10::to_string(v);
  28. lists_.erase(k);
  29. }
  30. // Retrieve the string representation of the value stored at 'k' from the map.
  31. // Raises an exception if the key is not found.
  32. const std::string& s(const std::string& k) const {
  33. if (strings_.count(k) == 0) {
  34. if (parent) {
  35. return parent->s(k);
  36. }
  37. notFound(k);
  38. }
  39. return strings_.at(k);
  40. }
  41. // Store a list of strings 'v' in the map at 'k'.
  42. void v(const std::string& k, const string_list& v) {
  43. lists_[k] = v;
  44. strings_.erase(k);
  45. }
  46. // Retrieve a list of strings stored at 'k' from the map.
  47. // Raises an exception if the key is not found.
  48. const string_list& v(const std::string& k) const {
  49. if (lists_.count(k) == 0) {
  50. if (parent) {
  51. return parent->v(k);
  52. }
  53. notFound(k);
  54. }
  55. return lists_.at(k);
  56. }
  57. // Test if a string 'k' is a string (as opposed to a list.)
  58. bool keyIsString(const std::string& k) const {
  59. if (strings_.count(k) > 0)
  60. return true;
  61. if (lists_.count(k) > 0)
  62. return false;
  63. if (parent)
  64. return parent->keyIsString(k);
  65. notFound(k);
  66. }
  67. private:
  68. [[noreturn]] void notFound(const std::string& k) const {
  69. std::stringstream ss;
  70. ss << "key not found: " << k;
  71. throw std::logic_error(ss.str());
  72. }
  73. std::unordered_map<std::string, std::string> strings_;
  74. std::unordered_map<std::string, string_list> lists_;
  75. TemplateEnv* parent{nullptr};
  76. };
  77. /*
  78. # Match $identifier or ${identifier} and replace with the value in env.
  79. # If this identifier is at the beginning of whitespace on a line
  80. # and its value is a list then it is treated as
  81. # block substitution by indenting all lines of all elements.
  82. # If the identifier is on a line starting with non-whitespace and a list
  83. # then it is comma separated. ${,foo} will insert a comma before the list
  84. # if this list is not empty and ${foo,} will insert one after.
  85. */
  86. struct CodeTemplate {
  87. /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
  88. std::string format(const TemplateEnv& env) const {
  89. std::stringstream out;
  90. size_t pos = 0;
  91. size_t indent = 0;
  92. bool all_whitespace = true;
  93. while (pos < template_text.size()) {
  94. char c = template_text[pos];
  95. if (c == '$') {
  96. std::stringstream kss;
  97. // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  98. bool comma_before;
  99. // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  100. bool comma_after;
  101. size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
  102. std::string k = kss.str();
  103. bool is_string = env.keyIsString(k);
  104. if (all_whitespace) {
  105. if (is_string)
  106. emitStringWithIndents(out, indent, env.s(k));
  107. else
  108. emitLinesIndented(out, indent, env.v(k));
  109. } else {
  110. if (is_string)
  111. out << env.s(k);
  112. else
  113. emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
  114. }
  115. all_whitespace = false;
  116. pos = new_pos;
  117. } else {
  118. out << c;
  119. if (!isspace(c))
  120. all_whitespace = false;
  121. indent++;
  122. if (c == '\n') {
  123. indent = 0;
  124. all_whitespace = true;
  125. }
  126. pos++;
  127. }
  128. }
  129. return out.str();
  130. }
  131. private:
  132. using string_list = std::vector<std::string>;
  133. char charAt(size_t p) const {
  134. if (p >= template_text.size())
  135. throw std::logic_error("EOS found in key");
  136. return template_text[p];
  137. }
  138. size_t parseKey(
  139. size_t pos,
  140. std::ostream& k,
  141. bool& comma_before,
  142. bool& comma_after) const {
  143. comma_before = false;
  144. comma_after = false;
  145. pos++;
  146. if (charAt(pos) == '{') {
  147. pos++;
  148. if (charAt(pos) == ',') {
  149. comma_before = true;
  150. pos++;
  151. }
  152. pos = parseIdent(pos, k);
  153. if (charAt(pos) == ',') {
  154. comma_after = true;
  155. pos++;
  156. }
  157. if (charAt(pos) != '}')
  158. throw std::logic_error("missing terminating '}'");
  159. pos++;
  160. return pos;
  161. } else {
  162. return parseIdent(pos, k);
  163. }
  164. }
  165. size_t parseIdent(size_t pos, std::ostream& k) const {
  166. while (pos < template_text.size() &&
  167. (isalnum(template_text[pos]) || template_text[pos] == '_')) {
  168. k << template_text[pos];
  169. pos++;
  170. }
  171. return pos;
  172. }
  173. void emitCommaSeparatedList(
  174. std::ostream& out,
  175. const string_list& strings,
  176. bool comma_before,
  177. bool comma_after) const {
  178. if (comma_before && !strings.empty())
  179. out << ", ";
  180. for (const auto i : c10::irange(strings.size())) {
  181. if (i > 0)
  182. out << ", ";
  183. out << strings[i];
  184. }
  185. if (comma_after && !strings.empty())
  186. out << ", ";
  187. }
  188. // These indentation functions follow the convention that they never emit
  189. // leading or trailing newlines when the input string does not have leading
  190. // or trailing newlines. It's the responsibility of the calling function
  191. // to indent correctly in the context.
  192. void emitIndent(std::ostream& out, size_t indent) const {
  193. for (C10_UNUSED const auto i : c10::irange(indent)) {
  194. out << " ";
  195. }
  196. }
  197. void emitStringWithIndents(
  198. std::ostream& out,
  199. size_t indent,
  200. const std::string& str) const {
  201. for (auto c : str) {
  202. out << c;
  203. if (c == '\n') {
  204. emitIndent(out, indent);
  205. }
  206. }
  207. }
  208. void emitLinesIndented(
  209. std::stringstream& out,
  210. size_t indent,
  211. const string_list& strings) const {
  212. for (const auto i : c10::irange(strings.size())) {
  213. if (i > 0)
  214. emitIndent(out, indent);
  215. emitStringWithIndents(out, indent, strings[i]);
  216. if (i + 1 != strings.size())
  217. out << "\n";
  218. }
  219. }
  220. std::string template_text;
  221. };
  222. static inline std::string format(const std::string& fmt, TemplateEnv& env) {
  223. return CodeTemplate(fmt).format(env);
  224. }
  225. } // namespace at::jit