Description
Currently, LabelTensor provides an excellent wrapper for keeping track of physical dimensions in SciML tasks. However, performing complex tensor contractions still requires users to remember the integer indices of dimensions (e.g., dim=1), which is error-prone and defeats the purpose of having labels.
I propose adding a named einsum functionality to LabelTensor. This would allow users to perform contractions and rearrangements using the dimension names directly, making the code more readable and physically intuitive.
Proposed API
The goal is to allow a syntax similar to einops, but leveraging the internal labels of the LabelTensor:
import torch
from pina import LabelTensor
# Existing LabelTensor
tensor_a = LabelTensor(torch.rand((200, 30, 30)), ["batch", "width", "height"])
# Proposed: Named einsum (either as a static method or instance method)
# This would sum over 'height' and return a LabelTensor with ['batch', 'width']
result = LabelTensor.einsum("batch width height -> batch width", tensor_a)
result.labels# returns ["batch" "width"]
Implementation Sketch
The implementation would involve:
- Parsing the input string to identify the requested dimensions.
- Mapping the labels of the input
LabelTensor to the characters used in standard torch.einsum.
- Executing the native PyTorch operation.
- Re-wrapping the output in a
LabelTensor with the new label subset.
Description
Currently,
LabelTensorprovides an excellent wrapper for keeping track of physical dimensions in SciML tasks. However, performing complex tensor contractions still requires users to remember the integer indices of dimensions (e.g., dim=1), which is error-prone and defeats the purpose of having labels.I propose adding a named
einsumfunctionality toLabelTensor. This would allow users to perform contractions and rearrangements using the dimension names directly, making the code more readable and physically intuitive.Proposed API
The goal is to allow a syntax similar to
einops, but leveraging the internal labels of theLabelTensor:Implementation Sketch
The implementation would involve:
LabelTensorto the characters used in standardtorch.einsum.LabelTensorwith the new label subset.