LossMulti.h 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/AccumulateType.h>
  4. #include <ATen/Dispatch.h>
  5. #include <ATen/TensorUtils.h>
  6. namespace at::native {
  7. namespace {
  8. static C10_UNUSED void multilabel_margin_loss_shape_check(
  9. int64_t& nframe,
  10. int64_t& dim,
  11. const int64_t& ndims,
  12. const Tensor& input,
  13. const Tensor& target) {
  14. TORCH_CHECK(
  15. (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
  16. "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
  17. input.sizes());
  18. if (ndims <= 1) {
  19. nframe = 1;
  20. dim = ndims == 0 ? 1 : input.size(0);
  21. TORCH_CHECK(
  22. target.dim() <= 1 && target.numel() == dim,
  23. "inconsistent target size: ", target.sizes(), " for input of size: ",
  24. input.sizes());
  25. } else {
  26. nframe = input.size(0);
  27. dim = input.size(1);
  28. TORCH_CHECK(
  29. target.dim() == 2 && target.size(0) == nframe &&
  30. target.size(1) == dim,
  31. "inconsistent target size: ", target.sizes(), " for input of size: ",
  32. input.sizes());
  33. }
  34. }
  35. static C10_UNUSED void multi_margin_loss_shape_check(
  36. int64_t& nframe,
  37. int64_t& dim,
  38. const int64_t& ndims,
  39. const Tensor& input,
  40. const Tensor& target,
  41. const std::optional<Tensor>& weight) {
  42. TORCH_CHECK(
  43. (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
  44. "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
  45. input.sizes());
  46. if (ndims <= 1) {
  47. nframe = 1;
  48. dim = ndims == 0 ? 1 : input.size(0);
  49. } else {
  50. nframe = input.size(0);
  51. dim = input.size(1);
  52. }
  53. TORCH_CHECK(
  54. target.dim() <= 1 && target.numel() == nframe,
  55. "inconsistent target size, expected ", nframe, " but got ",
  56. target.sizes());
  57. if (weight && weight->defined()) {
  58. TORCH_CHECK(
  59. weight->dim() <= 1 && weight->numel() == dim,
  60. "inconsistent weight size, expected ", dim, " but got ",
  61. weight->sizes());
  62. }
  63. }
  64. } // anonymous namespace
  65. } // namespace at::native