cuvis_ai.transformation.torch_transformation.TorchTransformation

class cuvis_ai.transformation.torch_transformation.TorchTransformation(function_name: str | None = None, *, operand_b: Any | None = None, **kwargs)[source]

Bases: Node, BaseTransformation

Node representing a transformation of data using a pytorch function.

Parameters:
  • function_name (str,optional) – The name of the pytorch function to use. Almost any function available from the torch module can be used.

  • operand_b (Any,optional) – A constant value to pass into the function alongside the regular input data.

  • kwargs (Dict) – Any additional keyword arguments will be passed to the pytorch function anytime it is called.

__init__(function_name: str | None = None, *, operand_b: Any | None = None, **kwargs)[source]

Methods

__init__([function_name, operand_b])

check_input_dim(X[, Y])

Check that the parameters for the input data data match user expectations

check_output_dim(X[, Y])

Check that the parameters for the output data data match user expectations

forward(X[, Y])

Apply the pytorch method :arg:`function_name` on :arg:`X`. This node basically runs torch.<function_name>(X, Y).

get_fit_requested_meta()

get_forward_requested_meta()

load(params, serial_dir)

Load this node from a serialized graph.

serialize(serial_dir)

Serialize this node and save to :arg:`serial_dir`.

set_fit_meta_request(**kwargs)

set_forward_meta_request(**kwargs)

Attributes

input_dim

Returns the needed shape for the input data.

output_dim

Returns the shape for the output data.

check_input_dim(X: Iterable, Y: Iterable | None = None)[source]

Check that the parameters for the input data data match user expectations

Parameters: X (array-like): Input data.

Returns: (Bool) Valid data

check_output_dim(X: Any, Y: Any | None = None)[source]

Check that the parameters for the output data data match user expectations

Parameters: X (array-like): Input data.

Returns: (Bool) Valid data

forward(X: ndarray, Y: ndarray | None = None)[source]

Apply the pytorch method :arg:`function_name` on :arg:`X`. This node basically runs torch.<function_name>(X, Y).

Parameters:
  • X (np.ndarray) – The first operand for the pytorch method.

  • Y (np.ndarray, optional) – The second operand for the pytorch method.

Returns:

Returns the result of the pytorch method and any additional labels or metadata passed along with :arg:`X`

Return type:

Any, np.ndarray

get_fit_requested_meta()
get_forward_requested_meta()
property input_dim: Tuple[int, int, int]

Returns the needed shape for the input data. If a dimension is not important, it will return -1 in the specific position.

Returns: (tuple) needed shape for data

load(params: dict, serial_dir: str)[source]

Load this node from a serialized graph.

property output_dim: Tuple[int, int, int]

Returns the shape for the output data. If a dimension is dependent on the input, it will return -1 in the specific position.

Returns: (tuple) expected output shape for data

serialize(serial_dir: str) str[source]

Serialize this node and save to :arg:`serial_dir`.

set_fit_meta_request(**kwargs)
set_forward_meta_request(**kwargs)