Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: incorrect dispatch called
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 18, 2024
1 parent 8d7c497 commit 0b19476
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
9 changes: 4 additions & 5 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,14 @@ function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp,
return
end

function batched_matmul_cpu!(
z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3},
α::Number=true, β::Number=false) where {zT, xT, yT}
function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
if can_loopvec_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) &&
!unsafe_known(explicit_blas_loaded())
batched_matmul_loopvec_impl!(z, x, y, α, β)
batched_matmul_loopvec_impl!(z, x, y)
return
end
NNlib.batched_mul!(z, x, y, α, β)
NNlib.batched_mul!(z, x, y)
return
end

Expand Down
8 changes: 3 additions & 5 deletions test/common_ops/bias_act_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@
@jet bias_act_loss2(act, x, b)
@jet bias_act_loss3(act, x, b)

if act !== lisht
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any broken=(T ==
Float16)
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any broken=(T ==
Float16)
if act !== lisht && T != Float16
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any
end

@test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol,
Expand Down

0 comments on commit 0b19476

Please sign in to comment.