Skip to content

Commit 0440e4f

Browse files
Prefer covariant types for function parameters
Closes #80
1 parent f151504 commit 0440e4f

File tree

4 files changed

+22
-13
lines changed

4 files changed

+22
-13
lines changed

mcbackend/adapters/pymc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
import base64
77
import pickle
8-
from typing import Dict, List, Optional, Sequence, Tuple
8+
from typing import List, Mapping, Optional, Sequence, Tuple, Union
99

1010
import hagelkorn
1111
import numpy
@@ -59,8 +59,8 @@ def __init__(
5959
*,
6060
from_trace: BaseTrace,
6161
length: int,
62-
draws: Dict[str, numpy.ndarray],
63-
stats: Sequence[Dict[str, numpy.ndarray]],
62+
draws: Mapping[str, numpy.ndarray],
63+
stats: Sequence[Mapping[str, numpy.ndarray]],
6464
):
6565
self._length = length
6666
self._draws = draws
@@ -105,7 +105,7 @@ def __init__( # pylint: disable=W0622
105105
self,
106106
backend: Backend,
107107
*,
108-
name: str = None,
108+
name: Optional[str] = None,
109109
model=None,
110110
vars=None,
111111
test_point=None,
@@ -129,7 +129,7 @@ def setup(
129129
self,
130130
draws: int,
131131
chain: int,
132-
sampler_vars: Optional[List[Dict[str, numpy.dtype]]] = None,
132+
sampler_vars: Optional[Sequence[Mapping[str, Union[type, numpy.dtype]]]] = None,
133133
) -> None:
134134
super().setup(draws, chain, sampler_vars)
135135
self.chain = chain

mcbackend/backends/clickhouse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import time
77
from datetime import datetime, timezone
8-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
8+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
99

1010
import clickhouse_driver
1111
import numpy
@@ -165,7 +165,7 @@ def __init__(
165165
super().__init__(cmeta, rmeta)
166166

167167
def append(
168-
self, draw: Dict[str, numpy.ndarray], stats: Optional[Dict[str, numpy.ndarray]] = None
168+
self, draw: Mapping[str, numpy.ndarray], stats: Optional[Mapping[str, numpy.ndarray]] = None
169169
):
170170
stat = {f"__stat_{sname}": svals for sname, svals in (stats or {}).items()}
171171
params: Dict[str, numpy.ndarray] = {**draw, **stat}

mcbackend/backends/numpy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This backend holds draws in memory, managing them via NumPy arrays.
33
"""
44
import math
5-
from typing import Dict, List, Optional, Sequence, Tuple
5+
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
66

77
import numpy
88

@@ -12,8 +12,8 @@
1212

1313
def grow_append(
1414
storage_dict: Dict[str, numpy.ndarray],
15-
values: Dict[str, numpy.ndarray],
16-
rigid: Dict[str, bool],
15+
values: Mapping[str, numpy.ndarray],
16+
rigid: Mapping[str, bool],
1717
draw_idx: int,
1818
):
1919
"""Writes values into storage arrays, growing them if needed."""
@@ -76,7 +76,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
7676
super().__init__(cmeta, rmeta)
7777

7878
def append(
79-
self, draw: Dict[str, numpy.ndarray], stats: Optional[Dict[str, numpy.ndarray]] = None
79+
self, draw: Mapping[str, numpy.ndarray], stats: Optional[Mapping[str, numpy.ndarray]] = None
8080
):
8181
grow_append(self._samples, draw, self._var_is_rigid, self._draw_idx)
8282
if stats:

mcbackend/core.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33
"""
44
import collections
55
import logging
6-
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Sized, TypeVar
6+
from typing import (
7+
TYPE_CHECKING,
8+
Dict,
9+
List,
10+
Mapping,
11+
Optional,
12+
Sequence,
13+
Sized,
14+
TypeVar,
15+
)
716

817
import numpy
918

@@ -58,7 +67,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta) -> None:
5867
super().__init__()
5968

6069
def append(
61-
self, draw: Dict[str, numpy.ndarray], stats: Optional[Dict[str, numpy.ndarray]] = None
70+
self, draw: Mapping[str, numpy.ndarray], stats: Optional[Mapping[str, numpy.ndarray]] = None
6271
):
6372
"""Appends an iteration to the chain.
6473

0 commit comments

Comments
 (0)