Browse Source

add sdata to control individual sims

devel
Stefan Holst 3 years ago
parent
commit
b981b1153c
  1. 60
      src/kyupy/wave_sim.py

60
src/kyupy/wave_sim.py

@ -123,6 +123,9 @@ class WaveSim:
self.cdata = np.zeros((len(self.interface), sims, 7), dtype='float32') self.cdata = np.zeros((len(self.interface), sims, 7), dtype='float32')
self.sdata = np.zeros((sims, 4), dtype='float32')
self.sdata[...,0] = 1.0
if isinstance(wavecaps, int): if isinstance(wavecaps, int):
wavecaps = [wavecaps] * len(circuit.lines) wavecaps = [wavecaps] * len(circuit.lines)
@ -328,7 +331,7 @@ class WaveSim:
sims = min(sims or self.sims, self.sims) sims = min(sims or self.sims, self.sims)
for op_start, op_stop in zip(self.level_starts, self.level_stops): for op_start, op_stop in zip(self.level_starts, self.level_stops):
self.overflows += level_eval(self.ops, op_start, op_stop, self.state, self.sat, 0, sims, self.overflows += level_eval(self.ops, op_start, op_stop, self.state, self.sat, 0, sims,
self.timing, sd, seed) self.timing, self.sdata, sd, seed)
self.lst_eat_valid = False self.lst_eat_valid = False
def wave(self, line, vector): def wave(self, line, vector):
@ -521,12 +524,12 @@ class WaveSim:
@numba.njit @numba.njit
def level_eval(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sd, seed): def level_eval(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sdata, sd, seed):
overflows = 0 overflows = 0
for op_idx in range(op_start, op_stop): for op_idx in range(op_start, op_stop):
op = ops[op_idx] op = ops[op_idx]
for st_idx in range(st_start, st_stop): for st_idx in range(st_start, st_stop):
overflows += wave_eval(op, state, sat, st_idx, line_times, sd, seed) overflows += wave_eval(op, state, sat, st_idx, line_times, sdata[st_idx], sd, seed)
return overflows return overflows
@ -547,7 +550,7 @@ def rand_gauss(seed, sd):
@numba.njit @numba.njit
def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0): def wave_eval(op, state, sat, st_idx, line_times, sdata, sd=0.0, seed=0):
lut, z_idx, a_idx, b_idx = op lut, z_idx, a_idx, b_idx = op
overflows = int(0) overflows = int(0)
@ -563,8 +566,10 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
if z_cur == 1: if z_cur == 1:
state[z_mem, st_idx] = TMIN state[z_mem, st_idx] = TMIN
a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss(_seed ^ a_mem ^ z_cur, sd) a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss(_seed ^ a_mem ^ z_cur, sd) * sdata[0]
b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss(_seed ^ b_mem ^ z_cur, sd) if int(sdata[1]) == a_idx: a += sdata[2+z_cur]
b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss(_seed ^ b_mem ^ z_cur, sd) * sdata[0]
if int(sdata[1]) == b_idx: b += sdata[2+z_cur]
previous_t = TMIN previous_t = TMIN
@ -576,15 +581,21 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
if b < a: if b < a:
b_cur += 1 b_cur += 1
b = state[b_mem + b_cur, st_idx] b = state[b_mem + b_cur, st_idx]
b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ b_mem ^ z_val ^ 1, sd) b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ b_mem ^ z_val ^ 1, sd) * sdata[0]
thresh = line_times[b_idx, 1, z_val] * rand_gauss(_seed ^ b_mem ^ z_val, sd) thresh = line_times[b_idx, 1, z_val] * rand_gauss(_seed ^ b_mem ^ z_val, sd) * sdata[0]
if int(sdata[1]) == b_idx:
b += sdata[2+(z_val^1)]
thresh += sdata[2+z_val]
inputs ^= 2 inputs ^= 2
next_t = b next_t = b
else: else:
a_cur += 1 a_cur += 1
a = state[a_mem + a_cur, st_idx] a = state[a_mem + a_cur, st_idx]
a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ a_mem ^ z_val ^ 1, sd) a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ a_mem ^ z_val ^ 1, sd) * sdata[0]
thresh = line_times[a_idx, 1, z_val] * rand_gauss(_seed ^ a_mem ^ z_val, sd) thresh = line_times[a_idx, 1, z_val] * rand_gauss(_seed ^ a_mem ^ z_val, sd) * sdata[0]
if int(sdata[1]) == a_idx:
a += sdata[2+(z_val^1)]
thresh += sdata[2+z_val]
inputs ^= 1 inputs ^= 1
next_t = a next_t = a
@ -636,6 +647,7 @@ class WaveSimCuda(WaveSim):
self.d_timing = cuda.to_device(self.timing) self.d_timing = cuda.to_device(self.timing)
self.d_tdata = cuda.to_device(self.tdata) self.d_tdata = cuda.to_device(self.tdata)
self.d_cdata = cuda.to_device(self.cdata) self.d_cdata = cuda.to_device(self.cdata)
self.d_sdata = cuda.to_device(self.sdata)
self._block_dim = (32, 16) self._block_dim = (32, 16)
@ -651,6 +663,9 @@ class WaveSimCuda(WaveSim):
def set_line_delay(self, line, polarity, delay): def set_line_delay(self, line, polarity, delay):
self.d_timing[line, 0, polarity] = delay self.d_timing[line, 0, polarity] = delay
def sdata_to_device(self):
cuda.to_device(self.sdata, to=self.d_sdata)
def assign(self, vectors, time=0.0, offset=0): def assign(self, vectors, time=0.0, offset=0):
assert (offset % 8) == 0 assert (offset % 8) == 0
byte_offset = offset // 8 byte_offset = offset // 8
@ -676,7 +691,7 @@ class WaveSimCuda(WaveSim):
for op_start, op_stop in zip(self.level_starts, self.level_stops): for op_start, op_stop in zip(self.level_starts, self.level_stops):
grid_dim = self._grid_dim(sims, op_stop - op_start) grid_dim = self._grid_dim(sims, op_stop - op_start)
wave_kernel[grid_dim, self._block_dim](self.d_ops, op_start, op_stop, self.d_state, self.sat, int(0), wave_kernel[grid_dim, self._block_dim](self.d_ops, op_start, op_stop, self.d_state, self.sat, int(0),
sims, self.d_timing, sd, seed) sims, self.d_timing, self.d_sdata, sd, seed)
cuda.synchronize() cuda.synchronize()
self.lst_eat_valid = False self.lst_eat_valid = False
@ -858,7 +873,7 @@ def rand_gauss_dev(seed, sd):
@cuda.jit() @cuda.jit()
def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sd, seed): def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sdata, sd, seed):
x, y = cuda.grid(2) x, y = cuda.grid(2)
st_idx = st_start + x st_idx = st_start + x
op_idx = op_start + y op_idx = op_start + y
@ -869,6 +884,7 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
a_idx = ops[op_idx, 2] a_idx = ops[op_idx, 2]
b_idx = ops[op_idx, 3] b_idx = ops[op_idx, 3]
overflows = int(0) overflows = int(0)
sdata = sdata[st_idx]
_seed = (seed << 4) + (z_idx << 20) + (st_idx << 1) _seed = (seed << 4) + (z_idx << 20) + (st_idx << 1)
@ -882,8 +898,10 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
if z_cur == 1: if z_cur == 1:
state[z_mem, st_idx] = TMIN state[z_mem, st_idx] = TMIN
a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss_dev(_seed ^ a_mem ^ z_cur, sd) a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss_dev(_seed ^ a_mem ^ z_cur, sd) * sdata[0]
b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss_dev(_seed ^ b_mem ^ z_cur, sd) if int(sdata[1]) == a_idx: a += sdata[2+z_cur]
b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss_dev(_seed ^ b_mem ^ z_cur, sd) * sdata[0]
if int(sdata[1]) == b_idx: b += sdata[2+z_cur]
previous_t = TMIN previous_t = TMIN
@ -895,15 +913,21 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
if b < a: if b < a:
b_cur += 1 b_cur += 1
b = state[b_mem + b_cur, st_idx] b = state[b_mem + b_cur, st_idx]
b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ b_mem ^ z_val ^ 1, sd) b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ b_mem ^ z_val ^ 1, sd) * sdata[0]
thresh = line_times[b_idx, 1, z_val] * rand_gauss_dev(_seed ^ b_mem ^ z_val, sd) thresh = line_times[b_idx, 1, z_val] * rand_gauss_dev(_seed ^ b_mem ^ z_val, sd) * sdata[0]
if int(sdata[1]) == b_idx:
b += sdata[2+(z_val^1)]
thresh += sdata[2+z_val]
inputs ^= 2 inputs ^= 2
next_t = b next_t = b
else: else:
a_cur += 1 a_cur += 1
a = state[a_mem + a_cur, st_idx] a = state[a_mem + a_cur, st_idx]
a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ a_mem ^ z_val ^ 1, sd) a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ a_mem ^ z_val ^ 1, sd) * sdata[0]
thresh = line_times[a_idx, 1, z_val] * rand_gauss_dev(_seed ^ a_mem ^ z_val, sd) thresh = line_times[a_idx, 1, z_val] * rand_gauss_dev(_seed ^ a_mem ^ z_val, sd) * sdata[0]
if int(sdata[1]) == a_idx:
a += sdata[2+(z_val^1)]
thresh += sdata[2+z_val]
inputs ^= 1 inputs ^= 1
next_t = a next_t = a

Loading…
Cancel
Save