@ -122,6 +122,9 @@ class WaveSim:
				@@ -122,6 +122,9 @@ class WaveSim:
					 
			
		
	
		
			
				
					        self . lst_eat_valid  =  False  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					        self . cdata  =  np . zeros ( ( len ( self . interface ) ,  sims ,  7 ) ,  dtype = ' float32 ' )  
			
		
	
		
			
				
					                      
			
		
	
		
			
				
					        self . sdata  =  np . zeros ( ( sims ,  4 ) ,  dtype = ' float32 ' )  
			
		
	
		
			
				
					        self . sdata [ . . . , 0 ]  =  1.0  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					        if  isinstance ( wavecaps ,  int ) :  
			
		
	
		
			
				
					            wavecaps  =  [ wavecaps ]  *  len ( circuit . lines )  
			
		
	
	
		
			
				
					
						
							
								 
						
						
							
								 
						
						
					 
				
				@ -158,7 +161,8 @@ class WaveSim:
				@@ -158,7 +161,8 @@ class WaveSim:
					 
			
		
	
		
			
				
					                if  kind  ==  ' __fork__ ' :  
			
		
	
		
			
				
					                    if  not  strip_forks :  
			
		
	
		
			
				
					                        for  o_line  in  n . outs :  
			
		
	
		
			
				
					                            ops . append ( ( 0b1010 ,  o_line . index ,  i0_idx ,  i1_idx ) )  
			
		
	
		
			
				
					                            if  o_line  is  not  None :  
			
		
	
		
			
				
					                                ops . append ( ( 0b1010 ,  o_line . index ,  i0_idx ,  i1_idx ) )  
			
		
	
		
			
				
					                elif  kind . startswith ( ' nand ' ) :  
			
		
	
		
			
				
					                    ops . append ( ( 0b0111 ,  o0_idx ,  i0_idx ,  i1_idx ) )  
			
		
	
		
			
				
					                elif  kind . startswith ( ' nor ' ) :  
			
		
	
	
		
			
				
					
						
							
								 
						
						
							
								 
						
						
					 
				
				@ -328,7 +332,7 @@ class WaveSim:
				@@ -328,7 +332,7 @@ class WaveSim:
					 
			
		
	
		
			
				
					        sims  =  min ( sims  or  self . sims ,  self . sims )  
			
		
	
		
			
				
					        for  op_start ,  op_stop  in  zip ( self . level_starts ,  self . level_stops ) :  
			
		
	
		
			
				
					            self . overflows  + =  level_eval ( self . ops ,  op_start ,  op_stop ,  self . state ,  self . sat ,  0 ,  sims ,  
			
		
	
		
			
				
					                                         self . timing ,  sd ,  seed )  
			
		
	
		
			
				
					                                         self . timing ,  self . sdata ,  sd ,  seed )  
			
		
	
		
			
				
					        self . lst_eat_valid  =  False  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    def  wave ( self ,  line ,  vector ) :  
			
		
	
	
		
			
				
					
						
							
								 
						
						
							
								 
						
						
					 
				
				@ -521,12 +525,12 @@ class WaveSim:
				@@ -521,12 +525,12 @@ class WaveSim:
					 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					@numba . njit  
			
		
	
		
			
				
					def  level_eval ( ops ,  op_start ,  op_stop ,  state ,  sat ,  st_start ,  st_stop ,  line_times ,  sd ,  seed ) :  
			
		
	
		
			
				
					def  level_eval ( ops ,  op_start ,  op_stop ,  state ,  sat ,  st_start ,  st_stop ,  line_times ,  sdata ,  sd  ,  seed ) :  
			
		
	
		
			
				
					    overflows  =  0  
			
		
	
		
			
				
					    for  op_idx  in  range ( op_start ,  op_stop ) :  
			
		
	
		
			
				
					        op  =  ops [ op_idx ]  
			
		
	
		
			
				
					        for  st_idx  in  range ( st_start ,  st_stop ) :  
			
		
	
		
			
				
					            overflows  + =  wave_eval ( op ,  state ,  sat ,  st_idx ,  line_times ,  sd ,  seed )  
			
		
	
		
			
				
					            overflows  + =  wave_eval ( op ,  state ,  sat ,  st_idx ,  line_times ,  sdata [ st_idx ] ,  sd  ,  seed )  
			
		
	
		
			
				
					    return  overflows  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					
 
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -547,7 +551,7 @@ def rand_gauss(seed, sd):
				@@ -547,7 +551,7 @@ def rand_gauss(seed, sd):
					 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					@numba . njit  
			
		
	
		
			
				
					def  wave_eval ( op ,  state ,  sat ,  st_idx ,  line_times ,  sd = 0.0 ,  seed = 0 ) :  
			
		
	
		
			
				
					def  wave_eval ( op ,  state ,  sat ,  st_idx ,  line_times ,  sdata ,  sd  = 0.0 ,  seed = 0 ) :  
			
		
	
		
			
				
					    lut ,  z_idx ,  a_idx ,  b_idx  =  op  
			
		
	
		
			
				
					    overflows  =  int ( 0 )  
			
		
	
		
			
				
					
 
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -563,9 +567,11 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
				@@ -563,9 +567,11 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
					 
			
		
	
		
			
				
					    if  z_cur  ==  1 :  
			
		
	
		
			
				
					        state [ z_mem ,  st_idx ]  =  TMIN  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    a  =  state [ a_mem ,  st_idx ]  +  line_times [ a_idx ,  0 ,  z_cur ]  *  rand_gauss ( _seed  ^  a_mem  ^  z_cur ,  sd )  
			
		
	
		
			
				
					    b  =  state [ b_mem ,  st_idx ]  +  line_times [ b_idx ,  0 ,  z_cur ]  *  rand_gauss ( _seed  ^  b_mem  ^  z_cur ,  sd )  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    a  =  state [ a_mem ,  st_idx ]  +  line_times [ a_idx ,  0 ,  z_cur ]  *  rand_gauss ( _seed  ^  a_mem  ^  z_cur ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					    if  int ( sdata [ 1 ] )  ==  a_idx :  a  + =  sdata [ 2 + z_cur ]  
			
		
	
		
			
				
					    b  =  state [ b_mem ,  st_idx ]  +  line_times [ b_idx ,  0 ,  z_cur ]  *  rand_gauss ( _seed  ^  b_mem  ^  z_cur ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					    if  int ( sdata [ 1 ] )  ==  b_idx :  b  + =  sdata [ 2 + z_cur ]  
			
		
	
		
			
				
					     
			
		
	
		
			
				
					    previous_t  =  TMIN  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    current_t  =  min ( a ,  b )  
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -576,15 +582,21 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
				@@ -576,15 +582,21 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
					 
			
		
	
		
			
				
					        if  b  <  a :  
			
		
	
		
			
				
					            b_cur  + =  1  
			
		
	
		
			
				
					            b  =  state [ b_mem  +  b_cur ,  st_idx ]  
			
		
	
		
			
				
					            b  + =  line_times [ b_idx ,  0 ,  z_val  ^  1 ]  *  rand_gauss ( _seed  ^  b_mem  ^  z_val  ^  1 ,  sd )  
			
		
	
		
			
				
					            thresh  =  line_times [ b_idx ,  1 ,  z_val ]  *  rand_gauss ( _seed  ^  b_mem  ^  z_val ,  sd )  
			
		
	
		
			
				
					            b  + =  line_times [ b_idx ,  0 ,  z_val  ^  1 ]  *  rand_gauss ( _seed  ^  b_mem  ^  z_val  ^  1 ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					            thresh  =  line_times [ b_idx ,  1 ,  z_val ]  *  rand_gauss ( _seed  ^  b_mem  ^  z_val ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					            if  int ( sdata [ 1 ] )  ==  b_idx :  
			
		
	
		
			
				
					                b  + =  sdata [ 2 + ( z_val ^ 1 ) ]  
			
		
	
		
			
				
					                thresh  + =  sdata [ 2 + z_val ]  
			
		
	
		
			
				
					            inputs  ^ =  2  
			
		
	
		
			
				
					            next_t  =  b  
			
		
	
		
			
				
					        else :  
			
		
	
		
			
				
					            a_cur  + =  1  
			
		
	
		
			
				
					            a  =  state [ a_mem  +  a_cur ,  st_idx ]  
			
		
	
		
			
				
					            a  + =  line_times [ a_idx ,  0 ,  z_val  ^  1 ]  *  rand_gauss ( _seed  ^  a_mem  ^  z_val  ^  1 ,  sd )  
			
		
	
		
			
				
					            thresh  =  line_times [ a_idx ,  1 ,  z_val ]  *  rand_gauss ( _seed  ^  a_mem  ^  z_val ,  sd )  
			
		
	
		
			
				
					            a  + =  line_times [ a_idx ,  0 ,  z_val  ^  1 ]  *  rand_gauss ( _seed  ^  a_mem  ^  z_val  ^  1 ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					            thresh  =  line_times [ a_idx ,  1 ,  z_val ]  *  rand_gauss ( _seed  ^  a_mem  ^  z_val ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					            if  int ( sdata [ 1 ] )  ==  a_idx :  
			
		
	
		
			
				
					                a  + =  sdata [ 2 + ( z_val ^ 1 ) ]  
			
		
	
		
			
				
					                thresh  + =  sdata [ 2 + z_val ]  
			
		
	
		
			
				
					            inputs  ^ =  1  
			
		
	
		
			
				
					            next_t  =  a  
			
		
	
		
			
				
					
 
			
		
	
	
		
			
				
					
						
							
								 
						
						
							
								 
						
						
					 
				
				@ -618,6 +630,7 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
				@@ -618,6 +630,7 @@ def wave_eval(op, state, sat, st_idx, line_times, sd=0.0, seed=0):
					 
			
		
	
		
			
				
					    return  overflows  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					class  WaveSimCuda ( WaveSim ) :  
			
		
	
		
			
				
					    """ A GPU-accelerated waveform-based combinational logic timing simulator.  
			
		
	
		
			
				
					
 
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -636,6 +649,7 @@ class WaveSimCuda(WaveSim):
				@@ -636,6 +649,7 @@ class WaveSimCuda(WaveSim):
					 
			
		
	
		
			
				
					        self . d_timing  =  cuda . to_device ( self . timing )  
			
		
	
		
			
				
					        self . d_tdata  =  cuda . to_device ( self . tdata )  
			
		
	
		
			
				
					        self . d_cdata  =  cuda . to_device ( self . cdata )  
			
		
	
		
			
				
					        self . d_sdata  =  cuda . to_device ( self . sdata )  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					        self . _block_dim  =  ( 32 ,  16 )  
			
		
	
		
			
				
					
 
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -650,6 +664,9 @@ class WaveSimCuda(WaveSim):
				@@ -650,6 +664,9 @@ class WaveSimCuda(WaveSim):
					 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    def  set_line_delay ( self ,  line ,  polarity ,  delay ) :  
			
		
	
		
			
				
					        self . d_timing [ line ,  0 ,  polarity ]  =  delay  
			
		
	
		
			
				
					                      
			
		
	
		
			
				
					    def  sdata_to_device ( self ) :  
			
		
	
		
			
				
					        cuda . to_device ( self . sdata ,  to = self . d_sdata )  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    def  assign ( self ,  vectors ,  time = 0.0 ,  offset = 0 ) :  
			
		
	
		
			
				
					        assert  ( offset  %  8 )  ==  0  
			
		
	
	
		
			
				
					
						
							
								 
						
						
							
								 
						
						
					 
				
				@ -676,7 +693,7 @@ class WaveSimCuda(WaveSim):
				@@ -676,7 +693,7 @@ class WaveSimCuda(WaveSim):
					 
			
		
	
		
			
				
					        for  op_start ,  op_stop  in  zip ( self . level_starts ,  self . level_stops ) :  
			
		
	
		
			
				
					            grid_dim  =  self . _grid_dim ( sims ,  op_stop  -  op_start )  
			
		
	
		
			
				
					            wave_kernel [ grid_dim ,  self . _block_dim ] ( self . d_ops ,  op_start ,  op_stop ,  self . d_state ,  self . sat ,  int ( 0 ) ,  
			
		
	
		
			
				
					                                                   sims ,  self . d_timing ,  sd ,  seed )  
			
		
	
		
			
				
					                                                   sims ,  self . d_timing ,  self . d_sdata ,  sd ,  seed )  
			
		
	
		
			
				
					        cuda . synchronize ( )  
			
		
	
		
			
				
					        self . lst_eat_valid  =  False  
			
		
	
		
			
				
					
 
			
		
	
	
		
			
				
					
						
							
								 
						
						
							
								 
						
						
					 
				
				@ -858,7 +875,7 @@ def rand_gauss_dev(seed, sd):
				@@ -858,7 +875,7 @@ def rand_gauss_dev(seed, sd):
					 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					@cuda . jit ( )  
			
		
	
		
			
				
					def  wave_kernel ( ops ,  op_start ,  op_stop ,  state ,  sat ,  st_start ,  st_stop ,  line_times ,  sd ,  seed ) :  
			
		
	
		
			
				
					def  wave_kernel ( ops ,  op_start ,  op_stop ,  state ,  sat ,  st_start ,  st_stop ,  line_times ,  sdata ,  sd  ,  seed ) :  
			
		
	
		
			
				
					    x ,  y  =  cuda . grid ( 2 )  
			
		
	
		
			
				
					    st_idx  =  st_start  +  x  
			
		
	
		
			
				
					    op_idx  =  op_start  +  y  
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -869,6 +886,7 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
				@@ -869,6 +886,7 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
					 
			
		
	
		
			
				
					    a_idx  =  ops [ op_idx ,  2 ]  
			
		
	
		
			
				
					    b_idx  =  ops [ op_idx ,  3 ]  
			
		
	
		
			
				
					    overflows  =  int ( 0 )  
			
		
	
		
			
				
					    sdata  =  sdata [ st_idx ]  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    _seed  =  ( seed  <<  4 )  +  ( z_idx  <<  20 )  +  ( st_idx  <<  1 )  
			
		
	
		
			
				
					
 
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -882,9 +900,11 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
				@@ -882,9 +900,11 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
					 
			
		
	
		
			
				
					    if  z_cur  ==  1 :  
			
		
	
		
			
				
					        state [ z_mem ,  st_idx ]  =  TMIN  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    a  =  state [ a_mem ,  st_idx ]  +  line_times [ a_idx ,  0 ,  z_cur ]  *  rand_gauss_dev ( _seed  ^  a_mem  ^  z_cur ,  sd )  
			
		
	
		
			
				
					    b  =  state [ b_mem ,  st_idx ]  +  line_times [ b_idx ,  0 ,  z_cur ]  *  rand_gauss_dev ( _seed  ^  b_mem  ^  z_cur ,  sd )  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    a  =  state [ a_mem ,  st_idx ]  +  line_times [ a_idx ,  0 ,  z_cur ]  *  rand_gauss_dev ( _seed  ^  a_mem  ^  z_cur ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					    if  int ( sdata [ 1 ] )  ==  a_idx :  a  + =  sdata [ 2 + z_cur ]  
			
		
	
		
			
				
					    b  =  state [ b_mem ,  st_idx ]  +  line_times [ b_idx ,  0 ,  z_cur ]  *  rand_gauss_dev ( _seed  ^  b_mem  ^  z_cur ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					    if  int ( sdata [ 1 ] )  ==  b_idx :  b  + =  sdata [ 2 + z_cur ]  
			
		
	
		
			
				
					     
			
		
	
		
			
				
					    previous_t  =  TMIN  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    current_t  =  min ( a ,  b )  
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -895,15 +915,21 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
				@@ -895,15 +915,21 @@ def wave_kernel(ops, op_start, op_stop, state, sat, st_start, st_stop, line_time
					 
			
		
	
		
			
				
					        if  b  <  a :  
			
		
	
		
			
				
					            b_cur  + =  1  
			
		
	
		
			
				
					            b  =  state [ b_mem  +  b_cur ,  st_idx ]  
			
		
	
		
			
				
					            b  + =  line_times [ b_idx ,  0 ,  z_val  ^  1 ]  *  rand_gauss_dev ( _seed  ^  b_mem  ^  z_val  ^  1 ,  sd )  
			
		
	
		
			
				
					            thresh  =  line_times [ b_idx ,  1 ,  z_val ]  *  rand_gauss_dev ( _seed  ^  b_mem  ^  z_val ,  sd )  
			
		
	
		
			
				
					            b  + =  line_times [ b_idx ,  0 ,  z_val  ^  1 ]  *  rand_gauss_dev ( _seed  ^  b_mem  ^  z_val  ^  1 ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					            thresh  =  line_times [ b_idx ,  1 ,  z_val ]  *  rand_gauss_dev ( _seed  ^  b_mem  ^  z_val ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					            if  int ( sdata [ 1 ] )  ==  b_idx :  
			
		
	
		
			
				
					                b  + =  sdata [ 2 + ( z_val ^ 1 ) ]  
			
		
	
		
			
				
					                thresh  + =  sdata [ 2 + z_val ]  
			
		
	
		
			
				
					            inputs  ^ =  2  
			
		
	
		
			
				
					            next_t  =  b  
			
		
	
		
			
				
					        else :  
			
		
	
		
			
				
					            a_cur  + =  1  
			
		
	
		
			
				
					            a  =  state [ a_mem  +  a_cur ,  st_idx ]  
			
		
	
		
			
				
					            a  + =  line_times [ a_idx ,  0 ,  z_val  ^  1 ]  *  rand_gauss_dev ( _seed  ^  a_mem  ^  z_val  ^  1 ,  sd )  
			
		
	
		
			
				
					            thresh  =  line_times [ a_idx ,  1 ,  z_val ]  *  rand_gauss_dev ( _seed  ^  a_mem  ^  z_val ,  sd )  
			
		
	
		
			
				
					            a  + =  line_times [ a_idx ,  0 ,  z_val  ^  1 ]  *  rand_gauss_dev ( _seed  ^  a_mem  ^  z_val  ^  1 ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					            thresh  =  line_times [ a_idx ,  1 ,  z_val ]  *  rand_gauss_dev ( _seed  ^  a_mem  ^  z_val ,  sd )  *  sdata [ 0 ]  
			
		
	
		
			
				
					            if  int ( sdata [ 1 ] )  ==  a_idx :  
			
		
	
		
			
				
					                a  + =  sdata [ 2 + ( z_val ^ 1 ) ]  
			
		
	
		
			
				
					                thresh  + =  sdata [ 2 + z_val ]  
			
		
	
		
			
				
					            inputs  ^ =  1  
			
		
	
		
			
				
					            next_t  =  a