Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update API to break out fp8 quantization functionality. #3724

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/api/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,21 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}

struct quantize_fp8_options
{
std::vector<parameter_map> calibration = {};
};

void add_calibration_data(quantize_fp8_options& options, parameter_map& data)
{
options.calibration.push_back(data);
}

void quantize_fp8_wrap(program& prog, const target& t, quantize_fp8_options& options)
{
migraphx::quantize_fp8(prog, t, options.calibration);
}

#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
Expand Down Expand Up @@ -691,6 +706,17 @@ struct migraphx_quantize_int8_options
migraphx::quantize_int8_options object;
};

extern "C" struct migraphx_quantize_fp8_options;
struct migraphx_quantize_fp8_options
{
template <class... Ts>
migraphx_quantize_fp8_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::quantize_fp8_options object;
};

extern "C" struct migraphx_context;
struct migraphx_context
{
Expand Down Expand Up @@ -2267,6 +2293,61 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
return api_error_result;
}

extern "C" migraphx_status
migraphx_quantize_fp8_options_destroy(migraphx_quantize_fp8_options_t quantize_fp8_options)
{
auto api_error_result = migraphx::try_([&] { destroy((quantize_fp8_options)); });
return api_error_result;
}

extern "C" migraphx_status
migraphx_quantize_fp8_options_assign_to(migraphx_quantize_fp8_options_t output,
const_migraphx_quantize_fp8_options_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}

extern "C" migraphx_status
migraphx_quantize_fp8_options_create(migraphx_quantize_fp8_options_t* quantize_fp8_options)
{
auto api_error_result = migraphx::try_([&] {
*quantize_fp8_options = object_cast<migraphx_quantize_fp8_options_t>(
allocate<migraphx::quantize_fp8_options>());
});
return api_error_result;
}

extern "C" migraphx_status migraphx_quantize_fp8_options_add_calibration_data(
migraphx_quantize_fp8_options_t quantize_fp8_options, migraphx_program_parameters_t data)
{
auto api_error_result = migraphx::try_([&] {
if(quantize_fp8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_fp8_options: Null pointer");
if(data == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer");
migraphx::add_calibration_data((quantize_fp8_options->object), (data->object));
});
return api_error_result;
}

extern "C" migraphx_status migraphx_quantize_fp8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_fp8_options_t options)
{
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(target == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer");
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::quantize_fp8_wrap((prog->object), (target->object), (options->object));
});
return api_error_result;
}

extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
Expand Down
19 changes: 19 additions & 0 deletions src/api/include/migraphx/migraphx.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name
typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t;
typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t;

typedef struct migraphx_quantize_fp8_options* migraphx_quantize_fp8_options_t;
typedef const struct migraphx_quantize_fp8_options* const_migraphx_quantize_fp8_options_t;

typedef struct migraphx_context* migraphx_context_t;
typedef const struct migraphx_context* const_migraphx_context_t;

Expand Down Expand Up @@ -623,6 +626,22 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_int8(migraphx_program_t prog
migraphx_target_t target,
migraphx_quantize_int8_options_t options);

MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_fp8_options_destroy(migraphx_quantize_fp8_options_t quantize_fp8_options);

MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp8_options_assign_to(
migraphx_quantize_fp8_options_t output, const_migraphx_quantize_fp8_options_t input);

MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_fp8_options_create(migraphx_quantize_fp8_options_t* quantize_fp8_options);

MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp8_options_add_calibration_data(
migraphx_quantize_fp8_options_t quantize_fp8_options, migraphx_program_parameters_t data);

MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_fp8_options_t options);

MIGRAPHX_C_EXPORT migraphx_status migraphx_context_finish(const_migraphx_context_t context);

MIGRAPHX_C_EXPORT migraphx_status migraphx_context_get_queue(void** out,
Expand Down
25 changes: 25 additions & 0 deletions src/api/include/migraphx/migraphx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,31 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
options.get_handle_ptr());
}

/// Options to be passed when quantizing for int8
struct quantize_fp8_options : MIGRAPHX_HANDLE_BASE(quantize_fp8_options)
{
quantize_fp8_options() { this->make_handle(&migraphx_quantize_fp8_options_create); }

MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_fp8_options)
/// Add calibrartion data to be used for quantizing
void add_calibration_data(const program_parameters& pp)
{
call(&migraphx_quantize_fp8_options_add_calibration_data,
this->get_handle_ptr(),
pp.get_handle_ptr());
}
};

