Skip to content

Commit

Permalink
support half and bf16 in to_dim_order_copy (#7693)
Browse files Browse the repository at this point in the history
Summary:

make dim order copy support half and bf16

Reviewed By: digantdesai, Jack-Khuu

Differential Revision: D68245619
  • Loading branch information
Gasoonjia authored and facebook-github-bot committed Jan 16, 2025
1 parent fc6b83e commit 786c2ff
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
8 changes: 6 additions & 2 deletions kernels/portable/cpu/op__to_dim_order_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,17 @@ Tensor& _to_dim_order_copy_out(
InvalidArgument,
out);

ET_SWITCH_REALHB_TYPES(
if (self.numel() == 0) {
return out;
}

ET_SWITCH_REALHBBF16_TYPES(
self.scalar_type(),
ctx,
"dim_order_ops::_to_dim_order_copy.out",
CTYPE_IN,
[&] {
ET_SWITCH_REALHB_TYPES(
ET_SWITCH_REALHBBF16_TYPES(
out.scalar_type(),
ctx,
"dim_order_ops::_to_dim_order_copy.out",
Expand Down
19 changes: 15 additions & 4 deletions kernels/test/op__to_dim_order_copy_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ typedef std::map<
std::type_index,
std::variant<
std::vector<float>,
std::vector<double>>>
std::vector<double>,
std::vector<exec_aten::Half>,
std::vector<exec_aten::BFloat16>>>
FloatingTypeToDataMap;

typedef std::map<
Expand Down Expand Up @@ -381,9 +383,9 @@ TEST_F(OpToDimOrderCopyTest, NanInfSupported) {
ScalarType::OUTPUT_DTYPE>(test_cases);

#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);

ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);

#undef TEST_ENTRY
#undef TEST_KERNEL
Expand Down Expand Up @@ -413,6 +415,13 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
-0.30919688936285893988};
// clang-format on

std::vector<exec_aten::Half> half_data;
std::vector<exec_aten::BFloat16> bf16_data;
for (auto d : double_data) {
half_data.emplace_back(d);
bf16_data.emplace_back(d);
}

std::vector<int64_t> int64_data = {
-1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0};
std::vector<int32_t> int32_data = {
Expand All @@ -426,6 +435,8 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
FloatingTypeToDataMap floating_point_data;
floating_point_data[typeid(float)] = float_data;
floating_point_data[typeid(double)] = double_data;
floating_point_data[typeid(exec_aten::Half)] = half_data;
floating_point_data[typeid(exec_aten::BFloat16)] = bf16_data;

// Gathering all int data together for better traversial
IntTypeToDataMap int_data;
Expand All @@ -444,7 +455,7 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
ET_FORALL_INT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);

ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
}

TEST_F(OpToDimOrderCopyTest, MismatchedSizesDie) {
Expand Down

0 comments on commit 786c2ff

Please sign in to comment.