1
- import argparse , logging , os , sys
1
+ import argparse
2
+ import logging
3
+ import os
2
4
from pathlib import Path
3
5
4
6
import torch
5
- from torch .utils .data import DataLoader , TensorDataset
6
7
import torch .nn as nn
7
8
8
- import wandb
9
- from mlflow import log_metric , log_param , log_artifacts , set_experiment , set_tags
10
- from torch .utils .tensorboard import SummaryWriter
9
+ from training import train , load_data , log_epoch , track_run , checkpoint_model
10
+
11
+ ARCHITECTURE = "Simple conv"
12
+ EXPERIMENT_NAME = "ml-downscaling-emulator"
13
+ TAGS = ["baseline" , ARCHITECTURE , "debug" ]
11
14
12
15
def get_args ():
13
- parser = argparse .ArgumentParser (description = 'Train U-Net to downscale' ,
16
+ parser = argparse .ArgumentParser (description = f 'Train { ARCHITECTURE } to downscale' ,
14
17
formatter_class = argparse .ArgumentDefaultsHelpFormatter )
15
18
parser .add_argument ('--loss' , '-l' , dest = 'loss' , type = str , default = 'l1' , help = 'Loss function' )
16
19
parser .add_argument ('--data' , dest = 'data_dir' , type = Path , required = True ,
@@ -19,67 +22,10 @@ def get_args():
19
22
help = 'Base path to storage for models' )
20
23
parser .add_argument ('--epochs' , '-e' , metavar = 'E' , type = int , default = 5 , help = 'Number of epochs' )
21
24
parser .add_argument ('--batch-size' , '-b' , dest = 'batch_size' , metavar = 'B' , type = int , default = 4 , help = 'Batch size' )
25
+ parser .add_argument ('--lr' , dest = 'learning_rate' , type = float , default = 2e-4 , help = 'Learning rate' )
22
26
23
27
return parser .parse_args ()
24
28
25
- def train_epoch (model , dataloader , device , epoch ):
26
- model .train ()
27
-
28
- epoch_loss = 0.0
29
-
30
- for i , (batch_X , batch_y ) in enumerate (dataloader ):
31
- loss = train_on_batch (batch_X .to (device ), batch_y .to (device ), model )
32
- epoch_loss += loss .item ()
33
-
34
- # Log progress at least every 10th batch
35
- if (len (dataloader ) <= 10 ) or ((i + 1 ) % max (len (dataloader )// 10 ,1 ) == 0 ):
36
- logging .info (f"Epoch { epoch } : Batch { i } : Batch Train Loss { loss .item ()} Running Epoch Train Loss { epoch_loss } " )
37
-
38
- return epoch_loss
39
-
40
- def val_epoch (model , dataloader , device , epoch ):
41
- model .eval ()
42
-
43
- epoch_val_loss = 0
44
- for batch_X , batch_y in dataloader :
45
- val_loss = val_on_batch (batch_X .to (device ), batch_y .to (device ), model )
46
-
47
- # Progress
48
- epoch_val_loss += val_loss .item ()
49
-
50
- model .train ()
51
-
52
- return epoch_val_loss
53
-
54
- def train_on_batch (batch_X , batch_y , model ):
55
- # Compute prediction and loss
56
- outputs_tensor = model (batch_X )
57
- loss = criterion (outputs_tensor , batch_y )
58
-
59
- # Backpropagation
60
- optimizer .zero_grad ()
61
- loss .backward ()
62
- optimizer .step ()
63
-
64
- return loss
65
-
66
- def val_on_batch (batch_X , batch_y , model ):
67
- with torch .no_grad ():
68
- # Compute prediction and loss
69
- outputs_tensor = model (batch_X )
70
- loss = criterion (outputs_tensor , batch_y )
71
-
72
- return loss
73
-
74
- def load_data (data_dirpath , batch_size ):
75
- train_set = TensorDataset (torch .load (data_dirpath / 'train_X.pt' ), torch .load (data_dirpath / 'train_y.pt' ))
76
- val_set = TensorDataset (torch .load (data_dirpath / 'val_X.pt' ), torch .load (data_dirpath / 'val_y.pt' ))
77
-
78
- train_dl = DataLoader (train_set , batch_size = batch_size )
79
- val_dl = DataLoader (val_set , batch_size = batch_size )
80
-
81
- return train_dl , val_dl
82
-
83
29
if __name__ == '__main__' :
84
30
args = get_args ()
85
31
@@ -97,7 +43,8 @@ def load_data(data_dirpath, batch_size):
97
43
98
44
# Setup model, loss and optimiser
99
45
num_predictors , _ , _ = train_dl .dataset [0 ][0 ].shape
100
- model = nn .Conv2d (2 , 1 , kernel_size = 31 , padding = 15 )
46
+ model_opts = dict (kernel_size = 31 , padding = 15 )
47
+ model = nn .Conv2d (num_predictors , 1 , ** model_opts ).to (device = device )
101
48
102
49
if args .loss == 'l1' :
103
50
criterion = torch .nn .L1Loss ().to (device )
@@ -106,59 +53,26 @@ def load_data(data_dirpath, batch_size):
106
53
else :
107
54
raise ("Unkwown loss function" )
108
55
109
- learning_rate = 2e-4
110
- optimizer = torch .optim .Adam (model .parameters (), lr = learning_rate )
56
+ optimizer = torch .optim .Adam (model .parameters (), lr = args .learning_rate )
111
57
112
- config = dict (
58
+ run_config = dict (
113
59
dataset = args .data_dir ,
114
60
optimizer = "Adam" ,
115
- learning_rate = learning_rate ,
61
+ learning_rate = args . learning_rate ,
116
62
batch_size = args .batch_size ,
117
- architecture = "Simple conv" ,
63
+ architecture = ARCHITECTURE ,
118
64
device = device ,
119
65
epochs = args .epochs
120
66
)
121
67
122
- wandb .init (
123
- project = "ml-downscaling-emulator" ,
124
- tags = ["baseline" , "Simple conv" , "debug" ],
125
- config = config
126
- )
68
+ with track_run (EXPERIMENT_NAME , run_config , TAGS ) as (wandb_run , tb_writer ):
69
+ # Fit model
70
+ wandb_run .watch (model , criterion = criterion , log_freq = 100 )
71
+ for (epoch , epoch_metrics ) in train (train_dl , val_dl , model , criterion , optimizer , args .epochs , device ):
72
+ log_epoch (epoch , epoch_metrics , wandb_run , tb_writer )
73
+
74
+ # Checkpoint model
75
+ if (epoch % 10 == 0 ) or (epoch + 1 == args .epochs ): # every 10th epoch or final one (to be safe)
76
+ checkpoint_model (model , args .model_checkpoints_dir , epoch )
127
77
128
- wandb .watch (model , criterion = criterion , log_freq = 100 )
129
-
130
- set_experiment ("ml-downscaling-emulator" )
131
- set_tags ({"model" : "Simple conv" , "purpose" : "baseline" })
132
- log_param ("dataset" , args .data_dir )
133
- log_param ("optimizer" , "Adam" )
134
- log_param ("learning_rate" , learning_rate )
135
- log_param ("batch_size" , args .batch_size )
136
- log_param ("architecture" , "Simple conv" )
137
- log_param ("device" , device )
138
- log_param ("epochs" , args .epochs )
139
-
140
- writer = SummaryWriter ()
141
-
142
- # Fit model
143
- for epoch in range (args .epochs ):
144
- # Update model based on training data
145
- epoch_train_loss = train_epoch (model , train_dl , device , epoch )
146
-
147
- # Compute validation loss
148
- epoch_val_loss = val_epoch (model , val_dl , device , epoch )
149
-
150
- logging .info (f"Epoch { epoch } : Train Loss { epoch_train_loss } Val Loss { epoch_val_loss } " )
151
- wandb .log ({"train/loss" : epoch_train_loss , "val/loss" : epoch_val_loss })
152
- log_metric ("train/loss" ,epoch_train_loss , step = epoch )
153
- log_metric ("val/loss" , epoch_val_loss , step = epoch )
154
- writer .add_scalar ("train/loss" , epoch_train_loss , epoch )
155
- writer .add_scalar ("val/loss" , epoch_val_loss , epoch )
156
- # Checkpoint model
157
- if (epoch % 10 == 0 ) or (epoch + 1 == args .epochs ): # every 10th epoch or final one (to be safe)
158
- model_checkpoint_path = args .model_checkpoints_dir / f"model-epoch{ epoch } .pth"
159
- torch .save (model , model_checkpoint_path )
160
- logging .info (f"Epoch { epoch } : Saved model to { model_checkpoint_path } " )
161
-
162
- # writer.add_hparams(config, {"train/loss": epoch_train_loss, "val/loss": epoch_val_loss})
163
- writer .flush ()
164
78
logging .info (f"Finished { os .path .basename (__file__ )} " )
0 commit comments