Skip to content

Commit 8a6239f

Browse files
randolf-scholzSkylion007
authored andcommitted
[typing] Add type hints to __init__ methods in torch.distributions. (pytorch#144197)
Fixes pytorch#144196 Extends pytorch#144106 and pytorch#144110 ## Open Problems: - [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. pytorch#144197 (comment) # Notes - `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped. - `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ pytorch#144402 - `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. pytorch#144218 - `independent.py`: made `Independent` a generic class of its base distribution. - `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - `relaxed_bernoulli.py`: added class-level type hint for `base_dist`. - `relaxed_categorical.py`: added class-level type hint for `base_dist`. - ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ pytorch#144401 - ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ pytorch#144400 - `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`. - `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1]. - `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`. - skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`. ## Remark `TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`. ```python import torch from torch.distributions import * b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0])) b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1)) t = StickBreakingTransform() d1 = TransformedDistribution(b1, t) d2 = TransformedDistribution(b2, t) print(d1.base_dist) # Independent with 1 dimension print(d2.base_dist) # MultivariateNormal ``` One could consider changing this to `if reinterpreted_batch_ndims > 1:`. [^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <python/mypy#3186>). This results in us having to add type-ignore comments in several places [^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. Pull Request resolved: pytorch#144197 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <[email protected]>
1 parent 1b77f73 commit 8a6239f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+382
-84
lines changed

torch/distributions/bernoulli.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional, Union
3+
24
import torch
35
from torch import nan, Tensor
46
from torch.distributions import constraints
@@ -10,7 +12,7 @@
1012
probs_to_logits,
1113
)
1214
from torch.nn.functional import binary_cross_entropy_with_logits
13-
from torch.types import _Number
15+
from torch.types import _Number, Number
1416

1517

1618
__all__ = ["Bernoulli"]
@@ -41,7 +43,12 @@ class Bernoulli(ExponentialFamily):
4143
has_enumerate_support = True
4244
_mean_carrier_measure = 0
4345

