ParamUtils.h 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #pragma once
  2. #include <c10/util/ArrayRef.h>
  3. #include <vector>
  4. namespace at {
  5. namespace native {
  6. template <typename T>
  7. inline std::vector<T> _expand_param_if_needed(
  8. ArrayRef<T> list_param,
  9. const char* param_name,
  10. int64_t expected_dim) {
  11. if (list_param.size() == 1) {
  12. return std::vector<T>(expected_dim, list_param[0]);
  13. } else if ((int64_t)list_param.size() != expected_dim) {
  14. std::ostringstream ss;
  15. ss << "expected " << param_name << " to be a single integer value or a "
  16. << "list of " << expected_dim << " values to match the convolution "
  17. << "dimensions, but got " << param_name << "=" << list_param;
  18. AT_ERROR(ss.str());
  19. } else {
  20. return list_param.vec();
  21. }
  22. }
  23. inline std::vector<int64_t> expand_param_if_needed(
  24. IntArrayRef list_param,
  25. const char* param_name,
  26. int64_t expected_dim) {
  27. return _expand_param_if_needed(list_param, param_name, expected_dim);
  28. }
  29. inline std::vector<c10::SymInt> expand_param_if_needed(
  30. SymIntArrayRef list_param,
  31. const char* param_name,
  32. int64_t expected_dim) {
  33. return _expand_param_if_needed(list_param, param_name, expected_dim);
  34. }
  35. } // namespace native
  36. } // namespace at