import networkx as nx
from ..node.node import Node
import numpy as np
from typing import Optional, Union, List, Dict, Tuple
import torch
from ..node.Consumers import CubeConsumer, LabelConsumer
from .meta_routing import get_forward_metadata, get_fit_metadata
[docs]
class MemoryExecutor:
[docs]
def __init__(self, graph: nx.DiGraph, nodes: dict[str, Node], entry_point: str):
self.graph = graph
self.nodes = nodes
self.entry_point = entry_point
[docs]
def forward(self, X: np.ndarray, Y: Optional[Union[np.ndarray, List]] = None, M: Optional[Union[np.ndarray, List]] = None) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Pass data through the graph by starting at the root node and flowing through all
intermediary stages.
Parameters
----------
X : np.ndarray
Input data
Y : Optional[Union[np.ndarray, List]], optional
Label data
M : Optional[Union[np.ndarray, List]], optional
Metadata by default None
Returns
-------
tuple[np.ndarray, np.ndarray, np.ndarray]
Residuals of processed X, Y, and M
"""
sorted_graph = list(nx.topological_sort(self.graph))
assert (sorted_graph[0] == self.entry_point)
xs = X
ys = Y or [None]*len(xs)
ms = M or [None]*len(xs)
intermediary = {}
intermediary_labels = {}
intermediary_metas = {}
intermediary[self.entry_point], intermediary_labels[self.entry_point], intermediary_metas[self.entry_point] = self.forward_node(
self.nodes[self.entry_point], xs, ys, ms)
for node in sorted_graph[1:]:
self._forward_helper(node, intermediary,
intermediary_labels, intermediary_metas)
results = intermediary[sorted_graph[-1]]
return results
def _forward_helper(self, current: str, intermediary: dict, intermediary_labels: dict, intermediary_metas: dict):
"""Helper function to aggregate inputs and calculate products from a given node.
Parameters
----------
current : str
id for current node
intermediary : dict
Dictionary containing intermediary products with key as id of node
intermediary_labels : np.ndarray
Dictionary containing intermediary labels with key as id of node
intermediary_metas : np.ndarray
Dictionary containing intermediary metadata with key as id of node
"""
p_nodes = list(self.graph.predecessors(current))
# TODO how to concat multiple input data from multiple nodes
use_prods = np.concatenate([intermediary[p] for p in p_nodes], axis=-1)
no_labels = intermediary_labels[p_nodes[0]] is None
if not no_labels:
if isinstance(intermediary_labels[p_nodes[0]], np.ndarray):
use_labels = np.concatenate(
[intermediary_labels[p] for p in p_nodes], axis=-1)
else:
use_labels = [intermediary_labels[p] for p in p_nodes]
else:
use_labels = []
no_metas = intermediary_metas[p_nodes[0]] is None
if not no_metas:
use_metas = [intermediary_metas[p] for p in p_nodes]
else:
use_metas = []
intermediary[current], intermediary_labels[current], intermediary_metas[current] = self.forward_node(
self.nodes[current], use_prods, use_labels, use_metas)
if self._not_needed_anymore(current, intermediary):
# Free memory that is not needed for the current passthrough anymore
intermediary.pop(current)
intermediary_labels.pop(current)
intermediary_metas.pop(current)
[docs]
def forward_node(self, node: Node, data: np.ndarray, labels: np.ndarray, metadata: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Pass data through a node which has already been trained/fit.
Parameters
----------
node : Node
Node within the graph
data : np.ndarray
Data to pass through the nodes
labels : np.ndarray
Labels associated with input data
metadata : np.ndarray
Metadata needed for forward pass
Returns
-------
tuple[np.ndarray, np.ndarray, np.ndarray]
Output data, output labels, output metadata
"""
additional_meta = get_forward_metadata(node, metadata)
if len(additional_meta) > 0:
out = node.forward(data, **additional_meta)
else:
out = node.forward(data)
if isinstance(out, Tuple):
return out
else:
return out, labels, metadata
def _not_needed_anymore(self, id: str, intermediary: list[Node]) -> bool:
"""Private function to determine if a node products are still needed or can be safely removed.
Parameters
----------
id : str
Alphanumeric identifier for node to check
intermediary : list[Node]
List node nodes for which the current node's data is an intermediary
Returns
-------
bool
If all successors are already present in intermediary, it will return True
"""
return all([succs in intermediary for succs in self.graph.successors(id)]) and \
len(list(self.graph.successors(id))
) > 0 # Do not remove a terminal nodes data
[docs]
def train(self, train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader):
"""Train a graph use a dataloader to iteratively pass data through the graph
Parameters
----------
train_dataloader : torch.utils.data.DataLoader
Training dataloader
test_dataloader : torch.utils.data.DataLoader
Test dataloader
Raises
------
TypeError
Raises error if dataloaders passed to train function are not pytorch dataloaders
"""
if not isinstance(train_dataloader, torch.utils.data.DataLoader) or not isinstance(test_dataloader, torch.utils.data.DataLoader):
raise TypeError(
"train or test dataloader argument is not a pytorch DataLoader!")
for x, y, m in iter(train_dataloader):
self.fit(np.squeeze(x), np.squeeze(y), m, warm_start=True)
# test stage
test_results = []
for x, y, m in iter(test_dataloader):
test_results.append(self.forward(x, y, m))
# do some metrics
[docs]
def fit(self, X: np.ndarray, Y: Optional[Union[np.ndarray, List]] = None, M: Optional[Union[np.ndarray, List]] = None, warm_start=False):
"""Take a graph of uninitialized nodes and fit then given a set of inputs and outputs
Parameters
----------
X : np.ndarray
Input data
Y : Optional[Union[np.ndarray, List]], optional
Input labels, by default None
M : Optional[Union[np.ndarray, List]], optional
Input metadata, by default None
"""
# training stage
sorted_graph = list(nx.topological_sort(self.graph))
assert (sorted_graph[0] == self.entry_point)
intermediary = {}
intermediary_labels = {}
intermediary_metas = {}
intermediary[self.entry_point], intermediary_labels[self.entry_point], intermediary_metas[self.entry_point] = self.fit_node(
self.nodes[self.entry_point], X, Y, M, warm_start=warm_start)
for node in sorted_graph[1:]:
self._fit_helper(node, intermediary,
intermediary_labels, intermediary_metas)
def _fit_helper(self, current: str, intermediary: dict, intermediary_labels: dict, intermediary_metas: dict, warm_start=False):
"""Private helper function to fit an individual node.
Parameters
----------
current : str
id of current node in graph
intermediary : str
Dictionary containing intermediary products
intermediary_labels : np.ndarray
Dictionary containing intermediary labels
intermediary_metas : np.ndarray
Dictionary containing intermediary metadata
"""
p_nodes = list(self.graph.predecessors(current))
# TODO how to concat multiple input data from multiple nodes
use_prods = np.concatenate([intermediary[p] for p in p_nodes], axis=-1)
no_labels = intermediary_labels[p_nodes[0]] is None
if not no_labels:
use_labels = np.concatenate(
[intermediary_labels[p] for p in p_nodes], axis=-1)
else:
use_labels = None
no_metas = intermediary_metas[p_nodes[0]] is None
if not no_metas:
use_metas = np.concatenate(
[intermediary_metas[p] for p in p_nodes], axis=-1)
else:
use_metas = None
intermediary[current], intermediary_labels[current], intermediary_metas[current] = self.fit_node(
self.nodes[current], use_prods, use_labels, use_metas)
if self._not_needed_anymore(current, intermediary):
# Free memory that is not needed for the current passthrough anymore
intermediary.pop(current)
intermediary_labels.pop(current)
intermediary_metas.pop(current)
[docs]
def fit_node(self, node: Node, data: np.ndarray, labels: np.ndarray, metadata: np.ndarray, warm_start=False) -> np.ndarray:
"""Private function wrapper to call the fit function for an individual node
Parameters
----------
node : Node
Graph node that will be fit
data : np.ndarray
Training data
labels : np.ndarray
Training labels
metadata : np.ndarray
Training metadata
Returns
-------
np.ndarray
Results of passing data through the fit node
Raises
------
RuntimeError
Data is empty (length 0)
"""
node_input = []
if isinstance(node, CubeConsumer):
node_input.append(data)
if isinstance(node, LabelConsumer):
node_input.append(labels)
if len(node_input) == 0:
raise RuntimeError(
F"Node {node} invalid, does not indicate input data type!")
additional_meta = get_fit_metadata(node, metadata)
if node.freezed == False:
if len(additional_meta) > 0:
node.fit(*node_input, **additional_meta, warm_start=warm_start)
else:
node.fit(*node_input, warm_start=warm_start)
return self.forward_node(node, data, labels, metadata)
[docs]
class HummingBirdExecutor:
[docs]
def __init__(self, graph: nx.DiGraph, nodes: dict[str, Node], entry_point: str):
self.graph = graph
self.nodes = nodes
self.entry_point = entry_point
self.sorted_nodes = list(nx.topological_sort(self.graph))