vec.h 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. #pragma once
  2. #if defined(CPU_CAPABILITY_AVX512)
  3. #include <ATen/cpu/vec/vec512/vec512.h>
  4. #else
  5. #include <ATen/cpu/vec/vec256/vec256.h>
  6. #endif
  7. namespace at::vec {
  8. // See Note [CPU_CAPABILITY namespace]
  9. inline namespace CPU_CAPABILITY {
  10. inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) {
  11. __at_align__ bool buffer[x.size()];
  12. x.ne(Vectorized<int8_t>(0)).store(buffer);
  13. Vectorized<bool> ret;
  14. static_assert(x.size() == ret.size(), "");
  15. std::memcpy(ret, buffer, ret.size() * sizeof(bool));
  16. return ret;
  17. }
  18. template <>
  19. inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) {
  20. // See NOTE [Loading boolean values]
  21. return convert_to_bool(Vectorized<int8_t>::loadu(ptr));
  22. }
  23. template <>
  24. inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr, int64_t count) {
  25. // See NOTE [Loading boolean values]
  26. return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
  27. }
  28. template <typename VT>
  29. struct VecHoldType { using hold_type = typename VT::value_type; };
  30. template <>
  31. struct VecHoldType<Vectorized<BFloat16>> { using hold_type = BFloat16; };
  32. template <>
  33. struct VecHoldType<Vectorized<Half>> {using hold_type = Half; };
  34. template <typename VT>
  35. using vechold_type = typename VecHoldType<VT>::hold_type;
  36. }} // namespace at::vec::CPU_CAPABILITY