• 自定义数据集类必须是 Dataset 的子类

from torch.utils.data import Dataset, DataLoader

class TrajData(Dataset):
    """
    用于处理变长轨迹数据的自定义数据集类
    """

    def __init__(self, data_dict, sequence_length=10):
        self.data = []
        self.sequence_length = sequence_length

        # 计算每条轨迹的实际长度
        for key in data_dict:
            traj_data = data_dict[key]
            for i in range(len(traj_data) - self.sequence_length):
                input_sequence = traj_data[i:i + self.sequence_length, :-2]  # 特征
                target_sequence = traj_data[i + 1:i + self.sequence_length + 1, -2:]  # 标签
                self.data.append((input_sequence, target_sequence))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
  • 这是轨迹预测的一个自定义数据集类,前10步预测后10步

  • 标签的位置的 x,y 坐标

  • 将处理好的数据存在 data 中

  • 然后定义批处理函数

def collate_fn(batch):
    """
    批处理函数,当数据长度一致时,无需填充和掩蔽。
    """
    inputs, targets = zip(*[(item[0], item[1]) for item in batch])

    # 将 numpy 数组转换为 torch.Tensor
    padded_inputs = torch.stack([torch.tensor(input_seq, dtype=torch.float32) for input_seq in inputs], dim=0)
    padded_targets = torch.stack([torch.tensor(target_seq, dtype=torch.float32) for target_seq in targets], dim=0)

    return padded_inputs, padded_targets
  • item 就是 batch 中的一条数据,batch 是从自定义数据集类的 data 中抽取的

  • item 为 (input_sequence, target_sequence)

  • 按 dim=0 堆叠,即输入变为 (batch,10,features),输出变为(batch,10,2)

  • 然后定义数据迭代器

train_loader = DataLoader(
        dataset=data_train,
        batch_size=opt.batch_size,
        collate_fn=collate_fn,
        shuffle=True
    )
  • data_train 为自定义数据集的实例

  • shuffle 表示顺序是否打乱