diff --git a/src/kyupy/circuit.py b/src/kyupy/circuit.py index c3e8487..fd8dc05 100644 --- a/src/kyupy/circuit.py +++ b/src/kyupy/circuit.py @@ -18,6 +18,8 @@ from typing import Union, Any import numpy as np +type NestedNumericList = list['int|NestedNumericList'] +type NestedStrIntDict = dict[str,'int|NestedStrIntDict'] class GrowingList[T](list[T]): def __setitem__(self, index, value): @@ -290,10 +292,10 @@ class Circuit: return list(self.io_nodes) + [n for n in self.nodes if 'dff' in n.kind.lower()] + [n for n in self.nodes if 'latch' in n.kind.lower()] def io_locs(self, prefix): - """Returns the indices of primary I/Os that start with given name prefix. + """Returns a list of indices of primary I/Os that start with given name prefix. The returned values are used to index into the :py:attr:`io_nodes` array. - If only one I/O cell matches the given prefix, a single integer is returned. + If only one I/O cell matches the given prefix, a list with a single integer is returned. If a bus matches the given prefix, a sorted list of indices is returned. Busses are identified by integers in the cell names following the given prefix. Lists for bus indices are sorted from LSB (e.g. :code:`data[0]`) to MSB (e.g. :code:`data[31]`). @@ -311,8 +313,8 @@ class Circuit: """ return self._locs(prefix, self.s_nodes) - def _locs(self, prefix, nodes:list[Node]) -> Node|list[Any]: # can return list[list[...]] - d_top = dict() + def _locs(self, prefix, nodes:list[Node]) -> NestedNumericList: # can return list[list[...]] + d_top: NestedStrIntDict = dict() for i, n in enumerate(nodes): if m := re.match(fr'({re.escape(prefix)}.*?)((?:[\d_\[\]])*$)', n.name): path = [m[1]] + [int(v) for v in re.split(r'[_\[\]]+', m[2]) if len(v) > 0] @@ -320,13 +322,14 @@ class Circuit: for j in path[:-1]: d[j] = d.get(j, dict()) d = d[j] + assert isinstance(d, dict) d[path[-1]] = i # sort recursively for multi-dimensional lists. - def sorted_values(d): return [sorted_values(v) for k, v in sorted(d.items())] if isinstance(d, dict) else d + def sorted_values(d) -> NestedNumericList: return [sorted_values(v) for k, v in sorted(d.items())] if isinstance(d, dict) else d l = sorted_values(d_top) - while isinstance(l, list) and len(l) == 1: l = l[0] - return l #None if isinstance(l, list) and len(l) == 0 else l + while isinstance(l, list) and len(l) == 1 and isinstance(l[0], list): l = l[0] + return l @property def stats(self):