Source code for cuvis_ai.data.BaseDataSet

from abc import ABC, abstractmethod
import torch
import numpy as np
from torchvision.datasets import VisionDataset
from typing import Optional, Callable, Union, Dict
from .OutputFormat import OutputFormat

C_SUPPORTED_DTYPES = (np.float64, np.float32, np.float16, np.complex64,
                      np.complex128, np.int64, np.int32, np.int16, np.int8, np.uint8, np.bool_)


[docs] class BaseDataSet(VisionDataset):
[docs] def __init__(self, root: Optional[str] = None, transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, output_format: OutputFormat = OutputFormat.Full, output_lambda: Optional[Callable] = None, ): super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) self.output_format = output_format self.output_lambda = output_lambda self.provide_datatype: np.dtype = np.float32
def _get_return_shape(self, data: np.ndarray, labels: Dict, metadata: Dict): if self.output_format == OutputFormat.Full: return (data, labels, metadata) elif self.output_format == OutputFormat.BoundingBox: return (data, [l['bbox'] for l in labels], []) elif self.output_format == OutputFormat.SegmentationMask: return (data, [l['segmentation'] for l in labels], []) elif self.output_format == OutputFormat.Metadata: return (data, [], metadata) elif self.output_format == OutputFormat.BoundingBoxWithMeta: return (data, [l['bbox'] for l in labels], metadata) elif self.output_format == OutputFormat.SegmentationMaskWithMeta: return (data, [l['segmentation'] for l in labels], metadata) elif self.output_format == OutputFormat.CustomFilter and self.output_lambda is not None: return self.output_lambda(data, labels, metadata) else: raise NotImplementedError("Think about it.") def _apply_transform(self, d: Union[Dict, torch.Tensor], convert_to_numpy: Optional[bool] = False): def unTensorify(source): if isinstance(source, dict): for k, v in source.items(): if isinstance(v, torch.Tensor): source[k] = v.numpy() elif isinstance(v, dict): source[k] = unTensorify(source[k]) if isinstance(source, torch.Tensor): source = source.numpy() return source ret = self.transforms(d) if self.transforms is not None else d return unTensorify(ret) if convert_to_numpy else ret
[docs] def set_datatype(self, dtype: np.dtype): """Specify a Numpy datatype to transform the cube into before returning it. Valid data types are: np.float64, np.float32, np.float16, np.complex64, np.complex128, np.int64, np.int32, np.int16, np.int8, np.uint8, np.bool_ """ if dtype in C_SUPPORTED_DTYPES: self.provide_datatype = dtype else: raise ValueError("Unsupported data type: {" + str( dtype.name) + " - use one of: " + str([d.name for d in C_SUPPORTED_DTYPES]))
[docs] def get_datatype(self): """Get the current datatype set that all data will be converted into before return.""" return self.provide_datatype