-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix fft rev diff rule #176
base: main
Are you sure you want to change the base?
Conversation
@mofeing is this the right math here? |
auto RT = RankedTensorType::get({1}, resTy.getElementType()); | ||
auto zero_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get( | ||
RT, FloatAttr::get(resTy.getElementType(), 0))); | ||
auto end_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get( | ||
RT, FloatAttr::get(resTy.getElementType(), lengths.back()-1))); | ||
|
||
auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64)); | ||
|
||
Value start[] = { | ||
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(0))) | ||
}; | ||
Value end[] = { | ||
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(lengths.size()-1))) | ||
}; | ||
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, zero_constant, start); | ||
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, end_constant, end); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is not right, and i believe it's my fault for not explaining it properly
- RFFT => mult = N/2, except for i_n = 0 and i_n = dim(i_n) - 1 whose value is then N
- IRFFT => mult = 2/N, except for i_n = 0 and i_n = dim(i_n) - 1 whose value is then 1/N
what i meant by i_n = 0
is that the n
-th index to be equal to 0 and the res to be "colons", so i really meant [:,:,:,...,:,0]
but now rechecking, this is wrong because it should be for slices where i_0 = 0
and i_0 = dim(i_0) - 1
. it's hard because StableHLO has weird semantics: if doing a 3-dim FFT, it performs FFT on 1st dimension and last 2 ones
so... for the case of RFFT and IRFFT, this Julia code should be equivalent:
value = if FFT
N
elseif IFFT
1/N
elseif RFFT
N/2
else # IRFFT
2/N
end
multiplier = fill(value, size(input))
if RFFT || IRFFT
value = RFFT ? N : 1/N
selectdim(multiplier, 1, 1) .= value
selectdim(multiplier, 1, size(input, 1)) .= value
end
also note that because it's just a multiplier... maybe we could just skip the dynamic_update_slice
in here, let it multiply with the FFT result and call slice
+ multiply
+ dynamic_update_slice
on the result to correct it. this has the advantage that no array instantiation will happen even after optimizations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm perhaps I'll let you take it from here then. That said it is much better to have it done here (since in batched mode the constant is generated once vs in the ops itself we'd have to do it for each)
890800b
to
e6dd644
Compare
CC @avik-pal
should fix #170