Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix misconception between FFT and Inverse FFT #153

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions circlestark/fast_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Converts a list of evaluations to a list of coefficients. Note that the
# coefficients are in a "weird" basis: 1, y, x, xy, 2x^2-1...
def fft(vals, is_top_level=True):
def inv_fft(vals, is_top_level=True):
vals = vals.copy()
shape_suffix = vals.shape[1:]
size = vals.shape[0]
Expand All @@ -33,7 +33,7 @@ def fft(vals, is_top_level=True):
)

# Converts a list of coefficients into a list of evaluations
def inv_fft(vals):
def fft(vals):
vals = vals.copy()
shape_suffix = vals.shape[1:]
size = vals.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions circlestark/fast_fri.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
merkelize_top_dimension, get_challenges, rbo_index_to_original
)
from precomputes import folded_rbos, invx, invy
from fast_fft import fft
from fast_fft import inv_fft
from merkle import merkelize, hash, get_branch, verify_branch

BASE_CASE_SIZE = 64
Expand Down Expand Up @@ -169,6 +169,6 @@ def verify_low_degree(proof, extra_entropy=b''):
o = zeros_like(final_values)
N = final_values.shape[0]
o[rbo_index_to_original(N, cp.arange(N))] = final_values
coeffs = fft(o, is_top_level=False)
coeffs = inv_fft(o, is_top_level=False)
assert coeffs[N//2:] == 0
return True
16 changes: 8 additions & 8 deletions circlestark/fast_stark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
# Tweaks the last row of a trace or constraint object to reduce its degree
def tweak_last_row(obj):
obj = obj.copy()
coeffs = fft(obj)
coeffs = inv_fft(obj)
cls = obj.__class__
tweak_value = fft(cls.append(cls.zeros(obj.shape[0]-1), cls([1])))[-1]
tweak_value = inv_fft(cls.append(cls.zeros(obj.shape[0]-1), cls([1])))[-1]
obj[-1] -= coeffs[-1] / tweak_value
return obj

Expand Down Expand Up @@ -114,7 +114,7 @@ def mk_stark(check_constraint,
trace_length*ext_degree:
trace_length*ext_degree*2
]
trace_ext = inv_fft(pad_to(fft(trace), trace_length*ext_degree))
trace_ext = fft(pad_to(inv_fft(trace), trace_length*ext_degree))
print('Generated trace extension', time.time() - START)
# Decompose the trace into the public part and the private part:
# trace = public * V + private. We commit to the private part, and show
Expand All @@ -124,13 +124,13 @@ def mk_stark(check_constraint,
public_args,
trace[cp.array(public_args)],
)
V_ext = inv_fft(pad_to(fft(V), trace_length*ext_degree))
I_ext = inv_fft(pad_to(fft(I), trace_length*ext_degree))
V_ext = fft(pad_to(inv_fft(V), trace_length*ext_degree))
I_ext = fft(pad_to(inv_fft(I), trace_length*ext_degree))
print('Generated V,I', time.time() - START)
trace_quotient_ext = (
(trace_ext - I_ext) / V_ext.reshape(V_ext.shape+(1,))
)
constants_ext = inv_fft(pad_to(fft(constants), trace_length*ext_degree))
constants_ext = fft(pad_to(inv_fft(constants), trace_length*ext_degree))
rolled_trace_ext = M31.append(
trace_ext[ext_degree:],
trace_ext[:ext_degree]
Expand Down Expand Up @@ -283,9 +283,9 @@ def f2():
# Generate the Merkle tree of constants
def build_constants_tree(constants, H_degree=2):
trace_length = constants.shape[0]
constants_coeffs = fft(pad_to(constants, trace_length))
constants_coeffs = inv_fft(pad_to(constants, trace_length))
return merkelize_top_dimension(
inv_fft(pad_to(constants_coeffs, trace_length*H_degree*2))
fft(pad_to(constants_coeffs, trace_length*H_degree*2))
)

# Generate the verification key (basically the Merkle root of the
Expand Down
12 changes: 6 additions & 6 deletions circlestark/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def halve_single_domain_value(value):
else:
return 2*value**2-1

def fft(vals, domain=None):
def inv_fft(vals, domain=None):
if len(vals) == 1:
return vals
if domain is None:
Expand All @@ -86,19 +86,19 @@ def fft(vals, domain=None):
f0 = [(L+R)/2 for L,R in zip(left, right)]
f1 = [(L-R)/(2*x) for L,R,x in zip(left, right, domain)]
o = [0] * len(domain)
o[::2] = fft(f0, half_domain)
o[1::2] = fft(f1, half_domain)
o[::2] = inv_fft(f0, half_domain)
o[1::2] = inv_fft(f1, half_domain)
return o

def inv_fft(vals, domain=None):
def fft(vals, domain=None):
if len(vals) == 1:
#print('o', vals)
return vals
if domain is None:
domain = get_initial_domain_of_size(vals[0].__class__, len(vals))
half_domain = halve_domain(domain)
f0 = inv_fft(vals[::2], half_domain)
f1 = inv_fft(vals[1::2], half_domain)
f0 = fft(vals[::2], half_domain)
f1 = fft(vals[1::2], half_domain)
if isinstance(domain[0], tuple):
left = [L+y*R for L,R,(x,y) in zip(f0, f1, domain)]
right = [L-y*R for L,R,(x,y) in zip(f0, f1, domain)]
Expand Down
6 changes: 3 additions & 3 deletions circlestark/fri.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

def extend_trace(field, trace):
small_domain = get_initial_domain_of_size(field, len(trace))
coeffs = fft.fft(trace, small_domain)
coeffs = fft.inv_fft(trace, small_domain)
big_domain = get_initial_domain_of_size(field, len(trace)*2)
return fft.inv_fft(trace, big_domain)
return fft.fft(trace, big_domain)

def line_function(P1, P2, domain):
x1, y1 = P1
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_challenges(root, domain_size, num_challenges):

def is_rbo_low_degree(evaluations, domain):
halflen = len(evaluations)//2
return fft(
return inv_fft(
undo_folded_reverse_bit_order(evaluations),
undo_folded_reverse_bit_order(domain)
)[halflen:] == [0] * halflen
Expand Down
28 changes: 14 additions & 14 deletions circlestark/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def test_basic_arithmetic():

def test_fft():
INPUT_SIZE = 512
data = [M(3**i) for i in range(INPUT_SIZE)]
coeffs = fft(data)
data2 = inv_fft(coeffs)
assert data2 == data
coeffs1 = [B(3**i) for i in range(INPUT_SIZE)] + [B(0)] * INPUT_SIZE
evaluations = fft(coeffs1)
coeffs2 = inv_fft(evaluations)
assert coeffs2 == coeffs1
print("Basic FFT test passed")

def test_fri():
print("Testing FRI")
INPUT_SIZE = 4096
coeffs = [B(3**i) for i in range(INPUT_SIZE)] + [B(0)] * INPUT_SIZE
evaluations = inv_fft(coeffs)
evaluations = fft(coeffs)
global fri_proof
fri_proof = prove_low_degree([EB(v) for v in evaluations])
length = (
Expand All @@ -76,19 +76,19 @@ def test_fri():
def test_fast_fft():
print("Testing fast FFT")
INPUT_SIZE = 2**13
data = [pow(3, i, 2**31-1) for i in range(INPUT_SIZE)]
npdata = M31(data)
coeffs = [pow(3, i, modulus) for i in range(INPUT_SIZE)] + [0] * INPUT_SIZE
npcoeffs = M31(coeffs)
t0 = time.time()
coeffs1 = fft([B(x) for x in data])
evaluations1 = fft([B(x) for x in coeffs])
t1 = time.time()
print("Computed size-{} slow fft in {} sec".format(INPUT_SIZE, t1 - t0))
t1 = time.time()
coeffs2 = f_fft(npdata)
print(coeffs2)
evaluations2 = f_fft(npcoeffs)
print(evaluations2)
t2 = time.time()
print("Computed size-{} fast fft in {} sec".format(INPUT_SIZE, t2 - t1))
assert [int(x) for x in coeffs2] == coeffs1
assert f_inv_fft(coeffs2) == npdata
assert [int(x) for x in evaluations2] == evaluations1
assert f_inv_fft(evaluations2) == npcoeffs
print("Fast FFT checks passed")

def test_fast_fri():
Expand All @@ -97,7 +97,7 @@ def test_fast_fri():
coeffs = M31(
[pow(3, i, modulus) for i in range(INPUT_SIZE)] + [0] * INPUT_SIZE,
)
evaluations = f_inv_fft(coeffs)
evaluations = f_fft(coeffs)
print('ev', evaluations, evaluations.__class__)
proof = f_prove_low_degree(evaluations.to_extended())
assert f_verify_low_degree(proof)
Expand All @@ -122,7 +122,7 @@ def test_mega_fri():
M31.zeros(INPUT_SIZE)
)
t1 = time.time()
evaluations = f_inv_fft(coeffs)
evaluations = f_fft(coeffs)
print("Low-degree extended coeffs in time {}".format(time.time() - t1))
t2 = time.time()
proof = f_prove_low_degree(evaluations.to_extended())
Expand Down