diff --git a/src/kyupy/circuit.py b/src/kyupy/circuit.py index 559ae5d..d8c3182 100644 --- a/src/kyupy/circuit.py +++ b/src/kyupy/circuit.py @@ -99,6 +99,11 @@ class Node: del self.circuit.cells[self.name] self.circuit = None + def __eq__(self, other): + """Checks equality of node name and kind. Does not check pin connections. + """ + return self.name == other.name and self.kind == other.kind + class Line: """A line is a directional 1:1 connection between two nodes. @@ -172,6 +177,10 @@ class Line: def __lt__(self, other): return self.index < other.index + def __eq__(self, other): + return self.driver == other.driver and self.driver_pin == other.driver_pin and \ + self.reader == other.reader and self.reader_pin == other.reader_pin + class Circuit: """A Circuit is a container for interconnected nodes and lines. @@ -238,6 +247,32 @@ class Circuit: c.interface.append(n) return c + def __getstate__(self): + nodes = [(node.name, node.kind) for node in self.nodes] + lines = [(line.driver.index, line.driver_pin, line.reader.index, line.reader_pin) for line in self.lines] + interface = [n.index for n in self.interface] + return {'name': self.name, + 'nodes': nodes, + 'lines': lines, + 'interface': interface } + + def __setstate__(self, state): + self.name = state['name'] + self.nodes = IndexList() + self.lines = IndexList() + self.interface = GrowingList() + self.cells = {} + self.forks = {} + for s in state['nodes']: + Node(self, *s) + for driver, driver_pin, reader, reader_pin in state['lines']: + Line(self, (self.nodes[driver], driver_pin), (self.nodes[reader], reader_pin)) + for n in state['interface']: + self.interface.append(self.nodes[n]) + + def __eq__(self, other): + return self.nodes == other.nodes and self.lines == other.lines and self.interface == other.interface + def dump(self): """Returns a string representation of the circuit and all its nodes. """ diff --git a/tests/test_circuit.py b/tests/test_circuit.py index b5d6055..446ba90 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -1,5 +1,7 @@ -from kyupy.circuit import Circuit, Node, Line +import pickle +from kyupy.circuit import Circuit, Node, Line +from kyupy import verilog def test_lines(): c = Circuit() @@ -99,3 +101,12 @@ def test_circuit(): for n in c.topological_order(): repr(n) + + +def test_pickle(mydir): + c = verilog.load(mydir / 'b14.v.gz') + assert c is not None + cs = pickle.dumps(c) + assert cs is not None + c2 = pickle.loads(cs) + assert c == c2