自定义数据集类必须是 Dataset 的子类
from torch.utils.data import Dataset, DataLoader
class TrajData(Dataset):
"""
用于处理变长轨迹数据的自定义数据集类
"""
def __init__(self, data_dict, sequence_length=10):
self.data = []
self.sequence_length = sequence_length
# 计算每条轨迹的实际长度
for key in data_dict:
traj_data = data_dict[key]
for i in range(len(traj_data) - self.sequence_length):
input_sequence = traj_data[i:i + self.sequence_length, :-2] # 特征
target_sequence = traj_data[i + 1:i + self.sequence_length + 1, -2:] # 标签
self.data.append((input_sequence, target_sequence))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
这是轨迹预测的一个自定义数据集类,前10步预测后10步
标签的位置的 x,y 坐标
将处理好的数据存在 data 中
然后定义批处理函数
def collate_fn(batch):
"""
批处理函数,当数据长度一致时,无需填充和掩蔽。
"""
inputs, targets = zip(*[(item[0], item[1]) for item in batch])
# 将 numpy 数组转换为 torch.Tensor
padded_inputs = torch.stack([torch.tensor(input_seq, dtype=torch.float32) for input_seq in inputs], dim=0)
padded_targets = torch.stack([torch.tensor(target_seq, dtype=torch.float32) for target_seq in targets], dim=0)
return padded_inputs, padded_targets
item 就是 batch 中的一条数据,batch 是从自定义数据集类的 data 中抽取的
item 为 (input_sequence, target_sequence)
按 dim=0 堆叠,即输入变为 (batch,10,features),输出变为(batch,10,2)
然后定义数据迭代器
train_loader = DataLoader(
dataset=data_train,
batch_size=opt.batch_size,
collate_fn=collate_fn,
shuffle=True
)
data_train 为自定义数据集的实例
shuffle 表示顺序是否打乱