44-
def __init__(self, probs=None, logits=None, validate_args=None):
46+
def __init__(
47+
self,
48+
probs: Optional[Union[Tensor, Number]] = None,
49+
logits: Optional[Union[Tensor, Number]] = None,
50+
validate_args: Optional[bool] = None,
51+
) -> None:
4552
if (probs is None) == (logits is None):
4653
raise ValueError(
4754
"Either `probs` or `logits` must be specified, but not both."
@@ -50,6 +57,7 @@ def __init__(self, probs=None, logits=None, validate_args=None):
5057
is_scalar = isinstance(probs, _Number)
5158
(self.probs,) = broadcast_all(probs)
5259
else:
60+
assert logits is not None # helps mypy
5361
is_scalar = isinstance(logits, _Number)
5462
(self.logits,) = broadcast_all(logits)
5563
self._param = self.probs if probs is not None else self.logits

torch/distributions/beta.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional, Union
3+
24
import torch
35
from torch import Tensor
46
from torch.distributions import constraints
@@ -36,7 +38,12 @@ class Beta(ExponentialFamily):
3638
support = constraints.unit_interval
3739
has_rsample = True
3840

39-
def __init__(self, concentration1, concentration0, validate_args=None):
41+
def __init__(
42+
self,
43+
concentration1: Union[Tensor, float],
44+
concentration0: Union[Tensor, float],
45+
validate_args: Optional[bool] = None,
46+
) -> None:
4047
if isinstance(concentration1, _Number) and isinstance(concentration0, _Number):
4148
concentration1_concentration0 = torch.tensor(
4249
[float(concentration1), float(concentration0)]

torch/distributions/binomial.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional, Union
3+
24
import torch
35
from torch import Tensor
46
from torch.distributions import constraints
@@ -50,7 +52,13 @@ class Binomial(Distribution):
5052
}
5153
has_enumerate_support = True
5254

53-
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
55+
def __init__(
56+
self,
57+
total_count: Union[Tensor, int] = 1,
58+
probs: Optional[Tensor] = None,
59+
logits: Optional[Tensor] = None,
60+
validate_args: Optional[bool] = None,
61+
) -> None:
5462
if (probs is None) == (logits is None):
5563
raise ValueError(
5664
"Either `probs` or `logits` must be specified, but not both."
@@ -62,6 +70,7 @@ def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
6270
) = broadcast_all(total_count, probs)
6371
self.total_count = self.total_count.type_as(self.probs)
6472
else:
73+
assert logits is not None # helps mypy
6574
(
6675
self.total_count,
6776
self.logits,

torch/distributions/categorical.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional
3+
24
import torch
35
from torch import nan, Tensor
46
from torch.distributions import constraints
@@ -51,7 +53,12 @@ class Categorical(Distribution):
5153
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
5254
has_enumerate_support = True
5355

54-
def __init__(self, probs=None, logits=None, validate_args=None):
56+
def __init__(
57+
self,
58+
probs: Optional[Tensor] = None,
59+
logits: Optional[Tensor] = None,
60+
validate_args: Optional[bool] = None,
61+
) -> None:
5562
if (probs is None) == (logits is None):
5663
raise ValueError(
5764
"Either `probs` or `logits` must be specified, but not both."
@@ -61,6 +68,7 @@ def __init__(self, probs=None, logits=None, validate_args=None):
6168
raise ValueError("`probs` parameter must be at least one-dimensional.")
6269
self.probs = probs / probs.sum(-1, keepdim=True)
6370
else:
71+
assert logits is not None # helps mypy
6472
if logits.dim() < 1:
6573
raise ValueError("`logits` parameter must be at least one-dimensional.")
6674
# Normalize

torch/distributions/cauchy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import math
3+
from typing import Optional, Union
34

45
import torch
56
from torch import inf, nan, Tensor
@@ -34,7 +35,12 @@ class Cauchy(Distribution):
3435
support = constraints.real
3536
has_rsample = True
3637

37-
def __init__(self, loc, scale, validate_args=None):
38+
def __init__(
39+
self,
40+
loc: Union[Tensor, float],
41+
scale: Union[Tensor, float],
42+
validate_args: Optional[bool] = None,
43+
) -> None:
3844
self.loc, self.scale = broadcast_all(loc, scale)
3945
if isinstance(loc, _Number) and isinstance(scale, _Number):
4046
batch_shape = torch.Size()

torch/distributions/chi2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional, Union
3+
24
from torch import Tensor
35
from torch.distributions import constraints
46
from torch.distributions.gamma import Gamma
@@ -25,7 +27,11 @@ class Chi2(Gamma):
2527

2628
arg_constraints = {"df": constraints.positive}
2729

28-
def __init__(self, df, validate_args=None):
30+
def __init__(
31+
self,
32+
df: Union[Tensor, float],
33+
validate_args: Optional[bool] = None,
34+
) -> None:
2935
super().__init__(0.5 * df, 0.5, validate_args=validate_args)
3036

3137
def expand(self, batch_shape, _instance=None):

torch/distributions/continuous_bernoulli.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import math
3+
from typing import Optional, Union
34

45
import torch
56
from torch import Tensor
@@ -13,7 +14,7 @@
1314
probs_to_logits,
1415
)
1516
from torch.nn.functional import binary_cross_entropy_with_logits
16-
from torch.types import _Number, _size
17+
from torch.types import _Number, _size, Number
1718

1819

1920
__all__ = ["ContinuousBernoulli"]
@@ -52,7 +53,11 @@ class ContinuousBernoulli(ExponentialFamily):
5253
has_rsample = True
5354

5455
def __init__(
55-
self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None
56+
self,
57+
probs: Optional[Union[Tensor, Number]] = None,
58+
logits: Optional[Union[Tensor, Number]] = None,
59+
lims: tuple[float, float] = (0.499, 0.501),
60+
validate_args: Optional[bool] = None,
5661
) -> None:
5762
if (probs is None) == (logits is None):
5863
raise ValueError(
@@ -68,6 +73,7 @@ def __init__(
6873
raise ValueError("The parameter probs has invalid values")
6974
self.probs = clamp_probs(self.probs)
7075
else:
76+
assert logits is not None # helps mypy
7177
is_scalar = isinstance(logits, _Number)
7278
(self.logits,) = broadcast_all(logits)
7379
self._param = self.probs if probs is not None else self.logits

torch/distributions/dirichlet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional
3+
24
import torch
35
from torch import Tensor
46
from torch.autograd import Function
@@ -54,7 +56,11 @@ class Dirichlet(ExponentialFamily):
5456
support = constraints.simplex
5557
has_rsample = True
5658

57-
def __init__(self, concentration, validate_args=None):
59+
def __init__(
60+
self,
61+
concentration: Tensor,
62+
validate_args: Optional[bool] = None,
63+
) -> None:
5864
if concentration.dim() < 1:
5965
raise ValueError(
6066
"`concentration` parameter must be at least one-dimensional."

torch/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
batch_shape: torch.Size = torch.Size(),
4545
event_shape: torch.Size = torch.Size(),
4646
validate_args: Optional[bool] = None,
47-
):
47+
) -> None:
4848
self._batch_shape = batch_shape
4949
self._event_shape = event_shape
5050
if validate_args is not None:

torch/distributions/exponential.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional, Union
3+
24
import torch
35
from torch import Tensor
46
from torch.distributions import constraints
@@ -46,7 +48,11 @@ def stddev(self) -> Tensor:
4648
def variance(self) -> Tensor:
4749
return self.rate.pow(-2)
4850

49-
def __init__(self, rate, validate_args=None):
51+
def __init__(
52+
self,
53+
rate: Union[Tensor, float],
54+
validate_args: Optional[bool] = None,
55+
) -> None:
5056
(self.rate,) = broadcast_all(rate)
5157
batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size()
5258
super().__init__(batch_shape, validate_args=validate_args)

torch/distributions/fishersnedecor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional, Union
3+
24
import torch
35
from torch import nan, Tensor
46
from torch.distributions import constraints
@@ -31,7 +33,12 @@ class FisherSnedecor(Distribution):
3133
support = constraints.positive
3234
has_rsample = True
3335

34-
def __init__(self, df1, df2, validate_args=None):
36+
def __init__(
37+
self,
38+
df1: Union[Tensor, float],
39+
df2: Union[Tensor, float],
40+
validate_args: Optional[bool] = None,
41+
) -> None:
3542
self.df1, self.df2 = broadcast_all(df1, df2)
3643
self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
3744
self._gamma2 = Gamma(self.df2 * 0.5, self.df2)

torch/distributions/gamma.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional, Union
3+
24
import torch
35
from torch import Tensor
46
from torch.distributions import constraints
@@ -52,7 +54,12 @@ def mode(self) -> Tensor:
5254
def variance(self) -> Tensor:
5355
return self.concentration / self.rate.pow(2)
5456

55-
def __init__(self, concentration, rate, validate_args=None):
57+
def __init__(
58+
self,
59+
concentration: Union[Tensor, float],
60+
rate: Union[Tensor, float],
61+
validate_args: Optional[bool] = None,
62+
) -> None:
5663
self.concentration, self.rate = broadcast_all(concentration, rate)
5764
if isinstance(concentration, _Number) and isinstance(rate, _Number):
5865
batch_shape = torch.Size()

torch/distributions/geometric.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
from typing import Optional, Union
3+
24
import torch
35
from torch import Tensor
46
from torch.distributions import constraints
@@ -10,7 +12,7 @@
1012
probs_to_logits,
1113
)
1214
from torch.nn.functional import binary_cross_entropy_with_logits
13-
from torch.types import _Number
15+
from torch.types import _Number, Number
1416

1517

1618
__all__ = ["Geometric"]
@@ -45,19 +47,26 @@ class Geometric(Distribution):
4547
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
4648
support = constraints.nonnegative_integer
4749

48-
def __init__(self, probs=None, logits=None, validate_args=None):
50+
def __init__(
51+
self,
52+
probs: Optional[Union[Tensor, Number]] = None,
53+
logits: Optional[Union[Tensor, Number]] = None,
54+
validate_args: Optional[bool] = None,
55+
) -> None:
4956
if (probs is None) == (logits is None):
5057
raise ValueError(
5158
"Either `probs` or `logits` must be specified, but not both."
5259
)
5360
if probs is not None:
5461
(self.probs,) = broadcast_all(probs)
5562
else:
63+
assert logits is not None # helps mypy
5664
(self.logits,) = broadcast_all(logits)
5765
probs_or_logits = probs if probs is not None else logits
5866
if isinstance(probs_or_logits, _Number):
5967
batch_shape = torch.Size()
6068
else:
69+
assert probs_or_logits is not None # helps mypy
6170
batch_shape = probs_or_logits.size()
6271
super().__init__(batch_shape, validate_args=validate_args)
6372
if self._validate_args and probs is not None:

torch/distributions/gumbel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import math
3+
from typing import Optional, Union
34

45
import torch
56
from torch import Tensor
@@ -33,7 +34,12 @@ class Gumbel(TransformedDistribution):
3334
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
3435
support = constraints.real
3536

36-
def __init__(self, loc, scale, validate_args=None):
37+
def __init__(
38+
self,
39+
loc: Union[Tensor, float],
40+
scale: Union[Tensor, float],
41+
validate_args: Optional[bool] = None,
42+
) -> None:
3743
self.loc, self.scale = broadcast_all(loc, scale)
3844
finfo = torch.finfo(self.loc.dtype)
3945
if isinstance(loc, _Number) and isinstance(scale, _Number):

torch/distributions/half_cauchy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import math
3+
from typing import Optional, Union
34

45
import torch
56
from torch import inf, Tensor
@@ -33,8 +34,13 @@ class HalfCauchy(TransformedDistribution):
3334
arg_constraints = {"scale": constraints.positive}
3435
support = constraints.nonnegative
3536
has_rsample = True
37+
base_dist: Cauchy
3638

37-
def __init__(self, scale, validate_args=None):
39+
def __init__(
40+
self,
41+
scale: Union[Tensor, float],
42+
validate_args: Optional[bool] = None,
43+
) -> None:
3844
base_dist = Cauchy(0, scale, validate_args=False)
3945
super().__init__(base_dist, AbsTransform(), validate_args=validate_args)
4046

torch/distributions/half_normal.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import math
3+
from typing import Optional, Union
34

45
import torch
56
from torch import inf, Tensor
@@ -33,8 +34,13 @@ class HalfNormal(TransformedDistribution):
3334
arg_constraints = {"scale": constraints.positive}
3435
support = constraints.nonnegative
3536
has_rsample = True
37+
base_dist: Normal
3638

37-
def __init__(self, scale, validate_args=None):
39+
def __init__(
40+
self,
41+
scale: Union[Tensor, float],
42+
validate_args: Optional[bool] = None,
43+
) -> None:
3844
base_dist = Normal(0, scale, validate_args=False)
3945
super().__init__(base_dist, AbsTransform(), validate_args=validate_args)
4046

0 commit comments

Comments
 (0)