数据接口

最后更新:05/19/2025(API 文档字符串是自动生成的)。

DataProto 是数据交换的接口。

verl.DataProto 类包含两个关键成员:

  • batch:一个 tensordict.TensorDict 对象,用于存放实际数据

  • meta_info:一个 Dict,包含额外的元信息

TensorDict

DataProto.batch 构建于 tensordict 之上,这是一个 PyTorch 生态系统中的项目。 TensorDict 是一种类字典(dict-like)的张量容器。要实例化一个 TensorDict,你必须指定键值对(key-value pairs)以及批次大小(batch size)。

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 5)}, batch_size=[2,])
>>> tensordict["twos"] = 2 * torch.ones(2, 5, 6)
>>> zeros = tensordict["zeros"]
>>> tensordict
TensorDict(
fields={
    ones: Tensor(shape=torch.Size([2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    twos: Tensor(shape=torch.Size([2, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),
    zeros: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)

你还可以沿着其 batch_size 对 TensorDict 进行索引。同时,也可以对 TensorDict 的内容进行集体操作。

>>> tensordict[..., :1]
TensorDict(
fields={
    ones: Tensor(shape=torch.Size([1, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    twos: Tensor(shape=torch.Size([1, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),
    zeros: Tensor(shape=torch.Size([1, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([1]),
device=None,
is_shared=False)
>>> tensordict = tensordict.to("cuda:0")
>>> tensordict = tensordict.reshape(6)

有关 tensordict.TensorDict 用法的更多信息,请参阅官方 tensordict 文档。

核心 API