Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Dec 15, 2023
1 parent d848568 commit 724fc62
Showing 1 changed file with 0 additions and 40 deletions.
40 changes: 0 additions & 40 deletions mlc_llm/relax_model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,46 +150,6 @@ def __init__(self, config: MixtralConfig):
self.dtype = config.dtype
self.hidden_size = config.hidden_size

def topk(self, x, is_ascend, index_dtype, k=-1):
if not is_ascend and k == 2:
# fast path
total_rows = x.struct_info.shape[0]
result = nn.emit(
relax.call_dps_packed(
"top2",
[x],
out_sinfo=[
relax.TensorStructInfo([total_rows, k], x.struct_info.dtype),
relax.TensorStructInfo([total_rows, k], index_dtype),
],
)
)
return relax.TupleGetItem(result, 0), relax.TupleGetItem(result, 1)

# topk along axis -1
result = nn.emit(
relax.call_dps_packed(
"tvm.contrib.thrust.sort_dps",
[x, is_ascend],
out_sinfo=[
x.struct_info,
relax.TensorStructInfo(x.struct_info.shape, index_dtype),
],
)
)
sorted_x = relax.TupleGetItem(result, 0)
indices = relax.TupleGetItem(result, 1)
if k != -1:
ndim = len(x.struct_info.shape)
beg = [0] * ndim
end = [x.struct_info.shape[i] for i in range(ndim - 1)] + [k]
axes = list(range(ndim))
sorted_x = nn.emit(
relax.op.strided_slice(sorted_x, axes, beg, end, assume_inbound=True)
)
indices = nn.emit(relax.op.strided_slice(indices, axes, beg, end, assume_inbound=True))
return sorted_x, indices

def scatter(self, linear_out, indices):
@T.prim_func
def scatter_func(
Expand Down

0 comments on commit 724fc62

Please sign in to comment.