Source code for cuvis_ai.utils.torch
import torch
import torch.nn as nn
from enum import Enum
from .dict import remove_prefix
[docs]
def guess_input_dimensionalty(model: nn.Module) -> InputDimension:
"""
Inspects the first layer of a PyTorch model to guess if the model expects
1D (flattened), 2D (spatial), or 3D input data.
Returns:
InputDimension: Enum value indicating likely input dimensions.
"""
# Get the first layer of the model
first_layer = next(model.children(), None)
if len(first_layer._modules.items()) > 0:
first_layer = list(first_layer._modules.values())[0]
# Check if there is a first layer in the model
if first_layer is None:
return InputDimension.Unknown # No layers in the model
# Analyze the type of the first layer
if isinstance(first_layer, nn.Conv2d):
# 2D convolution, likely expects spatial (image-like) data
if first_layer.in_channels == 1:
return InputDimension.Two
else:
return InputDimension.Three
elif isinstance(first_layer, nn.Conv1d):
# 1D convolution, likely expects sequence data with channels
return InputDimension.One
elif isinstance(first_layer, nn.Conv3d):
# 3D convolution, likely expects 3D spatial data (e.g., video)
return InputDimension.Three
elif isinstance(first_layer, nn.Linear):
return InputDimension.One # Linear layer, likely expects flattened data
return guess_input_dimensionalty(first_layer)
[docs]
def get_output_shape(input_shape, model):
dummy_input = torch.randn(*input_shape)
with torch.no_grad():
output = model(dummy_input)
return output.shape
[docs]
def guess_state_dict_format(state_dict):
keys = set(state_dict.keys())
if 'pytorch-lightning_version' in keys:
return 'lightning'
[docs]
def extract_state_dict(state_dict, format='torch'):
existing_format = guess_state_dict_format(state_dict)
if existing_format == 'lightning':
if format == 'torch':
return remove_prefix(state_dict['state_dict'], 'model.', keep_only=True)
return state_dict