/// Quantize program to use fp8
inline void
quantize_fp8(const program& prog, const target& ptarget, const quantize_fp8_options& options)
{
call(&migraphx_quantize_fp8,
prog.get_handle_ptr(),
ptarget.get_handle_ptr(),
options.get_handle_ptr());
}

struct experimental_custom_op_base
{
experimental_custom_op_base() = default;
Expand Down
17 changes: 17 additions & 0 deletions src/api/migraphx.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,23 @@ def quantize_int8_options(h):
fname='migraphx::quantize_int8_wrap')


@auto_handle()
def quantize_fp8_options(h):
h.constructor('create')
h.method(
'add_calibration_data',
api.params(data='std::unordered_map<std::string, migraphx::argument>'),
invoke='migraphx::add_calibration_data($@)',
)


api.add_function('migraphx_quantize_fp8',
api.params(prog='migraphx::program&',
target='migraphx::target',
options='migraphx::quantize_fp8_options'),
fname='migraphx::quantize_fp8_wrap')


@auto_handle(ref=True)
def context(h):
h.method('finish', const=True)
Expand Down
5 changes: 5 additions & 0 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::unordered_set<std::string>{"dot", "convolution"});
m.def("quantize_fp8",
&migraphx::quantize_fp8,
py::arg("prog"),
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::parameter_map>{});
m.def(
"autocast_fp8",
[](migraphx::program& prog) {
Expand Down
20 changes: 20 additions & 0 deletions test/api/test_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ TEST_CASE(quantize_int8)
CHECK(bool{p1 == p2});
}

TEST_CASE(quantize_fp8)
{
auto p1 = migraphx::parse_onnx("gemm_test.onnx");
const auto& p2 = p1;
auto t = migraphx::target("ref");
migraphx::quantize_fp8_options options;
migraphx::quantize_fp8(p1, t, options);

migraphx::program_parameters pp;
auto param_shapes = p1.get_parameter_shapes();
for(auto&& name : param_shapes.names())
{
pp.add(name, migraphx::argument::generate(param_shapes[name]));
}
options.add_calibration_data(pp);

migraphx::quantize_fp8(p2, t, options);
CHECK(bool{p1 == p2});
}

TEST_CASE(load_and_run_user_input_shape)
{
migraphx::onnx_options options;
Expand Down
72 changes: 72 additions & 0 deletions test/gpu/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,76 @@ TEST_CASE(int8_quantization)
}
}

TEST_CASE(fp8_quantization)
{
auto run_prog = [](migraphx::program p,
const migraphx::target& t,
migraphx::parameter_map& m_in,
std::vector<float>& res) {
std::vector<migraphx::parameter_map> cali_data;
cali_data.push_back(m_in);
migraphx::quantize_fp8(p, t, cali_data);
p.compile(t);
migraphx::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(m_in.count(x.first) > 0)
{
m[x.first] = t.copy_to(m_in[x.first]);
}
else
{
m[x.first] = t.allocate(x.second);
}
}

auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
};

auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {5, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
mm->add_instruction(migraphx::make_op("dot"), pa, pb);

return p;
};

{
auto p = create_program();
migraphx::parameter_map m;
migraphx::shape sa{migraphx::shape::float_type, {5, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb);
std::vector<float> ref_result;
migraphx::target ref_t = migraphx::make_target("ref");
run_prog(p, ref_t, m, ref_result);

std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result);

// Note: the tolerance for mlir_enabled result is temporarily bumped
// higher because the lowering pipeline between mlir fallback and
// regular non-mlir pipeline diverged. MLIR fallback uses the
// rewrite_quantization at the very end of the pipeline, whereas
// the regular pipeline uses the rewrite_quantization in the much
// earlier stage.
if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range_with_tolerance(
gpu_result,
migraphx::verify::expected{ref_result},
migraphx::verify::tolerance{0.02}));
else
EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result));
}
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
15 changes: 15 additions & 0 deletions tools/api/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,21 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}

struct quantize_fp8_options
{
std::vector<parameter_map> calibration = {};
};

void add_calibration_data(quantize_fp8_options& options, parameter_map& data)
{
options.calibration.push_back(data);
}

void quantize_fp8_wrap(program& prog, const target& t, quantize_fp8_options& options)
{
migraphx::quantize_fp8(prog, t, options.calibration);
}

#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
Expand Down
Loading