Source code for cuvis_ai.node.wrap


from types import MethodWrapperType, ModuleType
import inspect
import torch
import sklearn
import torchvision
import torchvision.transforms.v2

from .sklearn import _wrap_sklearn_class, _wrap_sklearn_instance
from .skorch import _wrap_torch_class, _wrap_torch_instance
from .torchvision import _wrap_torchvision_class, _wrap_torchvision_instance


def _wrap_class(cls):

    if issubclass(cls, sklearn.base.BaseEstimator):
        return _wrap_sklearn_class(cls)
    elif issubclass(cls, torchvision.transforms.v2.Transform):
        return _wrap_torchvision_class(cls)
    elif issubclass(cls, torch.nn.Module):
        return _wrap_torch_class(cls)
    else:
        raise ValueError("Called on unsupported class")


def _wrap_instance(obj):

    if isinstance(obj, sklearn.base.BaseEstimator):
        return _wrap_sklearn_instance(obj)
    elif isinstance(obj, torchvision.transforms.v2.Transform):
        return _wrap_torchvision_instance(obj)
    elif isinstance(obj, torch.nn.Module):
        return _wrap_torch_instance(obj)
    else:
        raise ValueError("Called on unsupported object")


[docs] def make_node(wrapped): """Node Wrapper / Decorator. Use to wrap a specific module into a node.""" if isinstance(wrapped, ModuleType): raise NotImplementedError('Currently cannot be wrapped') if inspect.isclass(wrapped): return _wrap_class(wrapped) if isinstance(wrapped, object): return _wrap_instance(wrapped)