Skip to content

Commit a4ae812

Browse files
Internal change
PiperOrigin-RevId: 471682920
1 parent 56453cd commit a4ae812

File tree

1 file changed

+6
-2
lines changed
  • official/legacy/image_classification/resnet

1 file changed

+6
-2
lines changed

official/legacy/image_classification/resnet/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,14 @@ def get_config(self):
106106
}
107107

108108

109-
def get_optimizer(learning_rate=0.1):
109+
def get_optimizer(learning_rate=0.1, use_legacy_optimizer=True):
110110
"""Returns optimizer to use."""
111111
# The learning_rate is overwritten at the beginning of each step by callback.
112-
return tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
112+
if use_legacy_optimizer:
113+
return tf.keras.optimizers.legacy.SGD(
114+
learning_rate=learning_rate, momentum=0.9)
115+
else:
116+
return tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
113117

114118

115119
def get_callbacks(pruning_method=None,

0 commit comments

Comments
 (0)