Skip to content

Commit

Permalink
add missed files; fix clang format
Browse files Browse the repository at this point in the history
  • Loading branch information
coderfeli committed Dec 27, 2024
1 parent a721edb commit a6fca1c
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 2 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances<GemmDefault>{});
}

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances<GemmKPadding>{});
}

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_instances<Intrawave,
GemmDefault>{});
}

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_instances<Intrawave,
GemmKPadding>{});
}

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_instances<Interwave,
GemmDefault>{});
}

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_instances<Interwave,
GemmKPadding>{});
}

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
4 changes: 2 additions & 2 deletions profiler/src/profile_gemm_multiply_multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
INT8_INT8_BF16, // 8
F8_F8_F16, // 9
F8_F8_F16, // 9
};

#define OP_NAME "gemm_multiply_multiply"
Expand Down Expand Up @@ -90,7 +90,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])

using F32 = float;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using I32 = int;
Expand Down

0 comments on commit a6fca1c

Please sign in to comment.