import networkx as nx
from ..node import Node
[docs]
class ShapeValidator():
[docs]
def __init__(self, input_shape: tuple[int, int, int]) -> None:
self.input_shape = input_shape
[docs]
def verify(self, node: Node, inshape: tuple[int, int, int] = None):
if inshape is None:
inshape = self.input_shape
[docs]
class GraphValidator():
[docs]
def __init__(self, graph: nx.DiGraph) -> None:
self.graph = graph
[docs]
def verify(self) -> bool:
if len(list(nx.simple_cycles(self.graph))) > 0:
return False
return True