crowdcount.engine

crowdcount.engine.train(model, train_set: torch.utils.data.Dataset, test_set: torch.utils.data.Dataset, train_loss, test_loss, cuda_num=[0], optim='Adam', scheduler_flag=True, learning_rate=1e-05, weight_decay=0.0001, train_batch=1, test_batch=1, num_worker=8, epoch_num=2000, learning_decay=0.995, saver=None, enlarge_num=1)[source]

start to train

Parameters
  • model (torch.nn.Module) – the model built to train.

  • train_set (torch.utils.data.Dataset or object) – train dataset constructed into torch.utils.data.DataLoader.

  • test_set (torch.utils.data.Dataset or object) – test dataset constructed into torch.utils.data.DataLoader.

  • train_loss (object) – train loss function constructed from crowdcount.utils.

  • test_loss (object) – test loss function constructed from crowdcount.utils.

  • cuda_num (list, optional) – CUDA devices(default: [0]).

  • optim (str, optional) – optimizer, “Adam” | “SGD”, if “Adam”, torch.optim.Adam is used, elif “SGD”, torch.optim.SGD is used(default:”Adam”).

  • scheduler_flag (bool, optional) – if True, learning rate will decline every step with learning decay(default:True).

  • learning_rate (float, optional) – learning rate used in optimizer (default: 1e-5).

  • weight_decay (float, optional) – weight decay (L2 penalty)(default:1e-4).

  • train_batch (int, optional) – train batch(default: 1).

  • test_batch (int, optional) – test batch(default: 1).

  • num_worker (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process(default: 8).

  • epoch_num (int, optional) – how many epochs to train(default: 2000).

  • learning_decay (float, optional) – leaning decay used in scheduler(default: 0.995).

  • saver (crowdcount.utils.Saver, optional) – save model(default:None).

  • enlarge_num (int, optional) – the scale factor used to enlarge density map(default: 1).