|
|
|
@ -122,6 +122,9 @@ class WaveSim:
@@ -122,6 +122,9 @@ class WaveSim:
|
|
|
|
|
self.lst_eat_valid = False |
|
|
|
|
|
|
|
|
|
self.cdata = np.zeros((len(self.interface), sims, 7), dtype='float32') |
|
|
|
|
|
|
|
|
|
self.sdata = np.zeros((sims, 4), dtype='float32') |
|
|
|
|
self.sdata[...,0] = 1.0 |
|
|
|
|
|
|
|
|
|
if isinstance(wavecaps, int): |
|
|
|
|
wavecaps = [wavecaps] * len(circuit.lines) |
|
|
|
@ -328,7 +331,7 @@ class WaveSim:
@@ -328,7 +331,7 @@ class WaveSim:
|
|
|
|
|
sims = min(sims or self.sims, self.sims) |
|
|
|
|
for op_start, op_stop in zip(self.level_starts, self.level_stops): |
|
|
|
|
self.overflows += level_eval(self.ops, op_start, op_stop, self.state, self.sat, 0, sims, |
|
|
|
|
self.timing, sd, seed) |
|
|
|
|
self.timing, self.sdata, sd, seed) |
|
|
|
|
self.lst_eat_valid = False |
|
|
|
|
|
|
|
|
|
def wave(self, line, vector): |
|
|
|
@ -521,12 +524,12 @@ class WaveSim:
@@ -521,12 +524,12 @@ class WaveSim:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@numba.njit |
|
|
|
|
def level_eval(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sd, seed): |
|
|
|
|
def level_eval(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sdata, sd, seed): |
|
|
|
|
overflows = 0 |
|
|
|
|
for op_idx in range(op_start, op_stop): |
|
|
|
|
op = ops[op_idx] |
|
|
|
|
for st_idx in range(st_start, st_stop): |
|
|
|
|
overflows += wave_eval(op, state, sat, st_idx, line_times, sd, seed) |
|
|
|
|
overflows += wave_eval(op, state, sat, st_idx, line_times, sdata[st_idx], sd, seed) |
|
|
|
|
return overflows |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -547,7 +550,7 @@ def rand_gauss(seed, sd):
@@ -547,7 +550,7 @@ def rand_gauss(seed, sd):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@numba.njit |
|
|
|
|
def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0): |
|
|
|
|
def wave_eval(op, state, sat, st_idx, line_times, sdata, sd=0.0, seed=0): |
|
|
|
|
lut, z_idx, a_idx, b_idx = op |
|
|
|
|
overflows = int(0) |
|
|
|
|
|
|
|
|
@ -563,9 +566,11 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
@@ -563,9 +566,11 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
|
|
|
|
|
if z_cur == 1: |
|
|
|
|
state[z_mem, st_idx] = TMIN |
|
|
|
|
|
|
|
|
|
a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss(_seed ^ a_mem ^ z_cur, sd) |
|
|
|
|
b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss(_seed ^ b_mem ^ z_cur, sd) |
|
|
|
|
|
|
|
|
|
a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss(_seed ^ a_mem ^ z_cur, sd) * sdata[0] |
|
|
|
|
if int(sdata[1]) == a_idx: a += sdata[2+z_cur] |
|
|
|
|
b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss(_seed ^ b_mem ^ z_cur, sd) * sdata[0] |
|
|
|
|
if int(sdata[1]) == b_idx: b += sdata[2+z_cur] |
|
|
|
|
|
|
|
|
|
previous_t = TMIN |
|
|
|
|
|
|
|
|
|
current_t = min(a, b) |
|
|
|
@ -576,15 +581,21 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
@@ -576,15 +581,21 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
|
|
|
|
|
if b < a: |
|
|
|
|
b_cur += 1 |
|
|
|
|
b = state[b_mem + b_cur, st_idx] |
|
|
|
|
b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ b_mem ^ z_val ^ 1, sd) |
|
|
|
|
thresh = line_times[b_idx, 1, z_val] * rand_gauss(_seed ^ b_mem ^ z_val, sd) |
|
|
|
|
b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ b_mem ^ z_val ^ 1, sd) * sdata[0] |
|
|
|
|
thresh = line_times[b_idx, 1, z_val] * rand_gauss(_seed ^ b_mem ^ z_val, sd) * sdata[0] |
|
|
|
|
if int(sdata[1]) == b_idx: |
|
|
|
|
b += sdata[2+(z_val^1)] |
|
|
|
|
thresh += sdata[2+z_val] |
|
|
|
|
inputs ^= 2 |
|
|
|
|
next_t = b |
|
|
|
|
else: |
|
|
|
|
a_cur += 1 |
|
|
|
|
a = state[a_mem + a_cur, st_idx] |
|
|
|
|
a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ a_mem ^ z_val ^ 1, sd) |
|
|
|
|
thresh = line_times[a_idx, 1, z_val] * rand_gauss(_seed ^ a_mem ^ z_val, sd) |
|
|
|
|
a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ a_mem ^ z_val ^ 1, sd) * sdata[0] |
|
|
|
|
thresh = line_times[a_idx, 1, z_val] * rand_gauss(_seed ^ a_mem ^ z_val, sd) * sdata[0] |
|
|
|
|
if int(sdata[1]) == a_idx: |
|
|
|
|
a += sdata[2+(z_val^1)] |
|
|
|
|
thresh += sdata[2+z_val] |
|
|
|
|
inputs ^= 1 |
|
|
|
|
next_t = a |
|
|
|
|
|
|
|
|
@ -636,6 +647,7 @@ class WaveSimCuda(WaveSim):
@@ -636,6 +647,7 @@ class WaveSimCuda(WaveSim):
|
|
|
|
|
self.d_timing = cuda.to_device(self.timing) |
|
|
|
|
self.d_tdata = cuda.to_device(self.tdata) |
|
|
|
|
self.d_cdata = cuda.to_device(self.cdata) |
|
|
|
|
self.d_sdata = cuda.to_device(self.sdata) |
|
|
|
|
|
|
|
|
|
self._block_dim = (32, 16) |
|
|
|
|
|
|
|
|
@ -650,6 +662,9 @@ class WaveSimCuda(WaveSim):
@@ -650,6 +662,9 @@ class WaveSimCuda(WaveSim):
|
|
|
|
|
|
|
|
|
|
def set_line_delay(self, line, polarity, delay): |
|
|
|
|
self.d_timing[line, 0, polarity] = delay |
|
|
|
|
|
|
|
|
|
def sdata_to_device(self): |
|
|
|
|
cuda.to_device(self.sdata, to=self.d_sdata) |
|
|
|
|
|
|
|
|
|
def assign(self, vectors, time=0.0, offset=0): |
|
|
|
|
assert (offset % 8) == 0 |
|
|
|
@ -676,7 +691,7 @@ class WaveSimCuda(WaveSim):
@@ -676,7 +691,7 @@ class WaveSimCuda(WaveSim):
|
|
|
|
|
for op_start, op_stop in zip(self.level_starts, self.level_stops): |
|
|
|
|
grid_dim = self._grid_dim(sims, op_stop - op_start) |
|
|
|
|
wave_kernel[grid_dim, self._block_dim](self.d_ops, op_start, op_stop, self.d_state, self.sat, int(0), |
|
|
|
|
sims, self.d_timing, sd, seed) |
|
|
|
|
sims, self.d_timing, self.d_sdata, sd, seed) |
|
|
|
|
cuda.synchronize() |
|
|
|
|
self.lst_eat_valid = False |
|
|
|
|
|
|
|
|
@ -858,7 +873,7 @@ def rand_gauss_dev(seed, sd):
@@ -858,7 +873,7 @@ def rand_gauss_dev(seed, sd):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@cuda.jit() |
|
|
|
|
def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sd, seed): |
|
|
|
|
def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sdata, sd, seed): |
|
|
|
|
x, y = cuda.grid(2) |
|
|
|
|
st_idx = st_start + x |
|
|
|
|
op_idx = op_start + y |
|
|
|
@ -869,6 +884,7 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
@@ -869,6 +884,7 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
|
|
|
|
|
a_idx = ops[op_idx, 2] |
|
|
|
|
b_idx = ops[op_idx, 3] |
|
|
|
|
overflows = int(0) |
|
|
|
|
sdata = sdata[st_idx] |
|
|
|
|
|
|
|
|
|
_seed = (seed << 4) + (z_idx << 20) + (st_idx << 1) |
|
|
|
|
|
|
|
|
@ -882,9 +898,11 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
@@ -882,9 +898,11 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
|
|
|
|
|
if z_cur == 1: |
|
|
|
|
state[z_mem, st_idx] = TMIN |
|
|
|
|
|
|
|
|
|
a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss_dev(_seed ^ a_mem ^ z_cur, sd) |
|
|
|
|
b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss_dev(_seed ^ b_mem ^ z_cur, sd) |
|
|
|
|
|
|
|
|
|
a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss_dev(_seed ^ a_mem ^ z_cur, sd) * sdata[0] |
|
|
|
|
if int(sdata[1]) == a_idx: a += sdata[2+z_cur] |
|
|
|
|
b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss_dev(_seed ^ b_mem ^ z_cur, sd) * sdata[0] |
|
|
|
|
if int(sdata[1]) == b_idx: b += sdata[2+z_cur] |
|
|
|
|
|
|
|
|
|
previous_t = TMIN |
|
|
|
|
|
|
|
|
|
current_t = min(a, b) |
|
|
|
@ -895,15 +913,21 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
@@ -895,15 +913,21 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
|
|
|
|
|
if b < a: |
|
|
|
|
b_cur += 1 |
|
|
|
|
b = state[b_mem + b_cur, st_idx] |
|
|
|
|
b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ b_mem ^ z_val ^ 1, sd) |
|
|
|
|
thresh = line_times[b_idx, 1, z_val] * rand_gauss_dev(_seed ^ b_mem ^ z_val, sd) |
|
|
|
|
b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ b_mem ^ z_val ^ 1, sd) * sdata[0] |
|
|
|
|
thresh = line_times[b_idx, 1, z_val] * rand_gauss_dev(_seed ^ b_mem ^ z_val, sd) * sdata[0] |
|
|
|
|
if int(sdata[1]) == b_idx: |
|
|
|
|
b += sdata[2+(z_val^1)] |
|
|
|
|
thresh += sdata[2+z_val] |
|
|
|
|
inputs ^= 2 |
|
|
|
|
next_t = b |
|
|
|
|
else: |
|
|
|
|
a_cur += 1 |
|
|
|
|
a = state[a_mem + a_cur, st_idx] |
|
|
|
|
a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ a_mem ^ z_val ^ 1, sd) |
|
|
|
|
thresh = line_times[a_idx, 1, z_val] * rand_gauss_dev(_seed ^ a_mem ^ z_val, sd) |
|
|
|
|
a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ a_mem ^ z_val ^ 1, sd) * sdata[0] |
|
|
|
|
thresh = line_times[a_idx, 1, z_val] * rand_gauss_dev(_seed ^ a_mem ^ z_val, sd) * sdata[0] |
|
|
|
|
if int(sdata[1]) == a_idx: |
|
|
|
|
a += sdata[2+(z_val^1)] |
|
|
|
|
thresh += sdata[2+z_val] |
|
|
|
|
inputs ^= 1 |
|
|
|
|
next_t = a |
|
|
|
|
|
|
|
|
|