Skip to content

Commit fffe3f8

Browse files
committed
DRY train scripts
and tidy up some of the tracking and training code into functions
1 parent 1402fbe commit fffe3f8

File tree

3 files changed

+154
-223
lines changed

3 files changed

+154
-223
lines changed

train-simple-conv.py

Lines changed: 25 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
import argparse, logging, os, sys
1+
import argparse
2+
import logging
3+
import os
24
from pathlib import Path
35

46
import torch
5-
from torch.utils.data import DataLoader, TensorDataset
67
import torch.nn as nn
78

8-
import wandb
9-
from mlflow import log_metric, log_param, log_artifacts, set_experiment, set_tags
10-
from torch.utils.tensorboard import SummaryWriter
9+
from training import train, load_data, log_epoch, track_run, checkpoint_model
10+
11+
ARCHITECTURE="Simple conv"
12+
EXPERIMENT_NAME="ml-downscaling-emulator"
13+
TAGS = ["baseline", ARCHITECTURE, "debug"]
1114

1215
def get_args():
13-
parser = argparse.ArgumentParser(description='Train U-Net to downscale',
16+
parser = argparse.ArgumentParser(description=f'Train {ARCHITECTURE} to downscale',
1417
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
1518
parser.add_argument('--loss', '-l', dest='loss', type=str, default='l1', help='Loss function')
1619
parser.add_argument('--data', dest='data_dir', type=Path, required=True,
@@ -19,67 +22,10 @@ def get_args():
1922
help='Base path to storage for models')
2023
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
2124
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=4, help='Batch size')
25+
parser.add_argument('--lr', dest='learning_rate', type=float, default=2e-4, help='Learning rate')
2226

2327
return parser.parse_args()
2428

25-
def train_epoch(model, dataloader, device, epoch):
26-
model.train()
27-
28-
epoch_loss = 0.0
29-
30-
for i, (batch_X, batch_y) in enumerate(dataloader):
31-
loss = train_on_batch(batch_X.to(device), batch_y.to(device), model)
32-
epoch_loss += loss.item()
33-
34-
# Log progress at least every 10th batch
35-
if (len(dataloader) <= 10) or ((i+1) % max(len(dataloader)//10,1) == 0):
36-
logging.info(f"Epoch {epoch}: Batch {i}: Batch Train Loss {loss.item()} Running Epoch Train Loss {epoch_loss}")
37-
38-
return epoch_loss
39-
40-
def val_epoch(model, dataloader, device, epoch):
41-
model.eval()
42-
43-
epoch_val_loss = 0
44-
for batch_X, batch_y in dataloader:
45-
val_loss = val_on_batch(batch_X.to(device), batch_y.to(device), model)
46-
47-
# Progress
48-
epoch_val_loss += val_loss.item()
49-
50-
model.train()
51-
52-
return epoch_val_loss
53-
54-
def train_on_batch(batch_X, batch_y, model):
55-
# Compute prediction and loss
56-
outputs_tensor = model(batch_X)
57-
loss = criterion(outputs_tensor, batch_y)
58-
59-
# Backpropagation
60-
optimizer.zero_grad()
61-
loss.backward()
62-
optimizer.step()
63-
64-
return loss
65-
66-
def val_on_batch(batch_X, batch_y, model):
67-
with torch.no_grad():
68-
# Compute prediction and loss
69-
outputs_tensor = model(batch_X)
70-
loss = criterion(outputs_tensor, batch_y)
71-
72-
return loss
73-
74-
def load_data(data_dirpath, batch_size):
75-
train_set = TensorDataset(torch.load(data_dirpath/'train_X.pt'), torch.load(data_dirpath/'train_y.pt'))
76-
val_set = TensorDataset(torch.load(data_dirpath/'val_X.pt'), torch.load(data_dirpath/'val_y.pt'))
77-
78-
train_dl = DataLoader(train_set, batch_size=batch_size)
79-
val_dl = DataLoader(val_set, batch_size=batch_size)
80-
81-
return train_dl, val_dl
82-
8329
if __name__ == '__main__':
8430
args = get_args()
8531

@@ -97,7 +43,8 @@ def load_data(data_dirpath, batch_size):
9743

9844
# Setup model, loss and optimiser
9945
num_predictors, _, _ = train_dl.dataset[0][0].shape
100-
model = nn.Conv2d(2, 1, kernel_size=31, padding=15)
46+
model_opts = dict(kernel_size=31, padding=15)
47+
model = nn.Conv2d(num_predictors, 1, **model_opts).to(device=device)
10148

10249
if args.loss == 'l1':
10350
criterion = torch.nn.L1Loss().to(device)
@@ -106,59 +53,26 @@ def load_data(data_dirpath, batch_size):
10653
else:
10754
raise("Unkwown loss function")
10855

109-
learning_rate = 2e-4
110-
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
56+
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
11157

112-
config = dict(
58+
run_config = dict(
11359
dataset = args.data_dir,
11460
optimizer = "Adam",
115-
learning_rate = learning_rate,
61+
learning_rate = args.learning_rate,
11662
batch_size = args.batch_size,
117-
architecture = "Simple conv",
63+
architecture = ARCHITECTURE,
11864
device = device,
11965
epochs=args.epochs
12066
)
12167

122-
wandb.init(
123-
project="ml-downscaling-emulator",
124-
tags=["baseline", "Simple conv", "debug"],
125-
config=config
126-
)
68+
with track_run(EXPERIMENT_NAME, run_config, TAGS) as (wandb_run, tb_writer):
69+
# Fit model
70+
wandb_run.watch(model, criterion=criterion, log_freq=100)
71+
for (epoch, epoch_metrics) in train(train_dl, val_dl, model, criterion, optimizer, args.epochs, device):
72+
log_epoch(epoch, epoch_metrics, wandb_run, tb_writer)
73+
74+
# Checkpoint model
75+
if (epoch % 10 == 0) or (epoch + 1 == args.epochs): # every 10th epoch or final one (to be safe)
76+
checkpoint_model(model, args.model_checkpoints_dir, epoch)
12777

128-
wandb.watch(model, criterion=criterion, log_freq=100)
129-
130-
set_experiment("ml-downscaling-emulator")
131-
set_tags({"model": "Simple conv", "purpose": "baseline"})
132-
log_param("dataset", args.data_dir)
133-
log_param("optimizer", "Adam")
134-
log_param("learning_rate", learning_rate)
135-
log_param("batch_size", args.batch_size)
136-
log_param("architecture", "Simple conv")
137-
log_param("device", device)
138-
log_param("epochs", args.epochs)
139-
140-
writer = SummaryWriter()
141-
142-
# Fit model
143-
for epoch in range(args.epochs):
144-
# Update model based on training data
145-
epoch_train_loss = train_epoch(model, train_dl, device, epoch)
146-
147-
# Compute validation loss
148-
epoch_val_loss = val_epoch(model, val_dl, device, epoch)
149-
150-
logging.info(f"Epoch {epoch}: Train Loss {epoch_train_loss} Val Loss {epoch_val_loss}")
151-
wandb.log({"train/loss": epoch_train_loss, "val/loss": epoch_val_loss})
152-
log_metric("train/loss",epoch_train_loss, step=epoch)
153-
log_metric("val/loss", epoch_val_loss, step=epoch)
154-
writer.add_scalar("train/loss", epoch_train_loss, epoch)
155-
writer.add_scalar("val/loss", epoch_val_loss, epoch)
156-
# Checkpoint model
157-
if (epoch % 10 == 0) or (epoch + 1 == args.epochs): # every 10th epoch or final one (to be safe)
158-
model_checkpoint_path = args.model_checkpoints_dir / f"model-epoch{epoch}.pth"
159-
torch.save(model, model_checkpoint_path)
160-
logging.info(f"Epoch {epoch}: Saved model to {model_checkpoint_path}")
161-
162-
# writer.add_hparams(config, {"train/loss": epoch_train_loss, "val/loss": epoch_val_loss})
163-
writer.flush()
16478
logging.info(f"Finished {os.path.basename(__file__)}")

train-unet.py

Lines changed: 25 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
import argparse, logging, os, sys
1+
import argparse
2+
import logging
3+
import os
24
from pathlib import Path
5+
import sys
6+
7+
import torch
38
dir2 = os.path.abspath('unet/unet')
49
dir1 = os.path.dirname(dir2)
510
if not dir1 in sys.path: sys.path.append(dir1)
611
import unet
7-
import torch
8-
from torch.utils.data import random_split, DataLoader, TensorDataset
912

10-
import numpy as np
13+
from training import train, load_data, log_epoch, track_run, checkpoint_model
1114

12-
import wandb
13-
from mlflow import log_metric, log_param, log_artifacts, set_experiment, set_tags
14-
from torch.utils.tensorboard import SummaryWriter
15+
ARCHITECTURE="U-Net"
16+
EXPERIMENT_NAME="ml-downscaling-emulator"
17+
TAGS = ["baseline", ARCHITECTURE]
1518

1619
def get_args():
17-
parser = argparse.ArgumentParser(description='Train U-Net to downscale',
20+
parser = argparse.ArgumentParser(description=f'Train {ARCHITECTURE} to downscale',
1821
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
1922
parser.add_argument('--loss', '-l', dest='loss', type=str, default='l1', help='Loss function')
2023
parser.add_argument('--data', dest='data_dir', type=Path, required=True,
@@ -23,67 +26,10 @@ def get_args():
2326
help='Base path to storage for models')
2427
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
2528
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=4, help='Batch size')
29+
parser.add_argument('--lr', dest='learning_rate', type=float, default=2e-4, help='Learning rate')
2630

2731
return parser.parse_args()
2832

29-
def train_epoch(model, dataloader, device, epoch):
30-
model.train()
31-
32-
epoch_loss = 0.0
33-
34-
for i, (batch_X, batch_y) in enumerate(dataloader):
35-
loss = train_on_batch(batch_X.to(device), batch_y.to(device), model)
36-
epoch_loss += loss.item()
37-
38-
# Log progress at least every 10th batch
39-
if (len(dataloader) <= 10) or ((i+1) % max(len(dataloader)//10,1) == 0):
40-
logging.info(f"Epoch {epoch}: Batch {i}: Batch Train Loss {loss.item()} Running Epoch Train Loss {epoch_loss}")
41-
42-
return epoch_loss
43-
44-
def val_epoch(model, dataloader, device, epoch):
45-
model.eval()
46-
47-
epoch_val_loss = 0
48-
for batch_X, batch_y in dataloader:
49-
val_loss = val_on_batch(batch_X.to(device), batch_y.to(device), model)
50-
51-
# Progress
52-
epoch_val_loss += val_loss.item()
53-
54-
model.train()
55-
56-
return epoch_val_loss
57-
58-
def train_on_batch(batch_X, batch_y, model):
59-
# Compute prediction and loss
60-
outputs_tensor = model(batch_X)
61-
loss = criterion(outputs_tensor, batch_y)
62-
63-
# Backpropagation
64-
optimizer.zero_grad()
65-
loss.backward()
66-
optimizer.step()
67-
68-
return loss
69-
70-
def val_on_batch(batch_X, batch_y, model):
71-
with torch.no_grad():
72-
# Compute prediction and loss
73-
outputs_tensor = model(batch_X)
74-
loss = criterion(outputs_tensor, batch_y)
75-
76-
return loss
77-
78-
def load_data(data_dirpath, batch_size):
79-
train_set = TensorDataset(torch.load(data_dirpath/'train_X.pt'), torch.load(data_dirpath/'train_y.pt'))
80-
val_set = TensorDataset(torch.load(data_dirpath/'val_X.pt'), torch.load(data_dirpath/'val_y.pt'))
81-
82-
train_dl = DataLoader(train_set, batch_size=batch_size)
83-
val_dl = DataLoader(val_set, batch_size=batch_size)
84-
85-
return train_dl, val_dl
86-
8733
if __name__ == '__main__':
8834
args = get_args()
8935

@@ -110,59 +56,26 @@ def load_data(data_dirpath, batch_size):
11056
else:
11157
raise("Unkwown loss function")
11258

113-
learning_rate = 2e-4
114-
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
59+
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
11560

116-
config = dict(
61+
run_config = dict(
11762
dataset = args.data_dir,
11863
optimizer = "Adam",
119-
learning_rate = learning_rate,
64+
learning_rate = args.learning_rate,
12065
batch_size = args.batch_size,
121-
architecture = "U-Net",
66+
architecture = ARCHITECTURE,
12267
device = device,
12368
epochs=args.epochs
12469
)
12570

126-
wandb.init(
127-
project="ml-downscaling-emulator",
128-
tags=["baseline", "U-Net"],
129-
config=config
130-
)
71+
with track_run(EXPERIMENT_NAME, run_config, TAGS) as (wandb_run, tb_writer):
72+
# Fit model
73+
wandb_run.watch(model, criterion=criterion, log_freq=100)
74+
for (epoch, epoch_metrics) in train(train_dl, val_dl, model, criterion, optimizer, args.epochs, device):
75+
log_epoch(epoch, epoch_metrics, wandb_run, tb_writer)
76+
77+
# Checkpoint model
78+
if (epoch % 10 == 0) or (epoch + 1 == args.epochs): # every 10th epoch or final one (to be safe)
79+
checkpoint_model(model, args.model_checkpoints_dir, epoch)
13180

132-
wandb.watch(model, criterion=criterion, log_freq=100)
133-
134-
set_experiment("ml-downscaling-emulator")
135-
set_tags({"model": "U-Net", "purpose": "baseline"})
136-
log_param("dataset", args.data_dir)
137-
log_param("optimizer", "Adam")
138-
log_param("learning_rate", learning_rate)
139-
log_param("batch_size", args.batch_size)
140-
log_param("architecture", "U-Net")
141-
log_param("device", device)
142-
log_param("epochs", args.epochs)
143-
144-
writer = SummaryWriter()
145-
146-
# Fit model
147-
for epoch in range(args.epochs):
148-
# Update model based on training data
149-
epoch_train_loss = train_epoch(model, train_dl, device, epoch)
150-
151-
# Compute validation loss
152-
epoch_val_loss = val_epoch(model, val_dl, device, epoch)
153-
154-
logging.info(f"Epoch {epoch}: Train Loss {epoch_train_loss} Val Loss {epoch_val_loss}")
155-
wandb.log({"train/loss": epoch_train_loss, "val/loss": epoch_val_loss})
156-
log_metric("train/loss",epoch_train_loss, step=epoch)
157-
log_metric("val/loss", epoch_val_loss, step=epoch)
158-
writer.add_scalar("train/loss", epoch_train_loss, epoch)
159-
writer.add_scalar("val/loss", epoch_val_loss, epoch)
160-
# Checkpoint model
161-
if (epoch % 10 == 0) or (epoch + 1 == args.epochs): # every 10th epoch or final one (to be safe)
162-
model_checkpoint_path = args.model_checkpoints_dir / f"model-epoch{epoch}.pth"
163-
torch.save(model, model_checkpoint_path)
164-
logging.info(f"Epoch {epoch}: Saved model to {model_checkpoint_path}")
165-
166-
# writer.add_hparams(config, {"train/loss": epoch_train_loss, "val/loss": epoch_val_loss})
167-
writer.flush()
16881
logging.info(f"Finished {os.path.basename(__file__)}")

0 commit comments

Comments
 (0)