tensor.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. /*
  2. pybind11/eigen/tensor.h: Transparent conversion for Eigen tensors
  3. All rights reserved. Use of this source code is governed by a
  4. BSD-style license that can be found in the LICENSE file.
  5. */
  6. #pragma once
  7. #include "../numpy.h"
  8. #include "common.h"
  9. #if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
  10. static_assert(__GNUC__ > 5, "Eigen Tensor support in pybind11 requires GCC > 5.0");
  11. #endif
  12. // Disable warnings for Eigen
  13. PYBIND11_WARNING_PUSH
  14. PYBIND11_WARNING_DISABLE_MSVC(4554)
  15. PYBIND11_WARNING_DISABLE_MSVC(4127)
  16. #if defined(__MINGW32__)
  17. PYBIND11_WARNING_DISABLE_GCC("-Wmaybe-uninitialized")
  18. #endif
  19. #include <unsupported/Eigen/CXX11/Tensor>
  20. PYBIND11_WARNING_POP
  21. static_assert(EIGEN_VERSION_AT_LEAST(3, 3, 0),
  22. "Eigen Tensor support in pybind11 requires Eigen >= 3.3.0");
  23. PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
  24. PYBIND11_WARNING_DISABLE_MSVC(4127)
  25. PYBIND11_NAMESPACE_BEGIN(detail)
  26. inline bool is_tensor_aligned(const void *data) {
  27. return (reinterpret_cast<std::size_t>(data) % EIGEN_DEFAULT_ALIGN_BYTES) == 0;
  28. }
  29. template <typename T>
  30. constexpr int compute_array_flag_from_tensor() {
  31. static_assert((static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor))
  32. || (static_cast<int>(T::Layout) == static_cast<int>(Eigen::ColMajor)),
  33. "Layout must be row or column major");
  34. return (static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor)) ? array::c_style
  35. : array::f_style;
  36. }
  37. template <typename T>
  38. struct eigen_tensor_helper {};
  39. template <typename Scalar_, int NumIndices_, int Options_, typename IndexType>
  40. struct eigen_tensor_helper<Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>> {
  41. using Type = Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>;
  42. using ValidType = void;
  43. static Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape(const Type &f) {
  44. return f.dimensions();
  45. }
  46. static constexpr bool
  47. is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> & /*shape*/) {
  48. return true;
  49. }
  50. template <typename T>
  51. struct helper {};
  52. template <size_t... Is>
  53. struct helper<index_sequence<Is...>> {
  54. static constexpr auto value = ::pybind11::detail::concat(const_name(((void) Is, "?"))...);
  55. };
  56. static constexpr auto dimensions_descriptor
  57. = helper<decltype(make_index_sequence<Type::NumIndices>())>::value;
  58. template <typename... Args>
  59. static Type *alloc(Args &&...args) {
  60. return new Type(std::forward<Args>(args)...);
  61. }
  62. static void free(Type *tensor) { delete tensor; }
  63. };
  64. template <typename Scalar_, typename std::ptrdiff_t... Indices, int Options_, typename IndexType>
  65. struct eigen_tensor_helper<
  66. Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>> {
  67. using Type = Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>;
  68. using ValidType = void;
  69. static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices>
  70. get_shape(const Type & /*f*/) {
  71. return get_shape();
  72. }
  73. static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape() {
  74. return Eigen::DSizes<typename Type::Index, Type::NumIndices>(Indices...);
  75. }
  76. static bool
  77. is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> &shape) {
  78. return get_shape() == shape;
  79. }
  80. static constexpr auto dimensions_descriptor
  81. = ::pybind11::detail::concat(const_name<Indices>()...);
  82. template <typename... Args>
  83. static Type *alloc(Args &&...args) {
  84. Eigen::aligned_allocator<Type> allocator;
  85. return ::new (allocator.allocate(1)) Type(std::forward<Args>(args)...);
  86. }
  87. static void free(Type *tensor) {
  88. Eigen::aligned_allocator<Type> allocator;
  89. tensor->~Type();
  90. allocator.deallocate(tensor, 1);
  91. }
  92. };
  93. template <typename Type, bool ShowDetails, bool NeedsWriteable = false>
  94. struct get_tensor_descriptor {
  95. static constexpr auto details
  96. = const_name<NeedsWriteable>(", flags.writeable", "")
  97. + const_name<static_cast<int>(Type::Layout) == static_cast<int>(Eigen::RowMajor)>(
  98. ", flags.c_contiguous", ", flags.f_contiguous");
  99. static constexpr auto value
  100. = const_name("numpy.ndarray[") + npy_format_descriptor<typename Type::Scalar>::name
  101. + const_name("[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
  102. + const_name("]") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
  103. };
  104. // When EIGEN_AVOID_STL_ARRAY is defined, Eigen::DSizes<T, 0> does not have the begin() member
  105. // function. Falling back to a simple loop works around this issue.
  106. //
  107. // We need to disable the type-limits warning for the inner loop when size = 0.
  108. PYBIND11_WARNING_PUSH
  109. PYBIND11_WARNING_DISABLE_GCC("-Wtype-limits")
  110. template <typename T, int size>
  111. std::vector<T> convert_dsizes_to_vector(const Eigen::DSizes<T, size> &arr) {
  112. std::vector<T> result(size);
  113. for (size_t i = 0; i < size; i++) {
  114. result[i] = arr[i];
  115. }
  116. return result;
  117. }
  118. template <typename T, int size>
  119. Eigen::DSizes<T, size> get_shape_for_array(const array &arr) {
  120. Eigen::DSizes<T, size> result;
  121. const T *shape = arr.shape();
  122. for (size_t i = 0; i < size; i++) {
  123. result[i] = shape[i];
  124. }
  125. return result;
  126. }
  127. PYBIND11_WARNING_POP
  128. template <typename Type>
  129. struct type_caster<Type, typename eigen_tensor_helper<Type>::ValidType> {
  130. static_assert(!std::is_pointer<typename Type::Scalar>::value,
  131. PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
  132. using Helper = eigen_tensor_helper<Type>;
  133. static constexpr auto temp_name = get_tensor_descriptor<Type, false>::value;
  134. PYBIND11_TYPE_CASTER(Type, temp_name);
  135. bool load(handle src, bool convert) {
  136. if (!convert) {
  137. if (!isinstance<array>(src)) {
  138. return false;
  139. }
  140. array temp = array::ensure(src);
  141. if (!temp) {
  142. return false;
  143. }
  144. if (!temp.dtype().is(dtype::of<typename Type::Scalar>())) {
  145. return false;
  146. }
  147. }
  148. array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()> arr(
  149. reinterpret_borrow<object>(src));
  150. if (arr.ndim() != Type::NumIndices) {
  151. return false;
  152. }
  153. auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
  154. if (!Helper::is_correct_shape(shape)) {
  155. return false;
  156. }
  157. #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
  158. auto data_pointer = arr.data();
  159. #else
  160. // Handle Eigen bug
  161. auto data_pointer = const_cast<typename Type::Scalar *>(arr.data());
  162. #endif
  163. if (is_tensor_aligned(arr.data())) {
  164. value = Eigen::TensorMap<const Type, Eigen::Aligned>(data_pointer, shape);
  165. } else {
  166. value = Eigen::TensorMap<const Type>(data_pointer, shape);
  167. }
  168. return true;
  169. }
  170. static handle cast(Type &&src, return_value_policy policy, handle parent) {
  171. if (policy == return_value_policy::reference
  172. || policy == return_value_policy::reference_internal) {
  173. pybind11_fail("Cannot use a reference return value policy for an rvalue");
  174. }
  175. return cast_impl(&src, return_value_policy::move, parent);
  176. }
  177. static handle cast(const Type &&src, return_value_policy policy, handle parent) {
  178. if (policy == return_value_policy::reference
  179. || policy == return_value_policy::reference_internal) {
  180. pybind11_fail("Cannot use a reference return value policy for an rvalue");
  181. }
  182. return cast_impl(&src, return_value_policy::move, parent);
  183. }
  184. static handle cast(Type &src, return_value_policy policy, handle parent) {
  185. if (policy == return_value_policy::automatic
  186. || policy == return_value_policy::automatic_reference) {
  187. policy = return_value_policy::copy;
  188. }
  189. return cast_impl(&src, policy, parent);
  190. }
  191. static handle cast(const Type &src, return_value_policy policy, handle parent) {
  192. if (policy == return_value_policy::automatic
  193. || policy == return_value_policy::automatic_reference) {
  194. policy = return_value_policy::copy;
  195. }
  196. return cast(&src, policy, parent);
  197. }
  198. static handle cast(Type *src, return_value_policy policy, handle parent) {
  199. if (policy == return_value_policy::automatic) {
  200. policy = return_value_policy::take_ownership;
  201. } else if (policy == return_value_policy::automatic_reference) {
  202. policy = return_value_policy::reference;
  203. }
  204. return cast_impl(src, policy, parent);
  205. }
  206. static handle cast(const Type *src, return_value_policy policy, handle parent) {
  207. if (policy == return_value_policy::automatic) {
  208. policy = return_value_policy::take_ownership;
  209. } else if (policy == return_value_policy::automatic_reference) {
  210. policy = return_value_policy::reference;
  211. }
  212. return cast_impl(src, policy, parent);
  213. }
  214. template <typename C>
  215. static handle cast_impl(C *src, return_value_policy policy, handle parent) {
  216. object parent_object;
  217. bool writeable = false;
  218. switch (policy) {
  219. case return_value_policy::move:
  220. if (std::is_const<C>::value) {
  221. pybind11_fail("Cannot move from a constant reference");
  222. }
  223. src = Helper::alloc(std::move(*src));
  224. parent_object
  225. = capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
  226. writeable = true;
  227. break;
  228. case return_value_policy::take_ownership:
  229. if (std::is_const<C>::value) {
  230. // This cast is ugly, and might be UB in some cases, but we don't have an
  231. // alternative here as we must free that memory
  232. Helper::free(const_cast<Type *>(src));
  233. pybind11_fail("Cannot take ownership of a const reference");
  234. }
  235. parent_object
  236. = capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
  237. writeable = true;
  238. break;
  239. case return_value_policy::copy:
  240. writeable = true;
  241. break;
  242. case return_value_policy::reference:
  243. parent_object = none();
  244. writeable = !std::is_const<C>::value;
  245. break;
  246. case return_value_policy::reference_internal:
  247. // Default should do the right thing
  248. if (!parent) {
  249. pybind11_fail("Cannot use reference internal when there is no parent");
  250. }
  251. parent_object = reinterpret_borrow<object>(parent);
  252. writeable = !std::is_const<C>::value;
  253. break;
  254. default:
  255. pybind11_fail("pybind11 bug in eigen.h, please file a bug report");
  256. }
  257. auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
  258. convert_dsizes_to_vector(Helper::get_shape(*src)), src->data(), parent_object);
  259. if (!writeable) {
  260. array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
  261. }
  262. return result.release();
  263. }
  264. };
  265. template <typename StoragePointerType,
  266. bool needs_writeable,
  267. enable_if_t<!needs_writeable, bool> = true>
  268. StoragePointerType get_array_data_for_type(array &arr) {
  269. #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
  270. return reinterpret_cast<StoragePointerType>(arr.data());
  271. #else
  272. // Handle Eigen bug
  273. return reinterpret_cast<StoragePointerType>(const_cast<void *>(arr.data()));
  274. #endif
  275. }
  276. template <typename StoragePointerType,
  277. bool needs_writeable,
  278. enable_if_t<needs_writeable, bool> = true>
  279. StoragePointerType get_array_data_for_type(array &arr) {
  280. return reinterpret_cast<StoragePointerType>(arr.mutable_data());
  281. }
  282. template <typename T, typename = void>
  283. struct get_storage_pointer_type;
  284. template <typename MapType>
  285. struct get_storage_pointer_type<MapType, void_t<typename MapType::StoragePointerType>> {
  286. using SPT = typename MapType::StoragePointerType;
  287. };
  288. template <typename MapType>
  289. struct get_storage_pointer_type<MapType, void_t<typename MapType::PointerArgType>> {
  290. using SPT = typename MapType::PointerArgType;
  291. };
  292. template <typename Type, int Options>
  293. struct type_caster<Eigen::TensorMap<Type, Options>,
  294. typename eigen_tensor_helper<remove_cv_t<Type>>::ValidType> {
  295. static_assert(!std::is_pointer<typename Type::Scalar>::value,
  296. PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
  297. using MapType = Eigen::TensorMap<Type, Options>;
  298. using Helper = eigen_tensor_helper<remove_cv_t<Type>>;
  299. bool load(handle src, bool /*convert*/) {
  300. // Note that we have a lot more checks here as we want to make sure to avoid copies
  301. if (!isinstance<array>(src)) {
  302. return false;
  303. }
  304. auto arr = reinterpret_borrow<array>(src);
  305. if ((arr.flags() & compute_array_flag_from_tensor<Type>()) == 0) {
  306. return false;
  307. }
  308. if (!arr.dtype().is(dtype::of<typename Type::Scalar>())) {
  309. return false;
  310. }
  311. if (arr.ndim() != Type::NumIndices) {
  312. return false;
  313. }
  314. constexpr bool is_aligned = (Options & Eigen::Aligned) != 0;
  315. if (is_aligned && !is_tensor_aligned(arr.data())) {
  316. return false;
  317. }
  318. auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
  319. if (!Helper::is_correct_shape(shape)) {
  320. return false;
  321. }
  322. if (needs_writeable && !arr.writeable()) {
  323. return false;
  324. }
  325. auto result = get_array_data_for_type<typename get_storage_pointer_type<MapType>::SPT,
  326. needs_writeable>(arr);
  327. value.reset(new MapType(std::move(result), std::move(shape)));
  328. return true;
  329. }
  330. static handle cast(MapType &&src, return_value_policy policy, handle parent) {
  331. return cast_impl(&src, policy, parent);
  332. }
  333. static handle cast(const MapType &&src, return_value_policy policy, handle parent) {
  334. return cast_impl(&src, policy, parent);
  335. }
  336. static handle cast(MapType &src, return_value_policy policy, handle parent) {
  337. if (policy == return_value_policy::automatic
  338. || policy == return_value_policy::automatic_reference) {
  339. policy = return_value_policy::copy;
  340. }
  341. return cast_impl(&src, policy, parent);
  342. }
  343. static handle cast(const MapType &src, return_value_policy policy, handle parent) {
  344. if (policy == return_value_policy::automatic
  345. || policy == return_value_policy::automatic_reference) {
  346. policy = return_value_policy::copy;
  347. }
  348. return cast(&src, policy, parent);
  349. }
  350. static handle cast(MapType *src, return_value_policy policy, handle parent) {
  351. if (policy == return_value_policy::automatic) {
  352. policy = return_value_policy::take_ownership;
  353. } else if (policy == return_value_policy::automatic_reference) {
  354. policy = return_value_policy::reference;
  355. }
  356. return cast_impl(src, policy, parent);
  357. }
  358. static handle cast(const MapType *src, return_value_policy policy, handle parent) {
  359. if (policy == return_value_policy::automatic) {
  360. policy = return_value_policy::take_ownership;
  361. } else if (policy == return_value_policy::automatic_reference) {
  362. policy = return_value_policy::reference;
  363. }
  364. return cast_impl(src, policy, parent);
  365. }
  366. template <typename C>
  367. static handle cast_impl(C *src, return_value_policy policy, handle parent) {
  368. object parent_object;
  369. constexpr bool writeable = !std::is_const<C>::value;
  370. switch (policy) {
  371. case return_value_policy::reference:
  372. parent_object = none();
  373. break;
  374. case return_value_policy::reference_internal:
  375. // Default should do the right thing
  376. if (!parent) {
  377. pybind11_fail("Cannot use reference internal when there is no parent");
  378. }
  379. parent_object = reinterpret_borrow<object>(parent);
  380. break;
  381. case return_value_policy::take_ownership:
  382. delete src;
  383. // fallthrough
  384. default:
  385. // move, take_ownership don't make any sense for a ref/map:
  386. pybind11_fail("Invalid return_value_policy for Eigen Map type, must be either "
  387. "reference or reference_internal");
  388. }
  389. auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
  390. convert_dsizes_to_vector(Helper::get_shape(*src)),
  391. src->data(),
  392. std::move(parent_object));
  393. if (!writeable) {
  394. array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
  395. }
  396. return result.release();
  397. }
  398. #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
  399. static constexpr bool needs_writeable = !std::is_const<typename std::remove_pointer<
  400. typename get_storage_pointer_type<MapType>::SPT>::type>::value;
  401. #else
  402. // Handle Eigen bug
  403. static constexpr bool needs_writeable = !std::is_const<Type>::value;
  404. #endif
  405. protected:
  406. // TODO: Move to std::optional once std::optional has more support
  407. std::unique_ptr<MapType> value;
  408. public:
  409. static constexpr auto name = get_tensor_descriptor<Type, true, needs_writeable>::value;
  410. explicit operator MapType *() { return value.get(); }
  411. explicit operator MapType &() { return *value; }
  412. explicit operator MapType &&() && { return std::move(*value); }
  413. template <typename T_>
  414. using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>;
  415. };
  416. PYBIND11_NAMESPACE_END(detail)
  417. PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)