Skip to content

Commit

Permalink
add hand-generated wavelet kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
ooreilly committed Dec 28, 2020
1 parent 10898fc commit 8e48484
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 136 deletions.
10 changes: 10 additions & 0 deletions compute_ans.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
iz = 0
0.00 1.00 2.00 3.00 4.00 5.00 6.00 7.00 8.00 9.00 10.00 11.00 12.00 13.00 14.00 15.00 16.00 17.00 18.00 19.00 20.00 21.00 22.00 23.00 24.00 25.00 26.00 27.00 28.00 29.00 30.00 31.00
iz = 0
0.47 2.93 5.66 8.49 11.31 14.14 16.97 19.80 22.63 25.46 28.28 31.11 33.94 36.77 39.52 42.52 0.18 -0.00 -0.00 0.00 -0.00 -0.00 -0.00 0.00 -0.00 0.00 -0.00 -0.00 -0.00 -0.00 -0.13 0.61
iz = 0
1.82 8.28 16.02 24.00 32.00 40.00 47.79 56.23 0.38 0.03 -0.00 0.00 0.00 -0.00 -0.34 1.86 0.18 -0.00 -0.00 0.00 -0.00 -0.00 -0.00 0.00 -0.00 0.00 -0.00 -0.00 -0.00 -0.00 -0.13 0.61
iz = 0
5.53 23.36 44.72 68.53 0.86 0.09 -0.97 5.25 0.38 0.03 -0.00 0.00 0.00 -0.00 -0.34 1.86 0.18 -0.00 -0.00 0.00 -0.00 -0.00 -0.00 0.00 -0.00 0.00 -0.00 -0.00 -0.00 -0.00 -0.13 0.61
iz = 0
12.57 68.04 -0.56 15.45 0.86 0.09 -0.97 5.25 0.38 0.03 -0.00 0.00 0.00 -0.00 -0.34 1.86 0.18 -0.00 -0.00 0.00 -0.00 -0.00 -0.00 0.00 -0.00 0.00 -0.00 -0.00 -0.00 -0.00 -0.13 0.61
67 changes: 67 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np

def declare(n, decl="float"):
s = []
for i in range(n):
s += ['%s p_%d = p_in[stride * %d];' % (decl, i, i)]
return "\n".join(s)

def header():
return "inline __device__ void opt5ds79_compute2(float * __restrict__ p_in, const int stride) {"

def mirror(x, n):
y = []
for xi in x:
z = xi
if xi < 0:
z = -xi
if z >= n:
z = 2 * n - 2 - z
z = abs(z)
if z >= n:
z = 2 * n - 2 - z
y.append(z)
return y

def compute(i, n):
idx = np.array(range(-4, 5))
p = mirror(idx + 2*i, n)
q = mirror(idx + 2*i + 1, n)
p_m4, p_m3, p_m2, p_m1, p_00, p_p1, p_p2, p_p3, p_p4 = p
q_m4, q_m3, q_m2, q_m1, q_00, q_p1, q_p2, q_p3, q_p4 = q
#print(p, q)
s = """
{
float acc1 = al4 * (p_%d + p_%d);
acc1 += al1 * (p_%d + p_%d);
acc1 += al0 * p_%d;
float acc2 = al3 * (p_%d + p_%d);
acc2 += al2 * (p_%d + p_%d);
p_in[%d] = acc1 + acc2;
}
// High
{
const int nl = %d;
float acc1 = ah3 * (p_%d + p_%d);
acc1 += ah0 * p_%d;
float acc2 = ah2 * (p_%d + p_%d);
acc2 += ah1 * (p_%d + p_%d);
p_in[(nl+%d)*stride] = acc1 + acc2;
}
""" % (p_m4, p_p4, p_m1, p_p1, p_00, p_m3, p_p3, p_m2, p_p2, i, n // 2, q_m3, q_p3, q_00, q_m2, q_p2, q_m1,
q_p1, i)
return s


print(header())
n = 32
print(declare(n))
while n >= 4:
for i in range(int(n)//2):
print(compute(i, n))
n = n // 2
# print('printf("n = %d \\n");' % n)
# print("print_array(p_in, 32, 1, 1);")
print(declare(32, decl=""))
print("}")
Loading

0 comments on commit 8e48484

Please sign in to comment.