DataFrameDataModule

DataFrameDataModule.

DataFrameDataModule


source

DataFrameDataModule

 DataFrameDataModule (df:pandas.core.frame.DataFrame=<factory>,
                      label_key:str='label', batch_size:Optional[int]=64,
                      include_time:Optional[bool]=False,
                      device:Optional[littyping.device.Device]=None)
df = MockTimeSeries(set_index=True).df
df = df.reset_index().drop(columns='series')
df.head()
dfm = DataFrameDataModule(df=df, label_key='time')
dfm.df.head()
time feature_0 feature_1 feature_2
0 0 4 8 0
1 1 0 0 0
2 2 0 2 2
3 3 6 0 4
4 4 7 4 3
for b in dfm.train_dataloader():
    break
b[0].shape, b[1].shape
(torch.Size([3, 9, 3]), torch.Size([3, 9]))
dfm.train_ds.df.shape
(23, 4)