Deep-learning Traffic Prediction Models (Pytorch)

We present pytorch implementation of the popular deep-learning models in the following git repository:
Github: https://github.com/cdsnlab/traffic-pytorch

traffic-pytorch

Integrated platform for urban intelligence tasks including traffic and demand prediction.

Traffic prediction

We report MAE / RMSE in pems-bay dataset (12 steps / 1 hour).

ModelMAERMSE
DCRNN0.921.58
GMAN1.993.87
WaveNet4.707.53

We report MAE / RMSE in PeMS dataset (9 steps).

ModelMAERMSE
STGCN18.3018.92
ASTGCN2.945.50
MSTGCN2.945.52

Getting Started

Data

  • pems-bay, metr-la: Download h5 files from Google Drive and place in datasets directory.
  • PeMSD7: Download files from STGCN Github and place in datasets directory.
  • PEMS: Download files from ASTGNN Github and place in datasets directory.

Environment

conda create -n $ENV_NAME$ python=3.7
conda activate $ENV_NAME$

# CUDA 11.3
pip install torch==1.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 
# Or, CUDA 10.2 
pip install torch==1.11.0+cu102 --extra-index-url https://download.pytorch.org/whl/cu102 
pip install -r requirements.txt

Train

# DCRNN 
python train.py --model DCRNN --ddir ../datasets/ --dname pems-bay --device $DEVICE$ --num_pred 12

# GMAN
python train.py --model GMAN --ddir ../datasets/ --dname pems-bay --device $DEVICE$ --num_pred 12

# WaveNet 
python train.py --model WaveNet --ddir ../datasets/ --dname pems-bay --device $DEVICE$ --num_pred 12

# STGCN
python train.py --model STGCN --ddir ../datasets/ --dname PEMSD --device $DEVICE$ --num_pred 9

# ASTGCN
python train.py --model ASTGCN --ddir ../datasets/ --dname PEMSD --device $DEVICE$ --num_pred 9

# MSTGCN 
python train.py --model MSTGCN --ddir ../datasets/ --dname PEMSD --device $DEVICE$ --num_pred 9

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top