From b981b1153c8860cbac9b86c468661661ce255ca5 Mon Sep 17 00:00:00 2001 From: Stefan Holst Date: Thu, 17 Jun 2021 12:05:55 +0900 Subject: [PATCH] add sdata to control individual sims --- src/kyupy/wave_sim.py | 64 +++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/src/kyupy/wave_sim.py b/src/kyupy/wave_sim.py index bd04f10..3412d5c 100644 --- a/src/kyupy/wave_sim.py +++ b/src/kyupy/wave_sim.py @@ -122,6 +122,9 @@ class WaveSim: self.lst_eat_valid = False self.cdata = np.zeros((len(self.interface), sims, 7), dtype='float32') + + self.sdata = np.zeros((sims, 4), dtype='float32') + self.sdata[...,0] = 1.0 if isinstance(wavecaps, int): wavecaps = [wavecaps] * len(circuit.lines) @@ -328,7 +331,7 @@ class WaveSim: sims = min(sims or self.sims, self.sims) for op_start, op_stop in zip(self.level_starts, self.level_stops): self.overflows += level_eval(self.ops, op_start, op_stop, self.state, self.sat, 0, sims, - self.timing, sd, seed) + self.timing, self.sdata, sd, seed) self.lst_eat_valid = False def wave(self, line, vector): @@ -521,12 +524,12 @@ class WaveSim: @numba.njit -def level_eval(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sd, seed): +def level_eval(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sdata, sd, seed): overflows = 0 for op_idx in range(op_start, op_stop): op = ops[op_idx] for st_idx in range(st_start, st_stop): - overflows += wave_eval(op, state, sat, st_idx, line_times, sd, seed) + overflows += wave_eval(op, state, sat, st_idx, line_times, sdata[st_idx], sd, seed) return overflows @@ -547,7 +550,7 @@ def rand_gauss(seed, sd): @numba.njit -def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0): +def wave_eval(op, state, sat, st_idx, line_times, sdata, sd=0.0, seed=0): lut, z_idx, a_idx, b_idx = op overflows = int(0) @@ -563,9 +566,11 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0): if z_cur == 1: state[z_mem, st_idx] = TMIN - a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss(_seed ^ a_mem ^ z_cur, sd) - b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss(_seed ^ b_mem ^ z_cur, sd) - + a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss(_seed ^ a_mem ^ z_cur, sd) * sdata[0] + if int(sdata[1]) == a_idx: a += sdata[2+z_cur] + b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss(_seed ^ b_mem ^ z_cur, sd) * sdata[0] + if int(sdata[1]) == b_idx: b += sdata[2+z_cur] + previous_t = TMIN current_t = min(a, b) @@ -576,15 +581,21 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0): if b < a: b_cur += 1 b = state[b_mem + b_cur, st_idx] - b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ b_mem ^ z_val ^ 1, sd) - thresh = line_times[b_idx, 1, z_val] * rand_gauss(_seed ^ b_mem ^ z_val, sd) + b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ b_mem ^ z_val ^ 1, sd) * sdata[0] + thresh = line_times[b_idx, 1, z_val] * rand_gauss(_seed ^ b_mem ^ z_val, sd) * sdata[0] + if int(sdata[1]) == b_idx: + b += sdata[2+(z_val^1)] + thresh += sdata[2+z_val] inputs ^= 2 next_t = b else: a_cur += 1 a = state[a_mem + a_cur, st_idx] - a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ a_mem ^ z_val ^ 1, sd) - thresh = line_times[a_idx, 1, z_val] * rand_gauss(_seed ^ a_mem ^ z_val, sd) + a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss(_seed ^ a_mem ^ z_val ^ 1, sd) * sdata[0] + thresh = line_times[a_idx, 1, z_val] * rand_gauss(_seed ^ a_mem ^ z_val, sd) * sdata[0] + if int(sdata[1]) == a_idx: + a += sdata[2+(z_val^1)] + thresh += sdata[2+z_val] inputs ^= 1 next_t = a @@ -636,6 +647,7 @@ class WaveSimCuda(WaveSim): self.d_timing = cuda.to_device(self.timing) self.d_tdata = cuda.to_device(self.tdata) self.d_cdata = cuda.to_device(self.cdata) + self.d_sdata = cuda.to_device(self.sdata) self._block_dim = (32, 16) @@ -650,6 +662,9 @@ class WaveSimCuda(WaveSim): def set_line_delay(self, line, polarity, delay): self.d_timing[line, 0, polarity] = delay + + def sdata_to_device(self): + cuda.to_device(self.sdata, to=self.d_sdata) def assign(self, vectors, time=0.0, offset=0): assert (offset % 8) == 0 @@ -676,7 +691,7 @@ class WaveSimCuda(WaveSim): for op_start, op_stop in zip(self.level_starts, self.level_stops): grid_dim = self._grid_dim(sims, op_stop - op_start) wave_kernel[grid_dim, self._block_dim](self.d_ops, op_start, op_stop, self.d_state, self.sat, int(0), - sims, self.d_timing, sd, seed) + sims, self.d_timing, self.d_sdata, sd, seed) cuda.synchronize() self.lst_eat_valid = False @@ -858,7 +873,7 @@ def rand_gauss_dev(seed, sd): @cuda.jit() -def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sd, seed): +def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_times, sdata, sd, seed): x, y = cuda.grid(2) st_idx = st_start + x op_idx = op_start + y @@ -869,6 +884,7 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time a_idx = ops[op_idx, 2] b_idx = ops[op_idx, 3] overflows = int(0) + sdata = sdata[st_idx] _seed = (seed << 4) + (z_idx << 20) + (st_idx << 1) @@ -882,9 +898,11 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time if z_cur == 1: state[z_mem, st_idx] = TMIN - a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss_dev(_seed ^ a_mem ^ z_cur, sd) - b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss_dev(_seed ^ b_mem ^ z_cur, sd) - + a = state[a_mem, st_idx] + line_times[a_idx, 0, z_cur] * rand_gauss_dev(_seed ^ a_mem ^ z_cur, sd) * sdata[0] + if int(sdata[1]) == a_idx: a += sdata[2+z_cur] + b = state[b_mem, st_idx] + line_times[b_idx, 0, z_cur] * rand_gauss_dev(_seed ^ b_mem ^ z_cur, sd) * sdata[0] + if int(sdata[1]) == b_idx: b += sdata[2+z_cur] + previous_t = TMIN current_t = min(a, b) @@ -895,15 +913,21 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time if b < a: b_cur += 1 b = state[b_mem + b_cur, st_idx] - b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ b_mem ^ z_val ^ 1, sd) - thresh = line_times[b_idx, 1, z_val] * rand_gauss_dev(_seed ^ b_mem ^ z_val, sd) + b += line_times[b_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ b_mem ^ z_val ^ 1, sd) * sdata[0] + thresh = line_times[b_idx, 1, z_val] * rand_gauss_dev(_seed ^ b_mem ^ z_val, sd) * sdata[0] + if int(sdata[1]) == b_idx: + b += sdata[2+(z_val^1)] + thresh += sdata[2+z_val] inputs ^= 2 next_t = b else: a_cur += 1 a = state[a_mem + a_cur, st_idx] - a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ a_mem ^ z_val ^ 1, sd) - thresh = line_times[a_idx, 1, z_val] * rand_gauss_dev(_seed ^ a_mem ^ z_val, sd) + a += line_times[a_idx, 0, z_val ^ 1] * rand_gauss_dev(_seed ^ a_mem ^ z_val ^ 1, sd) * sdata[0] + thresh = line_times[a_idx, 1, z_val] * rand_gauss_dev(_seed ^ a_mem ^ z_val, sd) * sdata[0] + if int(sdata[1]) == a_idx: + a += sdata[2+(z_val^1)] + thresh += sdata[2+z_val] inputs ^= 1 next_t = a