import unittest
import os
import shutil
import numpy as np
from ..utils.test import get_np_dummy_data
from ..preprocessor import PCA, NMF
from ..unsupervised import GMM, KMeans, MeanShift
from ..transformation import Reflectance, TorchTransformation
from ..supervised import SVM, QDA, LDA
from ..tv_transforms import Bandpass
from ..utils.serializer import YamlSerializer
TYPES_TO_CHECK = (int, float, str, bool, list, tuple, np.ndarray)
TEST_DIR = "./test/temp"
[docs]
class TestNodeSerialization():
[docs]
def test_serialization(self):
os.makedirs(TEST_DIR, exist_ok=True)
node_params = self.node.serialize(TEST_DIR)
serializer = YamlSerializer(TEST_DIR, 'test_node')
serializer.serialize(node_params)
node_dict = serializer.load()
lnode = self.node.__class__()
lnode.id = self.node.id
lnode.load(node_dict, TEST_DIR)
load_ok = True
for attr in lnode.__dict__.keys():
if type(getattr(lnode, attr)) not in TYPES_TO_CHECK:
continue
if isinstance(getattr(lnode, attr), np.ndarray) or isinstance(getattr(self.node, attr), np.ndarray):
# special check for numpy arrays
not_equal = (getattr(lnode, attr) !=
getattr(self.node, attr)).any()
if not_equal:
print(f"Attribute '{attr}' not equal! "
f"{getattr(lnode, attr)} != {getattr(self.node, attr)}")
load_ok = False
continue
if getattr(lnode, attr) != getattr(self.node, attr):
print(f"Attribute '{attr}' not equal! "
f"{getattr(lnode, attr)} != {getattr(self.node, attr)}")
load_ok = False
shutil.rmtree(TEST_DIR)
self.assertTrue(load_ok)
[docs]
class TestPreprocessorPCA(TestNodeSerialization, unittest.TestCase):
[docs]
def setUp(self):
self.node = PCA(15)
self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
[docs]
class TestPreprocessorNMF(TestNodeSerialization, unittest.TestCase):
[docs]
def setUp(self):
self.node = NMF(15)
self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
[docs]
class TestUnsupervisedKMeans(TestNodeSerialization, unittest.TestCase):
[docs]
def setUp(self):
self.node = KMeans(15)
self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
[docs]
class TestUnsupervisedGMM(TestNodeSerialization, unittest.TestCase):
[docs]
def setUp(self):
self.node = GMM(15)
self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
[docs]
class TestUnsupervisedMeanShift(TestNodeSerialization, unittest.TestCase):
[docs]
def setUp(self):
self.node = MeanShift()
self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
# class TestTransformationTorch(TestNodeSerialization, unittest.TestCase):
#
# def setUp(self):
# self.node = TorchTransformation("add", operand_b=5)
# self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
# class TestTransformationTorchVision(TestNodeSerialization, unittest.TestCase):
#
# def setUp(self):
# self.node = Bandpass(5, 10)
# self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
# class TestTransformationReflectance(TestNodeSerialization, unittest.TestCase):
#
# def setUp(self):
# self.node = Reflectance(0.1, 1.8)
# self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
# class TestSupervisedSVM(TestNodeSerialization, unittest.TestCase):
#
# def setUp(self):
# self.node = SVM()
# self.node.fit(get_np_dummy_data((15, 20, 25)),
# np.where(get_np_dummy_data((15, 20, 1)) > 0.5, 1, 0))
# class TestSupervisedQDA(TestNodeSerialization, unittest.TestCase):
#
# def setUp(self):
# self.node = QDA()
# self.node.fit(get_np_dummy_data((15, 20, 25)),
# np.where(get_np_dummy_data((15, 20, 1)) > 0.5, 1, 0))
[docs]
class TestSupervisedLDA(TestNodeSerialization, unittest.TestCase):
[docs]
def setUp(self):
self.node = LDA()
self.node.fit(get_np_dummy_data((15, 20, 25)),
np.where(get_np_dummy_data((15, 20, 1)) > 0.5, 1, 0))