From 724fc621bfbb9bfa1db790e39163dc5f7edd1b93 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 15 Dec 2023 19:19:58 +0000 Subject: [PATCH] cleanup --- mlc_llm/relax_model/mixtral.py | 40 ---------------------------------- 1 file changed, 40 deletions(-) diff --git a/mlc_llm/relax_model/mixtral.py b/mlc_llm/relax_model/mixtral.py index 9b3fde0fae..16839178cd 100644 --- a/mlc_llm/relax_model/mixtral.py +++ b/mlc_llm/relax_model/mixtral.py @@ -150,46 +150,6 @@ def __init__(self, config: MixtralConfig): self.dtype = config.dtype self.hidden_size = config.hidden_size - def topk(self, x, is_ascend, index_dtype, k=-1): - if not is_ascend and k == 2: - # fast path - total_rows = x.struct_info.shape[0] - result = nn.emit( - relax.call_dps_packed( - "top2", - [x], - out_sinfo=[ - relax.TensorStructInfo([total_rows, k], x.struct_info.dtype), - relax.TensorStructInfo([total_rows, k], index_dtype), - ], - ) - ) - return relax.TupleGetItem(result, 0), relax.TupleGetItem(result, 1) - - # topk along axis -1 - result = nn.emit( - relax.call_dps_packed( - "tvm.contrib.thrust.sort_dps", - [x, is_ascend], - out_sinfo=[ - x.struct_info, - relax.TensorStructInfo(x.struct_info.shape, index_dtype), - ], - ) - ) - sorted_x = relax.TupleGetItem(result, 0) - indices = relax.TupleGetItem(result, 1) - if k != -1: - ndim = len(x.struct_info.shape) - beg = [0] * ndim - end = [x.struct_info.shape[i] for i in range(ndim - 1)] + [k] - axes = list(range(ndim)) - sorted_x = nn.emit( - relax.op.strided_slice(sorted_x, axes, beg, end, assume_inbound=True) - ) - indices = nn.emit(relax.op.strided_slice(indices, axes, beg, end, assume_inbound=True)) - return sorted_x, indices - def scatter(self, linear_out, indices): @T.prim_func def scatter_func(