Source code for cuvis_ai.deciders.combining_decider


from .base_decider import BaseDecider
from typing import Callable, Dict

from ..utils.numpy import flatten_batch_and_spatial, unflatten_batch_and_spatial

import numpy as np
import pickle as pk


[docs] def all_agree(decisions: np.ndarray) -> bool: return np.all(decisions == decisions[0])
[docs] def at_least_n_agree(n: int) -> Callable[[np.ndarray], bool]: return lambda decisions: np.count_nonzero(decisions) >= n
[docs] class CombiningDecider(BaseDecider): """Decider using values of multiple channels to classify the result. The data of all channels at a spatial location are utilized in the chosen decision strategy to classify each data point. Parameters ---------- channel_count : int The number of channels to expect rule : Callable[[np.ndarray], bool] The decision strategy to use. :meth:`all_agree` and :meth:`at_least_n_agree` are provided here. Custom strategies may also be used. """
[docs] def __init__(self, channel_count: int = None, rule: Callable[[np.ndarray], bool] = None) -> None: super().__init__() self.n = channel_count self.rule = np.vectorize(rule) self.initialized = bool(rule is not None and channel_count is not None)
[docs] def forward(self, X: np.ndarray) -> np.ndarray: """Apply the chosen :arg:`rule` to the input data. Parameters ---------- X : np.ndarray Data to classify. Returns ------- np.ndarray : Data classified to a single channel boolean matrix. """ flatten_soft_output = flatten_batch_and_spatial(X) decisions = self.rule(flatten_soft_output) return unflatten_batch_and_spatial(decisions, X.shape)
@BaseDecider.input_dim.getter def input_dim(self): return [-1, -1, self.n] @BaseDecider.output_dim.getter def output_dim(self): """ Returns the provided shape for the output data. If a dimension is not important it will return -1 in the specific position. Returns ------- tuple Provided shape for data """ return [-1, -1, 1]
[docs] def serialize(self, directory: str): """ Convert the class into a serialized representation """ if not self.initialized: print('Module not fully initialized, skipping output!') return # Write pickle object to file dump_file = f"{hash(self.rule)}_pca.pkl" pk.dump(self.rule, open(dump_file, "wb")) data = { "class_count": self.n, "rules_file": os.path.join(directory, dump_file) } return data
[docs] def load(self, params: dict, filepath: str): """Load this node from a serialized graph.""" try: self.n = int(params["class_count"]) except: raise ValueError("Could not read attribute 'class_count' as int. " F"Read '{params}' from save file!") try: dump_file = os.path.join(filepath, params["rules_file"]) self.rule = pk.load(open(dump_file, 'rb')) except: raise ValueError( "Failed to restore attribute 'rule' from save file!") self.initialized = True