Source code for cuvis_ai.pipeline.graph

import os
import shutil
import torch
import sys
from typing import Any
from datetime import datetime
from os.path import expanduser
from typing import Optional, Any, Union, Iterator
import networkx as nx
from typing import List, Union
from collections import defaultdict
import pkg_resources  # part of setuptools
from ..node import Node
from ..node.wrap import make_node
from ..node.Consumers import *
from ..data.OutputFormat import OutputFormat
from ..utils.numpy import get_shape_without_batch, check_array_shape
from ..utils.filesystem import change_working_dir
from ..utils.serializer import YamlSerializer
from ..utils.dependencies import get_installed_packages_str
import numpy as np
import tempfile
from pathlib import Path
import importlib
from .executor import MemoryExecutor, HummingBirdExecutor
from copy import copy, deepcopy
from functools import lru_cache
from ..node.skorch import SkorchWrapped
from ..node.sklearn import SklearnWrapped


[docs] def maybe_wrap_node(node): if isinstance(node, Node): return node return make_node(node)
[docs] class Graph(): """Main class for connecting nodes in a CUVIS.AI processing graph """
[docs] def __init__(self, name: str) -> None: self.graph = nx.DiGraph() self.nodes: dict[str, Node] = {} self.entry_point = None self.name = name
[docs] def add_node(self, node: Node, parent: list[Node] | Node = None) -> None: """Add a new node into the graph structure Parameters ---------- node : Node CUVIS.AI type node parent : list[Node] | Node, optional Node(s) that the child node should be connected to, with data flowing from parent(s) to child, by default None. Raises ------ ValueError If no parent is provided, node is assumed to be the base node of the graph. This event will raise an error to prevent base from being overwritten. ValueError If parent(s) do not already belong to the graph. ValueError If parent(s) and child nodes are mismatched in expected data size. """ if parent is None: # this is the first Node of the graph if self.entry_point is not None: raise ValueError("Graph already has base node") self.entry_point = node.id parent = [] if isinstance(parent, Node): parent = [parent] # Check if operation is valid if not all([self.graph.has_node(p.id) for p in parent]): raise ValueError("Not all parents are part of the Graph") if not all([check_array_shape(p.output_dim, node.input_dim) for p in parent]): raise ValueError('Unsatisfied dimensionality constraint!') self.graph.add_node(node.id) for p in parent: self.graph.add_edge(p.id, node.id) self.nodes[node.id] = node # Remove if verify fails if not self._verify(): self.delete_node(node)
[docs] def add_base_node(self, node: Node) -> None: """Adds new node into the graph by creating the first entry point. Parameters ---------- node : Node CUVIS.AI node to add to the graph """ node = maybe_wrap_node(node) self.graph.add_node(node.id) self.nodes[node.id] = node self.entry_point = node.id
[docs] def add_edge(self, node: Node, node2: Node) -> None: """Adds sequential nodes to create a directed edge. At least one of the nodes should already be in the graph. Parameters ---------- node : Node Parent node. node2 : Node Child node. """ node = maybe_wrap_node(node) node2 = maybe_wrap_node(node2) self.graph.add_edge(node.id, node2.id) self.nodes[node.id] = node self.nodes[node2.id] = node2 if not self._verify(): # TODO Issue: This could potentially leave the graph in an invalid state # Delete nodes and connection del self.nodes[node.id] del self.nodes[node2.id] # Remove the nodes from the graph as a whole self.graph.remove_nodes_from([node.id, node2.id])
[docs] def custom_copy(self): # Create a new instance of the class new_instance = self.__class__.__new__(self.__class__) new_instance.name = deepcopy(self.name) new_instance.graph = deepcopy(self.graph) # Deep copy new_instance.nodes = copy(self.nodes) # Shallow copy new_instance.entry_point = deepcopy(self.entry_point) return new_instance
def __rshift__(self, other: Node): """Compose with *other*. Example: t = a >> b >> c """ new_graph = self.custom_copy() if new_graph.entry_point == None: new_graph.add_base_node(other) return new_graph # Get all nodes without successors sink_nodes = [ new_graph.nodes[node] for node in new_graph.graph.nodes if new_graph.graph.out_degree(node) == 0] if (len(sink_nodes) == 1): new_graph.add_edge(sink_nodes[0], other) return new_graph def __repr__(self) -> str: res = self.name + ":\n" for node in self.nodes: res += f"{node}\n" return res def _verify_input_outputs(self) -> bool: """Private function to validate the integrity of data passed between nodes. Returns ------- bool Inputs and outputs of all nodes are congruent. """ all_edges = list(self.graph.edges) for start, end in all_edges: # TODO: Issue what if multiple Nodes feed into the same successor Node, how would the shape look like? if not check_array_shape(self.nodes[start].output_dim, self.nodes[end].input_dim): # TODO reenable this, for now skip print('Unsatisfied dimensionality constraint!') # return True return True def _verify(self) -> bool: """Private function to verify the integrity of the processing graph. Returns ------- bool Graph meets/does not meet requirements for ordered and error-free flow of data. """ if len(self.nodes.keys()) == 0: print('Empty graph!') return True elif len(self.nodes.keys()) == 1: print('Single stage graph!') return True # Check that no cycles exist if len(list(nx.simple_cycles(self.graph))) > 0: return False # Get all edges in the graph if not self._verify_input_outputs(): return False return True
[docs] def delete_node(self, id: Node | str) -> None: """Removes a node from the graph. To successfully remove a node, it must not have successors. Parameters ---------- id : Node | str UUID for target node to delete, or a copy of the node itself. Raises ------ ValueError Node to delete contains successors in the graph. ValueError Node does not exist in the graph. """ if isinstance(id, Node): id = id.id # Check if operation is valid if not len(list(self.graph.successors(id))) == 0: raise ValueError( "The node does have successors, removing it would invalidate the Graph structure") if not id in self.nodes: raise ValueError("Cannot remove node, it no longer exists") self.graph.remove_edges_from([id]) del self.nodes[id]
[docs] def serialize(self, data_dir: Path) -> dict: """Convert graph structure and all contained nodes to a serializable YAML format. Numeric data and fit models will be stored in zipped directory named with current time. """ from importlib.metadata import version data_dir = Path(data_dir) nodes_data = {} for key, node in self.nodes.items(): serialized = node.serialize(data_dir) node_data = {'__node_module__': str(node.__module__), '__node_class__': str(node.__class__.__name__)} # maybe serialize source code if 'code' in serialized.keys(): import cuvis_ai.utils.inspect as ins cls = serialized.pop('code') node_code = ins.get_src(cls) with open(data_dir / f'{cls.__name__}.py', 'w') as f: f.writelines(node_code) node_data['__node_code__'] = f'{cls.__name__}.py' node_data |= serialized nodes_data[key] = node_data edges_data = [{'from': start, 'to': end} for start, end in list(self.graph.edges)] output = { 'edges': edges_data, 'nodes': nodes_data, 'name': self.name, 'entry_point': self.entry_point, 'version': version('cuvis_ai'), 'packages': get_installed_packages_str() } return output
[docs] def load(self, structure: dict, data_dir: Path) -> None: data_dir = Path(data_dir) self.name = structure.get('name') installed_cuvis_version = pkg_resources.require('cuvis_ai')[0].version serialized_cuvis_version = structure.get('version') if installed_cuvis_version != serialized_cuvis_version: raise ValueError(f'Incorrect version of cuvis_ai package. Installed {installed_cuvis_version} but serialized with {serialized_cuvis_version}') # nopep8 if not structure.get('nodes'): print('No node information available!') LOAD_SOURCE_FILES = True for key, params in structure.get('nodes').items(): node_module = params.get('__node_module__') node_class = params.get('__node_class__') node_code = params.get('__node_code__', None) if not node_code is None and LOAD_SOURCE_FILES: spec = importlib.util.spec_from_file_location( node_module, data_dir / node_code) module = importlib.util.module_from_spec(spec) sys.modules[node_module] = module spec.loader.exec_module(module) cls = getattr( module, node_class) else: cls = getattr(importlib.import_module(node_module), node_class) if not issubclass(cls, Node): cls = make_node(cls) if 'params' in params.keys(): stage = cls(**params['params']) else: stage = cls() stage.load(params, data_dir) self.nodes[key] = stage # Set the entry point self.entry_point = structure.get('entry_point') # Create the graph instance self.graph = nx.DiGraph() # Handle base case where there is only one node if len(structure.get('nodes')) > 1: # Graph has at least one valid edge for edge in structure.get('edges'): self.graph.add_edge(edge.get('from'), edge.get('to')) else: # Only single node exists, add it into the graph self.add_base_node(list(self.nodes.values())[0])
[docs] def save_to_file(self, filepath) -> None: filepath = Path(filepath) os.makedirs(filepath.parent, exist_ok=True) with tempfile.TemporaryDirectory() as tmpDir: with change_working_dir(tmpDir): graph_data = self.serialize('.') serial = YamlSerializer(tmpDir, 'main') serial.serialize(graph_data) shutil.make_archive( f'{str(filepath)}', 'zip', tmpDir) print(f'Project saved to {str(filepath)}')
[docs] @classmethod def load_from_file(cls, filepath: str) -> None: """Reconstruct the graph from a file path defining the location of a zip archive. Parameters ---------- filepath : str Location of zip archive """ new_graph = cls('Loaded') with tempfile.TemporaryDirectory() as tmpDir: shutil.unpack_archive(filepath, tmpDir) with change_working_dir(tmpDir): serial = YamlSerializer(tmpDir, 'main') graph_data = serial.load() new_graph.load(graph_data, '.') return new_graph
[docs] def forward(self, X: np.ndarray, Y: Optional[Union[np.ndarray, List]] = None, M: Optional[Union[np.ndarray, List]] = None, backend: str = 'memory') -> tuple[np.ndarray, np.ndarray, np.ndarray]: if backend == 'memory': executor = MemoryExecutor(self.graph, self.nodes, self.entry_point) elif backend == 'hummingbird': from hummingbird.ml import convert from copy import copy def convert_node(node): new_node = copy(node) if '_wrapped' in node.__dict__ and isinstance(node, SklearnWrapped): new_node._wrapped = convert(node._wrapped, 'torch') return new_node nodes = {k: convert_node(v) for k, v in self.nodes.items()} executor = MemoryExecutor( self.graph, nodes, self.entry_point) else: raise ValueError("Unknown Backend") return executor.forward(X, Y, M)
[docs] def fit(self, X: np.ndarray, Y: Optional[Union[np.ndarray, List]] = None, M: Optional[Union[np.ndarray, List]] = None): executor = MemoryExecutor(self.graph, self.nodes, self.entry_point) executor.fit(X, Y, M)
[docs] def train(self, train_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader): executor = MemoryExecutor(self.graph, self.nodes, self.entry_point) executor.train(train_dl, test_dl)
@property @lru_cache(maxsize=128) def torch_layers(self) -> List[torch.nn.Module]: """Get a list with all pytorch layers in the Graph. """ layers = [] for key, node in self.nodes.items(): if isinstance(node, SkorchWrapped): layers.append(node.net.model_) return layers
[docs] def parameters(self) -> Iterator: """Iterate over all (pytorch-) parameters in all layers contained in the Graph. """ for layer in self.torch_layers: yield from layer.parameters()
[docs] def freeze(self): for node in self.nodes.values(): if node.initialized: node.freezed = True else: raise RuntimeError("Tried freezing a uninitialized node")