Skip to content

Commit 9027b4d

Browse files
authored
fix: type checking (#993)
* fix: type checking * update license * format * format * update catalog * revert type annotation * format * format * update
1 parent a80a788 commit 9027b4d

File tree

8 files changed

+68
-24
lines changed

8 files changed

+68
-24
lines changed

python/datafusion/catalog.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,12 @@ def __init__(self, table: df_internal.Table) -> None:
6666
"""This constructor is not typically called by the end user."""
6767
self.table = table
6868

69+
@property
6970
def schema(self) -> pyarrow.Schema:
7071
"""Returns the schema associated with this table."""
71-
return self.table.schema()
72+
return self.table.schema
7273

7374
@property
7475
def kind(self) -> str:
7576
"""Returns the kind of table."""
76-
return self.table.kind()
77+
return self.table.kind

python/datafusion/context.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def register_table(self, name: str, table: Table) -> None:
728728
name: Name of the resultant table.
729729
table: DataFusion table to add to the session context.
730730
"""
731-
self.ctx.register_table(name, table)
731+
self.ctx.register_table(name, table.table)
732732

733733
def deregister_table(self, name: str) -> None:
734734
"""Remove a table from the session."""
@@ -767,7 +767,7 @@ def register_parquet(
767767
file_extension: str = ".parquet",
768768
skip_metadata: bool = True,
769769
schema: pyarrow.Schema | None = None,
770-
file_sort_order: list[list[Expr]] | None = None,
770+
file_sort_order: list[list[SortExpr]] | None = None,
771771
) -> None:
772772
"""Register a Parquet file as a table.
773773
@@ -798,7 +798,9 @@ def register_parquet(
798798
file_extension,
799799
skip_metadata,
800800
schema,
801-
file_sort_order,
801+
[sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order]
802+
if file_sort_order is not None
803+
else None,
802804
)
803805

804806
def register_csv(
@@ -934,7 +936,7 @@ def register_udwf(self, udwf: WindowUDF) -> None:
934936

935937
def catalog(self, name: str = "datafusion") -> Catalog:
936938
"""Retrieve a catalog by name."""
937-
return self.ctx.catalog(name)
939+
return Catalog(self.ctx.catalog(name))
938940

939941
@deprecated(
940942
"Use the catalog provider interface ``SessionContext.Catalog`` to "
@@ -1054,7 +1056,7 @@ def read_parquet(
10541056
file_extension: str = ".parquet",
10551057
skip_metadata: bool = True,
10561058
schema: pyarrow.Schema | None = None,
1057-
file_sort_order: list[list[Expr]] | None = None,
1059+
file_sort_order: list[list[Expr | SortExpr]] | None = None,
10581060
) -> DataFrame:
10591061
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
10601062
@@ -1078,6 +1080,11 @@ def read_parquet(
10781080
"""
10791081
if table_partition_cols is None:
10801082
table_partition_cols = []
1083+
file_sort_order = (
1084+
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
1085+
if file_sort_order is not None
1086+
else None
1087+
)
10811088
return DataFrame(
10821089
self.ctx.read_parquet(
10831090
str(path),
@@ -1121,7 +1128,7 @@ def read_table(self, table: Table) -> DataFrame:
11211128
:py:class:`~datafusion.catalog.ListingTable`, create a
11221129
:py:class:`~datafusion.dataframe.DataFrame`.
11231130
"""
1124-
return DataFrame(self.ctx.read_table(table))
1131+
return DataFrame(self.ctx.read_table(table.table))
11251132

11261133
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11271134
"""Execute the ``plan`` and return the results."""

python/datafusion/dataframe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from enum import Enum
5353

5454
from datafusion._internal import DataFrame as DataFrameInternal
55+
from datafusion._internal import expr as expr_internal
5556
from datafusion.expr import Expr, SortExpr, sort_or_default
5657

5758

@@ -277,7 +278,7 @@ def with_columns(
277278

278279
def _simplify_expression(
279280
*exprs: Expr | Iterable[Expr], **named_exprs: Expr
280-
) -> list[Expr]:
281+
) -> list[expr_internal.Expr]:
281282
expr_list = []
282283
for expr in exprs:
283284
if isinstance(expr, Expr):

python/datafusion/expr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
176176
"""Helper function to return a default Sort if an Expr is provided."""
177177
if isinstance(e, SortExpr):
178178
return e.raw_sort
179-
return SortExpr(e.expr, True, True).raw_sort
179+
return SortExpr(e, True, True).raw_sort
180180

181181

182182
def sort_list_to_raw_sort_list(
@@ -231,7 +231,7 @@ def variant_name(self) -> str:
231231

232232
def __richcmp__(self, other: Expr, op: int) -> Expr:
233233
"""Comparison operator."""
234-
return Expr(self.expr.__richcmp__(other, op))
234+
return Expr(self.expr.__richcmp__(other.expr, op))
235235

236236
def __repr__(self) -> str:
237237
"""Generate a string representation of this expression."""
@@ -417,7 +417,7 @@ def sort(self, ascending: bool = True, nulls_first: bool = True) -> SortExpr:
417417
ascending: If true, sort in ascending order.
418418
nulls_first: Return null values first.
419419
"""
420-
return SortExpr(self.expr, ascending=ascending, nulls_first=nulls_first)
420+
return SortExpr(self, ascending=ascending, nulls_first=nulls_first)
421421

422422
def is_null(self) -> Expr:
423423
"""Returns ``True`` if this expression is null."""
@@ -789,7 +789,7 @@ class SortExpr:
789789

790790
def __init__(self, expr: Expr, ascending: bool, nulls_first: bool) -> None:
791791
"""This constructor should not be called by the end user."""
792-
self.raw_sort = expr_internal.SortExpr(expr, ascending, nulls_first)
792+
self.raw_sort = expr_internal.SortExpr(expr.expr, ascending, nulls_first)
793793

794794
def expr(self) -> Expr:
795795
"""Return the raw expr backing the SortExpr."""

python/datafusion/functions.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def concat_ws(separator: str, *args: Expr) -> Expr:
366366

367367
def order_by(expr: Expr, ascending: bool = True, nulls_first: bool = True) -> SortExpr:
368368
"""Creates a new sort expression."""
369-
return SortExpr(expr.expr, ascending=ascending, nulls_first=nulls_first)
369+
return SortExpr(expr, ascending=ascending, nulls_first=nulls_first)
370370

371371

372372
def alias(expr: Expr, name: str) -> Expr:
@@ -942,6 +942,7 @@ def to_timestamp_millis(arg: Expr, *formatters: Expr) -> Expr:
942942
943943
See :py:func:`to_timestamp` for a description on how to use formatters.
944944
"""
945+
formatters = [f.expr for f in formatters]
945946
return Expr(f.to_timestamp_millis(arg.expr, *formatters))
946947

947948

@@ -950,6 +951,7 @@ def to_timestamp_micros(arg: Expr, *formatters: Expr) -> Expr:
950951
951952
See :py:func:`to_timestamp` for a description on how to use formatters.
952953
"""
954+
formatters = [f.expr for f in formatters]
953955
return Expr(f.to_timestamp_micros(arg.expr, *formatters))
954956

955957

@@ -958,6 +960,7 @@ def to_timestamp_nanos(arg: Expr, *formatters: Expr) -> Expr:
958960
959961
See :py:func:`to_timestamp` for a description on how to use formatters.
960962
"""
963+
formatters = [f.expr for f in formatters]
961964
return Expr(f.to_timestamp_nanos(arg.expr, *formatters))
962965

963966

@@ -966,6 +969,7 @@ def to_timestamp_seconds(arg: Expr, *formatters: Expr) -> Expr:
966969
967970
See :py:func:`to_timestamp` for a description on how to use formatters.
968971
"""
972+
formatters = [f.expr for f in formatters]
969973
return Expr(f.to_timestamp_seconds(arg.expr, *formatters))
970974

971975

@@ -1078,9 +1082,9 @@ def range(start: Expr, stop: Expr, step: Expr) -> Expr:
10781082
return Expr(f.range(start.expr, stop.expr, step.expr))
10791083

10801084

1081-
def uuid(arg: Expr) -> Expr:
1085+
def uuid() -> Expr:
10821086
"""Returns uuid v4 as a string value."""
1083-
return Expr(f.uuid(arg.expr))
1087+
return Expr(f.uuid())
10841088

10851089

10861090
def struct(*args: Expr) -> Expr:

python/datafusion/input/location.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,20 @@ def is_correct_input(self, input_item: Any, table_name: str, **kwargs):
3737

3838
def build_table(
3939
self,
40-
input_file: str,
40+
input_item: str,
4141
table_name: str,
4242
**kwargs,
4343
) -> SqlTable:
4444
"""Create a table from the input source."""
45-
_, extension = os.path.splitext(input_file)
45+
_, extension = os.path.splitext(input_item)
4646
format = extension.lstrip(".").lower()
4747
num_rows = 0 # Total number of rows in the file. Used for statistics
4848
columns = []
4949
if format == "parquet":
5050
import pyarrow.parquet as pq
5151

5252
# Read the Parquet metadata
53-
metadata = pq.read_metadata(input_file)
53+
metadata = pq.read_metadata(input_item)
5454
num_rows = metadata.num_rows
5555
# Iterate through the schema and build the SqlTable
5656
for col in metadata.schema:
@@ -69,7 +69,7 @@ def build_table(
6969
# to get that information. However, this should only be occurring
7070
# at table creation time and therefore shouldn't
7171
# slow down query performance.
72-
with open(input_file, "r") as file:
72+
with open(input_item, "r") as file:
7373
reader = csv.reader(file)
7474
header_row = next(reader)
7575
print(header_row)
@@ -84,6 +84,6 @@ def build_table(
8484
)
8585

8686
# Input could possibly be multiple files. Create a list if so
87-
input_files = glob.glob(input_file)
87+
input_files = glob.glob(input_item)
8888

8989
return SqlTable(table_name, columns, num_rows, input_files)

python/datafusion/udf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class ScalarUDF:
8585

8686
def __init__(
8787
self,
88-
name: Optional[str],
88+
name: str,
8989
func: Callable[..., _R],
9090
input_types: pyarrow.DataType | list[pyarrow.DataType],
9191
return_type: _R,
@@ -182,7 +182,7 @@ class AggregateUDF:
182182

183183
def __init__(
184184
self,
185-
name: Optional[str],
185+
name: str,
186186
accumulator: Callable[[], Accumulator],
187187
input_types: list[pyarrow.DataType],
188188
return_type: pyarrow.DataType,
@@ -277,6 +277,7 @@ def sum_bias_10() -> Summarize:
277277
)
278278
if name is None:
279279
name = accum.__call__().__class__.__qualname__.lower()
280+
assert name is not None
280281
if isinstance(input_types, pyarrow.DataType):
281282
input_types = [input_types]
282283
return AggregateUDF(
@@ -462,7 +463,7 @@ class WindowUDF:
462463

463464
def __init__(
464465
self,
465-
name: Optional[str],
466+
name: str,
466467
func: Callable[[], WindowEvaluator],
467468
input_types: list[pyarrow.DataType],
468469
return_type: pyarrow.DataType,

python/tests/test_functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,22 @@ def test_temporal_functions(df):
871871
f.to_timestamp_millis(literal("2023-09-07 05:06:14.523952")),
872872
f.to_timestamp_micros(literal("2023-09-07 05:06:14.523952")),
873873
f.extract(literal("day"), column("d")),
874+
f.to_timestamp(
875+
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
876+
),
877+
f.to_timestamp_seconds(
878+
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
879+
),
880+
f.to_timestamp_millis(
881+
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
882+
),
883+
f.to_timestamp_micros(
884+
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
885+
),
874886
f.to_timestamp_nanos(literal("2023-09-07 05:06:14.523952")),
887+
f.to_timestamp_nanos(
888+
literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
889+
),
875890
)
876891
result = df.collect()
877892
assert len(result) == 1
@@ -913,6 +928,21 @@ def test_temporal_functions(df):
913928
assert result.column(11) == pa.array(
914929
[datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
915930
)
931+
assert result.column(12) == pa.array(
932+
[datetime(2023, 9, 7, 5, 6, 14)] * 3, type=pa.timestamp("s")
933+
)
934+
assert result.column(13) == pa.array(
935+
[datetime(2023, 9, 7, 5, 6, 14, 523000)] * 3, type=pa.timestamp("ms")
936+
)
937+
assert result.column(14) == pa.array(
938+
[datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us")
939+
)
940+
assert result.column(15) == pa.array(
941+
[datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
942+
)
943+
assert result.column(16) == pa.array(
944+
[datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns")
945+
)
916946

917947

918948
def test_arrow_cast(df):

0 commit comments

Comments
 (0)