Source code for diffengine.datasets.transforms.formatting
# flake8: noqa: RET505
from collections.abc import Sequence
import numpy as np
import torch
from mmengine.utils import is_str
from diffengine.datasets.transforms import BaseTransform
from diffengine.registry import TRANSFORMS
[docs]def to_tensor(data) -> torch.Tensor: # noqa
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
msg = (f"Type {type(data)} cannot be converted to "
"tensor.Supported types are: `numpy.ndarray`, `torch.Tensor`,"
" `Sequence`, `int` and `float`")
raise TypeError(msg)
@TRANSFORMS.register_module()
[docs]class PackInputs(BaseTransform):
"""Pack the inputs data.
**Required Keys:**
- ``input_key``
**Deleted Keys:**
All other keys in the dict.
Args:
----
input_keys (List[str]): The key of element to feed into the model
forwarding. Defaults to ['img', 'text'].
skip_to_tensor_key (List[str]): The key of element to skip to_tensor.
Defaults to ['text'].
"""
def __init__(self,
input_keys: list[str] | None = None,
skip_to_tensor_key: list[str] | None = None) -> None:
if skip_to_tensor_key is None:
skip_to_tensor_key = ["text"]
if input_keys is None:
input_keys = ["img", "text"]
self.input_keys = input_keys
self.skip_to_tensor_key = skip_to_tensor_key
[docs] def transform(self, results: dict) -> dict:
"""Transform the data."""
packed_results = {}
for k in self.input_keys:
if k in results and k not in self.skip_to_tensor_key:
packed_results[k] = to_tensor(results[k])
elif k in results:
# text skip to_tensor
packed_results[k] = results[k]
return {"inputs": packed_results}