Browse Source

fix substitute for inputs with fo, dot graph

devel
Stefan Holst 1 year ago
parent
commit
5e573b0408
  1. 57
      src/kyupy/circuit.py

57
src/kyupy/circuit.py

@ -13,6 +13,8 @@ Circuit graphs also define an ordering of inputs, outputs and other nodes to eas @@ -13,6 +13,8 @@ Circuit graphs also define an ordering of inputs, outputs and other nodes to eas
from collections import deque, defaultdict
import re
import numpy as np
class GrowingList(list):
def __setitem__(self, index, value):
@ -359,7 +361,7 @@ class Circuit: @@ -359,7 +361,7 @@ class Circuit:
of the input and output ports must match the pins of the replaced node.
"""
ios = set(impl.io_nodes)
impl_in_lines = [n.outs[0] for n in impl.io_nodes if len(n.ins) == 0]
impl_in_nodes = [n for n in impl.io_nodes if len(n.ins) == 0]
impl_out_lines = [n.ins[0] for n in impl.io_nodes if len(n.ins) > 0]
designated_cell = None
if len(impl_out_lines) > 0:
@ -367,9 +369,9 @@ class Circuit: @@ -367,9 +369,9 @@ class Circuit:
while n.kind == '__fork__' and n not in ios:
n = n.ins[0].driver
designated_cell = n
node_in_lines = list(node.ins) + [None] * (len(impl_in_lines)-len(node.ins))
node_in_lines = list(node.ins) + [None] * (len(impl_in_nodes)-len(node.ins))
node_out_lines = list(node.outs) + [None] * (len(impl_out_lines)-len(node.outs))
assert len(node_in_lines) == len(impl_in_lines)
assert len(node_in_lines) == len(impl_in_nodes)
assert len(node_out_lines) == len(impl_out_lines)
node_map = dict()
if designated_cell is not None:
@ -386,13 +388,20 @@ class Circuit: @@ -386,13 +388,20 @@ class Circuit:
node_map[n] = Node(self, f'{node.name}~{n.name}', n.kind)
elif len(n.outs) > 0 and len(n.ins) > 0: # output is also read by impl. circuit, need to add a fork.
node_map[n] = Node(self, f'{node.name}~{n.name}')
elif len(n.ins) == 0 and len(n.outs) > 1: # input is read by multiple nodes, need to add fork.
node_map[n] = Node(self, f'{node.name}~{n.name}')
for l in impl.lines: # add all internal lines to main circuit
if l.reader in node_map and l.driver in node_map:
Line(self, (node_map[l.driver], l.driver_pin), (node_map[l.reader], l.reader_pin))
for l, ll in zip(impl_in_lines, node_in_lines): # connect inputs
for inn, ll in zip(impl_in_nodes, node_in_lines): # connect inputs
if ll is None: continue
if len(inn.outs) == 1:
l = inn.outs[0]
ll.reader = node_map[l.reader]
ll.reader_pin = l.reader_pin
else:
ll.reader = node_map[inn] # connect to existing fork
ll.reader_pin = 0
ll.reader.ins[ll.reader_pin] = ll
for l, ll in zip(impl_out_lines, node_out_lines): # connect outputs
if ll is None:
@ -467,7 +476,7 @@ class Circuit: @@ -467,7 +476,7 @@ class Circuit:
Nodes without input lines and nodes whose :py:attr:`Node.kind` contains the
substrings 'dff' or 'latch' are yielded first.
"""
visit_count = [0] * len(self.nodes)
visit_count = np.zeros(len(self.nodes), dtype=np.uint32)
queue = deque(n for n in self.nodes if len(n.ins) == 0 or 'dff' in n.kind.lower() or 'latch' in n.kind.lower())
while len(queue) > 0:
n = queue.popleft()
@ -479,6 +488,16 @@ class Circuit: @@ -479,6 +488,16 @@ class Circuit:
queue.append(succ)
yield n
def topological_order_with_level(self):
level = np.zeros(len(self.nodes), dtype=np.int32) - 1
for n in self.topological_order():
if len(n.ins) == 0 or 'dff' in n.kind.lower() or 'latch' in n.kind.lower():
l = 0
else:
l = level[[l.driver.index for l in n.ins if l is not None]].max() + 1
level[n] = l
yield n, l
def topological_line_order(self):
"""Generator function to iterate over all lines in topological order.
"""
@ -540,3 +559,31 @@ class Circuit: @@ -540,3 +559,31 @@ class Circuit:
queue.extend(preds)
region.append(n)
yield stem, region
def dot(self, format='svg'):
from graphviz import Digraph
dot = Digraph(format=format, graph_attr={'rankdir': 'LR', 'splines': 'true'})
node_level = np.zeros(len(self.nodes), dtype=np.uint32)
level_nodes = defaultdict(list)
for n, lv in self.topological_order_with_level():
level_nodes[lv].append(n)
node_level[n] = lv
for lv in level_nodes:
with dot.subgraph() as s:
s.attr(rank='same')
for n in level_nodes[lv]:
ins = '|'.join([f'<i{i}>{i}' for i in range(len(n.ins))])
outs = '|'.join([f'<o{i}>{i}' for i in range(len(n.outs))])
s.node(name=str(n.index), label = f'{{{{{ins}}}|{n.index}\n{n.kind}\n{n.name}|{{{outs}}}}}', shape='record')
for l in self.lines:
driver, reader = f'{l.driver.index}:o{l.driver_pin}', f'{l.reader.index}:i{l.reader_pin}'
if node_level[l.driver] >= node_level[l.reader]:
dot.edge(driver, reader, style='dotted', label=str(l.index))
pass
else:
dot.edge(driver, reader, label=str(l.index))
return dot

Loading…
Cancel
Save