diff --git a/src/kyupy/circuit.py b/src/kyupy/circuit.py index 556057c..c3e8487 100644 --- a/src/kyupy/circuit.py +++ b/src/kyupy/circuit.py @@ -614,20 +614,24 @@ class Circuit: if marks[n]: yield n - def fanout(self, origin_nodes): + def fanout(self, origin_nodes: list[Node], node_filter = lambda n: 'dff' not in n.name.lower()): """Generator function to iterate over the fan-out cone of a given list of origin nodes. - Nodes are yielded in topological order. + origin_nodes are yielded first, followed by nodes driven by them in a breadth-first manner. + The search stops at nodes for which node_filter returns False. + Only origin_nodes and nodes for which node_filter returned True are yielded. + By default, search stops at flip-flops. """ - marks = [False] * len(self.nodes) - for n in origin_nodes: - marks[n] = True - for n in self.topological_order(): - if not marks[n]: - for line in n.ins.without_nones(): - marks[n] |= marks[line.driver] - if marks[n]: - yield n + queue = deque(origin_nodes) + yielded = set() + while len(queue) > 0: + n = queue.popleft() + for line in n.outs.without_nones(): + succ = line.reader + if succ not in yielded and node_filter(succ): + yielded.add(succ) + queue.append(succ) + yield n def fanout_free_regions(self): for stem in self.reversed_topological_order():