diff --git a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py index 56dcc9b730..eea7e42666 100644 --- a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py +++ b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py @@ -1,11 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + from typing import Optional import os +import sys import yaml -torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT") -assert torchao_root is not None, "TORCHAO_ROOT is not set" +if len(sys.argv) != 2: + print("Usage: gen_metal_shader_lib.py ") + sys.exit(1) + +# Output file where the generated code will be written +OUTPUT_FILE = sys.argv[1] -MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps") +MPS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) # Path to yaml file containing the list of .metal files to include METAL_YAML = os.path.join(MPS_DIR, "metal.yaml") @@ -21,9 +32,6 @@ # Path to the folder containing the .metal files METAL_DIR = os.path.join(MPS_DIR, "metal") -# Output file where the generated code will be written -OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h") - prefix = """/** * This file is generated by gen_metal_shader_lib.py */ @@ -48,6 +56,7 @@ """ +os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True) with open(OUTPUT_FILE, "w") as outf: outf.write(prefix) for file in metal_files: diff --git a/torchao/experimental/ops/mps/.gitignore b/torchao/experimental/ops/mps/.gitignore new file mode 100644 index 0000000000..d48f17d1c5 --- /dev/null +++ b/torchao/experimental/ops/mps/.gitignore @@ -0,0 +1 @@ +cmake-out/ diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt new file mode 100644 index 0000000000..044433ef95 --- /dev/null +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +project(torchao_ops_mps_linear_fp_act_xbit_weight) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED YES) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +if (CMAKE_SYSTEM_NAME STREQUAL "Darwin") + if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + message(FATAL_ERROR "Unified Memory requires Apple Silicon architecture") + endif() +else() + message(FATAL_ERROR "Torchao experimental mps ops can only be built on macOS/iOS") +endif() + +find_package(Torch REQUIRED) + +# Generate metal_shader_lib.h by running gen_metal_shader_lib.py +set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h) +add_custom_command( + OUTPUT ${GENERATED_METAL_SHADER_LIB} + COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB} + COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py" +) +add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB}) + +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) +endif() +message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") + +include_directories(${TORCHAO_INCLUDE_DIRS}) +include_directories(${CMAKE_INSTALL_PREFIX}/include) +add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm) +add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib) + +target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") +target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}") +target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE USE_ATEN=1) + +# Enable Metal support +find_library(METAL_LIB Metal) +find_library(FOUNDATION_LIB Foundation) +target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) + +install( + TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten + EXPORT _targets + DESTINATION lib +) diff --git a/torchao/experimental/ops/mps/register.mm b/torchao/experimental/ops/mps/aten/register.mm similarity index 98% rename from torchao/experimental/ops/mps/register.mm rename to torchao/experimental/ops/mps/aten/register.mm index 44946a30f0..92a3ba89f0 100644 --- a/torchao/experimental/ops/mps/register.mm +++ b/torchao/experimental/ops/mps/aten/register.mm @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. // clang-format off -#include +#include #include #include // clang-format on @@ -147,9 +147,6 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { return B; } -// Registers _C as a Python extension module. -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} - TORCH_LIBRARY(torchao, m) { m.def("_pack_weight_1bit(Tensor W) -> Tensor"); m.def("_pack_weight_2bit(Tensor W) -> Tensor"); diff --git a/torchao/experimental/ops/mps/build.sh b/torchao/experimental/ops/mps/build.sh new file mode 100644 index 0000000000..1ea032f8c6 --- /dev/null +++ b/torchao/experimental/ops/mps/build.sh @@ -0,0 +1,19 @@ +#!/bin/bash -eu +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cd "$(dirname "$BASH_SOURCE")" + +export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') +echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" +export CMAKE_OUT=${PWD}/cmake-out +echo "CMAKE_OUT: ${CMAKE_OUT}" + +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -S . \ + -B ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} -j 16 --target install --config Release diff --git a/torchao/experimental/ops/mps/setup.py b/torchao/experimental/ops/mps/setup.py deleted file mode 100644 index 1205d43d45..0000000000 --- a/torchao/experimental/ops/mps/setup.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import os -from setuptools import setup -from torch.utils.cpp_extension import CppExtension, BuildExtension - -setup( - name="torchao_mps_ops", - version="1.0", - ext_modules=[ - CppExtension( - name="torchao_mps_ops", - sources=["register.mm"], - include_dirs=[os.getenv("TORCHAO_ROOT")], - extra_compile_args=["-DUSE_ATEN=1"], - ), - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index 797c5dac29..f4c460a368 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -4,25 +4,38 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os +import sys import torch -import torchao_mps_ops import unittest +from parameterized import parameterized -def parameterized(test_cases): - def decorator(func): - def wrapper(self): - for case in test_cases: - with self.subTest(case=case): - func(self, *case) +libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libpath = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) +) - return wrapper - - return decorator +try: + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") +except AttributeError: + try: + torch.ops.load_library(libpath) + except: + raise RuntimeError(f"Failed to load library {libpath}") + else: + try: + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") + except AttributeError as e: + raise e class TestLowBitQuantWeightsLinear(unittest.TestCase): - cases = [ + CASES = [ (nbit, *param) for nbit in range(1, 8) for param in [ @@ -73,7 +86,7 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit): W = scales * W + zeros return torch.mm(A, W.t()) - @parameterized(cases) + @parameterized.expand(CASES) def test_linear(self, nbit, M=1, K=32, N=32, group_size=32): print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}") A, W, S, Z = self._init_tensors(group_size, M, K, N, nbit=nbit) diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 87f67c5452..00c08738c2 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -11,13 +11,34 @@ import sys import torch -import torchao_mps_ops import unittest from parameterized import parameterized from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer from torchao.experimental.quant_api import _quantize +libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libpath = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) +) + +try: + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") +except AttributeError: + try: + torch.ops.load_library(libpath) + except: + raise RuntimeError(f"Failed to load library {libpath}") + else: + try: + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") + except AttributeError as e: + raise e + class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase): BITWIDTHS = range(1, 8)