Skip to content

Fix mypy errors attributed to pytorch_lightning.demos.boring_classes #14201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
86d7e6a
tomlchange
krishnakalyan3 Aug 15, 2022
b4429b9
type code
krishnakalyan3 Aug 15, 2022
119d23a
mypy issues
krishnakalyan3 Aug 15, 2022
66d5c84
fix with Any
krishnakalyan3 Aug 15, 2022
9de784d
Apply suggestions from code review
justusschock Aug 15, 2022
2b21f87
type annotations
krishnakalyan3 Aug 17, 2022
b867abe
scheduler import
krishnakalyan3 Aug 17, 2022
96d4d9b
update types
krishnakalyan3 Aug 17, 2022
127a3f0
update signature
krishnakalyan3 Aug 17, 2022
d55fcdd
remove ignores
krishnakalyan3 Aug 17, 2022
fe3ebdc
minor type
krishnakalyan3 Aug 17, 2022
0598774
ignore mypy issues
krishnakalyan3 Aug 17, 2022
c8035b4
fix tensor args
krishnakalyan3 Aug 19, 2022
bcca207
added couple asserts and comments
Aug 22, 2022
4baa8bf
merge master
Aug 22, 2022
41a4329
resolve error by using iterator instead of generator
Aug 22, 2022
7a6d188
assert won't work, let's cast
Aug 22, 2022
f1249b5
remove incorrect module
krishnakalyan3 Aug 22, 2022
efdacb8
apply suggestions
Aug 22, 2022
c6cb3c8
Merge branch 'mypy_boring' of github.com:krishnakalyan3/pytorch-light…
Aug 22, 2022
d1d178b
remove spaces from type ignore
Aug 22, 2022
0b1e7fb
Merge branch 'master' into mypy_boring
otaj Aug 22, 2022
aa71cdc
Merge branch 'master' into mypy_boring
otaj Aug 23, 2022
1ecb5c7
Merge branch 'master' into mypy_boring
otaj Aug 24, 2022
550ffa7
Merge branch 'master' into mypy_boring
otaj Aug 25, 2022
96e1474
Merge branch 'master' into mypy_boring
rohitgr7 Aug 25, 2022
54ed458
Merge branch 'master' into mypy_boring
otaj Aug 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ warn_no_return = "False"
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
Expand Down
77 changes: 43 additions & 34 deletions src/pytorch_lightning/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import cast, Dict, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset

from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT


class RandomDictDataset(Dataset):
def __init__(self, size: int, length: int):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
def __getitem__(self, index: int) -> Dict[str, Tensor]:
a = self.data[index]
b = a + 2
return {"a": a, "b": b}
Expand All @@ -40,7 +45,7 @@ def __init__(self, size: int, length: int):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
def __getitem__(self, index: int) -> Tensor:
return self.data[index]

def __len__(self) -> int:
Expand All @@ -52,7 +57,7 @@ def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
def __iter__(self) -> Iterator[Tensor]:
for _ in range(self.count):
yield torch.randn(self.size)

Expand All @@ -62,16 +67,16 @@ def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
def __iter__(self) -> Iterator[Tensor]:
for _ in range(len(self)):
yield torch.randn(self.size)

def __len__(self):
def __len__(self) -> int:
return self.count


class BoringModel(LightningModule):
def __init__(self):
def __init__(self) -> None:
"""Testing PL Module.

Use as follows:
Expand All @@ -90,60 +95,63 @@ def training_step(...):
super().__init__()
self.layer = torch.nn.Linear(32, 2)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
return self.layer(x)

def loss(self, batch, preds):
def loss(self, batch: Tensor, preds: Tensor) -> Tensor:
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(preds, torch.ones_like(preds))

def step(self, x):
def step(self, x: Tensor) -> Tensor:
x = self(x)
out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
return out

def training_step(self, batch, batch_idx):
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: # type: ignore[override]
output = self(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def training_step_end(self, training_step_outputs):
def training_step_end(self, training_step_outputs: STEP_OUTPUT) -> STEP_OUTPUT:
return training_step_outputs

def training_epoch_end(self, outputs) -> None:
def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["loss"] for x in outputs]).mean()

def validation_step(self, batch, batch_idx):
def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: # type: ignore[override]
output = self(batch)
loss = self.loss(batch, output)
return {"x": loss}

def validation_epoch_end(self, outputs) -> None:
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["x"] for x in outputs]).mean()

def test_step(self, batch, batch_idx):
def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: # type: ignore[override]
output = self(batch)
loss = self.loss(batch, output)
return {"y": loss}

def test_epoch_end(self, outputs) -> None:
def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["y"] for x in outputs]).mean()

def configure_optimizers(self):
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]:
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))

def predict_dataloader(self):
def predict_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))


Expand All @@ -155,7 +163,7 @@ def __init__(self, data_dir: str = "./"):
self.checkpoint_state: Optional[str] = None
self.random_full = RandomDataset(32, 64 * 4)

def setup(self, stage: Optional[str] = None):
def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
self.random_train = Subset(self.random_full, indices=range(64))

Expand All @@ -168,26 +176,27 @@ def setup(self, stage: Optional[str] = None):
if stage == "predict" or stage is None:
self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
return DataLoader(self.random_train)

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
return DataLoader(self.random_val)

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
return DataLoader(self.random_test)

def predict_dataloader(self):
def predict_dataloader(self) -> DataLoader:
return DataLoader(self.random_predict)


class ManualOptimBoringModel(BoringModel):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: # type: ignore[override]
opt = self.optimizers()
assert isinstance(opt, (Optimizer, LightningOptimizer))
output = self(batch)
loss = self.loss(batch, output)
opt.zero_grad()
Expand All @@ -202,21 +211,21 @@ def __init__(self, out_dim: int = 10, learning_rate: float = 0.02):
self.l1 = torch.nn.Linear(32, out_dim)
self.learning_rate = learning_rate

def forward(self, x):
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
return torch.relu(self.l1(x.view(x.size(0), -1)))

def training_step(self, batch, batch_nb):
def training_step(self, batch: Tensor, batch_nb: int) -> STEP_OUTPUT: # type: ignore[override]
x = batch
x = self(x)
loss = x.sum()
return loss

def configure_optimizers(self):
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


class Net(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
Expand All @@ -225,7 +234,7 @@ def __init__(self):
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
Expand Down