crowdcount.engine¶
-
crowdcount.engine.
train
(model, train_set: torch.utils.data.Dataset, test_set: torch.utils.data.Dataset, train_loss, test_loss, cuda_num=[0], optim='Adam', scheduler_flag=True, learning_rate=1e-05, weight_decay=0.0001, train_batch=1, test_batch=1, num_worker=8, epoch_num=2000, learning_decay=0.995, saver=None, enlarge_num=1)[source]¶ start to train
- Parameters
model (torch.nn.Module) – the model built to train.
train_set (torch.utils.data.Dataset or object) – train dataset constructed into torch.utils.data.DataLoader.
test_set (torch.utils.data.Dataset or object) – test dataset constructed into torch.utils.data.DataLoader.
train_loss (object) – train loss function constructed from crowdcount.utils.
test_loss (object) – test loss function constructed from crowdcount.utils.
cuda_num (list, optional) – CUDA devices(default: [0]).
optim (str, optional) – optimizer, “Adam” | “SGD”, if “Adam”, torch.optim.Adam is used, elif “SGD”, torch.optim.SGD is used(default:”Adam”).
scheduler_flag (bool, optional) – if True, learning rate will decline every step with learning decay(default:True).
learning_rate (float, optional) – learning rate used in optimizer (default: 1e-5).
weight_decay (float, optional) – weight decay (L2 penalty)(default:1e-4).
train_batch (int, optional) – train batch(default: 1).
test_batch (int, optional) – test batch(default: 1).
num_worker (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process(default: 8).
epoch_num (int, optional) – how many epochs to train(default: 2000).
learning_decay (float, optional) – leaning decay used in scheduler(default: 0.995).
saver (crowdcount.utils.Saver, optional) – save model(default:None).
enlarge_num (int, optional) – the scale factor used to enlarge density map(default: 1).