Skip to content

Commit b4b09c5

Browse files
committed
Introduce retrieve_nodes() and retrieve_edges()
Generic parameters and defaults don't mix well (see python/mypy#3737) so we need to use overloads instead, or InstancesResult is always InstancesResult[Any, Any]. Overloads don't work on retrieve() due to defaults and ordering of params, so we create two new methods that can be safely typed.
1 parent daf4403 commit b4b09c5

File tree

2 files changed

+94
-23
lines changed

2 files changed

+94
-23
lines changed

cognite/client/_api/data_modeling/instances.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,27 +271,81 @@ def __iter__(self) -> Iterator[Node]:
271271
"""
272272
return self(None, "node")
273273

274+
@overload
275+
def retrieve_edges(
276+
self,
277+
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]],
278+
*,
279+
edge_cls: type[T_Edge],
280+
) -> EdgeList[T_Edge]: ...
281+
282+
@overload
283+
def retrieve_edges(
284+
self,
285+
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]],
286+
*,
287+
sources: Source | Sequence[Source] | None = None,
288+
include_typing: bool = False,
289+
) -> EdgeList[Edge]: ...
290+
291+
def retrieve_edges(
292+
self,
293+
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]],
294+
edge_cls: type[T_Edge] = Edge, # type: ignore
295+
sources: Source | Sequence[Source] | None = None,
296+
include_typing: bool = False,
297+
) -> EdgeList[T_Edge]:
298+
res = self._retrieve_typed(
299+
nodes=None, edges=edges, node_cls=Node, edge_cls=edge_cls, sources=sources, include_typing=include_typing
300+
)
301+
return res.edges
302+
303+
@overload
304+
def retrieve_nodes(
305+
self,
306+
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]],
307+
*,
308+
node_cls: type[T_Node],
309+
) -> NodeList[T_Node]: ...
310+
311+
@overload
312+
def retrieve_nodes(
313+
self,
314+
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]],
315+
*,
316+
sources: Source | Sequence[Source] | None = None,
317+
include_typing: bool = False,
318+
) -> NodeList[Node]: ...
319+
320+
def retrieve_nodes(
321+
self,
322+
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]],
323+
node_cls: type[T_Node] = Node, # type: ignore
324+
sources: Source | Sequence[Source] | None = None,
325+
include_typing: bool = False,
326+
) -> NodeList[T_Node]:
327+
res = self._retrieve_typed(
328+
nodes=nodes, edges=None, node_cls=node_cls, edge_cls=Edge, sources=sources, include_typing=include_typing
329+
)
330+
return res.nodes
331+
274332
def retrieve(
275333
self,
276334
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]] | None = None,
277335
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]] | None = None,
278336
sources: Source | Sequence[Source] | None = None,
279337
include_typing: bool = False,
280-
node_cls: type[T_Node] = Node, # type: ignore[assignment]
281-
edge_cls: type[T_Edge] = Edge, # type: ignore[assignment]
282-
) -> InstancesResult[T_Node, T_Edge]:
338+
) -> InstancesResult[Node, Edge]:
283339
"""`Retrieve one or more instance by id(s). <https://developer.cognite.com/api#tag/Instances/operation/byExternalIdsInstances>`_
284340
285341
Args:
286342
nodes (NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]] | None): Node ids
287343
edges (EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]] | None): Edge ids
288344
sources (Source | Sequence[Source] | None): Retrieve properties from the listed - by reference - views.
289345
include_typing (bool): Whether to return property type information as part of the result.
290-
node_cls (type[T_Node]): Node class to use when returning nodes.
291-
edge_cls (type[T_Edge]): Edge class to use when returning edges.
292346
293347
Returns:
294-
InstancesResult[T_Node, T_Edge]: Requested instances.
348+
InstancesResult[Node, Edge]: Requested instances.
295349
296350
Examples:
297351
@@ -324,6 +378,19 @@ def retrieve(
324378
... EdgeId("mySpace", "myEdge"),
325379
... sources=("myspace", "myView"))
326380
"""
381+
return self._retrieve_typed(
382+
nodes=nodes, edges=edges, sources=sources, include_typing=include_typing, node_cls=Node, edge_cls=Edge
383+
)
384+
385+
def _retrieve_typed(
386+
self,
387+
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]] | None,
388+
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]] | None,
389+
sources: Source | Sequence[Source] | None,
390+
include_typing: bool,
391+
node_cls: type[T_Node],
392+
edge_cls: type[T_Edge],
393+
) -> InstancesResult[T_Node, T_Edge]:
327394
identifiers = self._load_node_and_edge_ids(nodes, edges)
328395

329396
sources = self._to_sources(sources, node_cls, edge_cls)

tests/tests_integration/test_api/test_data_modeling/test_instances.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def test_apply_retrieve_and_delete(self, cognite_client: CogniteClient, person_v
494494

495495
try:
496496
created = cognite_client.data_modeling.instances.apply(new_node, replace=True)
497-
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(new_node.as_id())
497+
retrieved = cognite_client.data_modeling.instances.retrieve(new_node.as_id())
498498

499499
assert len(created.nodes) == 1
500500
assert created.nodes[0].created_time
@@ -503,9 +503,7 @@ def test_apply_retrieve_and_delete(self, cognite_client: CogniteClient, person_v
503503
assert retrieved.nodes[0].as_id() == new_node.as_id()
504504

505505
deleted_result = cognite_client.data_modeling.instances.delete(new_node.as_id())
506-
retrieved_deleted: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(
507-
new_node.as_id()
508-
)
506+
retrieved_deleted = cognite_client.data_modeling.instances.retrieve(new_node.as_id())
509507

510508
assert len(deleted_result.nodes) == 1
511509
assert deleted_result.nodes[0] == new_node.as_id()
@@ -581,7 +579,7 @@ def test_apply_auto_create_nodes(self, cognite_client: CogniteClient, person_vie
581579
created_edges = cognite_client.data_modeling.instances.apply(
582580
edges=person_to_actor, auto_create_start_nodes=True, auto_create_end_nodes=True, replace=True
583581
)
584-
created_nodes: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(node_pair)
582+
created_nodes = cognite_client.data_modeling.instances.retrieve(node_pair)
585583

586584
assert len(created_edges.edges) == 1
587585
assert created_edges.edges[0].created_time
@@ -600,13 +598,13 @@ def test_delete_non_existent(self, cognite_client: CogniteClient, integration_te
600598
assert res.edges == []
601599

602600
def test_retrieve_multiple(self, cognite_client: CogniteClient, movie_nodes: NodeList) -> None:
603-
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(movie_nodes.as_ids())
601+
retrieved = cognite_client.data_modeling.instances.retrieve(movie_nodes.as_ids())
604602
assert len(retrieved.nodes) == len(movie_nodes)
605603

606604
def test_retrieve_nodes_and_edges_using_id_tuples(
607605
self, cognite_client: CogniteClient, movie_nodes: NodeList, movie_edges: EdgeList
608606
) -> None:
609-
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(
607+
retrieved = cognite_client.data_modeling.instances.retrieve(
610608
nodes=[(id.space, id.external_id) for id in movie_nodes.as_ids()],
611609
edges=[(id.space, id.external_id) for id in movie_edges.as_ids()],
612610
)
@@ -616,17 +614,25 @@ def test_retrieve_nodes_and_edges_using_id_tuples(
616614
def test_retrieve_nodes_and_edges(
617615
self, cognite_client: CogniteClient, movie_nodes: NodeList, movie_edges: EdgeList
618616
) -> None:
619-
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(
617+
retrieved = cognite_client.data_modeling.instances.retrieve(
620618
nodes=movie_nodes.as_ids(), edges=movie_edges.as_ids()
621619
)
622620
assert set(retrieved.nodes.as_ids()) == set(movie_nodes.as_ids())
623621
assert set(retrieved.edges.as_ids()) == set(movie_edges.as_ids())
624622

623+
def test_retrieve_nodes(self, cognite_client: CogniteClient, movie_nodes: NodeList) -> None:
624+
retrieved = cognite_client.data_modeling.instances.retrieve_nodes(movie_nodes.as_ids())
625+
assert set(retrieved.as_ids()) == set(movie_nodes.as_ids())
626+
627+
def test_retrieve_edges(self, cognite_client: CogniteClient, movie_edges: EdgeList) -> None:
628+
retrieved = cognite_client.data_modeling.instances.retrieve_edges(movie_edges.as_ids())
629+
assert set(retrieved.as_ids()) == set(movie_edges.as_ids())
630+
625631
def test_retrieve_multiple_with_missing(self, cognite_client: CogniteClient, movie_nodes: NodeList) -> None:
626632
ids_without_missing = movie_nodes.as_ids()
627633
ids_with_missing = [*ids_without_missing, NodeId("myNonExistingSpace", "myImaginaryContainer")]
628634

629-
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(ids_with_missing)
635+
retrieved = cognite_client.data_modeling.instances.retrieve(ids_with_missing)
630636
assert retrieved.nodes.as_ids() == ids_without_missing
631637

632638
def test_retrieve_non_existent(self, cognite_client: CogniteClient) -> None:
@@ -877,9 +883,7 @@ def test_retrieve_in_units(
877883
node = node_with_1_1_pressure_in_bar
878884
source = SourceSelector(unit_view.as_id(), target_units=[TargetUnit("pressure", UnitReference("pressure:pa"))])
879885

880-
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(
881-
node.as_id(), sources=[source]
882-
)
886+
retrieved = cognite_client.data_modeling.instances.retrieve(node.as_id(), sources=[source])
883887
assert retrieved.nodes
884888
assert math.isclose(cast(float, retrieved.nodes[0]["pressure"]), 1.1 * 1e5)
885889

@@ -965,9 +969,9 @@ def test_write_typed_node(self, cognite_client: CogniteClient, integration_test_
965969
assert len(created.nodes) == 1
966970
assert created.nodes[0].external_id == external_id
967971

968-
retrieved = cognite_client.data_modeling.instances.retrieve(
972+
retrieved = cognite_client.data_modeling.instances.retrieve_nodes(
969973
primitive.as_id(), node_cls=PrimitiveNullableRead
970-
).nodes
974+
)
971975
assert len(retrieved) == 1
972976
assert isinstance(retrieved[0], PrimitiveNullableRead)
973977
assert retrieved[0].text == "text"
@@ -1000,9 +1004,9 @@ def test_write_typed_node_listed_properties(
10001004
assert len(created.nodes) == 1
10011005
assert created.nodes[0].external_id == external_id
10021006

1003-
retrieved = cognite_client.data_modeling.instances.retrieve(
1007+
retrieved = cognite_client.data_modeling.instances.retrieve_nodes(
10041008
primitive_listed.as_id(), node_cls=PrimitiveListedRead
1005-
).nodes
1009+
)
10061010
assert len(retrieved) == 1
10071011
assert isinstance(retrieved[0], PrimitiveListedRead)
10081012
assert retrieved[0].text == ["text"]
@@ -1022,7 +1026,7 @@ def test_write_type_node_instance_property_descriptor(
10221026
assert len(created.nodes) == 1
10231027
assert created.nodes[0].external_id == external_id
10241028

1025-
retrieved = cognite_client.data_modeling.instances.retrieve(person.as_id(), node_cls=PersonRead).nodes
1029+
retrieved = cognite_client.data_modeling.instances.retrieve_nodes(person.as_id(), node_cls=PersonRead)
10261030
assert len(retrieved) == 1
10271031
assert isinstance(retrieved[0], PersonRead)
10281032
assert retrieved[0].name == "John Doe"

0 commit comments

Comments
 (0)