diff --git a/src/kyupy/circuit.py b/src/kyupy/circuit.py index 76d4738..65ebfbb 100644 --- a/src/kyupy/circuit.py +++ b/src/kyupy/circuit.py @@ -13,6 +13,8 @@ Circuit graphs also define an ordering of inputs, outputs and other nodes to eas from collections import deque, defaultdict import re +import numpy as np + class GrowingList(list): def __setitem__(self, index, value): @@ -359,7 +361,7 @@ class Circuit: of the input and output ports must match the pins of the replaced node. """ ios = set(impl.io_nodes) - impl_in_lines = [n.outs[0] for n in impl.io_nodes if len(n.ins) == 0] + impl_in_nodes = [n for n in impl.io_nodes if len(n.ins) == 0] impl_out_lines = [n.ins[0] for n in impl.io_nodes if len(n.ins) > 0] designated_cell = None if len(impl_out_lines) > 0: @@ -367,9 +369,9 @@ class Circuit: while n.kind == '__fork__' and n not in ios: n = n.ins[0].driver designated_cell = n - node_in_lines = list(node.ins) + [None] * (len(impl_in_lines)-len(node.ins)) + node_in_lines = list(node.ins) + [None] * (len(impl_in_nodes)-len(node.ins)) node_out_lines = list(node.outs) + [None] * (len(impl_out_lines)-len(node.outs)) - assert len(node_in_lines) == len(impl_in_lines) + assert len(node_in_lines) == len(impl_in_nodes) assert len(node_out_lines) == len(impl_out_lines) node_map = dict() if designated_cell is not None: @@ -386,13 +388,20 @@ class Circuit: node_map[n] = Node(self, f'{node.name}~{n.name}', n.kind) elif len(n.outs) > 0 and len(n.ins) > 0: # output is also read by impl. circuit, need to add a fork. node_map[n] = Node(self, f'{node.name}~{n.name}') + elif len(n.ins) == 0 and len(n.outs) > 1: # input is read by multiple nodes, need to add fork. + node_map[n] = Node(self, f'{node.name}~{n.name}') for l in impl.lines: # add all internal lines to main circuit if l.reader in node_map and l.driver in node_map: Line(self, (node_map[l.driver], l.driver_pin), (node_map[l.reader], l.reader_pin)) - for l, ll in zip(impl_in_lines, node_in_lines): # connect inputs + for inn, ll in zip(impl_in_nodes, node_in_lines): # connect inputs if ll is None: continue - ll.reader = node_map[l.reader] - ll.reader_pin = l.reader_pin + if len(inn.outs) == 1: + l = inn.outs[0] + ll.reader = node_map[l.reader] + ll.reader_pin = l.reader_pin + else: + ll.reader = node_map[inn] # connect to existing fork + ll.reader_pin = 0 ll.reader.ins[ll.reader_pin] = ll for l, ll in zip(impl_out_lines, node_out_lines): # connect outputs if ll is None: @@ -467,7 +476,7 @@ class Circuit: Nodes without input lines and nodes whose :py:attr:`Node.kind` contains the substrings 'dff' or 'latch' are yielded first. """ - visit_count = [0] * len(self.nodes) + visit_count = np.zeros(len(self.nodes), dtype=np.uint32) queue = deque(n for n in self.nodes if len(n.ins) == 0 or 'dff' in n.kind.lower() or 'latch' in n.kind.lower()) while len(queue) > 0: n = queue.popleft() @@ -479,6 +488,16 @@ class Circuit: queue.append(succ) yield n + def topological_order_with_level(self): + level = np.zeros(len(self.nodes), dtype=np.int32) - 1 + for n in self.topological_order(): + if len(n.ins) == 0 or 'dff' in n.kind.lower() or 'latch' in n.kind.lower(): + l = 0 + else: + l = level[[l.driver.index for l in n.ins if l is not None]].max() + 1 + level[n] = l + yield n, l + def topological_line_order(self): """Generator function to iterate over all lines in topological order. """ @@ -540,3 +559,31 @@ class Circuit: queue.extend(preds) region.append(n) yield stem, region + + def dot(self, format='svg'): + from graphviz import Digraph + dot = Digraph(format=format, graph_attr={'rankdir': 'LR', 'splines': 'true'}) + + node_level = np.zeros(len(self.nodes), dtype=np.uint32) + level_nodes = defaultdict(list) + for n, lv in self.topological_order_with_level(): + level_nodes[lv].append(n) + node_level[n] = lv + + for lv in level_nodes: + with dot.subgraph() as s: + s.attr(rank='same') + for n in level_nodes[lv]: + ins = '|'.join([f'{i}' for i in range(len(n.ins))]) + outs = '|'.join([f'{i}' for i in range(len(n.outs))]) + s.node(name=str(n.index), label = f'{{{{{ins}}}|{n.index}\n{n.kind}\n{n.name}|{{{outs}}}}}', shape='record') + + for l in self.lines: + driver, reader = f'{l.driver.index}:o{l.driver_pin}', f'{l.reader.index}:i{l.reader_pin}' + if node_level[l.driver] >= node_level[l.reader]: + dot.edge(driver, reader, style='dotted', label=str(l.index)) + pass + else: + dot.edge(driver, reader, label=str(l.index)) + + return dot