import unittest
from ..preprocessor import PCA, NMF
from ..utils.test import get_np_dummy_data
from functools import wraps
[docs]
class TestPreprocessedNode():
[docs]
def setUp(self):
# setup is handled by the decorator
pass
[docs]
def test_initialization(self):
self.assertTrue(self.node.initialized)
self.assertTrue(self.node.input_dim)
self.assertTrue(self.node.output_dim)
[docs]
def test_correct_output_dim(self):
self.assertTrue(self.node.check_output_dim((15, 20, 15)))
[docs]
def test_passthrough(self):
# check if passthrough generates the correct shape
data = get_np_dummy_data((10, 15, 20, 25))
output = self.node.forward(data)
self.assertTrue(output.shape == (10, 15, 20, 15))
data = get_np_dummy_data((15, 20, 25))
output = self.node.forward(data)
self.assertTrue(output.shape == (15, 20, 15))
[docs]
class TestUnsupervisedPCA(TestPreprocessedNode, unittest.TestCase):
[docs]
def setUp(self):
self.node = PCA(15)
self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
[docs]
class TestUnsupervisedNMF(TestPreprocessedNode, unittest.TestCase):
[docs]
def setUp(self):
self.node = NMF(15)
self.node.fit(get_np_dummy_data((10, 15, 20, 25)))
if __name__ == '__main__':
unittest.main()