diff --git a/src/kyupy/bench.py b/src/kyupy/bench.py index fb80598..d678db3 100644 --- a/src/kyupy/bench.py +++ b/src/kyupy/bench.py @@ -74,14 +74,14 @@ GRAMMAR = r""" """ -def parse(text, name=None): +def parse(text, name=None) -> Circuit: """Parses the given ``text`` as ISCAS89 bench code. :param text: A string with bench code. :param name: The name of the circuit. Circuit names are not included in bench descriptions. :return: A :class:`Circuit` object. """ - return Lark(GRAMMAR, parser="lalr", transformer=BenchTransformer(name)).parse(text) + return Lark(GRAMMAR, parser="lalr", transformer=BenchTransformer(name)).parse(text) # type: ignore def load(file, name=None): diff --git a/src/kyupy/circuit.py b/src/kyupy/circuit.py index ee49275..228ecd4 100644 --- a/src/kyupy/circuit.py +++ b/src/kyupy/circuit.py @@ -14,12 +14,12 @@ from __future__ import annotations from collections import deque, defaultdict import re -from typing import Union +from typing import Union, Any import numpy as np -class GrowingList(list): +class GrowingList[T](list[T]): def __setitem__(self, index, value): if value is None: self.has_nones = True if index == len(self): return super().append(value) @@ -28,9 +28,13 @@ class GrowingList(list): self.has_nones = True super().__setitem__(index, value) - def __getitem__(self, index): - if isinstance(index, slice): return super().__getitem__(index) - return super().__getitem__(index) if index < len(self) else None + # Override __getitem__ to return None when reading beyond the list + # instead of throwing an exception. Type checker complains about the None return + # type, though. Probably not needed anyways. + + #def __getitem__(self, index) -> list[T] | T | None: + # if isinstance(index, slice): return super().__getitem__(index) + # return super().__getitem__(index) if index < len(self) else None @property def free_idx(self): @@ -155,7 +159,7 @@ class Line: Use the explicit case only if connections to specific pins are required. It may overwrite any previous line references in the connection list of the nodes. """ - def __init__(self, circuit: Circuit, driver: Union[Node, tuple[Node, int]], reader: Union[Node, tuple[Node, int]]): + def __init__(self, circuit: Circuit, driver: Node | tuple[Node, None|int], reader: Node | tuple[Node, None|int]): self.circuit = circuit """The :class:`Circuit` object the line is part of. """ @@ -168,20 +172,20 @@ class Line: accessing it by :code:`my_data[l.index]` or simply by :code:`my_data[l]`. """ if not isinstance(driver, tuple): driver = (driver, driver.outs.free_idx) - self.driver = driver[0] + self.driver: Node = driver[0] """The :class:`Node` object that drives this line. """ - self.driver_pin = driver[1] + self.driver_pin = driver[1] if driver[1] is not None else self.driver.outs.free_idx """The output pin position of the driver node this line is connected to. This is the position in the list :py:attr:`Node.outs` of the driving node this line referenced from: :code:`self.driver.outs[self.driver_pin] == self`. """ if not isinstance(reader, tuple): reader = (reader, reader.ins.free_idx) - self.reader = reader[0] + self.reader: Node = reader[0] """The :class:`Node` object that reads this line. """ - self.reader_pin = reader[1] + self.reader_pin = reader[1] if reader[1] is not None else self.reader.ins.free_idx """The input pin position of the reader node this line is connected to. This is the position in the list :py:attr:`Node.ins` of the reader node this line referenced from: @@ -203,8 +207,6 @@ class Line: for i, l in enumerate(self.driver.outs): l.driver_pin = i if self.reader is not None: self.reader.ins[self.reader_pin] = None if self.circuit is not None: del self.circuit.lines[self.index] - self.driver = None - self.reader = None self.circuit = None def __index__(self): @@ -309,7 +311,7 @@ class Circuit: """ return self._locs(prefix, self.s_nodes) - def _locs(self, prefix, nodes): + def _locs(self, prefix, nodes:list[Node]) -> Node|list[Any]: # can return list[list[...]] d_top = dict() for i, n in enumerate(nodes): if m := re.match(fr'({re.escape(prefix)}.*?)((?:[\d_\[\]])*$)', n.name): @@ -324,7 +326,7 @@ class Circuit: def sorted_values(d): return [sorted_values(v) for k, v in sorted(d.items())] if isinstance(d, dict) else d l = sorted_values(d_top) while isinstance(l, list) and len(l) == 1: l = l[0] - return None if isinstance(l, list) and len(l) == 0 else l + return l #None if isinstance(l, list) and len(l) == 0 else l @property def stats(self): diff --git a/src/kyupy/logic.py b/src/kyupy/logic.py index 5cfbe92..c850e27 100644 --- a/src/kyupy/logic.py +++ b/src/kyupy/logic.py @@ -47,6 +47,7 @@ The functions in this module use the ``mv...`` and ``bp...`` prefixes to signify from collections.abc import Iterable import numpy as np +from numpy.typing import DTypeLike from . import numba, hr_bytes @@ -433,7 +434,7 @@ def unpackbits(a : np.ndarray): return np.unpackbits(a.view(np.uint8), bitorder='little').reshape(*a.shape, 8*a.itemsize) -def packbits(a, dtype=np.uint8): +def packbits(a, dtype:DTypeLike=np.uint8): """Packs the values of a boolean-valued array ``a`` along its last axis into bits. Similar to ``np.packbits``, but returns an array of given dtype and the shape of ``a`` with the last axis removed. diff --git a/src/kyupy/verilog.py b/src/kyupy/verilog.py index 39bd192..d6d6f33 100644 --- a/src/kyupy/verilog.py +++ b/src/kyupy/verilog.py @@ -258,7 +258,7 @@ GRAMMAR = r""" """ -def parse(text, tlib=KYUPY, branchforks=False): +def parse(text, tlib=KYUPY, branchforks=False) -> Circuit: """Parses the given ``text`` as Verilog code. :param text: A string with Verilog code. @@ -269,7 +269,7 @@ def parse(text, tlib=KYUPY, branchforks=False): (see :py:func:`~kyupy.sdf.DelayFile.interconnects()`). :return: A :py:class:`~kyupy.circuit.Circuit` object. """ - return Lark(GRAMMAR, parser="lalr", transformer=VerilogTransformer(branchforks, tlib)).parse(text) + return Lark(GRAMMAR, parser="lalr", transformer=VerilogTransformer(branchforks, tlib)).parse(text) # type: ignore def load(file, tlib=KYUPY, branchforks=False):