From 71e037cf677b9d4d36245f2f6cab39951c2d5e42 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 19 Dec 2024 01:10:53 +0000 Subject: [PATCH 1/3] Update API to break out fp8 quantization functionality. --- src/api/api.cpp | 81 +++++++++++++++++++++++++++ src/api/include/migraphx/migraphx.h | 19 +++++++ src/api/include/migraphx/migraphx.hpp | 25 +++++++++ src/api/migraphx.py | 17 ++++++ src/py/migraphx_py.cpp | 5 ++ test/api/test_cpu.cpp | 20 +++++++ test/gpu/quantization.cpp | 72 ++++++++++++++++++++++++ tools/api/api.cpp | 15 +++++ 8 files changed, 254 insertions(+) diff --git a/src/api/api.cpp b/src/api/api.cpp index 4ecd0763225..e40eb774087 100644 --- a/src/api/api.cpp +++ b/src/api/api.cpp @@ -260,6 +260,21 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o migraphx::quantize_int8(prog, t, options.calibration, options.op_names); } +struct quantize_fp8_options +{ + std::vector calibration = {}; +}; + +void add_calibration_data(quantize_fp8_options& options, parameter_map& data) +{ + options.calibration.push_back(data); +} + +void quantize_fp8_wrap(program& prog, const target& t, quantize_fp8_options& options) +{ + migraphx::quantize_fp8(prog, t, options.calibration); +} + #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wformat-nonliteral" @@ -691,6 +706,17 @@ struct migraphx_quantize_int8_options migraphx::quantize_int8_options object; }; +extern "C" struct migraphx_quantize_fp8_options; +struct migraphx_quantize_fp8_options +{ + template + migraphx_quantize_fp8_options(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::quantize_fp8_options object; +}; + extern "C" struct migraphx_context; struct migraphx_context { @@ -2267,6 +2293,61 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog, return api_error_result; } +extern "C" migraphx_status +migraphx_quantize_fp8_options_destroy(migraphx_quantize_fp8_options_t quantize_fp8_options) +{ + auto api_error_result = migraphx::try_([&] { destroy((quantize_fp8_options)); }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_fp8_options_assign_to(migraphx_quantize_fp8_options_t output, + const_migraphx_quantize_fp8_options_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_fp8_options_create(migraphx_quantize_fp8_options_t* quantize_fp8_options) +{ + auto api_error_result = migraphx::try_([&] { + *quantize_fp8_options = object_cast( + allocate()); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_quantize_fp8_options_add_calibration_data( + migraphx_quantize_fp8_options_t quantize_fp8_options, migraphx_program_parameters_t data) +{ + auto api_error_result = migraphx::try_([&] { + if(quantize_fp8_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter quantize_fp8_options: Null pointer"); + if(data == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer"); + migraphx::add_calibration_data((quantize_fp8_options->object), (data->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_quantize_fp8(migraphx_program_t prog, + migraphx_target_t target, + migraphx_quantize_fp8_options_t options) +{ + auto api_error_result = migraphx::try_([&] { + if(prog == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); + if(target == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer"); + if(options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); + migraphx::quantize_fp8_wrap((prog->object), (target->object), (options->object)); + }); + return api_error_result; +} + extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t context) { auto api_error_result = migraphx::try_([&] { diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index 1f1c05bb215..bde311d64a3 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -141,6 +141,9 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t; typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t; +typedef struct migraphx_quantize_fp8_options* migraphx_quantize_fp8_options_t; +typedef const struct migraphx_quantize_fp8_options* const_migraphx_quantize_fp8_options_t; + typedef struct migraphx_context* migraphx_context_t; typedef const struct migraphx_context* const_migraphx_context_t; @@ -623,6 +626,22 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8(migraphx_program_t prog migraphx_target_t target, migraphx_quantize_int8_options_t options); +MIGRAPHX_C_EXPORT migraphx_status +migraphx_quantize_fp8_options_destroy(migraphx_quantize_fp8_options_t quantize_fp8_options); + +MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp8_options_assign_to( + migraphx_quantize_fp8_options_t output, const_migraphx_quantize_fp8_options_t input); + +MIGRAPHX_C_EXPORT migraphx_status +migraphx_quantize_fp8_options_create(migraphx_quantize_fp8_options_t* quantize_fp8_options); + +MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp8_options_add_calibration_data( + migraphx_quantize_fp8_options_t quantize_fp8_options, migraphx_program_parameters_t data); + +MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp8(migraphx_program_t prog, + migraphx_target_t target, + migraphx_quantize_fp8_options_t options); + MIGRAPHX_C_EXPORT migraphx_status migraphx_context_finish(const_migraphx_context_t context); MIGRAPHX_C_EXPORT migraphx_status migraphx_context_get_queue(void** out, diff --git a/src/api/include/migraphx/migraphx.hpp b/src/api/include/migraphx/migraphx.hpp index fa6339b4389..d2001ac2141 100644 --- a/src/api/include/migraphx/migraphx.hpp +++ b/src/api/include/migraphx/migraphx.hpp @@ -1516,6 +1516,31 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op options.get_handle_ptr()); } +/// Options to be passed when quantizing for int8 +struct quantize_fp8_options : MIGRAPHX_HANDLE_BASE(quantize_fp8_options) +{ + quantize_fp8_options() { this->make_handle(&migraphx_quantize_fp8_options_create); } + + MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_fp8_options) + /// Add calibrartion data to be used for quantizing + void add_calibration_data(const program_parameters& pp) + { + call(&migraphx_quantize_fp8_options_add_calibration_data, + this->get_handle_ptr(), + pp.get_handle_ptr()); + } +}; + +/// Quantize program to use fp8 +inline void +quantize_fp8(const program& prog, const target& ptarget, const quantize_fp8_options& options) +{ + call(&migraphx_quantize_fp8, + prog.get_handle_ptr(), + ptarget.get_handle_ptr(), + options.get_handle_ptr()); +} + struct experimental_custom_op_base { experimental_custom_op_base() = default; diff --git a/src/api/migraphx.py b/src/api/migraphx.py index 1d75bc81f2d..c36cb8cde89 100755 --- a/src/api/migraphx.py +++ b/src/api/migraphx.py @@ -467,6 +467,23 @@ def quantize_int8_options(h): fname='migraphx::quantize_int8_wrap') +@auto_handle() +def quantize_fp8_options(h): + h.constructor('create') + h.method( + 'add_calibration_data', + api.params(data='std::unordered_map'), + invoke='migraphx::add_calibration_data($@)', + ) + + +api.add_function('migraphx_quantize_fp8', + api.params(prog='migraphx::program&', + target='migraphx::target', + options='migraphx::quantize_fp8_options'), + fname='migraphx::quantize_fp8_wrap') + + @auto_handle(ref=True) def context(h): h.method('finish', const=True) diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 75f7fab09d9..a49ef8f80e9 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -657,6 +657,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("t"), py::arg("calibration") = std::vector{}, py::arg("ins_names") = std::unordered_set{"dot", "convolution"}); + m.def("quantize_fp8", + &migraphx::quantize_fp8, + py::arg("prog"), + py::arg("t"), + py::arg("calibration") = std::vector{}); m.def( "autocast_fp8", [](migraphx::program& prog) { diff --git a/test/api/test_cpu.cpp b/test/api/test_cpu.cpp index 1c11a68989e..e57e8bb7fce 100644 --- a/test/api/test_cpu.cpp +++ b/test/api/test_cpu.cpp @@ -102,6 +102,26 @@ TEST_CASE(quantize_int8) CHECK(bool{p1 == p2}); } +TEST_CASE(quantize_fp8) +{ + auto p1 = migraphx::parse_onnx("gemm_test.onnx"); + const auto& p2 = p1; + auto t = migraphx::target("ref"); + migraphx::quantize_fp8_options options; + migraphx::quantize_fp8(p1, t, options); + + migraphx::program_parameters pp; + auto param_shapes = p1.get_parameter_shapes(); + for(auto&& name : param_shapes.names()) + { + pp.add(name, migraphx::argument::generate(param_shapes[name])); + } + options.add_calibration_data(pp); + + migraphx::quantize_fp8(p2, t, options); + CHECK(bool{p1 == p2}); +} + TEST_CASE(load_and_run_user_input_shape) { migraphx::onnx_options options; diff --git a/test/gpu/quantization.cpp b/test/gpu/quantization.cpp index 208ca76ddad..6044444fbaf 100644 --- a/test/gpu/quantization.cpp +++ b/test/gpu/quantization.cpp @@ -126,4 +126,76 @@ TEST_CASE(int8_quantization) } } +TEST_CASE(fp8_quantization) +{ + auto run_prog = [](migraphx::program p, + const migraphx::target& t, + migraphx::parameter_map& m_in, + std::vector& res) { + std::vector cali_data; + cali_data.push_back(m_in); + migraphx::quantize_fp8(p, t, cali_data); + p.compile(t); + migraphx::parameter_map m; + for(auto&& x : p.get_parameter_shapes()) + { + if(m_in.count(x.first) > 0) + { + m[x.first] = t.copy_to(m_in[x.first]); + } + else + { + m[x.first] = t.allocate(x.second); + } + } + + auto result = t.copy_from(p.eval(m).back()); + result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); + }; + + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::float_type, {5, 16}}; + migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; + migraphx::shape sc{migraphx::shape::float_type, {5, 8}}; + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + mm->add_instruction(migraphx::make_op("dot"), pa, pb); + + return p; + }; + + { + auto p = create_program(); + migraphx::parameter_map m; + migraphx::shape sa{migraphx::shape::float_type, {5, 16}}; + migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; + migraphx::shape sc{migraphx::shape::float_type, {5, 8}}; + m["a"] = migraphx::generate_argument(sa); + m["b"] = migraphx::generate_argument(sb); + std::vector ref_result; + migraphx::target ref_t = migraphx::make_target("ref"); + run_prog(p, ref_t, m, ref_result); + + std::vector gpu_result; + migraphx::target gpu_t = migraphx::make_target("gpu"); + run_prog(p, gpu_t, m, gpu_result); + + // Note: the tolerance for mlir_enabled result is temporarily bumped + // higher because the lowering pipeline between mlir fallback and + // regular non-mlir pipeline diverged. MLIR fallback uses the + // rewrite_quantization at the very end of the pipeline, whereas + // the regular pipeline uses the rewrite_quantization in the much + // earlier stage. + if(migraphx::gpu::mlir_enabled()) + EXPECT(migraphx::verify::verify_range_with_tolerance( + gpu_result, + migraphx::verify::expected{ref_result}, + migraphx::verify::tolerance{0.01})); + else + EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result)); + } +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/tools/api/api.cpp b/tools/api/api.cpp index 9049fc5116e..7e998c2f78c 100644 --- a/tools/api/api.cpp +++ b/tools/api/api.cpp @@ -260,6 +260,21 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o migraphx::quantize_int8(prog, t, options.calibration, options.op_names); } +struct quantize_fp8_options +{ + std::vector calibration = {}; +}; + +void add_calibration_data(quantize_fp8_options& options, parameter_map& data) +{ + options.calibration.push_back(data); +} + +void quantize_fp8_wrap(program& prog, const target& t, quantize_fp8_options& options) +{ + migraphx::quantize_fp8(prog, t, options.calibration); +} + #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wformat-nonliteral" From a9d6d30f509b410170b5d344043b308f4aa3d6ef Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 19 Dec 2024 22:10:00 +0000 Subject: [PATCH 2/3] Fix format --- tools/api/api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/api/api.cpp b/tools/api/api.cpp index 7e998c2f78c..e6246271039 100644 --- a/tools/api/api.cpp +++ b/tools/api/api.cpp @@ -262,7 +262,7 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o struct quantize_fp8_options { - std::vector calibration = {}; + std::vector calibration = {}; }; void add_calibration_data(quantize_fp8_options& options, parameter_map& data) From a2a285fdb14db5927251c3de79314ddb634751a0 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 6 Jan 2025 18:27:48 -0600 Subject: [PATCH 3/3] Reduce tolerace to 0.02 for fp8 --- test/gpu/quantization.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/quantization.cpp b/test/gpu/quantization.cpp index 6044444fbaf..c8287988657 100644 --- a/test/gpu/quantization.cpp +++ b/test/gpu/quantization.cpp @@ -192,7 +192,7 @@ TEST_CASE(fp8_quantization) EXPECT(migraphx::verify::verify_range_with_tolerance( gpu_result, migraphx::verify::expected{ref_result}, - migraphx::verify::tolerance{0.01})); + migraphx::verify::tolerance{0.02})); else EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result)); }