first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is called from a test in test_schema.py.
import pyarrow as pa
# the types where to_pandas_dtype returns a non-numpy dtype
cases = [
(pa.timestamp('ns', tz='UTC'), "datetime64[ns, UTC]"),
]
for arrow_type, pandas_type in cases:
assert str(arrow_type.to_pandas_dtype()) == pandas_type
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# distutils: language=c++
# cython: language_level = 3
import pyarrow as pa
from pyarrow.lib cimport *
from pyarrow.lib import frombytes, tobytes
# basic test to roundtrip through a BoundFunction
ctypedef CStatus visit_string_cb(const c_string&)
cdef extern from * namespace "arrow::py" nogil:
"""
#include <functional>
#include <string>
#include <vector>
#include "arrow/status.h"
namespace arrow {
namespace py {
Status VisitStrings(const std::vector<std::string>& strs,
std::function<Status(const std::string&)> cb) {
for (const std::string& str : strs) {
RETURN_NOT_OK(cb(str));
}
return Status::OK();
}
} // namespace py
} // namespace arrow
"""
cdef CStatus CVisitStrings" arrow::py::VisitStrings"(
vector[c_string], function[visit_string_cb])
cdef void _visit_strings_impl(py_cb, const c_string& s) except *:
py_cb(frombytes(s))
def _visit_strings(strings, cb):
cdef:
function[visit_string_cb] c_cb
vector[c_string] c_strings
c_cb = BindFunction[visit_string_cb](&_visit_strings_impl, cb)
for s in strings:
c_strings.push_back(tobytes(s))
check_status(CVisitStrings(c_strings, c_cb))
@@ -0,0 +1,313 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import pathlib
import subprocess
from tempfile import TemporaryDirectory
import pytest
import hypothesis as h
from pyarrow.util import find_free_port
from pyarrow import Codec
# setup hypothesis profiles
h.settings.register_profile('ci', max_examples=1000)
h.settings.register_profile('dev', max_examples=50)
h.settings.register_profile('debug', max_examples=10,
verbosity=h.Verbosity.verbose)
# load default hypothesis profile, either set HYPOTHESIS_PROFILE environment
# variable or pass --hypothesis-profile option to pytest, to see the generated
# examples try:
# pytest pyarrow -sv --enable-hypothesis --hypothesis-profile=debug
h.settings.load_profile(os.environ.get('HYPOTHESIS_PROFILE', 'dev'))
# Set this at the beginning before the AWS SDK was loaded to avoid reading in
# user configuration values.
os.environ['AWS_CONFIG_FILE'] = "/dev/null"
groups = [
'brotli',
'bz2',
'cython',
'dataset',
'hypothesis',
'fastparquet',
'gandiva',
'gdb',
'gzip',
'hdfs',
'large_memory',
'lz4',
'memory_leak',
'nopandas',
'orc',
'pandas',
'parquet',
'parquet_encryption',
'plasma',
's3',
'snappy',
'tensorflow',
'flight',
'slow',
'requires_testing_data',
'zstd',
]
defaults = {
'brotli': Codec.is_available('brotli'),
'bz2': Codec.is_available('bz2'),
'cython': False,
'dataset': False,
'fastparquet': False,
'flight': False,
'gandiva': False,
'gdb': True,
'gzip': Codec.is_available('gzip'),
'hdfs': False,
'hypothesis': False,
'large_memory': False,
'lz4': Codec.is_available('lz4'),
'memory_leak': False,
'nopandas': False,
'orc': False,
'pandas': False,
'parquet': False,
'parquet_encryption': False,
'plasma': False,
'requires_testing_data': True,
's3': False,
'slow': False,
'snappy': Codec.is_available('snappy'),
'tensorflow': False,
'zstd': Codec.is_available('zstd'),
}
try:
import cython # noqa
defaults['cython'] = True
except ImportError:
pass
try:
import fastparquet # noqa
defaults['fastparquet'] = True
except ImportError:
pass
try:
import pyarrow.gandiva # noqa
defaults['gandiva'] = True
except ImportError:
pass
try:
import pyarrow.dataset # noqa
defaults['dataset'] = True
except ImportError:
pass
try:
import pyarrow.orc # noqa
defaults['orc'] = True
except ImportError:
pass
try:
import pandas # noqa
defaults['pandas'] = True
except ImportError:
defaults['nopandas'] = True
try:
import pyarrow.parquet # noqa
defaults['parquet'] = True
except ImportError:
pass
try:
import pyarrow.parquet.encryption # noqa
defaults['parquet_encryption'] = True
except ImportError:
pass
try:
import pyarrow.plasma # noqa
defaults['plasma'] = True
except ImportError:
pass
try:
import tensorflow # noqa
defaults['tensorflow'] = True
except ImportError:
pass
try:
import pyarrow.flight # noqa
defaults['flight'] = True
except ImportError:
pass
try:
from pyarrow.fs import S3FileSystem # noqa
defaults['s3'] = True
except ImportError:
pass
try:
from pyarrow.fs import HadoopFileSystem # noqa
defaults['hdfs'] = True
except ImportError:
pass
def pytest_addoption(parser):
# Create options to selectively enable test groups
def bool_env(name, default=None):
value = os.environ.get(name.upper())
if not value: # missing or empty
return default
value = value.lower()
if value in {'1', 'true', 'on', 'yes', 'y'}:
return True
elif value in {'0', 'false', 'off', 'no', 'n'}:
return False
else:
raise ValueError('{}={} is not parsable as boolean'
.format(name.upper(), value))
for group in groups:
default = bool_env('PYARROW_TEST_{}'.format(group), defaults[group])
parser.addoption('--enable-{}'.format(group),
action='store_true', default=default,
help=('Enable the {} test group'.format(group)))
parser.addoption('--disable-{}'.format(group),
action='store_true', default=False,
help=('Disable the {} test group'.format(group)))
class PyArrowConfig:
def __init__(self):
self.is_enabled = {}
def apply_mark(self, mark):
group = mark.name
if group in groups:
self.requires(group)
def requires(self, group):
if not self.is_enabled[group]:
pytest.skip('{} NOT enabled'.format(group))
def pytest_configure(config):
# Apply command-line options to initialize PyArrow-specific config object
config.pyarrow = PyArrowConfig()
for mark in groups:
config.addinivalue_line(
"markers", mark,
)
enable_flag = '--enable-{}'.format(mark)
disable_flag = '--disable-{}'.format(mark)
is_enabled = (config.getoption(enable_flag) and not
config.getoption(disable_flag))
config.pyarrow.is_enabled[mark] = is_enabled
def pytest_runtest_setup(item):
# Apply test markers to skip tests selectively
for mark in item.iter_markers():
item.config.pyarrow.apply_mark(mark)
@pytest.fixture
def tempdir(tmpdir):
# convert pytest's LocalPath to pathlib.Path
return pathlib.Path(tmpdir.strpath)
@pytest.fixture(scope='session')
def base_datadir():
return pathlib.Path(__file__).parent / 'data'
@pytest.fixture(autouse=True)
def disable_aws_metadata(monkeypatch):
"""Stop the AWS SDK from trying to contact the EC2 metadata server.
Otherwise, this causes a 5 second delay in tests that exercise the
S3 filesystem.
"""
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
# TODO(kszucs): move the following fixtures to test_fs.py once the previous
# parquet dataset implementation and hdfs implementation are removed.
@pytest.fixture(scope='session')
def hdfs_connection():
host = os.environ.get('ARROW_HDFS_TEST_HOST', 'default')
port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0))
user = os.environ.get('ARROW_HDFS_TEST_USER', 'hdfs')
return host, port, user
@pytest.fixture(scope='session')
def s3_connection():
host, port = 'localhost', find_free_port()
access_key, secret_key = 'arrow', 'apachearrow'
return host, port, access_key, secret_key
@pytest.fixture(scope='session')
def s3_server(s3_connection):
host, port, access_key, secret_key = s3_connection
address = '{}:{}'.format(host, port)
env = os.environ.copy()
env.update({
'MINIO_ACCESS_KEY': access_key,
'MINIO_SECRET_KEY': secret_key
})
with TemporaryDirectory() as tempdir:
args = ['minio', '--compat', 'server', '--quiet', '--address',
address, tempdir]
proc = None
try:
proc = subprocess.Popen(args, env=env)
except OSError:
pytest.skip('`minio` command cannot be located')
else:
yield {
'connection': s3_connection,
'process': proc,
'tempdir': tempdir
}
finally:
if proc is not None:
proc.kill()
@@ -0,0 +1,22 @@
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
The ORC and JSON files come from the `examples` directory in the Apache ORC
source tree:
https://github.com/apache/orc/tree/master/examples
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is called from a test in test_serialization.py.
import sys
import pyarrow as pa
with open(sys.argv[1], 'rb') as f:
data = f.read()
pa.deserialize(data)
@@ -0,0 +1,172 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import OrderedDict
from datetime import date, time
import numpy as np
import pandas as pd
import pyarrow as pa
def dataframe_with_arrays(include_index=False):
"""
Dataframe with numpy arrays columns of every possible primitive type.
Returns
-------
df: pandas.DataFrame
schema: pyarrow.Schema
Arrow schema definition that is in line with the constructed df.
"""
dtypes = [('i1', pa.int8()), ('i2', pa.int16()),
('i4', pa.int32()), ('i8', pa.int64()),
('u1', pa.uint8()), ('u2', pa.uint16()),
('u4', pa.uint32()), ('u8', pa.uint64()),
('f4', pa.float32()), ('f8', pa.float64())]
arrays = OrderedDict()
fields = []
for dtype, arrow_dtype in dtypes:
fields.append(pa.field(dtype, pa.list_(arrow_dtype)))
arrays[dtype] = [
np.arange(10, dtype=dtype),
np.arange(5, dtype=dtype),
None,
np.arange(1, dtype=dtype)
]
fields.append(pa.field('str', pa.list_(pa.string())))
arrays['str'] = [
np.array(["1", "ä"], dtype="object"),
None,
np.array(["1"], dtype="object"),
np.array(["1", "2", "3"], dtype="object")
]
fields.append(pa.field('datetime64', pa.list_(pa.timestamp('ms'))))
arrays['datetime64'] = [
np.array(['2007-07-13T01:23:34.123456789',
None,
'2010-08-13T05:46:57.437699912'],
dtype='datetime64[ms]'),
None,
None,
np.array(['2007-07-13T02',
None,
'2010-08-13T05:46:57.437699912'],
dtype='datetime64[ms]'),
]
if include_index:
fields.append(pa.field('__index_level_0__', pa.int64()))
df = pd.DataFrame(arrays)
schema = pa.schema(fields)
return df, schema
def dataframe_with_lists(include_index=False, parquet_compatible=False):
"""
Dataframe with list columns of every possible primitive type.
Returns
-------
df: pandas.DataFrame
schema: pyarrow.Schema
Arrow schema definition that is in line with the constructed df.
parquet_compatible: bool
Exclude types not supported by parquet
"""
arrays = OrderedDict()
fields = []
fields.append(pa.field('int64', pa.list_(pa.int64())))
arrays['int64'] = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4],
None,
[],
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 2,
dtype=np.int64)[::2]
]
fields.append(pa.field('double', pa.list_(pa.float64())))
arrays['double'] = [
[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.],
None,
[],
np.array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.] * 2)[::2],
]
fields.append(pa.field('bytes_list', pa.list_(pa.binary())))
arrays['bytes_list'] = [
[b"1", b"f"],
None,
[b"1"],
[b"1", b"2", b"3"],
[],
]
fields.append(pa.field('str_list', pa.list_(pa.string())))
arrays['str_list'] = [
["1", "ä"],
None,
["1"],
["1", "2", "3"],
[],
]
date_data = [
[],
[date(2018, 1, 1), date(2032, 12, 30)],
[date(2000, 6, 7)],
None,
[date(1969, 6, 9), date(1972, 7, 3)]
]
time_data = [
[time(23, 11, 11), time(1, 2, 3), time(23, 59, 59)],
[],
[time(22, 5, 59)],
None,
[time(0, 0, 0), time(18, 0, 2), time(12, 7, 3)]
]
temporal_pairs = [
(pa.date32(), date_data),
(pa.date64(), date_data),
(pa.time32('s'), time_data),
(pa.time32('ms'), time_data),
(pa.time64('us'), time_data)
]
if not parquet_compatible:
temporal_pairs += [
(pa.time64('ns'), time_data),
]
for value_type, data in temporal_pairs:
field_name = '{}_list'.format(value_type)
field_type = pa.list_(value_type)
field = pa.field(field_name, field_type)
fields.append(field)
arrays[field_name] = data
if include_index:
fields.append(pa.field('__index_level_0__', pa.int64()))
df = pd.DataFrame(arrays)
schema = pa.schema(fields)
return df, schema
@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is called from a test in test_pandas.py.
from concurrent.futures import ThreadPoolExecutor
import faulthandler
import sys
import pyarrow as pa
num_threads = 60
timeout = 10 # seconds
def thread_func(i):
pa.array([i]).to_pandas()
def main():
# In case of import deadlock, crash after a finite timeout
faulthandler.dump_traceback_later(timeout, exit=True)
with ThreadPoolExecutor(num_threads) as pool:
assert "pandas" not in sys.modules # pandas is imported lazily
list(pool.map(thread_func, range(num_threads)))
assert "pandas" in sys.modules
if __name__ == "__main__":
main()
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
# Marks all of the tests in this module
# Ignore these with pytest ... -m 'not parquet'
pytestmark = [
pytest.mark.parquet,
pytest.mark.filterwarnings(
"ignore:Passing 'use_legacy_dataset=True':DeprecationWarning"
),
]
@@ -0,0 +1,189 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import numpy as np
import pytest
import pyarrow as pa
from pyarrow.tests import util
legacy_filter_mark = pytest.mark.filterwarnings(
"ignore:Passing 'use_legacy:FutureWarning"
)
parametrize_legacy_dataset = pytest.mark.parametrize(
"use_legacy_dataset",
[pytest.param(True, marks=legacy_filter_mark),
pytest.param(False, marks=pytest.mark.dataset)]
)
parametrize_legacy_dataset_not_supported = pytest.mark.parametrize(
"use_legacy_dataset",
[pytest.param(True, marks=legacy_filter_mark),
pytest.param(False, marks=pytest.mark.skip)]
)
parametrize_legacy_dataset_fixed = pytest.mark.parametrize(
"use_legacy_dataset",
[pytest.param(True, marks=[pytest.mark.xfail, legacy_filter_mark]),
pytest.param(False, marks=pytest.mark.dataset)]
)
# Marks all of the tests in this module
# Ignore these with pytest ... -m 'not parquet'
pytestmark = pytest.mark.parquet
def _write_table(table, path, **kwargs):
# So we see the ImportError somewhere
import pyarrow.parquet as pq
from pyarrow.pandas_compat import _pandas_api
if _pandas_api.is_data_frame(table):
table = pa.Table.from_pandas(table)
pq.write_table(table, path, **kwargs)
return table
def _read_table(*args, **kwargs):
import pyarrow.parquet as pq
table = pq.read_table(*args, **kwargs)
table.validate(full=True)
return table
def _roundtrip_table(table, read_table_kwargs=None,
write_table_kwargs=None, use_legacy_dataset=False):
read_table_kwargs = read_table_kwargs or {}
write_table_kwargs = write_table_kwargs or {}
writer = pa.BufferOutputStream()
_write_table(table, writer, **write_table_kwargs)
reader = pa.BufferReader(writer.getvalue())
return _read_table(reader, use_legacy_dataset=use_legacy_dataset,
**read_table_kwargs)
def _check_roundtrip(table, expected=None, read_table_kwargs=None,
use_legacy_dataset=False, **write_table_kwargs):
if expected is None:
expected = table
read_table_kwargs = read_table_kwargs or {}
# intentionally check twice
result = _roundtrip_table(table, read_table_kwargs=read_table_kwargs,
write_table_kwargs=write_table_kwargs,
use_legacy_dataset=use_legacy_dataset)
assert result.equals(expected)
result = _roundtrip_table(result, read_table_kwargs=read_table_kwargs,
write_table_kwargs=write_table_kwargs,
use_legacy_dataset=use_legacy_dataset)
assert result.equals(expected)
def _roundtrip_pandas_dataframe(df, write_kwargs, use_legacy_dataset=False):
table = pa.Table.from_pandas(df)
result = _roundtrip_table(
table, write_table_kwargs=write_kwargs,
use_legacy_dataset=use_legacy_dataset)
return result.to_pandas()
def _random_integers(size, dtype):
# We do not generate integers outside the int64 range
platform_int_info = np.iinfo('int_')
iinfo = np.iinfo(dtype)
return np.random.randint(max(iinfo.min, platform_int_info.min),
min(iinfo.max, platform_int_info.max),
size=size).astype(dtype)
def _test_dataframe(size=10000, seed=0):
import pandas as pd
np.random.seed(seed)
df = pd.DataFrame({
'uint8': _random_integers(size, np.uint8),
'uint16': _random_integers(size, np.uint16),
'uint32': _random_integers(size, np.uint32),
'uint64': _random_integers(size, np.uint64),
'int8': _random_integers(size, np.int8),
'int16': _random_integers(size, np.int16),
'int32': _random_integers(size, np.int32),
'int64': _random_integers(size, np.int64),
'float32': np.random.randn(size).astype(np.float32),
'float64': np.arange(size, dtype=np.float64),
'bool': np.random.randn(size) > 0,
'strings': [util.rands(10) for i in range(size)],
'all_none': [None] * size,
'all_none_category': [None] * size
})
# TODO(PARQUET-1015)
# df['all_none_category'] = df['all_none_category'].astype('category')
return df
def make_sample_file(table_or_df):
import pyarrow.parquet as pq
if isinstance(table_or_df, pa.Table):
a_table = table_or_df
else:
a_table = pa.Table.from_pandas(table_or_df)
buf = io.BytesIO()
_write_table(a_table, buf, compression='SNAPPY', version='2.6',
coerce_timestamps='ms')
buf.seek(0)
return pq.ParquetFile(buf)
def alltypes_sample(size=10000, seed=0, categorical=False):
import pandas as pd
np.random.seed(seed)
arrays = {
'uint8': np.arange(size, dtype=np.uint8),
'uint16': np.arange(size, dtype=np.uint16),
'uint32': np.arange(size, dtype=np.uint32),
'uint64': np.arange(size, dtype=np.uint64),
'int8': np.arange(size, dtype=np.int16),
'int16': np.arange(size, dtype=np.int16),
'int32': np.arange(size, dtype=np.int32),
'int64': np.arange(size, dtype=np.int64),
'float32': np.arange(size, dtype=np.float32),
'float64': np.arange(size, dtype=np.float64),
'bool': np.random.randn(size) > 0,
# TODO(wesm): Test other timestamp resolutions now that arrow supports
# them
'datetime': np.arange("2016-01-01T00:00:00.001", size,
dtype='datetime64[ms]'),
'timedelta': np.arange(0, size, dtype="timedelta64[s]"),
'str': pd.Series([str(x) for x in range(size)]),
'empty_str': [''] * size,
'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None],
'null': [None] * size,
'null_list': [None] * 2 + [[None] * (x % 4) for x in range(size - 2)],
}
if categorical:
arrays['str_category'] = arrays['str'].astype('category')
return pd.DataFrame(arrays)
@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pyarrow.util import guid
@pytest.fixture(scope='module')
def datadir(base_datadir):
return base_datadir / 'parquet'
@pytest.fixture
def s3_bucket(s3_server):
boto3 = pytest.importorskip('boto3')
botocore = pytest.importorskip('botocore')
host, port, access_key, secret_key = s3_server['connection']
s3 = boto3.resource(
's3',
endpoint_url='http://{}:{}'.format(host, port),
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
config=botocore.client.Config(signature_version='s3v4'),
region_name='us-east-1'
)
bucket = s3.Bucket('test-s3fs')
try:
bucket.create()
except Exception:
# we get BucketAlreadyOwnedByYou error with fsspec handler
pass
return 'test-s3fs'
@pytest.fixture
def s3_example_s3fs(s3_server, s3_bucket):
s3fs = pytest.importorskip('s3fs')
host, port, access_key, secret_key = s3_server['connection']
fs = s3fs.S3FileSystem(
key=access_key,
secret=secret_key,
client_kwargs={
'endpoint_url': 'http://{}:{}'.format(host, port)
}
)
test_path = '{}/{}'.format(s3_bucket, guid())
fs.mkdir(test_path)
yield fs, test_path
try:
fs.rm(test_path, recursive=True)
except FileNotFoundError:
pass
@pytest.fixture
def s3_example_fs(s3_server):
from pyarrow.fs import FileSystem
host, port, access_key, secret_key = s3_server['connection']
uri = (
"s3://{}:{}@mybucket/data.parquet?scheme=http&endpoint_override={}:{}"
.format(access_key, secret_key, host, port)
)
fs, path = FileSystem.from_uri(uri)
fs.create_dir("mybucket")
yield fs, uri, path
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import base64
import pyarrow.parquet.encryption as pe
class InMemoryKmsClient(pe.KmsClient):
"""This is a mock class implementation of KmsClient, built for testing only.
"""
def __init__(self, config):
"""Create an InMemoryKmsClient instance."""
pe.KmsClient.__init__(self)
self.master_keys_map = config.custom_kms_conf
def wrap_key(self, key_bytes, master_key_identifier):
"""Not a secure cipher - the wrapped key
is just the master key concatenated with key bytes"""
master_key_bytes = self.master_keys_map[master_key_identifier].encode(
'utf-8')
wrapped_key = b"".join([master_key_bytes, key_bytes])
result = base64.b64encode(wrapped_key)
return result
def unwrap_key(self, wrapped_key, master_key_identifier):
"""Not a secure cipher - just extract the key from
the wrapped key"""
expected_master_key = self.master_keys_map[master_key_identifier]
decoded_wrapped_key = base64.b64decode(wrapped_key)
master_key_bytes = decoded_wrapped_key[:16]
decrypted_key = decoded_wrapped_key[16:]
if (expected_master_key == master_key_bytes.decode('utf-8')):
return decrypted_key
raise ValueError("Incorrect master key used",
master_key_bytes, decrypted_key)
def verify_file_encrypted(path):
"""Verify that the file is encrypted by looking at its first 4 bytes.
If it's the magic string PARE
then this is a parquet with encrypted footer."""
with open(path, "rb") as file:
magic_str = file.read(4)
# Verify magic string for parquet with encrypted footer is PARE
assert magic_str == b'PARE'
@@ -0,0 +1,799 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import OrderedDict
import io
import numpy as np
import pytest
import pyarrow as pa
from pyarrow import fs
from pyarrow.filesystem import LocalFileSystem, FileSystem
from pyarrow.tests import util
from pyarrow.tests.parquet.common import (_check_roundtrip, _roundtrip_table,
parametrize_legacy_dataset,
_test_dataframe)
try:
import pyarrow.parquet as pq
from pyarrow.tests.parquet.common import _read_table, _write_table
except ImportError:
pq = None
try:
import pandas as pd
import pandas.testing as tm
from pyarrow.tests.pandas_examples import dataframe_with_lists
from pyarrow.tests.parquet.common import alltypes_sample
except ImportError:
pd = tm = None
def test_parquet_invalid_version(tempdir):
table = pa.table({'a': [1, 2, 3]})
with pytest.raises(ValueError, match="Unsupported Parquet format version"):
_write_table(table, tempdir / 'test_version.parquet', version="2.2")
with pytest.raises(ValueError, match="Unsupported Parquet data page " +
"version"):
_write_table(table, tempdir / 'test_version.parquet',
data_page_version="2.2")
@parametrize_legacy_dataset
def test_set_data_page_size(use_legacy_dataset):
arr = pa.array([1, 2, 3] * 100000)
t = pa.Table.from_arrays([arr], names=['f0'])
# 128K, 512K
page_sizes = [2 << 16, 2 << 18]
for target_page_size in page_sizes:
_check_roundtrip(t, data_page_size=target_page_size,
use_legacy_dataset=use_legacy_dataset)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_set_write_batch_size(use_legacy_dataset):
df = _test_dataframe(100)
table = pa.Table.from_pandas(df, preserve_index=False)
_check_roundtrip(
table, data_page_size=10, write_batch_size=1, version='2.4'
)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_set_dictionary_pagesize_limit(use_legacy_dataset):
df = _test_dataframe(100)
table = pa.Table.from_pandas(df, preserve_index=False)
_check_roundtrip(table, dictionary_pagesize_limit=1,
data_page_size=10, version='2.4')
with pytest.raises(TypeError):
_check_roundtrip(table, dictionary_pagesize_limit="a",
data_page_size=10, version='2.4')
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_chunked_table_write(use_legacy_dataset):
# ARROW-232
tables = []
batch = pa.RecordBatch.from_pandas(alltypes_sample(size=10))
tables.append(pa.Table.from_batches([batch] * 3))
df, _ = dataframe_with_lists()
batch = pa.RecordBatch.from_pandas(df)
tables.append(pa.Table.from_batches([batch] * 3))
for data_page_version in ['1.0', '2.0']:
for use_dictionary in [True, False]:
for table in tables:
_check_roundtrip(
table, version='2.6',
use_legacy_dataset=use_legacy_dataset,
data_page_version=data_page_version,
use_dictionary=use_dictionary)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_memory_map(tempdir, use_legacy_dataset):
df = alltypes_sample(size=10)
table = pa.Table.from_pandas(df)
_check_roundtrip(table, read_table_kwargs={'memory_map': True},
version='2.6', use_legacy_dataset=use_legacy_dataset)
filename = str(tempdir / 'tmp_file')
with open(filename, 'wb') as f:
_write_table(table, f, version='2.6')
table_read = pq.read_pandas(filename, memory_map=True,
use_legacy_dataset=use_legacy_dataset)
assert table_read.equals(table)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_enable_buffered_stream(tempdir, use_legacy_dataset):
df = alltypes_sample(size=10)
table = pa.Table.from_pandas(df)
_check_roundtrip(table, read_table_kwargs={'buffer_size': 1025},
version='2.6', use_legacy_dataset=use_legacy_dataset)
filename = str(tempdir / 'tmp_file')
with open(filename, 'wb') as f:
_write_table(table, f, version='2.6')
table_read = pq.read_pandas(filename, buffer_size=4096,
use_legacy_dataset=use_legacy_dataset)
assert table_read.equals(table)
@parametrize_legacy_dataset
def test_special_chars_filename(tempdir, use_legacy_dataset):
table = pa.Table.from_arrays([pa.array([42])], ["ints"])
filename = "foo # bar"
path = tempdir / filename
assert not path.exists()
_write_table(table, str(path))
assert path.exists()
table_read = _read_table(str(path), use_legacy_dataset=use_legacy_dataset)
assert table_read.equals(table)
@parametrize_legacy_dataset
def test_invalid_source(use_legacy_dataset):
# Test that we provide an helpful error message pointing out
# that None wasn't expected when trying to open a Parquet None file.
#
# Depending on use_legacy_dataset the message changes slightly
# but in both cases it should point out that None wasn't expected.
with pytest.raises(TypeError, match="None"):
pq.read_table(None, use_legacy_dataset=use_legacy_dataset)
with pytest.raises(TypeError, match="None"):
pq.ParquetFile(None)
@pytest.mark.slow
def test_file_with_over_int16_max_row_groups():
# PARQUET-1857: Parquet encryption support introduced a INT16_MAX upper
# limit on the number of row groups, but this limit only impacts files with
# encrypted row group metadata because of the int16 row group ordinal used
# in the Parquet Thrift metadata. Unencrypted files are not impacted, so
# this test checks that it works (even if it isn't a good idea)
t = pa.table([list(range(40000))], names=['f0'])
_check_roundtrip(t, row_group_size=1)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_empty_table_roundtrip(use_legacy_dataset):
df = alltypes_sample(size=10)
# Create a non-empty table to infer the types correctly, then slice to 0
table = pa.Table.from_pandas(df)
table = pa.Table.from_arrays(
[col.chunk(0)[:0] for col in table.itercolumns()],
names=table.schema.names)
assert table.schema.field('null').type == pa.null()
assert table.schema.field('null_list').type == pa.list_(pa.null())
_check_roundtrip(
table, version='2.6', use_legacy_dataset=use_legacy_dataset)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_empty_table_no_columns(use_legacy_dataset):
df = pd.DataFrame()
empty = pa.Table.from_pandas(df, preserve_index=False)
_check_roundtrip(empty, use_legacy_dataset=use_legacy_dataset)
@parametrize_legacy_dataset
def test_write_nested_zero_length_array_chunk_failure(use_legacy_dataset):
# Bug report in ARROW-3792
cols = OrderedDict(
int32=pa.int32(),
list_string=pa.list_(pa.string())
)
data = [[], [OrderedDict(int32=1, list_string=('G',)), ]]
# This produces a table with a column like
# <Column name='list_string' type=ListType(list<item: string>)>
# [
# [],
# [
# [
# "G"
# ]
# ]
# ]
#
# Each column is a ChunkedArray with 2 elements
my_arrays = [pa.array(batch, type=pa.struct(cols)).flatten()
for batch in data]
my_batches = [pa.RecordBatch.from_arrays(batch, schema=pa.schema(cols))
for batch in my_arrays]
tbl = pa.Table.from_batches(my_batches, pa.schema(cols))
_check_roundtrip(tbl, use_legacy_dataset=use_legacy_dataset)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_multiple_path_types(tempdir, use_legacy_dataset):
# Test compatibility with PEP 519 path-like objects
path = tempdir / 'zzz.parquet'
df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)})
_write_table(df, path)
table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
# Test compatibility with plain string paths
path = str(tempdir) + 'zzz.parquet'
df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)})
_write_table(df, path)
table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
@parametrize_legacy_dataset
def test_fspath(tempdir, use_legacy_dataset):
# ARROW-12472 support __fspath__ objects without using str()
path = tempdir / "test.parquet"
table = pa.table({"a": [1, 2, 3]})
_write_table(table, path)
fs_protocol_obj = util.FSProtocolClass(path)
result = _read_table(
fs_protocol_obj, use_legacy_dataset=use_legacy_dataset
)
assert result.equals(table)
# combined with non-local filesystem raises
with pytest.raises(TypeError):
_read_table(fs_protocol_obj, filesystem=FileSystem())
@pytest.mark.dataset
@parametrize_legacy_dataset
@pytest.mark.parametrize("filesystem", [
None, fs.LocalFileSystem(), LocalFileSystem._get_instance()
])
def test_relative_paths(tempdir, use_legacy_dataset, filesystem):
# reading and writing from relative paths
table = pa.table({"a": [1, 2, 3]})
# reading
pq.write_table(table, str(tempdir / "data.parquet"))
with util.change_cwd(tempdir):
result = pq.read_table("data.parquet", filesystem=filesystem,
use_legacy_dataset=use_legacy_dataset)
assert result.equals(table)
# writing
with util.change_cwd(tempdir):
pq.write_table(table, "data2.parquet", filesystem=filesystem)
result = pq.read_table(tempdir / "data2.parquet")
assert result.equals(table)
def test_read_non_existing_file():
# ensure we have a proper error message
with pytest.raises(FileNotFoundError):
pq.read_table('i-am-not-existing.parquet')
def test_file_error_python_exception():
class BogusFile(io.BytesIO):
def read(self, *args):
raise ZeroDivisionError("zorglub")
def seek(self, *args):
raise ZeroDivisionError("zorglub")
# ensure the Python exception is restored
with pytest.raises(ZeroDivisionError, match="zorglub"):
pq.read_table(BogusFile(b""))
@parametrize_legacy_dataset
def test_parquet_read_from_buffer(tempdir, use_legacy_dataset):
# reading from a buffer from python's open()
table = pa.table({"a": [1, 2, 3]})
pq.write_table(table, str(tempdir / "data.parquet"))
with open(str(tempdir / "data.parquet"), "rb") as f:
result = pq.read_table(f, use_legacy_dataset=use_legacy_dataset)
assert result.equals(table)
with open(str(tempdir / "data.parquet"), "rb") as f:
result = pq.read_table(pa.PythonFile(f),
use_legacy_dataset=use_legacy_dataset)
assert result.equals(table)
@parametrize_legacy_dataset
def test_byte_stream_split(use_legacy_dataset):
# This is only a smoke test.
arr_float = pa.array(list(map(float, range(100))))
arr_int = pa.array(list(map(int, range(100))))
data_float = [arr_float, arr_float]
table = pa.Table.from_arrays(data_float, names=['a', 'b'])
# Check with byte_stream_split for both columns.
_check_roundtrip(table, expected=table, compression="gzip",
use_dictionary=False, use_byte_stream_split=True)
# Check with byte_stream_split for column 'b' and dictionary
# for column 'a'.
_check_roundtrip(table, expected=table, compression="gzip",
use_dictionary=['a'],
use_byte_stream_split=['b'])
# Check with a collision for both columns.
_check_roundtrip(table, expected=table, compression="gzip",
use_dictionary=['a', 'b'],
use_byte_stream_split=['a', 'b'])
# Check with mixed column types.
mixed_table = pa.Table.from_arrays([arr_float, arr_int],
names=['a', 'b'])
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=['b'],
use_byte_stream_split=['a'])
# Try to use the wrong data type with the byte_stream_split encoding.
# This should throw an exception.
table = pa.Table.from_arrays([arr_int], names=['tmp'])
with pytest.raises(IOError):
_check_roundtrip(table, expected=table, use_byte_stream_split=True,
use_dictionary=False,
use_legacy_dataset=use_legacy_dataset)
@parametrize_legacy_dataset
def test_column_encoding(use_legacy_dataset):
arr_float = pa.array(list(map(float, range(100))))
arr_int = pa.array(list(map(int, range(100))))
mixed_table = pa.Table.from_arrays([arr_float, arr_int],
names=['a', 'b'])
# Check "BYTE_STREAM_SPLIT" for column 'a' and "PLAIN" column_encoding for
# column 'b'.
_check_roundtrip(mixed_table, expected=mixed_table, use_dictionary=False,
column_encoding={'a': "BYTE_STREAM_SPLIT", 'b': "PLAIN"},
use_legacy_dataset=use_legacy_dataset)
# Check "PLAIN" for all columns.
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=False,
column_encoding="PLAIN",
use_legacy_dataset=use_legacy_dataset)
# Try to pass "BYTE_STREAM_SPLIT" column encoding for integer column 'b'.
# This should throw an error as it is only supports FLOAT and DOUBLE.
with pytest.raises(IOError,
match="BYTE_STREAM_SPLIT only supports FLOAT and"
" DOUBLE"):
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=False,
column_encoding={'b': "BYTE_STREAM_SPLIT"},
use_legacy_dataset=use_legacy_dataset)
# Try to pass "DELTA_BINARY_PACKED".
# This should throw an error as it is only supported for reading.
with pytest.raises(IOError,
match="Not yet implemented: Selected encoding is"
" not supported."):
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=False,
column_encoding={'b': "DELTA_BINARY_PACKED"},
use_legacy_dataset=use_legacy_dataset)
# Try to pass "RLE_DICTIONARY".
# This should throw an error as dictionary encoding is already used by
# default and not supported to be specified as "fallback" encoding
with pytest.raises(ValueError):
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=False,
column_encoding="RLE_DICTIONARY",
use_legacy_dataset=use_legacy_dataset)
# Try to pass unsupported encoding.
with pytest.raises(ValueError):
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=False,
column_encoding={'a': "MADE_UP_ENCODING"},
use_legacy_dataset=use_legacy_dataset)
# Try to pass column_encoding and use_dictionary.
# This should throw an error.
with pytest.raises(ValueError):
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=['b'],
column_encoding={'b': "PLAIN"},
use_legacy_dataset=use_legacy_dataset)
# Try to pass column_encoding and use_dictionary=True (default value).
# This should throw an error.
with pytest.raises(ValueError):
_check_roundtrip(mixed_table, expected=mixed_table,
column_encoding={'b': "PLAIN"},
use_legacy_dataset=use_legacy_dataset)
# Try to pass column_encoding and use_byte_stream_split on same column.
# This should throw an error.
with pytest.raises(ValueError):
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=False,
use_byte_stream_split=['a'],
column_encoding={'a': "RLE",
'b': "BYTE_STREAM_SPLIT"},
use_legacy_dataset=use_legacy_dataset)
# Try to pass column_encoding and use_byte_stream_split=True.
# This should throw an error.
with pytest.raises(ValueError):
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=False,
use_byte_stream_split=True,
column_encoding={'a': "RLE",
'b': "BYTE_STREAM_SPLIT"},
use_legacy_dataset=use_legacy_dataset)
# Try to pass column_encoding=True.
# This should throw an error.
with pytest.raises(TypeError):
_check_roundtrip(mixed_table, expected=mixed_table,
use_dictionary=False,
column_encoding=True,
use_legacy_dataset=use_legacy_dataset)
@parametrize_legacy_dataset
def test_compression_level(use_legacy_dataset):
arr = pa.array(list(map(int, range(1000))))
data = [arr, arr]
table = pa.Table.from_arrays(data, names=['a', 'b'])
# Check one compression level.
_check_roundtrip(table, expected=table, compression="gzip",
compression_level=1,
use_legacy_dataset=use_legacy_dataset)
# Check another one to make sure that compression_level=1 does not
# coincide with the default one in Arrow.
_check_roundtrip(table, expected=table, compression="gzip",
compression_level=5,
use_legacy_dataset=use_legacy_dataset)
# Check that the user can provide a compression per column
_check_roundtrip(table, expected=table,
compression={'a': "gzip", 'b': "snappy"},
use_legacy_dataset=use_legacy_dataset)
# Check that the user can provide a compression level per column
_check_roundtrip(table, expected=table, compression="gzip",
compression_level={'a': 2, 'b': 3},
use_legacy_dataset=use_legacy_dataset)
# Check if both LZ4 compressors are working
# (level < 3 -> fast, level >= 3 -> HC)
_check_roundtrip(table, expected=table, compression="lz4",
compression_level=1,
use_legacy_dataset=use_legacy_dataset)
_check_roundtrip(table, expected=table, compression="lz4",
compression_level=9,
use_legacy_dataset=use_legacy_dataset)
# Check that specifying a compression level for a codec which does allow
# specifying one, results into an error.
# Uncompressed, snappy and lzo do not support specifying a compression
# level.
# GZIP (zlib) allows for specifying a compression level but as of up
# to version 1.2.11 the valid range is [-1, 9].
invalid_combinations = [("snappy", 4), ("gzip", -1337),
("None", 444), ("lzo", 14)]
buf = io.BytesIO()
for (codec, level) in invalid_combinations:
with pytest.raises((ValueError, OSError)):
_write_table(table, buf, compression=codec,
compression_level=level)
def test_sanitized_spark_field_names():
a0 = pa.array([0, 1, 2, 3, 4])
name = 'prohib; ,\t{}'
table = pa.Table.from_arrays([a0], [name])
result = _roundtrip_table(table, write_table_kwargs={'flavor': 'spark'})
expected_name = 'prohib______'
assert result.schema[0].name == expected_name
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_multithreaded_read(use_legacy_dataset):
df = alltypes_sample(size=10000)
table = pa.Table.from_pandas(df)
buf = io.BytesIO()
_write_table(table, buf, compression='SNAPPY', version='2.6')
buf.seek(0)
table1 = _read_table(
buf, use_threads=True, use_legacy_dataset=use_legacy_dataset)
buf.seek(0)
table2 = _read_table(
buf, use_threads=False, use_legacy_dataset=use_legacy_dataset)
assert table1.equals(table2)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_min_chunksize(use_legacy_dataset):
data = pd.DataFrame([np.arange(4)], columns=['A', 'B', 'C', 'D'])
table = pa.Table.from_pandas(data.reset_index())
buf = io.BytesIO()
_write_table(table, buf, chunk_size=-1)
buf.seek(0)
result = _read_table(buf, use_legacy_dataset=use_legacy_dataset)
assert result.equals(table)
with pytest.raises(ValueError):
_write_table(table, buf, chunk_size=0)
@pytest.mark.pandas
def test_write_error_deletes_incomplete_file(tempdir):
# ARROW-1285
df = pd.DataFrame({'a': list('abc'),
'b': list(range(1, 4)),
'c': np.arange(3, 6).astype('u1'),
'd': np.arange(4.0, 7.0, dtype='float64'),
'e': [True, False, True],
'f': pd.Categorical(list('abc')),
'g': pd.date_range('20130101', periods=3),
'h': pd.date_range('20130101', periods=3,
tz='US/Eastern'),
'i': pd.date_range('20130101', periods=3, freq='ns')})
pdf = pa.Table.from_pandas(df)
filename = tempdir / 'tmp_file'
try:
_write_table(pdf, filename)
except pa.ArrowException:
pass
assert not filename.exists()
@parametrize_legacy_dataset
def test_read_non_existent_file(tempdir, use_legacy_dataset):
path = 'non-existent-file.parquet'
try:
pq.read_table(path, use_legacy_dataset=use_legacy_dataset)
except Exception as e:
assert path in e.args[0]
@parametrize_legacy_dataset
def test_read_table_doesnt_warn(datadir, use_legacy_dataset):
with pytest.warns(None) as record:
pq.read_table(datadir / 'v0.7.1.parquet',
use_legacy_dataset=use_legacy_dataset)
if use_legacy_dataset:
# FutureWarning: 'use_legacy_dataset=True'
assert len(record) == 1
else:
assert len(record) == 0
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_zlib_compression_bug(use_legacy_dataset):
# ARROW-3514: "zlib deflate failed, output buffer too small"
table = pa.Table.from_arrays([pa.array(['abc', 'def'])], ['some_col'])
f = io.BytesIO()
pq.write_table(table, f, compression='gzip')
f.seek(0)
roundtrip = pq.read_table(f, use_legacy_dataset=use_legacy_dataset)
tm.assert_frame_equal(roundtrip.to_pandas(), table.to_pandas())
@parametrize_legacy_dataset
def test_parquet_file_too_small(tempdir, use_legacy_dataset):
path = str(tempdir / "test.parquet")
# TODO(dataset) with datasets API it raises OSError instead
with pytest.raises((pa.ArrowInvalid, OSError),
match='size is 0 bytes'):
with open(path, 'wb') as f:
pass
pq.read_table(path, use_legacy_dataset=use_legacy_dataset)
with pytest.raises((pa.ArrowInvalid, OSError),
match='size is 4 bytes'):
with open(path, 'wb') as f:
f.write(b'ffff')
pq.read_table(path, use_legacy_dataset=use_legacy_dataset)
@pytest.mark.pandas
@pytest.mark.fastparquet
@pytest.mark.filterwarnings("ignore:RangeIndex:FutureWarning")
@pytest.mark.filterwarnings("ignore:tostring:DeprecationWarning:fastparquet")
def test_fastparquet_cross_compatibility(tempdir):
fp = pytest.importorskip('fastparquet')
df = pd.DataFrame(
{
"a": list("abc"),
"b": list(range(1, 4)),
"c": np.arange(4.0, 7.0, dtype="float64"),
"d": [True, False, True],
"e": pd.date_range("20130101", periods=3),
"f": pd.Categorical(["a", "b", "a"]),
# fastparquet writes list as BYTE_ARRAY JSON, so no roundtrip
# "g": [[1, 2], None, [1, 2, 3]],
}
)
table = pa.table(df)
# Arrow -> fastparquet
file_arrow = str(tempdir / "cross_compat_arrow.parquet")
pq.write_table(table, file_arrow, compression=None)
fp_file = fp.ParquetFile(file_arrow)
df_fp = fp_file.to_pandas()
tm.assert_frame_equal(df, df_fp)
# Fastparquet -> arrow
file_fastparquet = str(tempdir / "cross_compat_fastparquet.parquet")
fp.write(file_fastparquet, df)
table_fp = pq.read_pandas(file_fastparquet)
# for fastparquet written file, categoricals comes back as strings
# (no arrow schema in parquet metadata)
df['f'] = df['f'].astype(object)
tm.assert_frame_equal(table_fp.to_pandas(), df)
@parametrize_legacy_dataset
@pytest.mark.parametrize('array_factory', [
lambda: pa.array([0, None] * 10),
lambda: pa.array([0, None] * 10).dictionary_encode(),
lambda: pa.array(["", None] * 10),
lambda: pa.array(["", None] * 10).dictionary_encode(),
])
@pytest.mark.parametrize('use_dictionary', [False, True])
@pytest.mark.parametrize('read_dictionary', [False, True])
def test_buffer_contents(
array_factory, use_dictionary, read_dictionary, use_legacy_dataset
):
# Test that null values are deterministically initialized to zero
# after a roundtrip through Parquet.
# See ARROW-8006 and ARROW-8011.
orig_table = pa.Table.from_pydict({"col": array_factory()})
bio = io.BytesIO()
pq.write_table(orig_table, bio, use_dictionary=True)
bio.seek(0)
read_dictionary = ['col'] if read_dictionary else None
table = pq.read_table(bio, use_threads=False,
read_dictionary=read_dictionary,
use_legacy_dataset=use_legacy_dataset)
for col in table.columns:
[chunk] = col.chunks
buf = chunk.buffers()[1]
assert buf.to_pybytes() == buf.size * b"\0"
def test_parquet_compression_roundtrip(tempdir):
# ARROW-10480: ensure even with nonstandard Parquet file naming
# conventions, writing and then reading a file works. In
# particular, ensure that we don't automatically double-compress
# the stream due to auto-detecting the extension in the filename
table = pa.table([pa.array(range(4))], names=["ints"])
path = tempdir / "arrow-10480.pyarrow.gz"
pq.write_table(table, path, compression="GZIP")
result = pq.read_table(path)
assert result.equals(table)
def test_empty_row_groups(tempdir):
# ARROW-3020
table = pa.Table.from_arrays([pa.array([], type='int32')], ['f0'])
path = tempdir / 'empty_row_groups.parquet'
num_groups = 3
with pq.ParquetWriter(path, table.schema) as writer:
for i in range(num_groups):
writer.write_table(table)
reader = pq.ParquetFile(path)
assert reader.metadata.num_row_groups == num_groups
for i in range(num_groups):
assert reader.read_row_group(i).equals(table)
def test_reads_over_batch(tempdir):
data = [None] * (1 << 20)
data.append([1])
# Large list<int64> with mostly nones and one final
# value. This should force batched reads when
# reading back.
table = pa.Table.from_arrays([data], ['column'])
path = tempdir / 'arrow-11607.parquet'
pq.write_table(table, path)
table2 = pq.read_table(path)
assert table == table2
@pytest.mark.dataset
def test_permutation_of_column_order(tempdir):
# ARROW-2366
case = tempdir / "dataset_column_order_permutation"
case.mkdir(exist_ok=True)
data1 = pa.table([[1, 2, 3], [.1, .2, .3]], names=['a', 'b'])
pq.write_table(data1, case / "data1.parquet")
data2 = pa.table([[.4, .5, .6], [4, 5, 6]], names=['b', 'a'])
pq.write_table(data2, case / "data2.parquet")
table = pq.read_table(str(case))
table2 = pa.table([[1, 2, 3, 4, 5, 6],
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]],
names=['a', 'b'])
assert table == table2
def test_read_table_legacy_deprecated(tempdir):
# ARROW-15870
table = pa.table({'a': [1, 2, 3]})
path = tempdir / 'data.parquet'
pq.write_table(table, path)
with pytest.warns(
FutureWarning, match="Passing 'use_legacy_dataset=True'"
):
pq.read_table(path, use_legacy_dataset=True)
@@ -0,0 +1,114 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import pyarrow as pa
from pyarrow.tests.parquet.common import parametrize_legacy_dataset
try:
import pyarrow.parquet as pq
from pyarrow.tests.parquet.common import (_read_table,
_check_roundtrip)
except ImportError:
pq = None
try:
import pandas as pd
import pandas.testing as tm
from pyarrow.tests.parquet.common import _roundtrip_pandas_dataframe
except ImportError:
pd = tm = None
# Tests for ARROW-11497
_test_data_simple = [
{'items': [1, 2]},
{'items': [0]},
]
_test_data_complex = [
{'items': [{'name': 'elem1', 'value': '1'},
{'name': 'elem2', 'value': '2'}]},
{'items': [{'name': 'elem1', 'value': '0'}]},
]
parametrize_test_data = pytest.mark.parametrize(
"test_data", [_test_data_simple, _test_data_complex])
@pytest.mark.pandas
@parametrize_legacy_dataset
@parametrize_test_data
def test_write_compliant_nested_type_enable(tempdir,
use_legacy_dataset, test_data):
# prepare dataframe for testing
df = pd.DataFrame(data=test_data)
# verify that we can read/write pandas df with new flag
_roundtrip_pandas_dataframe(df,
write_kwargs={
'use_compliant_nested_type': True},
use_legacy_dataset=use_legacy_dataset)
# Write to a parquet file with compliant nested type
table = pa.Table.from_pandas(df, preserve_index=False)
path = str(tempdir / 'data.parquet')
with pq.ParquetWriter(path, table.schema,
use_compliant_nested_type=True,
version='2.6') as writer:
writer.write_table(table)
# Read back as a table
new_table = _read_table(path)
# Validate that "items" columns compliant to Parquet nested format
# Should be like this: list<element: struct<name: string, value: string>>
assert isinstance(new_table.schema.types[0], pa.ListType)
assert new_table.schema.types[0].value_field.name == 'element'
# Verify that the new table can be read/written correctly
_check_roundtrip(new_table,
use_legacy_dataset=use_legacy_dataset,
use_compliant_nested_type=True)
@pytest.mark.pandas
@parametrize_legacy_dataset
@parametrize_test_data
def test_write_compliant_nested_type_disable(tempdir,
use_legacy_dataset, test_data):
# prepare dataframe for testing
df = pd.DataFrame(data=test_data)
# verify that we can read/write with new flag disabled (default behaviour)
_roundtrip_pandas_dataframe(df, write_kwargs={},
use_legacy_dataset=use_legacy_dataset)
# Write to a parquet file while disabling compliant nested type
table = pa.Table.from_pandas(df, preserve_index=False)
path = str(tempdir / 'data.parquet')
with pq.ParquetWriter(path, table.schema, version='2.6') as writer:
writer.write_table(table)
new_table = _read_table(path)
# Validate that "items" columns is not compliant to Parquet nested format
# Should be like this: list<item: struct<name: string, value: string>>
assert isinstance(new_table.schema.types[0], pa.ListType)
assert new_table.schema.types[0].value_field.name == 'item'
# Verify that the new table can be read/written correctly
_check_roundtrip(new_table,
use_legacy_dataset=use_legacy_dataset,
use_compliant_nested_type=False)
@@ -0,0 +1,526 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import decimal
import io
import numpy as np
import pytest
import pyarrow as pa
from pyarrow.tests import util
from pyarrow.tests.parquet.common import (_check_roundtrip,
parametrize_legacy_dataset)
try:
import pyarrow.parquet as pq
from pyarrow.tests.parquet.common import _read_table, _write_table
except ImportError:
pq = None
try:
import pandas as pd
import pandas.testing as tm
from pyarrow.tests.pandas_examples import (dataframe_with_arrays,
dataframe_with_lists)
from pyarrow.tests.parquet.common import alltypes_sample
except ImportError:
pd = tm = None
# General roundtrip of data types
# -----------------------------------------------------------------------------
@pytest.mark.pandas
@parametrize_legacy_dataset
@pytest.mark.parametrize('chunk_size', [None, 1000])
def test_parquet_2_0_roundtrip(tempdir, chunk_size, use_legacy_dataset):
df = alltypes_sample(size=10000, categorical=True)
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df)
assert arrow_table.schema.pandas_metadata is not None
_write_table(arrow_table, filename, version='2.6',
coerce_timestamps='ms', chunk_size=chunk_size)
table_read = pq.read_pandas(
filename, use_legacy_dataset=use_legacy_dataset)
assert table_read.schema.pandas_metadata is not None
read_metadata = table_read.schema.metadata
assert arrow_table.schema.metadata == read_metadata
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_parquet_1_0_roundtrip(tempdir, use_legacy_dataset):
size = 10000
np.random.seed(0)
df = pd.DataFrame({
'uint8': np.arange(size, dtype=np.uint8),
'uint16': np.arange(size, dtype=np.uint16),
'uint32': np.arange(size, dtype=np.uint32),
'uint64': np.arange(size, dtype=np.uint64),
'int8': np.arange(size, dtype=np.int16),
'int16': np.arange(size, dtype=np.int16),
'int32': np.arange(size, dtype=np.int32),
'int64': np.arange(size, dtype=np.int64),
'float32': np.arange(size, dtype=np.float32),
'float64': np.arange(size, dtype=np.float64),
'bool': np.random.randn(size) > 0,
'str': [str(x) for x in range(size)],
'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None],
'empty_str': [''] * size
})
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df)
_write_table(arrow_table, filename, version='1.0')
table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
# We pass uint32_t as int64_t if we write Parquet version 1.0
df['uint32'] = df['uint32'].values.astype(np.int64)
tm.assert_frame_equal(df, df_read)
# Dictionary
# -----------------------------------------------------------------------------
def _simple_table_write_read(table, use_legacy_dataset):
bio = pa.BufferOutputStream()
pq.write_table(table, bio)
contents = bio.getvalue()
return pq.read_table(
pa.BufferReader(contents), use_legacy_dataset=use_legacy_dataset
)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_direct_read_dictionary(use_legacy_dataset):
# ARROW-3325
repeats = 10
nunique = 5
data = [
[util.rands(10) for i in range(nunique)] * repeats,
]
table = pa.table(data, names=['f0'])
bio = pa.BufferOutputStream()
pq.write_table(table, bio)
contents = bio.getvalue()
result = pq.read_table(pa.BufferReader(contents),
read_dictionary=['f0'],
use_legacy_dataset=use_legacy_dataset)
# Compute dictionary-encoded subfield
expected = pa.table([table[0].dictionary_encode()], names=['f0'])
assert result.equals(expected)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_direct_read_dictionary_subfield(use_legacy_dataset):
repeats = 10
nunique = 5
data = [
[[util.rands(10)] for i in range(nunique)] * repeats,
]
table = pa.table(data, names=['f0'])
bio = pa.BufferOutputStream()
pq.write_table(table, bio)
contents = bio.getvalue()
result = pq.read_table(pa.BufferReader(contents),
read_dictionary=['f0.list.item'],
use_legacy_dataset=use_legacy_dataset)
arr = pa.array(data[0])
values_as_dict = arr.values.dictionary_encode()
inner_indices = values_as_dict.indices.cast('int32')
new_values = pa.DictionaryArray.from_arrays(inner_indices,
values_as_dict.dictionary)
offsets = pa.array(range(51), type='int32')
expected_arr = pa.ListArray.from_arrays(offsets, new_values)
expected = pa.table([expected_arr], names=['f0'])
assert result.equals(expected)
assert result[0].num_chunks == 1
@parametrize_legacy_dataset
def test_dictionary_array_automatically_read(use_legacy_dataset):
# ARROW-3246
# Make a large dictionary, a little over 4MB of data
dict_length = 4000
dict_values = pa.array([('x' * 1000 + '_{}'.format(i))
for i in range(dict_length)])
num_chunks = 10
chunk_size = 100
chunks = []
for i in range(num_chunks):
indices = np.random.randint(0, dict_length,
size=chunk_size).astype(np.int32)
chunks.append(pa.DictionaryArray.from_arrays(pa.array(indices),
dict_values))
table = pa.table([pa.chunked_array(chunks)], names=['f0'])
result = _simple_table_write_read(table, use_legacy_dataset)
assert result.equals(table)
# The only key in the metadata was the Arrow schema key
assert result.schema.metadata is None
# Decimal
# -----------------------------------------------------------------------------
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_decimal_roundtrip(tempdir, use_legacy_dataset):
num_values = 10
columns = {}
for precision in range(1, 39):
for scale in range(0, precision + 1):
with util.random_seed(0):
random_decimal_values = [
util.randdecimal(precision, scale)
for _ in range(num_values)
]
column_name = ('dec_precision_{:d}_scale_{:d}'
.format(precision, scale))
columns[column_name] = random_decimal_values
expected = pd.DataFrame(columns)
filename = tempdir / 'decimals.parquet'
string_filename = str(filename)
table = pa.Table.from_pandas(expected)
_write_table(table, string_filename)
result_table = _read_table(
string_filename, use_legacy_dataset=use_legacy_dataset)
result = result_table.to_pandas()
tm.assert_frame_equal(result, expected)
@pytest.mark.pandas
@pytest.mark.xfail(
raises=OSError, reason='Parquet does not support negative scale'
)
def test_decimal_roundtrip_negative_scale(tempdir):
expected = pd.DataFrame({'decimal_num': [decimal.Decimal('1.23E4')]})
filename = tempdir / 'decimals.parquet'
string_filename = str(filename)
t = pa.Table.from_pandas(expected)
_write_table(t, string_filename)
result_table = _read_table(string_filename)
result = result_table.to_pandas()
tm.assert_frame_equal(result, expected)
# List types
# -----------------------------------------------------------------------------
@parametrize_legacy_dataset
@pytest.mark.parametrize('dtype', [int, float])
def test_single_pylist_column_roundtrip(tempdir, dtype, use_legacy_dataset):
filename = tempdir / 'single_{}_column.parquet'.format(dtype.__name__)
data = [pa.array(list(map(dtype, range(5))))]
table = pa.Table.from_arrays(data, names=['a'])
_write_table(table, filename)
table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset)
for i in range(table.num_columns):
col_written = table[i]
col_read = table_read[i]
assert table.field(i).name == table_read.field(i).name
assert col_read.num_chunks == 1
data_written = col_written.chunk(0)
data_read = col_read.chunk(0)
assert data_written.equals(data_read)
@parametrize_legacy_dataset
def test_empty_lists_table_roundtrip(use_legacy_dataset):
# ARROW-2744: Shouldn't crash when writing an array of empty lists
arr = pa.array([[], []], type=pa.list_(pa.int32()))
table = pa.Table.from_arrays([arr], ["A"])
_check_roundtrip(table, use_legacy_dataset=use_legacy_dataset)
@parametrize_legacy_dataset
def test_nested_list_nonnullable_roundtrip_bug(use_legacy_dataset):
# Reproduce failure in ARROW-5630
typ = pa.list_(pa.field("item", pa.float32(), False))
num_rows = 10000
t = pa.table([
pa.array(([[0] * ((i + 5) % 10) for i in range(0, 10)] *
(num_rows // 10)), type=typ)
], ['a'])
_check_roundtrip(
t, data_page_size=4096, use_legacy_dataset=use_legacy_dataset)
@parametrize_legacy_dataset
def test_nested_list_struct_multiple_batches_roundtrip(
tempdir, use_legacy_dataset
):
# Reproduce failure in ARROW-11024
data = [[{'x': 'abc', 'y': 'abc'}]]*100 + [[{'x': 'abc', 'y': 'gcb'}]]*100
table = pa.table([pa.array(data)], names=['column'])
_check_roundtrip(
table, row_group_size=20, use_legacy_dataset=use_legacy_dataset)
# Reproduce failure in ARROW-11069 (plain non-nested structs with strings)
data = pa.array(
[{'a': '1', 'b': '2'}, {'a': '3', 'b': '4'}, {'a': '5', 'b': '6'}]*10
)
table = pa.table({'column': data})
_check_roundtrip(
table, row_group_size=10, use_legacy_dataset=use_legacy_dataset)
def test_writing_empty_lists():
# ARROW-2591: [Python] Segmentation fault issue in pq.write_table
arr1 = pa.array([[], []], pa.list_(pa.int32()))
table = pa.Table.from_arrays([arr1], ['list(int32)'])
_check_roundtrip(table)
@pytest.mark.pandas
def test_column_of_arrays(tempdir):
df, schema = dataframe_with_arrays()
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df, schema=schema)
_write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms')
table_read = _read_table(filename)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
@pytest.mark.pandas
def test_column_of_lists(tempdir):
df, schema = dataframe_with_lists(parquet_compatible=True)
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df, schema=schema)
_write_table(arrow_table, filename, version='2.6')
table_read = _read_table(filename)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
def test_large_list_records():
# This was fixed in PARQUET-1100
list_lengths = np.random.randint(0, 500, size=50)
list_lengths[::10] = 0
list_values = [list(map(int, np.random.randint(0, 100, size=x)))
if i % 8 else None
for i, x in enumerate(list_lengths)]
a1 = pa.array(list_values)
table = pa.Table.from_arrays([a1], ['int_lists'])
_check_roundtrip(table)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_parquet_nested_convenience(tempdir, use_legacy_dataset):
# ARROW-1684
df = pd.DataFrame({
'a': [[1, 2, 3], None, [4, 5], []],
'b': [[1.], None, None, [6., 7.]],
})
path = str(tempdir / 'nested_convenience.parquet')
table = pa.Table.from_pandas(df, preserve_index=False)
_write_table(table, path)
read = pq.read_table(
path, columns=['a'], use_legacy_dataset=use_legacy_dataset)
tm.assert_frame_equal(read.to_pandas(), df[['a']])
read = pq.read_table(
path, columns=['a', 'b'], use_legacy_dataset=use_legacy_dataset)
tm.assert_frame_equal(read.to_pandas(), df)
# Binary
# -----------------------------------------------------------------------------
def test_fixed_size_binary():
t0 = pa.binary(10)
data = [b'fooooooooo', None, b'barooooooo', b'quxooooooo']
a0 = pa.array(data, type=t0)
table = pa.Table.from_arrays([a0],
['binary[10]'])
_check_roundtrip(table)
# Large types
# -----------------------------------------------------------------------------
@pytest.mark.slow
@pytest.mark.large_memory
def test_large_table_int32_overflow():
size = np.iinfo('int32').max + 1
arr = np.ones(size, dtype='uint8')
parr = pa.array(arr, type=pa.uint8())
table = pa.Table.from_arrays([parr], names=['one'])
f = io.BytesIO()
_write_table(table, f)
def _simple_table_roundtrip(table, use_legacy_dataset=False, **write_kwargs):
stream = pa.BufferOutputStream()
_write_table(table, stream, **write_kwargs)
buf = stream.getvalue()
return _read_table(buf, use_legacy_dataset=use_legacy_dataset)
@pytest.mark.slow
@pytest.mark.large_memory
@parametrize_legacy_dataset
def test_byte_array_exactly_2gb(use_legacy_dataset):
# Test edge case reported in ARROW-3762
val = b'x' * (1 << 10)
base = pa.array([val] * ((1 << 21) - 1))
cases = [
[b'x' * 1023], # 2^31 - 1
[b'x' * 1024], # 2^31
[b'x' * 1025] # 2^31 + 1
]
for case in cases:
values = pa.chunked_array([base, pa.array(case)])
t = pa.table([values], names=['f0'])
result = _simple_table_roundtrip(
t, use_legacy_dataset=use_legacy_dataset, use_dictionary=False)
assert t.equals(result)
@pytest.mark.slow
@pytest.mark.pandas
@pytest.mark.large_memory
@parametrize_legacy_dataset
def test_binary_array_overflow_to_chunked(use_legacy_dataset):
# ARROW-3762
# 2^31 + 1 bytes
values = [b'x'] + [
b'x' * (1 << 20)
] * 2 * (1 << 10)
df = pd.DataFrame({'byte_col': values})
tbl = pa.Table.from_pandas(df, preserve_index=False)
read_tbl = _simple_table_roundtrip(
tbl, use_legacy_dataset=use_legacy_dataset)
col0_data = read_tbl[0]
assert isinstance(col0_data, pa.ChunkedArray)
# Split up into 2GB chunks
assert col0_data.num_chunks == 2
assert tbl.equals(read_tbl)
@pytest.mark.slow
@pytest.mark.pandas
@pytest.mark.large_memory
@parametrize_legacy_dataset
def test_list_of_binary_large_cell(use_legacy_dataset):
# ARROW-4688
data = []
# TODO(wesm): handle chunked children
# 2^31 - 1 bytes in a single cell
# data.append([b'x' * (1 << 20)] * 2047 + [b'x' * ((1 << 20) - 1)])
# A little under 2GB in cell each containing approximately 10MB each
data.extend([[b'x' * 1000000] * 10] * 214)
arr = pa.array(data)
table = pa.Table.from_arrays([arr], ['chunky_cells'])
read_table = _simple_table_roundtrip(
table, use_legacy_dataset=use_legacy_dataset)
assert table.equals(read_table)
def test_large_binary():
data = [b'foo', b'bar'] * 50
for type in [pa.large_binary(), pa.large_string()]:
arr = pa.array(data, type=type)
table = pa.Table.from_arrays([arr], names=['strs'])
for use_dictionary in [False, True]:
_check_roundtrip(table, use_dictionary=use_dictionary)
@pytest.mark.slow
@pytest.mark.large_memory
def test_large_binary_huge():
s = b'xy' * 997
data = [s] * ((1 << 33) // len(s))
for type in [pa.large_binary(), pa.large_string()]:
arr = pa.array(data, type=type)
table = pa.Table.from_arrays([arr], names=['strs'])
for use_dictionary in [False, True]:
_check_roundtrip(table, use_dictionary=use_dictionary)
del arr, table
@pytest.mark.large_memory
def test_large_binary_overflow():
s = b'x' * (1 << 31)
arr = pa.array([s], type=pa.large_binary())
table = pa.Table.from_arrays([arr], names=['strs'])
for use_dictionary in [False, True]:
writer = pa.BufferOutputStream()
with pytest.raises(
pa.ArrowInvalid,
match="Parquet cannot store strings with size 2GB or more"):
_write_table(table, writer, use_dictionary=use_dictionary)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,446 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import io
import numpy as np
import pytest
import pyarrow as pa
from pyarrow.tests.parquet.common import (
_check_roundtrip, parametrize_legacy_dataset)
try:
import pyarrow.parquet as pq
from pyarrow.tests.parquet.common import _read_table, _write_table
except ImportError:
pq = None
try:
import pandas as pd
import pandas.testing as tm
from pyarrow.tests.parquet.common import _roundtrip_pandas_dataframe
except ImportError:
pd = tm = None
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pandas_parquet_datetime_tz(use_legacy_dataset):
s = pd.Series([datetime.datetime(2017, 9, 6)])
s = s.dt.tz_localize('utc')
s.index = s
# Both a column and an index to hit both use cases
df = pd.DataFrame({'tz_aware': s,
'tz_eastern': s.dt.tz_convert('US/Eastern')},
index=s)
f = io.BytesIO()
arrow_table = pa.Table.from_pandas(df)
_write_table(arrow_table, f, coerce_timestamps='ms')
f.seek(0)
table_read = pq.read_pandas(f, use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_datetime_timezone_tzinfo(use_legacy_dataset):
value = datetime.datetime(2018, 1, 1, 1, 23, 45,
tzinfo=datetime.timezone.utc)
df = pd.DataFrame({'foo': [value]})
_roundtrip_pandas_dataframe(
df, write_kwargs={}, use_legacy_dataset=use_legacy_dataset)
@pytest.mark.pandas
def test_coerce_timestamps(tempdir):
from collections import OrderedDict
# ARROW-622
arrays = OrderedDict()
fields = [pa.field('datetime64',
pa.list_(pa.timestamp('ms')))]
arrays['datetime64'] = [
np.array(['2007-07-13T01:23:34.123456789',
None,
'2010-08-13T05:46:57.437699912'],
dtype='datetime64[ms]'),
None,
None,
np.array(['2007-07-13T02',
None,
'2010-08-13T05:46:57.437699912'],
dtype='datetime64[ms]'),
]
df = pd.DataFrame(arrays)
schema = pa.schema(fields)
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df, schema=schema)
_write_table(arrow_table, filename, version='2.6', coerce_timestamps='us')
table_read = _read_table(filename)
df_read = table_read.to_pandas()
df_expected = df.copy()
for i, x in enumerate(df_expected['datetime64']):
if isinstance(x, np.ndarray):
df_expected['datetime64'][i] = x.astype('M8[us]')
tm.assert_frame_equal(df_expected, df_read)
with pytest.raises(ValueError):
_write_table(arrow_table, filename, version='2.6',
coerce_timestamps='unknown')
@pytest.mark.pandas
def test_coerce_timestamps_truncated(tempdir):
"""
ARROW-2555: Test that we can truncate timestamps when coercing if
explicitly allowed.
"""
dt_us = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1,
second=1, microsecond=1)
dt_ms = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1,
second=1)
fields_us = [pa.field('datetime64', pa.timestamp('us'))]
arrays_us = {'datetime64': [dt_us, dt_ms]}
df_us = pd.DataFrame(arrays_us)
schema_us = pa.schema(fields_us)
filename = tempdir / 'pandas_truncated.parquet'
table_us = pa.Table.from_pandas(df_us, schema=schema_us)
_write_table(table_us, filename, version='2.6', coerce_timestamps='ms',
allow_truncated_timestamps=True)
table_ms = _read_table(filename)
df_ms = table_ms.to_pandas()
arrays_expected = {'datetime64': [dt_ms, dt_ms]}
df_expected = pd.DataFrame(arrays_expected)
tm.assert_frame_equal(df_expected, df_ms)
@pytest.mark.pandas
def test_date_time_types(tempdir):
t1 = pa.date32()
data1 = np.array([17259, 17260, 17261], dtype='int32')
a1 = pa.array(data1, type=t1)
t2 = pa.date64()
data2 = data1.astype('int64') * 86400000
a2 = pa.array(data2, type=t2)
t3 = pa.timestamp('us')
start = pd.Timestamp('2001-01-01').value / 1000
data3 = np.array([start, start + 1, start + 2], dtype='int64')
a3 = pa.array(data3, type=t3)
t4 = pa.time32('ms')
data4 = np.arange(3, dtype='i4')
a4 = pa.array(data4, type=t4)
t5 = pa.time64('us')
a5 = pa.array(data4.astype('int64'), type=t5)
t6 = pa.time32('s')
a6 = pa.array(data4, type=t6)
ex_t6 = pa.time32('ms')
ex_a6 = pa.array(data4 * 1000, type=ex_t6)
t7 = pa.timestamp('ns')
start = pd.Timestamp('2001-01-01').value
data7 = np.array([start, start + 1000, start + 2000],
dtype='int64')
a7 = pa.array(data7, type=t7)
table = pa.Table.from_arrays([a1, a2, a3, a4, a5, a6, a7],
['date32', 'date64', 'timestamp[us]',
'time32[s]', 'time64[us]',
'time32_from64[s]',
'timestamp[ns]'])
# date64 as date32
# time32[s] to time32[ms]
expected = pa.Table.from_arrays([a1, a1, a3, a4, a5, ex_a6, a7],
['date32', 'date64', 'timestamp[us]',
'time32[s]', 'time64[us]',
'time32_from64[s]',
'timestamp[ns]'])
_check_roundtrip(table, expected=expected, version='2.6')
t0 = pa.timestamp('ms')
data0 = np.arange(4, dtype='int64')
a0 = pa.array(data0, type=t0)
t1 = pa.timestamp('us')
data1 = np.arange(4, dtype='int64')
a1 = pa.array(data1, type=t1)
t2 = pa.timestamp('ns')
data2 = np.arange(4, dtype='int64')
a2 = pa.array(data2, type=t2)
table = pa.Table.from_arrays([a0, a1, a2],
['ts[ms]', 'ts[us]', 'ts[ns]'])
expected = pa.Table.from_arrays([a0, a1, a2],
['ts[ms]', 'ts[us]', 'ts[ns]'])
# int64 for all timestamps supported by default
filename = tempdir / 'int64_timestamps.parquet'
_write_table(table, filename, version='2.6')
parquet_schema = pq.ParquetFile(filename).schema
for i in range(3):
assert parquet_schema.column(i).physical_type == 'INT64'
read_table = _read_table(filename)
assert read_table.equals(expected)
t0_ns = pa.timestamp('ns')
data0_ns = np.array(data0 * 1000000, dtype='int64')
a0_ns = pa.array(data0_ns, type=t0_ns)
t1_ns = pa.timestamp('ns')
data1_ns = np.array(data1 * 1000, dtype='int64')
a1_ns = pa.array(data1_ns, type=t1_ns)
expected = pa.Table.from_arrays([a0_ns, a1_ns, a2],
['ts[ms]', 'ts[us]', 'ts[ns]'])
# int96 nanosecond timestamps produced upon request
filename = tempdir / 'explicit_int96_timestamps.parquet'
_write_table(table, filename, version='2.6',
use_deprecated_int96_timestamps=True)
parquet_schema = pq.ParquetFile(filename).schema
for i in range(3):
assert parquet_schema.column(i).physical_type == 'INT96'
read_table = _read_table(filename)
assert read_table.equals(expected)
# int96 nanosecond timestamps implied by flavor 'spark'
filename = tempdir / 'spark_int96_timestamps.parquet'
_write_table(table, filename, version='2.6',
flavor='spark')
parquet_schema = pq.ParquetFile(filename).schema
for i in range(3):
assert parquet_schema.column(i).physical_type == 'INT96'
read_table = _read_table(filename)
assert read_table.equals(expected)
@pytest.mark.pandas
@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns'])
def test_coerce_int96_timestamp_unit(unit):
i_s = pd.Timestamp('2010-01-01').value / 1000000000 # := 1262304000
d_s = np.arange(i_s, i_s + 10, 1, dtype='int64')
d_ms = d_s * 1000
d_us = d_ms * 1000
d_ns = d_us * 1000
a_s = pa.array(d_s, type=pa.timestamp('s'))
a_ms = pa.array(d_ms, type=pa.timestamp('ms'))
a_us = pa.array(d_us, type=pa.timestamp('us'))
a_ns = pa.array(d_ns, type=pa.timestamp('ns'))
arrays = {"s": a_s, "ms": a_ms, "us": a_us, "ns": a_ns}
names = ['ts_s', 'ts_ms', 'ts_us', 'ts_ns']
table = pa.Table.from_arrays([a_s, a_ms, a_us, a_ns], names)
# For either Parquet version, coercing to nanoseconds is allowed
# if Int96 storage is used
expected = pa.Table.from_arrays([arrays.get(unit)]*4, names)
read_table_kwargs = {"coerce_int96_timestamp_unit": unit}
_check_roundtrip(table, expected,
read_table_kwargs=read_table_kwargs,
use_deprecated_int96_timestamps=True)
_check_roundtrip(table, expected, version='2.6',
read_table_kwargs=read_table_kwargs,
use_deprecated_int96_timestamps=True)
@pytest.mark.pandas
@pytest.mark.parametrize('pq_reader_method', ['ParquetFile', 'read_table'])
def test_coerce_int96_timestamp_overflow(pq_reader_method, tempdir):
def get_table(pq_reader_method, filename, **kwargs):
if pq_reader_method == "ParquetFile":
return pq.ParquetFile(filename, **kwargs).read()
elif pq_reader_method == "read_table":
return pq.read_table(filename, **kwargs)
# Recreating the initial JIRA issue referenced in ARROW-12096
oob_dts = [
datetime.datetime(1000, 1, 1),
datetime.datetime(2000, 1, 1),
datetime.datetime(3000, 1, 1)
]
df = pd.DataFrame({"a": oob_dts})
table = pa.table(df)
filename = tempdir / "test_round_trip_overflow.parquet"
pq.write_table(table, filename, use_deprecated_int96_timestamps=True,
version="1.0")
# with the default resolution of ns, we get wrong values for INT96
# that are out of bounds for nanosecond range
tab_error = get_table(pq_reader_method, filename)
assert tab_error["a"].to_pylist() != oob_dts
# avoid this overflow by specifying the resolution to use for INT96 values
tab_correct = get_table(
pq_reader_method, filename, coerce_int96_timestamp_unit="s"
)
df_correct = tab_correct.to_pandas(timestamp_as_object=True)
tm.assert_frame_equal(df, df_correct)
def test_timestamp_restore_timezone():
# ARROW-5888, restore timezone from serialized metadata
ty = pa.timestamp('ms', tz='America/New_York')
arr = pa.array([1, 2, 3], type=ty)
t = pa.table([arr], names=['f0'])
_check_roundtrip(t)
def test_timestamp_restore_timezone_nanosecond():
# ARROW-9634, also restore timezone for nanosecond data that get stored
# as microseconds in the parquet file
ty = pa.timestamp('ns', tz='America/New_York')
arr = pa.array([1000, 2000, 3000], type=ty)
table = pa.table([arr], names=['f0'])
ty_us = pa.timestamp('us', tz='America/New_York')
expected = pa.table([arr.cast(ty_us)], names=['f0'])
_check_roundtrip(table, expected=expected)
@pytest.mark.pandas
def test_list_of_datetime_time_roundtrip():
# ARROW-4135
times = pd.to_datetime(['09:00', '09:30', '10:00', '10:30', '11:00',
'11:30', '12:00'])
df = pd.DataFrame({'time': [times.time]})
_roundtrip_pandas_dataframe(df, write_kwargs={})
@pytest.mark.pandas
def test_parquet_version_timestamp_differences():
i_s = pd.Timestamp('2010-01-01').value / 1000000000 # := 1262304000
d_s = np.arange(i_s, i_s + 10, 1, dtype='int64')
d_ms = d_s * 1000
d_us = d_ms * 1000
d_ns = d_us * 1000
a_s = pa.array(d_s, type=pa.timestamp('s'))
a_ms = pa.array(d_ms, type=pa.timestamp('ms'))
a_us = pa.array(d_us, type=pa.timestamp('us'))
a_ns = pa.array(d_ns, type=pa.timestamp('ns'))
names = ['ts:s', 'ts:ms', 'ts:us', 'ts:ns']
table = pa.Table.from_arrays([a_s, a_ms, a_us, a_ns], names)
# Using Parquet version 1.0, seconds should be coerced to milliseconds
# and nanoseconds should be coerced to microseconds by default
expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_us], names)
_check_roundtrip(table, expected)
# Using Parquet version 2.0, seconds should be coerced to milliseconds
# and nanoseconds should be retained by default
expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_ns], names)
_check_roundtrip(table, expected, version='2.6')
# Using Parquet version 1.0, coercing to milliseconds or microseconds
# is allowed
expected = pa.Table.from_arrays([a_ms, a_ms, a_ms, a_ms], names)
_check_roundtrip(table, expected, coerce_timestamps='ms')
# Using Parquet version 2.0, coercing to milliseconds or microseconds
# is allowed
expected = pa.Table.from_arrays([a_us, a_us, a_us, a_us], names)
_check_roundtrip(table, expected, version='2.6', coerce_timestamps='us')
# TODO: after pyarrow allows coerce_timestamps='ns', tests like the
# following should pass ...
# Using Parquet version 1.0, coercing to nanoseconds is not allowed
# expected = None
# with pytest.raises(NotImplementedError):
# _roundtrip_table(table, coerce_timestamps='ns')
# Using Parquet version 2.0, coercing to nanoseconds is allowed
# expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names)
# _check_roundtrip(table, expected, version='2.6', coerce_timestamps='ns')
# For either Parquet version, coercing to nanoseconds is allowed
# if Int96 storage is used
expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names)
_check_roundtrip(table, expected,
use_deprecated_int96_timestamps=True)
_check_roundtrip(table, expected, version='2.6',
use_deprecated_int96_timestamps=True)
@pytest.mark.pandas
def test_noncoerced_nanoseconds_written_without_exception(tempdir):
# ARROW-1957: the Parquet version 2.0 writer preserves Arrow
# nanosecond timestamps by default
n = 9
df = pd.DataFrame({'x': range(n)},
index=pd.date_range('2017-01-01', freq='1n', periods=n))
tb = pa.Table.from_pandas(df)
filename = tempdir / 'written.parquet'
try:
pq.write_table(tb, filename, version='2.6')
except Exception:
pass
assert filename.exists()
recovered_table = pq.read_table(filename)
assert tb.equals(recovered_table)
# Loss of data through coercion (without explicit override) still an error
filename = tempdir / 'not_written.parquet'
with pytest.raises(ValueError):
pq.write_table(tb, filename, coerce_timestamps='ms', version='2.6')
def test_duration_type():
# ARROW-6780
arrays = [pa.array([0, 1, 2, 3], type=pa.duration(unit))
for unit in ["s", "ms", "us", "ns"]]
table = pa.Table.from_arrays(arrays, ["d[s]", "d[ms]", "d[us]", "d[ns]"])
_check_roundtrip(table)
@@ -0,0 +1,530 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from datetime import timedelta
import pyarrow as pa
try:
import pyarrow.parquet as pq
import pyarrow.parquet.encryption as pe
except ImportError:
pq = None
pe = None
else:
from pyarrow.tests.parquet.encryption import (
InMemoryKmsClient, verify_file_encrypted)
PARQUET_NAME = 'encrypted_table.in_mem.parquet'
FOOTER_KEY = b"0123456789112345"
FOOTER_KEY_NAME = "footer_key"
COL_KEY = b"1234567890123450"
COL_KEY_NAME = "col_key"
# Marks all of the tests in this module
# Ignore these with pytest ... -m 'not parquet_encryption'
pytestmark = pytest.mark.parquet_encryption
@pytest.fixture(scope='module')
def data_table():
data_table = pa.Table.from_pydict({
'a': pa.array([1, 2, 3]),
'b': pa.array(['a', 'b', 'c']),
'c': pa.array(['x', 'y', 'z'])
})
return data_table
@pytest.fixture(scope='module')
def basic_encryption_config():
basic_encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={
COL_KEY_NAME: ["a", "b"],
})
return basic_encryption_config
def test_encrypted_parquet_write_read(tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted, and then read it."""
path = tempdir / PARQUET_NAME
# Encrypt the footer with the footer key,
# encrypt column `a` and column `b` with another key,
# keep `c` plaintext
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={
COL_KEY_NAME: ["a", "b"],
},
encryption_algorithm="AES_GCM_V1",
cache_lifetime=timedelta(minutes=5.0),
data_key_length_bits=256)
kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)
def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
verify_file_encrypted(path)
# Read with decryption properties
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
result_table = read_encrypted_parquet(
path, decryption_config, kms_connection_config, crypto_factory)
assert data_table.equals(result_table)
def write_encrypted_parquet(path, table, encryption_config,
kms_connection_config, crypto_factory):
file_encryption_properties = crypto_factory.file_encryption_properties(
kms_connection_config, encryption_config)
assert(file_encryption_properties is not None)
with pq.ParquetWriter(
path, table.schema,
encryption_properties=file_encryption_properties) as writer:
writer.write_table(table)
def read_encrypted_parquet(path, decryption_config,
kms_connection_config, crypto_factory):
file_decryption_properties = crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)
assert(file_decryption_properties is not None)
meta = pq.read_metadata(
path, decryption_properties=file_decryption_properties)
assert(meta.num_columns == 3)
schema = pq.read_schema(
path, decryption_properties=file_decryption_properties)
assert(len(schema.names) == 3)
result = pq.ParquetFile(
path, decryption_properties=file_decryption_properties)
return result.read(use_threads=False)
def test_encrypted_parquet_write_read_wrong_key(tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted,
and then read it using wrong keys."""
path = tempdir / PARQUET_NAME
# Encrypt the footer with the footer key,
# encrypt column `a` and column `b` with another key,
# keep `c` plaintext
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={
COL_KEY_NAME: ["a", "b"],
},
encryption_algorithm="AES_GCM_V1",
cache_lifetime=timedelta(minutes=5.0),
data_key_length_bits=256)
kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)
def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
verify_file_encrypted(path)
# Read with decryption properties
wrong_kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
# Wrong keys - mixup in names
FOOTER_KEY_NAME: COL_KEY.decode("UTF-8"),
COL_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
}
)
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
with pytest.raises(ValueError, match=r"Incorrect master key used"):
read_encrypted_parquet(
path, decryption_config, wrong_kms_connection_config,
crypto_factory)
def test_encrypted_parquet_read_no_decryption_config(tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted,
but then try to read it without decryption properties."""
test_encrypted_parquet_write_read(tempdir, data_table)
# Read without decryption properties
with pytest.raises(IOError, match=r"no decryption"):
pq.ParquetFile(tempdir / PARQUET_NAME).read()
def test_encrypted_parquet_read_metadata_no_decryption_config(
tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted,
but then try to read its metadata without decryption properties."""
test_encrypted_parquet_write_read(tempdir, data_table)
# Read metadata without decryption properties
with pytest.raises(IOError, match=r"no decryption"):
pq.read_metadata(tempdir / PARQUET_NAME)
def test_encrypted_parquet_read_schema_no_decryption_config(
tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted,
but then try to read its schema without decryption properties."""
test_encrypted_parquet_write_read(tempdir, data_table)
with pytest.raises(IOError, match=r"no decryption"):
pq.read_schema(tempdir / PARQUET_NAME)
def test_encrypted_parquet_write_no_col_key(tempdir, data_table):
"""Write an encrypted parquet, but give only footer key,
without column key."""
path = tempdir / 'encrypted_table_no_col_key.in_mem.parquet'
# Encrypt the footer with the footer key
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME)
kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)
def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(OSError,
match="Either column_keys or uniform_encryption "
"must be set"):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_write_kms_error(tempdir, data_table,
basic_encryption_config):
"""Write an encrypted parquet, but raise KeyError in KmsClient."""
path = tempdir / 'encrypted_table_kms_error.in_mem.parquet'
encryption_config = basic_encryption_config
# Empty master_keys_map
kms_connection_config = pe.KmsConnectionConfig()
def kms_factory(kms_connection_configuration):
# Empty master keys map will cause KeyError to be raised
# on wrap/unwrap calls
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(KeyError, match="footer_key"):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_write_kms_specific_error(tempdir, data_table,
basic_encryption_config):
"""Write an encrypted parquet, but raise KeyError in KmsClient."""
path = tempdir / 'encrypted_table_kms_error.in_mem.parquet'
encryption_config = basic_encryption_config
# Empty master_keys_map
kms_connection_config = pe.KmsConnectionConfig()
class ThrowingKmsClient(pe.KmsClient):
"""A KmsClient implementation that throws exception in
wrap/unwrap calls
"""
def __init__(self, config):
"""Create an InMemoryKmsClient instance."""
pe.KmsClient.__init__(self)
self.config = config
def wrap_key(self, key_bytes, master_key_identifier):
raise ValueError("Cannot Wrap Key")
def unwrap_key(self, wrapped_key, master_key_identifier):
raise ValueError("Cannot Unwrap Key")
def kms_factory(kms_connection_configuration):
# Exception thrown in wrap/unwrap calls
return ThrowingKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(ValueError, match="Cannot Wrap Key"):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_write_kms_factory_error(tempdir, data_table,
basic_encryption_config):
"""Write an encrypted parquet, but raise ValueError in kms_factory."""
path = tempdir / 'encrypted_table_kms_factory_error.in_mem.parquet'
encryption_config = basic_encryption_config
# Empty master_keys_map
kms_connection_config = pe.KmsConnectionConfig()
def kms_factory(kms_connection_configuration):
raise ValueError('Cannot create KmsClient')
crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(ValueError,
match="Cannot create KmsClient"):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_write_kms_factory_type_error(
tempdir, data_table, basic_encryption_config):
"""Write an encrypted parquet, but use wrong KMS client type
that doesn't implement KmsClient."""
path = tempdir / 'encrypted_table_kms_factory_error.in_mem.parquet'
encryption_config = basic_encryption_config
# Empty master_keys_map
kms_connection_config = pe.KmsConnectionConfig()
class WrongTypeKmsClient():
"""This is not an implementation of KmsClient.
"""
def __init__(self, config):
self.master_keys_map = config.custom_kms_conf
def wrap_key(self, key_bytes, master_key_identifier):
return None
def unwrap_key(self, wrapped_key, master_key_identifier):
return None
def kms_factory(kms_connection_configuration):
return WrongTypeKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(TypeError):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_encryption_configuration():
def validate_encryption_configuration(encryption_config):
assert(FOOTER_KEY_NAME == encryption_config.footer_key)
assert(["a", "b"] == encryption_config.column_keys[COL_KEY_NAME])
assert("AES_GCM_CTR_V1" == encryption_config.encryption_algorithm)
assert(encryption_config.plaintext_footer)
assert(not encryption_config.double_wrapping)
assert(timedelta(minutes=10.0) == encryption_config.cache_lifetime)
assert(not encryption_config.internal_key_material)
assert(192 == encryption_config.data_key_length_bits)
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={COL_KEY_NAME: ["a", "b"], },
encryption_algorithm="AES_GCM_CTR_V1",
plaintext_footer=True,
double_wrapping=False,
cache_lifetime=timedelta(minutes=10.0),
internal_key_material=False,
data_key_length_bits=192,
)
validate_encryption_configuration(encryption_config)
encryption_config_1 = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME)
encryption_config_1.column_keys = {COL_KEY_NAME: ["a", "b"], }
encryption_config_1.encryption_algorithm = "AES_GCM_CTR_V1"
encryption_config_1.plaintext_footer = True
encryption_config_1.double_wrapping = False
encryption_config_1.cache_lifetime = timedelta(minutes=10.0)
encryption_config_1.internal_key_material = False
encryption_config_1.data_key_length_bits = 192
validate_encryption_configuration(encryption_config_1)
def test_encrypted_parquet_decryption_configuration():
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=10.0))
assert(timedelta(minutes=10.0) == decryption_config.cache_lifetime)
decryption_config_1 = pe.DecryptionConfiguration()
decryption_config_1.cache_lifetime = timedelta(minutes=10.0)
assert(timedelta(minutes=10.0) == decryption_config_1.cache_lifetime)
def test_encrypted_parquet_kms_configuration():
def validate_kms_connection_config(kms_connection_config):
assert("Instance1" == kms_connection_config.kms_instance_id)
assert("URL1" == kms_connection_config.kms_instance_url)
assert("MyToken" == kms_connection_config.key_access_token)
assert({"key1": "key_material_1", "key2": "key_material_2"} ==
kms_connection_config.custom_kms_conf)
kms_connection_config = pe.KmsConnectionConfig(
kms_instance_id="Instance1",
kms_instance_url="URL1",
key_access_token="MyToken",
custom_kms_conf={
"key1": "key_material_1",
"key2": "key_material_2",
})
validate_kms_connection_config(kms_connection_config)
kms_connection_config_1 = pe.KmsConnectionConfig()
kms_connection_config_1.kms_instance_id = "Instance1"
kms_connection_config_1.kms_instance_url = "URL1"
kms_connection_config_1.key_access_token = "MyToken"
kms_connection_config_1.custom_kms_conf = {
"key1": "key_material_1",
"key2": "key_material_2",
}
validate_kms_connection_config(kms_connection_config_1)
@pytest.mark.xfail(reason="Plaintext footer - reading plaintext column subset"
" reads encrypted columns too")
def test_encrypted_parquet_write_read_plain_footer_single_wrapping(
tempdir, data_table):
"""Write an encrypted parquet, with plaintext footer
and with single wrapping,
verify it's encrypted, and then read plaintext columns."""
path = tempdir / PARQUET_NAME
# Encrypt the footer with the footer key,
# encrypt column `a` and column `b` with another key,
# keep `c` plaintext
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={
COL_KEY_NAME: ["a", "b"],
},
plaintext_footer=True,
double_wrapping=False)
kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)
def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
# # Read without decryption properties only the plaintext column
# result = pq.ParquetFile(path)
# result_table = result.read(columns='c', use_threads=False)
# assert table.num_rows == result_table.num_rows
@pytest.mark.xfail(reason="External key material not supported yet")
def test_encrypted_parquet_write_external(tempdir, data_table):
"""Write an encrypted parquet, with external key
material.
Currently it's not implemented, so should throw
an exception"""
path = tempdir / PARQUET_NAME
# Encrypt the file with the footer key
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={},
internal_key_material=False)
kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8")}
)
def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
@pytest.mark.skip(reason="ARROW-14114: Multithreaded read sometimes fails"
"decryption finalization or with Segmentation fault")
def test_encrypted_parquet_loop(tempdir, data_table, basic_encryption_config):
"""Write an encrypted parquet, verify it's encrypted,
and then read it multithreaded in a loop."""
path = tempdir / PARQUET_NAME
# Encrypt the footer with the footer key,
# encrypt column `a` and column `b` with another key,
# keep `c` plaintext
encryption_config = basic_encryption_config
kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)
def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
verify_file_encrypted(path)
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
for i in range(50):
# Read with decryption properties
file_decryption_properties = crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)
assert(file_decryption_properties is not None)
result = pq.ParquetFile(
path, decryption_properties=file_decryption_properties)
result_table = result.read(use_threads=True)
assert data_table.equals(result_table)
@@ -0,0 +1,528 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import decimal
from collections import OrderedDict
import numpy as np
import pytest
import pyarrow as pa
from pyarrow.tests.parquet.common import _check_roundtrip, make_sample_file
try:
import pyarrow.parquet as pq
from pyarrow.tests.parquet.common import _write_table
except ImportError:
pq = None
try:
import pandas as pd
import pandas.testing as tm
from pyarrow.tests.parquet.common import alltypes_sample
except ImportError:
pd = tm = None
@pytest.mark.pandas
def test_parquet_metadata_api():
df = alltypes_sample(size=10000)
df = df.reindex(columns=sorted(df.columns))
df.index = np.random.randint(0, 1000000, size=len(df))
fileh = make_sample_file(df)
ncols = len(df.columns)
# Series of sniff tests
meta = fileh.metadata
repr(meta)
assert meta.num_rows == len(df)
assert meta.num_columns == ncols + 1 # +1 for index
assert meta.num_row_groups == 1
assert meta.format_version == '2.6'
assert 'parquet-cpp' in meta.created_by
assert isinstance(meta.serialized_size, int)
assert isinstance(meta.metadata, dict)
# Schema
schema = fileh.schema
assert meta.schema is schema
assert len(schema) == ncols + 1 # +1 for index
repr(schema)
col = schema[0]
repr(col)
assert col.name == df.columns[0]
assert col.max_definition_level == 1
assert col.max_repetition_level == 0
assert col.max_repetition_level == 0
assert col.physical_type == 'BOOLEAN'
assert col.converted_type == 'NONE'
with pytest.raises(IndexError):
schema[ncols + 1] # +1 for index
with pytest.raises(IndexError):
schema[-1]
# Row group
for rg in range(meta.num_row_groups):
rg_meta = meta.row_group(rg)
assert isinstance(rg_meta, pq.RowGroupMetaData)
repr(rg_meta)
for col in range(rg_meta.num_columns):
col_meta = rg_meta.column(col)
assert isinstance(col_meta, pq.ColumnChunkMetaData)
repr(col_meta)
with pytest.raises(IndexError):
meta.row_group(-1)
with pytest.raises(IndexError):
meta.row_group(meta.num_row_groups + 1)
rg_meta = meta.row_group(0)
assert rg_meta.num_rows == len(df)
assert rg_meta.num_columns == ncols + 1 # +1 for index
assert rg_meta.total_byte_size > 0
with pytest.raises(IndexError):
col_meta = rg_meta.column(-1)
with pytest.raises(IndexError):
col_meta = rg_meta.column(ncols + 2)
col_meta = rg_meta.column(0)
assert col_meta.file_offset > 0
assert col_meta.file_path == '' # created from BytesIO
assert col_meta.physical_type == 'BOOLEAN'
assert col_meta.num_values == 10000
assert col_meta.path_in_schema == 'bool'
assert col_meta.is_stats_set is True
assert isinstance(col_meta.statistics, pq.Statistics)
assert col_meta.compression == 'SNAPPY'
assert col_meta.encodings == ('PLAIN', 'RLE')
assert col_meta.has_dictionary_page is False
assert col_meta.dictionary_page_offset is None
assert col_meta.data_page_offset > 0
assert col_meta.total_compressed_size > 0
assert col_meta.total_uncompressed_size > 0
with pytest.raises(NotImplementedError):
col_meta.has_index_page
with pytest.raises(NotImplementedError):
col_meta.index_page_offset
def test_parquet_metadata_lifetime(tempdir):
# ARROW-6642 - ensure that chained access keeps parent objects alive
table = pa.table({'a': [1, 2, 3]})
pq.write_table(table, tempdir / 'test_metadata_segfault.parquet')
parquet_file = pq.ParquetFile(tempdir / 'test_metadata_segfault.parquet')
parquet_file.metadata.row_group(0).column(0).statistics
@pytest.mark.pandas
@pytest.mark.parametrize(
(
'data',
'type',
'physical_type',
'min_value',
'max_value',
'null_count',
'num_values',
'distinct_count'
),
[
([1, 2, 2, None, 4], pa.uint8(), 'INT32', 1, 4, 1, 4, 0),
([1, 2, 2, None, 4], pa.uint16(), 'INT32', 1, 4, 1, 4, 0),
([1, 2, 2, None, 4], pa.uint32(), 'INT32', 1, 4, 1, 4, 0),
([1, 2, 2, None, 4], pa.uint64(), 'INT64', 1, 4, 1, 4, 0),
([-1, 2, 2, None, 4], pa.int8(), 'INT32', -1, 4, 1, 4, 0),
([-1, 2, 2, None, 4], pa.int16(), 'INT32', -1, 4, 1, 4, 0),
([-1, 2, 2, None, 4], pa.int32(), 'INT32', -1, 4, 1, 4, 0),
([-1, 2, 2, None, 4], pa.int64(), 'INT64', -1, 4, 1, 4, 0),
(
[-1.1, 2.2, 2.3, None, 4.4], pa.float32(),
'FLOAT', -1.1, 4.4, 1, 4, 0
),
(
[-1.1, 2.2, 2.3, None, 4.4], pa.float64(),
'DOUBLE', -1.1, 4.4, 1, 4, 0
),
(
['', 'b', chr(1000), None, 'aaa'], pa.binary(),
'BYTE_ARRAY', b'', chr(1000).encode('utf-8'), 1, 4, 0
),
(
[True, False, False, True, True], pa.bool_(),
'BOOLEAN', False, True, 0, 5, 0
),
(
[b'\x00', b'b', b'12', None, b'aaa'], pa.binary(),
'BYTE_ARRAY', b'\x00', b'b', 1, 4, 0
),
]
)
def test_parquet_column_statistics_api(data, type, physical_type, min_value,
max_value, null_count, num_values,
distinct_count):
df = pd.DataFrame({'data': data})
schema = pa.schema([pa.field('data', type)])
table = pa.Table.from_pandas(df, schema=schema, safe=False)
fileh = make_sample_file(table)
meta = fileh.metadata
rg_meta = meta.row_group(0)
col_meta = rg_meta.column(0)
stat = col_meta.statistics
assert stat.has_min_max
assert _close(type, stat.min, min_value)
assert _close(type, stat.max, max_value)
assert stat.null_count == null_count
assert stat.num_values == num_values
# TODO(kszucs) until parquet-cpp API doesn't expose HasDistinctCount
# method, missing distinct_count is represented as zero instead of None
assert stat.distinct_count == distinct_count
assert stat.physical_type == physical_type
def _close(type, left, right):
if type == pa.float32():
return abs(left - right) < 1E-7
elif type == pa.float64():
return abs(left - right) < 1E-13
else:
return left == right
# ARROW-6339
@pytest.mark.pandas
def test_parquet_raise_on_unset_statistics():
df = pd.DataFrame({"t": pd.Series([pd.NaT], dtype="datetime64[ns]")})
meta = make_sample_file(pa.Table.from_pandas(df)).metadata
assert not meta.row_group(0).column(0).statistics.has_min_max
assert meta.row_group(0).column(0).statistics.max is None
def test_statistics_convert_logical_types(tempdir):
# ARROW-5166, ARROW-4139
# (min, max, type)
cases = [(10, 11164359321221007157, pa.uint64()),
(10, 4294967295, pa.uint32()),
("ähnlich", "öffentlich", pa.utf8()),
(datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000),
pa.time32('ms')),
(datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000),
pa.time64('us')),
(datetime.datetime(2019, 6, 24, 0, 0, 0, 1000),
datetime.datetime(2019, 6, 25, 0, 0, 0, 1000),
pa.timestamp('ms')),
(datetime.datetime(2019, 6, 24, 0, 0, 0, 1000),
datetime.datetime(2019, 6, 25, 0, 0, 0, 1000),
pa.timestamp('us')),
(datetime.date(2019, 6, 24),
datetime.date(2019, 6, 25),
pa.date32()),
(decimal.Decimal("20.123"),
decimal.Decimal("20.124"),
pa.decimal128(12, 5))]
for i, (min_val, max_val, typ) in enumerate(cases):
t = pa.Table.from_arrays([pa.array([min_val, max_val], type=typ)],
['col'])
path = str(tempdir / ('example{}.parquet'.format(i)))
pq.write_table(t, path, version='2.6')
pf = pq.ParquetFile(path)
stats = pf.metadata.row_group(0).column(0).statistics
assert stats.min == min_val
assert stats.max == max_val
def test_parquet_write_disable_statistics(tempdir):
table = pa.Table.from_pydict(
OrderedDict([
('a', pa.array([1, 2, 3])),
('b', pa.array(['a', 'b', 'c']))
])
)
_write_table(table, tempdir / 'data.parquet')
meta = pq.read_metadata(tempdir / 'data.parquet')
for col in [0, 1]:
cc = meta.row_group(0).column(col)
assert cc.is_stats_set is True
assert cc.statistics is not None
_write_table(table, tempdir / 'data2.parquet', write_statistics=False)
meta = pq.read_metadata(tempdir / 'data2.parquet')
for col in [0, 1]:
cc = meta.row_group(0).column(col)
assert cc.is_stats_set is False
assert cc.statistics is None
_write_table(table, tempdir / 'data3.parquet', write_statistics=['a'])
meta = pq.read_metadata(tempdir / 'data3.parquet')
cc_a = meta.row_group(0).column(0)
cc_b = meta.row_group(0).column(1)
assert cc_a.is_stats_set is True
assert cc_b.is_stats_set is False
assert cc_a.statistics is not None
assert cc_b.statistics is None
def test_field_id_metadata():
# ARROW-7080
field_id = b'PARQUET:field_id'
inner = pa.field('inner', pa.int32(), metadata={field_id: b'100'})
middle = pa.field('middle', pa.struct(
[inner]), metadata={field_id: b'101'})
fields = [
pa.field('basic', pa.int32(), metadata={
b'other': b'abc', field_id: b'1'}),
pa.field(
'list',
pa.list_(pa.field('list-inner', pa.int32(),
metadata={field_id: b'10'})),
metadata={field_id: b'11'}),
pa.field('struct', pa.struct([middle]), metadata={field_id: b'102'}),
pa.field('no-metadata', pa.int32()),
pa.field('non-integral-field-id', pa.int32(),
metadata={field_id: b'xyz'}),
pa.field('negative-field-id', pa.int32(),
metadata={field_id: b'-1000'})
]
arrs = [[] for _ in fields]
table = pa.table(arrs, schema=pa.schema(fields))
bio = pa.BufferOutputStream()
pq.write_table(table, bio)
contents = bio.getvalue()
pf = pq.ParquetFile(pa.BufferReader(contents))
schema = pf.schema_arrow
assert schema[0].metadata[field_id] == b'1'
assert schema[0].metadata[b'other'] == b'abc'
list_field = schema[1]
assert list_field.metadata[field_id] == b'11'
list_item_field = list_field.type.value_field
assert list_item_field.metadata[field_id] == b'10'
struct_field = schema[2]
assert struct_field.metadata[field_id] == b'102'
struct_middle_field = struct_field.type[0]
assert struct_middle_field.metadata[field_id] == b'101'
struct_inner_field = struct_middle_field.type[0]
assert struct_inner_field.metadata[field_id] == b'100'
assert schema[3].metadata is None
# Invalid input is passed through (ok) but does not
# have field_id in parquet (not tested)
assert schema[4].metadata[field_id] == b'xyz'
assert schema[5].metadata[field_id] == b'-1000'
@pytest.mark.pandas
def test_multi_dataset_metadata(tempdir):
filenames = ["ARROW-1983-dataset.0", "ARROW-1983-dataset.1"]
metapath = str(tempdir / "_metadata")
# create a test dataset
df = pd.DataFrame({
'one': [1, 2, 3],
'two': [-1, -2, -3],
'three': [[1, 2], [2, 3], [3, 4]],
})
table = pa.Table.from_pandas(df)
# write dataset twice and collect/merge metadata
_meta = None
for filename in filenames:
meta = []
pq.write_table(table, str(tempdir / filename),
metadata_collector=meta)
meta[0].set_file_path(filename)
if _meta is None:
_meta = meta[0]
else:
_meta.append_row_groups(meta[0])
# Write merged metadata-only file
with open(metapath, "wb") as f:
_meta.write_metadata_file(f)
# Read back the metadata
meta = pq.read_metadata(metapath)
md = meta.to_dict()
_md = _meta.to_dict()
for key in _md:
if key != 'serialized_size':
assert _md[key] == md[key]
assert _md['num_columns'] == 3
assert _md['num_rows'] == 6
assert _md['num_row_groups'] == 2
assert _md['serialized_size'] == 0
assert md['serialized_size'] > 0
def test_write_metadata(tempdir):
path = str(tempdir / "metadata")
schema = pa.schema([("a", "int64"), ("b", "float64")])
# write a pyarrow schema
pq.write_metadata(schema, path)
parquet_meta = pq.read_metadata(path)
schema_as_arrow = parquet_meta.schema.to_arrow_schema()
assert schema_as_arrow.equals(schema)
# ARROW-8980: Check that the ARROW:schema metadata key was removed
if schema_as_arrow.metadata:
assert b'ARROW:schema' not in schema_as_arrow.metadata
# pass through writer keyword arguments
for version in ["1.0", "2.0", "2.4", "2.6"]:
pq.write_metadata(schema, path, version=version)
parquet_meta = pq.read_metadata(path)
# The version is stored as a single integer in the Parquet metadata,
# so it cannot correctly express dotted format versions
expected_version = "1.0" if version == "1.0" else "2.6"
assert parquet_meta.format_version == expected_version
# metadata_collector: list of FileMetaData objects
table = pa.table({'a': [1, 2], 'b': [.1, .2]}, schema=schema)
pq.write_table(table, tempdir / "data.parquet")
parquet_meta = pq.read_metadata(str(tempdir / "data.parquet"))
pq.write_metadata(
schema, path, metadata_collector=[parquet_meta, parquet_meta]
)
parquet_meta_mult = pq.read_metadata(path)
assert parquet_meta_mult.num_row_groups == 2
# append metadata with different schema raises an error
with pytest.raises(RuntimeError, match="requires equal schemas"):
pq.write_metadata(
pa.schema([("a", "int32"), ("b", "null")]),
path, metadata_collector=[parquet_meta, parquet_meta]
)
def test_table_large_metadata():
# ARROW-8694
my_schema = pa.schema([pa.field('f0', 'double')],
metadata={'large': 'x' * 10000000})
table = pa.table([np.arange(10)], schema=my_schema)
_check_roundtrip(table)
@pytest.mark.pandas
def test_compare_schemas():
df = alltypes_sample(size=10000)
fileh = make_sample_file(df)
fileh2 = make_sample_file(df)
fileh3 = make_sample_file(df[df.columns[::2]])
# ParquetSchema
assert isinstance(fileh.schema, pq.ParquetSchema)
assert fileh.schema.equals(fileh.schema)
assert fileh.schema == fileh.schema
assert fileh.schema.equals(fileh2.schema)
assert fileh.schema == fileh2.schema
assert fileh.schema != 'arbitrary object'
assert not fileh.schema.equals(fileh3.schema)
assert fileh.schema != fileh3.schema
# ColumnSchema
assert isinstance(fileh.schema[0], pq.ColumnSchema)
assert fileh.schema[0].equals(fileh.schema[0])
assert fileh.schema[0] == fileh.schema[0]
assert not fileh.schema[0].equals(fileh.schema[1])
assert fileh.schema[0] != fileh.schema[1]
assert fileh.schema[0] != 'arbitrary object'
@pytest.mark.pandas
def test_read_schema(tempdir):
N = 100
df = pd.DataFrame({
'index': np.arange(N),
'values': np.random.randn(N)
}, columns=['index', 'values'])
data_path = tempdir / 'test.parquet'
table = pa.Table.from_pandas(df)
_write_table(table, data_path)
read1 = pq.read_schema(data_path)
read2 = pq.read_schema(data_path, memory_map=True)
assert table.schema.equals(read1)
assert table.schema.equals(read2)
assert table.schema.metadata[b'pandas'] == read1.metadata[b'pandas']
def test_parquet_metadata_empty_to_dict(tempdir):
# https://issues.apache.org/jira/browse/ARROW-10146
table = pa.table({"a": pa.array([], type="int64")})
pq.write_table(table, tempdir / "data.parquet")
metadata = pq.read_metadata(tempdir / "data.parquet")
# ensure this doesn't error / statistics set to None
metadata_dict = metadata.to_dict()
assert len(metadata_dict["row_groups"]) == 1
assert len(metadata_dict["row_groups"][0]["columns"]) == 1
assert metadata_dict["row_groups"][0]["columns"][0]["statistics"] is None
@pytest.mark.slow
@pytest.mark.large_memory
def test_metadata_exceeds_message_size():
# ARROW-13655: Thrift may enable a default message size that limits
# the size of Parquet metadata that can be written.
NCOLS = 1000
NREPEATS = 4000
table = pa.table({str(i): np.random.randn(10) for i in range(NCOLS)})
with pa.BufferOutputStream() as out:
pq.write_table(table, out)
buf = out.getvalue()
original_metadata = pq.read_metadata(pa.BufferReader(buf))
metadata = pq.read_metadata(pa.BufferReader(buf))
for i in range(NREPEATS):
metadata.append_row_groups(original_metadata)
with pa.BufferOutputStream() as out:
metadata.write_metadata_file(out)
buf = out.getvalue()
metadata = pq.read_metadata(pa.BufferReader(buf))
@@ -0,0 +1,707 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import json
import numpy as np
import pytest
import pyarrow as pa
from pyarrow.fs import LocalFileSystem, SubTreeFileSystem
from pyarrow.tests.parquet.common import (
parametrize_legacy_dataset, parametrize_legacy_dataset_not_supported)
from pyarrow.util import guid
from pyarrow.vendored.version import Version
try:
import pyarrow.parquet as pq
from pyarrow.tests.parquet.common import (_read_table, _test_dataframe,
_write_table)
except ImportError:
pq = None
try:
import pandas as pd
import pandas.testing as tm
from pyarrow.tests.parquet.common import (_roundtrip_pandas_dataframe,
alltypes_sample)
except ImportError:
pd = tm = None
@pytest.mark.pandas
def test_pandas_parquet_custom_metadata(tempdir):
df = alltypes_sample(size=10000)
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df)
assert b'pandas' in arrow_table.schema.metadata
_write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms')
metadata = pq.read_metadata(filename).metadata
assert b'pandas' in metadata
js = json.loads(metadata[b'pandas'].decode('utf8'))
assert js['index_columns'] == [{'kind': 'range',
'name': None,
'start': 0, 'stop': 10000,
'step': 1}]
@pytest.mark.pandas
def test_merging_parquet_tables_with_different_pandas_metadata(tempdir):
# ARROW-3728: Merging Parquet Files - Pandas Meta in Schema Mismatch
schema = pa.schema([
pa.field('int', pa.int16()),
pa.field('float', pa.float32()),
pa.field('string', pa.string())
])
df1 = pd.DataFrame({
'int': np.arange(3, dtype=np.uint8),
'float': np.arange(3, dtype=np.float32),
'string': ['ABBA', 'EDDA', 'ACDC']
})
df2 = pd.DataFrame({
'int': [4, 5],
'float': [1.1, None],
'string': [None, None]
})
table1 = pa.Table.from_pandas(df1, schema=schema, preserve_index=False)
table2 = pa.Table.from_pandas(df2, schema=schema, preserve_index=False)
assert not table1.schema.equals(table2.schema, check_metadata=True)
assert table1.schema.equals(table2.schema)
writer = pq.ParquetWriter(tempdir / 'merged.parquet', schema=schema)
writer.write_table(table1)
writer.write_table(table2)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pandas_parquet_column_multiindex(tempdir, use_legacy_dataset):
df = alltypes_sample(size=10)
df.columns = pd.MultiIndex.from_tuples(
list(zip(df.columns, df.columns[::-1])),
names=['level_1', 'level_2']
)
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df)
assert arrow_table.schema.pandas_metadata is not None
_write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms')
table_read = pq.read_pandas(
filename, use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pandas_parquet_2_0_roundtrip_read_pandas_no_index_written(
tempdir, use_legacy_dataset
):
df = alltypes_sample(size=10000)
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df, preserve_index=False)
js = arrow_table.schema.pandas_metadata
assert not js['index_columns']
# ARROW-2170
# While index_columns should be empty, columns needs to be filled still.
assert js['columns']
_write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms')
table_read = pq.read_pandas(
filename, use_legacy_dataset=use_legacy_dataset)
js = table_read.schema.pandas_metadata
assert not js['index_columns']
read_metadata = table_read.schema.metadata
assert arrow_table.schema.metadata == read_metadata
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
# TODO(dataset) duplicate column selection actually gives duplicate columns now
@pytest.mark.pandas
@parametrize_legacy_dataset_not_supported
def test_pandas_column_selection(tempdir, use_legacy_dataset):
size = 10000
np.random.seed(0)
df = pd.DataFrame({
'uint8': np.arange(size, dtype=np.uint8),
'uint16': np.arange(size, dtype=np.uint16)
})
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df)
_write_table(arrow_table, filename)
table_read = _read_table(
filename, columns=['uint8'], use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df[['uint8']], df_read)
# ARROW-4267: Selection of duplicate columns still leads to these columns
# being read uniquely.
table_read = _read_table(
filename, columns=['uint8', 'uint8'],
use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df[['uint8']], df_read)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pandas_parquet_native_file_roundtrip(tempdir, use_legacy_dataset):
df = _test_dataframe(10000)
arrow_table = pa.Table.from_pandas(df)
imos = pa.BufferOutputStream()
_write_table(arrow_table, imos, version='2.6')
buf = imos.getvalue()
reader = pa.BufferReader(buf)
df_read = _read_table(
reader, use_legacy_dataset=use_legacy_dataset).to_pandas()
tm.assert_frame_equal(df, df_read)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_read_pandas_column_subset(tempdir, use_legacy_dataset):
df = _test_dataframe(10000)
arrow_table = pa.Table.from_pandas(df)
imos = pa.BufferOutputStream()
_write_table(arrow_table, imos, version='2.6')
buf = imos.getvalue()
reader = pa.BufferReader(buf)
df_read = pq.read_pandas(
reader, columns=['strings', 'uint8'],
use_legacy_dataset=use_legacy_dataset
).to_pandas()
tm.assert_frame_equal(df[['strings', 'uint8']], df_read)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pandas_parquet_empty_roundtrip(tempdir, use_legacy_dataset):
df = _test_dataframe(0)
arrow_table = pa.Table.from_pandas(df)
imos = pa.BufferOutputStream()
_write_table(arrow_table, imos, version='2.6')
buf = imos.getvalue()
reader = pa.BufferReader(buf)
df_read = _read_table(
reader, use_legacy_dataset=use_legacy_dataset).to_pandas()
tm.assert_frame_equal(df, df_read)
@pytest.mark.pandas
def test_pandas_can_write_nested_data(tempdir):
data = {
"agg_col": [
{"page_type": 1},
{"record_type": 1},
{"non_consecutive_home": 0},
],
"uid_first": "1001"
}
df = pd.DataFrame(data=data)
arrow_table = pa.Table.from_pandas(df)
imos = pa.BufferOutputStream()
# This succeeds under V2
_write_table(arrow_table, imos)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pandas_parquet_pyfile_roundtrip(tempdir, use_legacy_dataset):
filename = tempdir / 'pandas_pyfile_roundtrip.parquet'
size = 5
df = pd.DataFrame({
'int64': np.arange(size, dtype=np.int64),
'float32': np.arange(size, dtype=np.float32),
'float64': np.arange(size, dtype=np.float64),
'bool': np.random.randn(size) > 0,
'strings': ['foo', 'bar', None, 'baz', 'qux']
})
arrow_table = pa.Table.from_pandas(df)
with filename.open('wb') as f:
_write_table(arrow_table, f, version="1.0")
data = io.BytesIO(filename.read_bytes())
table_read = _read_table(data, use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pandas_parquet_configuration_options(tempdir, use_legacy_dataset):
size = 10000
np.random.seed(0)
df = pd.DataFrame({
'uint8': np.arange(size, dtype=np.uint8),
'uint16': np.arange(size, dtype=np.uint16),
'uint32': np.arange(size, dtype=np.uint32),
'uint64': np.arange(size, dtype=np.uint64),
'int8': np.arange(size, dtype=np.int16),
'int16': np.arange(size, dtype=np.int16),
'int32': np.arange(size, dtype=np.int32),
'int64': np.arange(size, dtype=np.int64),
'float32': np.arange(size, dtype=np.float32),
'float64': np.arange(size, dtype=np.float64),
'bool': np.random.randn(size) > 0
})
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df)
for use_dictionary in [True, False]:
_write_table(arrow_table, filename, version='2.6',
use_dictionary=use_dictionary)
table_read = _read_table(
filename, use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
for write_statistics in [True, False]:
_write_table(arrow_table, filename, version='2.6',
write_statistics=write_statistics)
table_read = _read_table(filename,
use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
for compression in ['NONE', 'SNAPPY', 'GZIP', 'LZ4', 'ZSTD']:
if (compression != 'NONE' and
not pa.lib.Codec.is_available(compression)):
continue
_write_table(arrow_table, filename, version='2.6',
compression=compression)
table_read = _read_table(
filename, use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
@pytest.mark.pandas
def test_spark_flavor_preserves_pandas_metadata():
df = _test_dataframe(size=100)
df.index = np.arange(0, 10 * len(df), 10)
df.index.name = 'foo'
result = _roundtrip_pandas_dataframe(df, {'version': '2.0',
'flavor': 'spark'})
tm.assert_frame_equal(result, df)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_index_column_name_duplicate(tempdir, use_legacy_dataset):
data = {
'close': {
pd.Timestamp('2017-06-30 01:31:00'): 154.99958999999998,
pd.Timestamp('2017-06-30 01:32:00'): 154.99958999999998,
},
'time': {
pd.Timestamp('2017-06-30 01:31:00'): pd.Timestamp(
'2017-06-30 01:31:00'
),
pd.Timestamp('2017-06-30 01:32:00'): pd.Timestamp(
'2017-06-30 01:32:00'
),
}
}
path = str(tempdir / 'data.parquet')
dfx = pd.DataFrame(data).set_index('time', drop=False)
tdfx = pa.Table.from_pandas(dfx)
_write_table(tdfx, path)
arrow_table = _read_table(path, use_legacy_dataset=use_legacy_dataset)
result_df = arrow_table.to_pandas()
tm.assert_frame_equal(result_df, dfx)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_multiindex_duplicate_values(tempdir, use_legacy_dataset):
num_rows = 3
numbers = list(range(num_rows))
index = pd.MultiIndex.from_arrays(
[['foo', 'foo', 'bar'], numbers],
names=['foobar', 'some_numbers'],
)
df = pd.DataFrame({'numbers': numbers}, index=index)
table = pa.Table.from_pandas(df)
filename = tempdir / 'dup_multi_index_levels.parquet'
_write_table(table, filename)
result_table = _read_table(filename, use_legacy_dataset=use_legacy_dataset)
assert table.equals(result_table)
result_df = result_table.to_pandas()
tm.assert_frame_equal(result_df, df)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_backwards_compatible_index_naming(datadir, use_legacy_dataset):
expected_string = b"""\
carat cut color clarity depth table price x y z
0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43
0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31
0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31
0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63
0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75
0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48
0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47
0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53
0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49
0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39"""
expected = pd.read_csv(io.BytesIO(expected_string), sep=r'\s{2,}',
index_col=None, header=0, engine='python')
table = _read_table(
datadir / 'v0.7.1.parquet', use_legacy_dataset=use_legacy_dataset)
result = table.to_pandas()
tm.assert_frame_equal(result, expected)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_backwards_compatible_index_multi_level_named(
datadir, use_legacy_dataset
):
expected_string = b"""\
carat cut color clarity depth table price x y z
0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43
0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31
0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31
0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63
0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75
0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48
0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47
0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53
0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49
0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39"""
expected = pd.read_csv(
io.BytesIO(expected_string), sep=r'\s{2,}',
index_col=['cut', 'color', 'clarity'],
header=0, engine='python'
).sort_index()
table = _read_table(datadir / 'v0.7.1.all-named-index.parquet',
use_legacy_dataset=use_legacy_dataset)
result = table.to_pandas()
tm.assert_frame_equal(result, expected)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_backwards_compatible_index_multi_level_some_named(
datadir, use_legacy_dataset
):
expected_string = b"""\
carat cut color clarity depth table price x y z
0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43
0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31
0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31
0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63
0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75
0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48
0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47
0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53
0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49
0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39"""
expected = pd.read_csv(
io.BytesIO(expected_string),
sep=r'\s{2,}', index_col=['cut', 'color', 'clarity'],
header=0, engine='python'
).sort_index()
expected.index = expected.index.set_names(['cut', None, 'clarity'])
table = _read_table(datadir / 'v0.7.1.some-named-index.parquet',
use_legacy_dataset=use_legacy_dataset)
result = table.to_pandas()
tm.assert_frame_equal(result, expected)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_backwards_compatible_column_metadata_handling(
datadir, use_legacy_dataset
):
expected = pd.DataFrame(
{'a': [1, 2, 3], 'b': [.1, .2, .3],
'c': pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')})
expected.index = pd.MultiIndex.from_arrays(
[['a', 'b', 'c'],
pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')],
names=['index', None])
path = datadir / 'v0.7.1.column-metadata-handling.parquet'
table = _read_table(path, use_legacy_dataset=use_legacy_dataset)
result = table.to_pandas()
tm.assert_frame_equal(result, expected)
table = _read_table(
path, columns=['a'], use_legacy_dataset=use_legacy_dataset)
result = table.to_pandas()
tm.assert_frame_equal(result, expected[['a']].reset_index(drop=True))
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_categorical_index_survives_roundtrip(use_legacy_dataset):
# ARROW-3652, addressed by ARROW-3246
df = pd.DataFrame([['a', 'b'], ['c', 'd']], columns=['c1', 'c2'])
df['c1'] = df['c1'].astype('category')
df = df.set_index(['c1'])
table = pa.Table.from_pandas(df)
bos = pa.BufferOutputStream()
pq.write_table(table, bos)
ref_df = pq.read_pandas(
bos.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas()
assert isinstance(ref_df.index, pd.CategoricalIndex)
assert ref_df.index.equals(df.index)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_categorical_order_survives_roundtrip(use_legacy_dataset):
# ARROW-6302
df = pd.DataFrame({"a": pd.Categorical(
["a", "b", "c", "a"], categories=["b", "c", "d"], ordered=True)})
table = pa.Table.from_pandas(df)
bos = pa.BufferOutputStream()
pq.write_table(table, bos)
contents = bos.getvalue()
result = pq.read_pandas(
contents, use_legacy_dataset=use_legacy_dataset).to_pandas()
tm.assert_frame_equal(result, df)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pandas_categorical_na_type_row_groups(use_legacy_dataset):
# ARROW-5085
df = pd.DataFrame({"col": [None] * 100, "int": [1.0] * 100})
df_category = df.astype({"col": "category", "int": "category"})
table = pa.Table.from_pandas(df)
table_cat = pa.Table.from_pandas(df_category)
buf = pa.BufferOutputStream()
# it works
pq.write_table(table_cat, buf, version='2.6', chunk_size=10)
result = pq.read_table(
buf.getvalue(), use_legacy_dataset=use_legacy_dataset)
# Result is non-categorical
assert result[0].equals(table[0])
assert result[1].equals(table[1])
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pandas_categorical_roundtrip(use_legacy_dataset):
# ARROW-5480, this was enabled by ARROW-3246
# Have one of the categories unobserved and include a null (-1)
codes = np.array([2, 0, 0, 2, 0, -1, 2], dtype='int32')
categories = ['foo', 'bar', 'baz']
df = pd.DataFrame({'x': pd.Categorical.from_codes(
codes, categories=categories)})
buf = pa.BufferOutputStream()
pq.write_table(pa.table(df), buf)
result = pq.read_table(
buf.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas()
assert result.x.dtype == 'category'
assert (result.x.cat.categories == categories).all()
tm.assert_frame_equal(result, df)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_write_to_dataset_pandas_preserve_extensiondtypes(
tempdir, use_legacy_dataset
):
# ARROW-8251 - preserve pandas extension dtypes in roundtrip
if Version(pd.__version__) < Version("1.0.0"):
pytest.skip("__arrow_array__ added to pandas in 1.0.0")
df = pd.DataFrame({'part': 'a', "col": [1, 2, 3]})
df['col'] = df['col'].astype("Int64")
table = pa.table(df)
pq.write_to_dataset(
table, str(tempdir / "case1"), partition_cols=['part'],
use_legacy_dataset=use_legacy_dataset
)
result = pq.read_table(
str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset
).to_pandas()
tm.assert_frame_equal(result[["col"]], df[["col"]])
pq.write_to_dataset(
table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset
)
result = pq.read_table(
str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset
).to_pandas()
tm.assert_frame_equal(result[["col"]], df[["col"]])
pq.write_table(table, str(tempdir / "data.parquet"))
result = pq.read_table(
str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset
).to_pandas()
tm.assert_frame_equal(result[["col"]], df[["col"]])
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_write_to_dataset_pandas_preserve_index(tempdir, use_legacy_dataset):
# ARROW-8251 - preserve pandas index in roundtrip
df = pd.DataFrame({'part': ['a', 'a', 'b'], "col": [1, 2, 3]})
df.index = pd.Index(['a', 'b', 'c'], name="idx")
table = pa.table(df)
df_cat = df[["col", "part"]].copy()
df_cat["part"] = df_cat["part"].astype("category")
pq.write_to_dataset(
table, str(tempdir / "case1"), partition_cols=['part'],
use_legacy_dataset=use_legacy_dataset
)
result = pq.read_table(
str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset
).to_pandas()
tm.assert_frame_equal(result, df_cat)
pq.write_to_dataset(
table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset
)
result = pq.read_table(
str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset
).to_pandas()
tm.assert_frame_equal(result, df)
pq.write_table(table, str(tempdir / "data.parquet"))
result = pq.read_table(
str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset
).to_pandas()
tm.assert_frame_equal(result, df)
@pytest.mark.pandas
@pytest.mark.parametrize('preserve_index', [True, False, None])
def test_dataset_read_pandas_common_metadata(tempdir, preserve_index):
# ARROW-1103
nfiles = 5
size = 5
dirpath = tempdir / guid()
dirpath.mkdir()
test_data = []
frames = []
paths = []
for i in range(nfiles):
df = _test_dataframe(size, seed=i)
df.index = pd.Index(np.arange(i * size, (i + 1) * size), name='index')
path = dirpath / '{}.parquet'.format(i)
table = pa.Table.from_pandas(df, preserve_index=preserve_index)
# Obliterate metadata
table = table.replace_schema_metadata(None)
assert table.schema.metadata is None
_write_table(table, path)
test_data.append(table)
frames.append(df)
paths.append(path)
# Write _metadata common file
table_for_metadata = pa.Table.from_pandas(
df, preserve_index=preserve_index
)
pq.write_metadata(table_for_metadata.schema, dirpath / '_metadata')
dataset = pq.ParquetDataset(dirpath)
columns = ['uint8', 'strings']
result = dataset.read_pandas(columns=columns).to_pandas()
expected = pd.concat([x[columns] for x in frames])
expected.index.name = (
df.index.name if preserve_index is not False else None)
tm.assert_frame_equal(result, expected)
@pytest.mark.pandas
def test_read_pandas_passthrough_keywords(tempdir):
# ARROW-11464 - previously not all keywords were passed through (such as
# the filesystem keyword)
df = pd.DataFrame({'a': [1, 2, 3]})
filename = tempdir / 'data.parquet'
_write_table(df, filename)
result = pq.read_pandas(
'data.parquet',
filesystem=SubTreeFileSystem(str(tempdir), LocalFileSystem())
)
assert result.equals(pa.table(df))
@pytest.mark.pandas
def test_read_pandas_map_fields(tempdir):
# ARROW-10140 - table created from Pandas with mapping fields
df = pd.DataFrame({
'col1': pd.Series([
[('id', 'something'), ('value2', 'else')],
[('id', 'something2'), ('value', 'else2')],
]),
'col2': pd.Series(['foo', 'bar'])
})
filename = tempdir / 'data.parquet'
udt = pa.map_(pa.string(), pa.string())
schema = pa.schema([pa.field('col1', udt), pa.field('col2', pa.string())])
arrow_table = pa.Table.from_pandas(df, schema)
_write_table(arrow_table, filename)
result = pq.read_pandas(filename).to_pandas()
tm.assert_frame_equal(result, df)
@@ -0,0 +1,274 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import os
import pytest
import pyarrow as pa
try:
import pyarrow.parquet as pq
from pyarrow.tests.parquet.common import _write_table
except ImportError:
pq = None
try:
import pandas as pd
import pandas.testing as tm
from pyarrow.tests.parquet.common import alltypes_sample
except ImportError:
pd = tm = None
@pytest.mark.pandas
def test_pass_separate_metadata():
# ARROW-471
df = alltypes_sample(size=10000)
a_table = pa.Table.from_pandas(df)
buf = io.BytesIO()
_write_table(a_table, buf, compression='snappy', version='2.6')
buf.seek(0)
metadata = pq.read_metadata(buf)
buf.seek(0)
fileh = pq.ParquetFile(buf, metadata=metadata)
tm.assert_frame_equal(df, fileh.read().to_pandas())
@pytest.mark.pandas
def test_read_single_row_group():
# ARROW-471
N, K = 10000, 4
df = alltypes_sample(size=N)
a_table = pa.Table.from_pandas(df)
buf = io.BytesIO()
_write_table(a_table, buf, row_group_size=N / K,
compression='snappy', version='2.6')
buf.seek(0)
pf = pq.ParquetFile(buf)
assert pf.num_row_groups == K
row_groups = [pf.read_row_group(i) for i in range(K)]
result = pa.concat_tables(row_groups)
tm.assert_frame_equal(df, result.to_pandas())
@pytest.mark.pandas
def test_read_single_row_group_with_column_subset():
N, K = 10000, 4
df = alltypes_sample(size=N)
a_table = pa.Table.from_pandas(df)
buf = io.BytesIO()
_write_table(a_table, buf, row_group_size=N / K,
compression='snappy', version='2.6')
buf.seek(0)
pf = pq.ParquetFile(buf)
cols = list(df.columns[:2])
row_groups = [pf.read_row_group(i, columns=cols) for i in range(K)]
result = pa.concat_tables(row_groups)
tm.assert_frame_equal(df[cols], result.to_pandas())
# ARROW-4267: Selection of duplicate columns still leads to these columns
# being read uniquely.
row_groups = [pf.read_row_group(i, columns=cols + cols) for i in range(K)]
result = pa.concat_tables(row_groups)
tm.assert_frame_equal(df[cols], result.to_pandas())
@pytest.mark.pandas
def test_read_multiple_row_groups():
N, K = 10000, 4
df = alltypes_sample(size=N)
a_table = pa.Table.from_pandas(df)
buf = io.BytesIO()
_write_table(a_table, buf, row_group_size=N / K,
compression='snappy', version='2.6')
buf.seek(0)
pf = pq.ParquetFile(buf)
assert pf.num_row_groups == K
result = pf.read_row_groups(range(K))
tm.assert_frame_equal(df, result.to_pandas())
@pytest.mark.pandas
def test_read_multiple_row_groups_with_column_subset():
N, K = 10000, 4
df = alltypes_sample(size=N)
a_table = pa.Table.from_pandas(df)
buf = io.BytesIO()
_write_table(a_table, buf, row_group_size=N / K,
compression='snappy', version='2.6')
buf.seek(0)
pf = pq.ParquetFile(buf)
cols = list(df.columns[:2])
result = pf.read_row_groups(range(K), columns=cols)
tm.assert_frame_equal(df[cols], result.to_pandas())
# ARROW-4267: Selection of duplicate columns still leads to these columns
# being read uniquely.
result = pf.read_row_groups(range(K), columns=cols + cols)
tm.assert_frame_equal(df[cols], result.to_pandas())
@pytest.mark.pandas
def test_scan_contents():
N, K = 10000, 4
df = alltypes_sample(size=N)
a_table = pa.Table.from_pandas(df)
buf = io.BytesIO()
_write_table(a_table, buf, row_group_size=N / K,
compression='snappy', version='2.6')
buf.seek(0)
pf = pq.ParquetFile(buf)
assert pf.scan_contents() == 10000
assert pf.scan_contents(df.columns[:4]) == 10000
def test_parquet_file_pass_directory_instead_of_file(tempdir):
# ARROW-7208
path = tempdir / 'directory'
os.mkdir(str(path))
with pytest.raises(IOError, match="Expected file path"):
pq.ParquetFile(path)
def test_read_column_invalid_index():
table = pa.table([pa.array([4, 5]), pa.array(["foo", "bar"])],
names=['ints', 'strs'])
bio = pa.BufferOutputStream()
pq.write_table(table, bio)
f = pq.ParquetFile(bio.getvalue())
assert f.reader.read_column(0).to_pylist() == [4, 5]
assert f.reader.read_column(1).to_pylist() == ["foo", "bar"]
for index in (-1, 2):
with pytest.raises((ValueError, IndexError)):
f.reader.read_column(index)
@pytest.mark.pandas
@pytest.mark.parametrize('batch_size', [300, 1000, 1300])
def test_iter_batches_columns_reader(tempdir, batch_size):
total_size = 3000
chunk_size = 1000
# TODO: Add categorical support
df = alltypes_sample(size=total_size)
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df)
_write_table(arrow_table, filename, version='2.6',
coerce_timestamps='ms', chunk_size=chunk_size)
file_ = pq.ParquetFile(filename)
for columns in [df.columns[:10], df.columns[10:]]:
batches = file_.iter_batches(batch_size=batch_size, columns=columns)
batch_starts = range(0, total_size+batch_size, batch_size)
for batch, start in zip(batches, batch_starts):
end = min(total_size, start + batch_size)
tm.assert_frame_equal(
batch.to_pandas(),
df.iloc[start:end, :].loc[:, columns].reset_index(drop=True)
)
@pytest.mark.pandas
@pytest.mark.parametrize('chunk_size', [1000])
def test_iter_batches_reader(tempdir, chunk_size):
df = alltypes_sample(size=10000, categorical=True)
filename = tempdir / 'pandas_roundtrip.parquet'
arrow_table = pa.Table.from_pandas(df)
assert arrow_table.schema.pandas_metadata is not None
_write_table(arrow_table, filename, version='2.6',
coerce_timestamps='ms', chunk_size=chunk_size)
file_ = pq.ParquetFile(filename)
def get_all_batches(f):
for row_group in range(f.num_row_groups):
batches = f.iter_batches(
batch_size=900,
row_groups=[row_group],
)
for batch in batches:
yield batch
batches = list(get_all_batches(file_))
batch_no = 0
for i in range(file_.num_row_groups):
tm.assert_frame_equal(
batches[batch_no].to_pandas(),
file_.read_row_groups([i]).to_pandas().head(900)
)
batch_no += 1
tm.assert_frame_equal(
batches[batch_no].to_pandas().reset_index(drop=True),
file_.read_row_groups([i]).to_pandas().iloc[900:].reset_index(
drop=True
)
)
batch_no += 1
@pytest.mark.pandas
@pytest.mark.parametrize('pre_buffer', [False, True])
def test_pre_buffer(pre_buffer):
N, K = 10000, 4
df = alltypes_sample(size=N)
a_table = pa.Table.from_pandas(df)
buf = io.BytesIO()
_write_table(a_table, buf, row_group_size=N / K,
compression='snappy', version='2.6')
buf.seek(0)
pf = pq.ParquetFile(buf, pre_buffer=pre_buffer)
assert pf.read().num_rows == N
@@ -0,0 +1,322 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import pyarrow as pa
from pyarrow import fs
from pyarrow.filesystem import FileSystem, LocalFileSystem
from pyarrow.tests.parquet.common import parametrize_legacy_dataset
try:
import pyarrow.parquet as pq
from pyarrow.tests.parquet.common import _read_table, _test_dataframe
except ImportError:
pq = None
try:
import pandas as pd
import pandas.testing as tm
except ImportError:
pd = tm = None
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_parquet_incremental_file_build(tempdir, use_legacy_dataset):
df = _test_dataframe(100)
df['unique_id'] = 0
arrow_table = pa.Table.from_pandas(df, preserve_index=False)
out = pa.BufferOutputStream()
writer = pq.ParquetWriter(out, arrow_table.schema, version='2.6')
frames = []
for i in range(10):
df['unique_id'] = i
arrow_table = pa.Table.from_pandas(df, preserve_index=False)
writer.write_table(arrow_table)
frames.append(df.copy())
writer.close()
buf = out.getvalue()
result = _read_table(
pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset)
expected = pd.concat(frames, ignore_index=True)
tm.assert_frame_equal(result.to_pandas(), expected)
def test_validate_schema_write_table(tempdir):
# ARROW-2926
simple_fields = [
pa.field('POS', pa.uint32()),
pa.field('desc', pa.string())
]
simple_schema = pa.schema(simple_fields)
# simple_table schema does not match simple_schema
simple_from_array = [pa.array([1]), pa.array(['bla'])]
simple_table = pa.Table.from_arrays(simple_from_array, ['POS', 'desc'])
path = tempdir / 'simple_validate_schema.parquet'
with pq.ParquetWriter(path, simple_schema,
version='2.6',
compression='snappy', flavor='spark') as w:
with pytest.raises(ValueError):
w.write_table(simple_table)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_parquet_writer_context_obj(tempdir, use_legacy_dataset):
df = _test_dataframe(100)
df['unique_id'] = 0
arrow_table = pa.Table.from_pandas(df, preserve_index=False)
out = pa.BufferOutputStream()
with pq.ParquetWriter(out, arrow_table.schema, version='2.6') as writer:
frames = []
for i in range(10):
df['unique_id'] = i
arrow_table = pa.Table.from_pandas(df, preserve_index=False)
writer.write_table(arrow_table)
frames.append(df.copy())
buf = out.getvalue()
result = _read_table(
pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset)
expected = pd.concat(frames, ignore_index=True)
tm.assert_frame_equal(result.to_pandas(), expected)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_parquet_writer_context_obj_with_exception(
tempdir, use_legacy_dataset
):
df = _test_dataframe(100)
df['unique_id'] = 0
arrow_table = pa.Table.from_pandas(df, preserve_index=False)
out = pa.BufferOutputStream()
error_text = 'Artificial Error'
try:
with pq.ParquetWriter(out,
arrow_table.schema,
version='2.6') as writer:
frames = []
for i in range(10):
df['unique_id'] = i
arrow_table = pa.Table.from_pandas(df, preserve_index=False)
writer.write_table(arrow_table)
frames.append(df.copy())
if i == 5:
raise ValueError(error_text)
except Exception as e:
assert str(e) == error_text
buf = out.getvalue()
result = _read_table(
pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset)
expected = pd.concat(frames, ignore_index=True)
tm.assert_frame_equal(result.to_pandas(), expected)
@pytest.mark.pandas
@pytest.mark.parametrize("filesystem", [
None,
LocalFileSystem._get_instance(),
fs.LocalFileSystem(),
])
def test_parquet_writer_write_wrappers(tempdir, filesystem):
df = _test_dataframe(100)
table = pa.Table.from_pandas(df, preserve_index=False)
batch = pa.RecordBatch.from_pandas(df, preserve_index=False)
path_table = str(tempdir / 'data_table.parquet')
path_batch = str(tempdir / 'data_batch.parquet')
with pq.ParquetWriter(
path_table, table.schema, filesystem=filesystem, version='2.6'
) as writer:
writer.write_table(table)
result = _read_table(path_table).to_pandas()
tm.assert_frame_equal(result, df)
with pq.ParquetWriter(
path_batch, table.schema, filesystem=filesystem, version='2.6'
) as writer:
writer.write_batch(batch)
result = _read_table(path_batch).to_pandas()
tm.assert_frame_equal(result, df)
with pq.ParquetWriter(
path_table, table.schema, filesystem=filesystem, version='2.6'
) as writer:
writer.write(table)
result = _read_table(path_table).to_pandas()
tm.assert_frame_equal(result, df)
with pq.ParquetWriter(
path_batch, table.schema, filesystem=filesystem, version='2.6'
) as writer:
writer.write(batch)
result = _read_table(path_batch).to_pandas()
tm.assert_frame_equal(result, df)
@pytest.mark.pandas
@pytest.mark.parametrize("filesystem", [
None,
LocalFileSystem._get_instance(),
fs.LocalFileSystem(),
])
def test_parquet_writer_filesystem_local(tempdir, filesystem):
df = _test_dataframe(100)
table = pa.Table.from_pandas(df, preserve_index=False)
path = str(tempdir / 'data.parquet')
with pq.ParquetWriter(
path, table.schema, filesystem=filesystem, version='2.6'
) as writer:
writer.write_table(table)
result = _read_table(path).to_pandas()
tm.assert_frame_equal(result, df)
@pytest.mark.pandas
@pytest.mark.s3
def test_parquet_writer_filesystem_s3(s3_example_fs):
df = _test_dataframe(100)
table = pa.Table.from_pandas(df, preserve_index=False)
fs, uri, path = s3_example_fs
with pq.ParquetWriter(
path, table.schema, filesystem=fs, version='2.6'
) as writer:
writer.write_table(table)
result = _read_table(uri).to_pandas()
tm.assert_frame_equal(result, df)
@pytest.mark.pandas
@pytest.mark.s3
def test_parquet_writer_filesystem_s3_uri(s3_example_fs):
df = _test_dataframe(100)
table = pa.Table.from_pandas(df, preserve_index=False)
fs, uri, path = s3_example_fs
with pq.ParquetWriter(uri, table.schema, version='2.6') as writer:
writer.write_table(table)
result = _read_table(path, filesystem=fs).to_pandas()
tm.assert_frame_equal(result, df)
@pytest.mark.pandas
@pytest.mark.s3
def test_parquet_writer_filesystem_s3fs(s3_example_s3fs):
df = _test_dataframe(100)
table = pa.Table.from_pandas(df, preserve_index=False)
fs, directory = s3_example_s3fs
path = directory + "/test.parquet"
with pq.ParquetWriter(
path, table.schema, filesystem=fs, version='2.6'
) as writer:
writer.write_table(table)
result = _read_table(path, filesystem=fs).to_pandas()
tm.assert_frame_equal(result, df)
@pytest.mark.pandas
def test_parquet_writer_filesystem_buffer_raises():
df = _test_dataframe(100)
table = pa.Table.from_pandas(df, preserve_index=False)
filesystem = fs.LocalFileSystem()
# Should raise ValueError when filesystem is passed with file-like object
with pytest.raises(ValueError, match="specified path is file-like"):
pq.ParquetWriter(
pa.BufferOutputStream(), table.schema, filesystem=filesystem
)
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_parquet_writer_with_caller_provided_filesystem(use_legacy_dataset):
out = pa.BufferOutputStream()
class CustomFS(FileSystem):
def __init__(self):
self.path = None
self.mode = None
def open(self, path, mode='rb'):
self.path = path
self.mode = mode
return out
fs = CustomFS()
fname = 'expected_fname.parquet'
df = _test_dataframe(100)
table = pa.Table.from_pandas(df, preserve_index=False)
with pq.ParquetWriter(fname, table.schema, filesystem=fs, version='2.6') \
as writer:
writer.write_table(table)
assert fs.path == fname
assert fs.mode == 'wb'
assert out.closed
buf = out.getvalue()
table_read = _read_table(
pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df_read, df)
# Should raise ValueError when filesystem is passed with file-like object
with pytest.raises(ValueError) as err_info:
pq.ParquetWriter(pa.BufferOutputStream(), table.schema, filesystem=fs)
expected_msg = ("filesystem passed but where is file-like, so"
" there is nothing to open with filesystem.")
assert str(err_info) == expected_msg
@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# distutils: language=c++
# cython: language_level = 3
from pyarrow.lib cimport *
def get_array_length(obj):
# An example function accessing both the pyarrow Cython API
# and the Arrow C++ API
cdef shared_ptr[CArray] arr = pyarrow_unwrap_array(obj)
if arr.get() == NULL:
raise TypeError("not an array")
return arr.get().length()
def make_null_array(length):
# An example function that returns a PyArrow object without PyArrow
# being imported explicitly at the Python level.
cdef shared_ptr[CArray] null_array
null_array.reset(new CNullArray(length))
return pyarrow_wrap_array(null_array)
def cast_scalar(scalar, to_type):
cdef:
shared_ptr[CScalar] c_scalar
shared_ptr[CDataType] c_type
CResult[shared_ptr[CScalar]] c_result
c_scalar = pyarrow_unwrap_scalar(scalar)
if c_scalar.get() == NULL:
raise TypeError("not a scalar")
c_type = pyarrow_unwrap_data_type(to_type)
if c_type.get() == NULL:
raise TypeError("not a type")
c_result = c_scalar.get().CastTo(c_type)
c_scalar = GetResultValue(c_result)
return pyarrow_wrap_scalar(c_scalar)
@@ -0,0 +1,25 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is called from a test in test_ipc.py.
import sys
import pyarrow as pa
with open(sys.argv[1], 'rb') as f:
pa.ipc.open_file(f).read_all().to_pandas()
@@ -0,0 +1,449 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import sys
import pytest
import hypothesis as h
import hypothesis.strategies as st
import hypothesis.extra.numpy as npst
try:
import hypothesis.extra.pytz as tzst
except ImportError:
tzst = None
try:
import zoneinfo
except ImportError:
zoneinfo = None
if sys.platform == 'win32':
try:
import tzdata # noqa:F401
except ImportError:
zoneinfo = None
import numpy as np
import pyarrow as pa
# TODO(kszucs): alphanum_text, surrogate_text
custom_text = st.text(
alphabet=st.characters(
min_codepoint=0x41,
max_codepoint=0x7E
)
)
null_type = st.just(pa.null())
bool_type = st.just(pa.bool_())
binary_type = st.just(pa.binary())
string_type = st.just(pa.string())
large_binary_type = st.just(pa.large_binary())
large_string_type = st.just(pa.large_string())
fixed_size_binary_type = st.builds(
pa.binary,
st.integers(min_value=0, max_value=16)
)
binary_like_types = st.one_of(
binary_type,
string_type,
large_binary_type,
large_string_type,
fixed_size_binary_type
)
signed_integer_types = st.sampled_from([
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64()
])
unsigned_integer_types = st.sampled_from([
pa.uint8(),
pa.uint16(),
pa.uint32(),
pa.uint64()
])
integer_types = st.one_of(signed_integer_types, unsigned_integer_types)
floating_types = st.sampled_from([
pa.float16(),
pa.float32(),
pa.float64()
])
decimal128_type = st.builds(
pa.decimal128,
precision=st.integers(min_value=1, max_value=38),
scale=st.integers(min_value=1, max_value=38)
)
decimal256_type = st.builds(
pa.decimal256,
precision=st.integers(min_value=1, max_value=76),
scale=st.integers(min_value=1, max_value=76)
)
numeric_types = st.one_of(integer_types, floating_types,
decimal128_type, decimal256_type)
date_types = st.sampled_from([
pa.date32(),
pa.date64()
])
time_types = st.sampled_from([
pa.time32('s'),
pa.time32('ms'),
pa.time64('us'),
pa.time64('ns')
])
if tzst and zoneinfo:
timezones = st.one_of(st.none(), tzst.timezones(), st.timezones())
elif tzst:
timezones = st.one_of(st.none(), tzst.timezones())
elif zoneinfo:
timezones = st.one_of(st.none(), st.timezones())
else:
timezones = st.none()
timestamp_types = st.builds(
pa.timestamp,
unit=st.sampled_from(['s', 'ms', 'us', 'ns']),
tz=timezones
)
duration_types = st.builds(
pa.duration,
st.sampled_from(['s', 'ms', 'us', 'ns'])
)
interval_types = st.just(pa.month_day_nano_interval())
temporal_types = st.one_of(
date_types,
time_types,
timestamp_types,
duration_types,
interval_types
)
primitive_types = st.one_of(
null_type,
bool_type,
numeric_types,
temporal_types,
binary_like_types
)
metadata = st.dictionaries(st.text(), st.text())
@st.composite
def fields(draw, type_strategy=primitive_types):
name = draw(custom_text)
typ = draw(type_strategy)
if pa.types.is_null(typ):
nullable = True
else:
nullable = draw(st.booleans())
meta = draw(metadata)
return pa.field(name, type=typ, nullable=nullable, metadata=meta)
def list_types(item_strategy=primitive_types):
return (
st.builds(pa.list_, item_strategy) |
st.builds(pa.large_list, item_strategy) |
st.builds(
pa.list_,
item_strategy,
st.integers(min_value=0, max_value=16)
)
)
@st.composite
def struct_types(draw, item_strategy=primitive_types):
fields_strategy = st.lists(fields(item_strategy))
fields_rendered = draw(fields_strategy)
field_names = [field.name for field in fields_rendered]
# check that field names are unique, see ARROW-9997
h.assume(len(set(field_names)) == len(field_names))
return pa.struct(fields_rendered)
def dictionary_types(key_strategy=None, value_strategy=None):
key_strategy = key_strategy or signed_integer_types
value_strategy = value_strategy or st.one_of(
bool_type,
integer_types,
st.sampled_from([pa.float32(), pa.float64()]),
binary_type,
string_type,
fixed_size_binary_type,
)
return st.builds(pa.dictionary, key_strategy, value_strategy)
@st.composite
def map_types(draw, key_strategy=primitive_types,
item_strategy=primitive_types):
key_type = draw(key_strategy)
h.assume(not pa.types.is_null(key_type))
value_type = draw(item_strategy)
return pa.map_(key_type, value_type)
# union type
# extension type
def schemas(type_strategy=primitive_types, max_fields=None):
children = st.lists(fields(type_strategy), max_size=max_fields)
return st.builds(pa.schema, children)
all_types = st.deferred(
lambda: (
primitive_types |
list_types() |
struct_types() |
dictionary_types() |
map_types() |
list_types(all_types) |
struct_types(all_types)
)
)
all_fields = fields(all_types)
all_schemas = schemas(all_types)
_default_array_sizes = st.integers(min_value=0, max_value=20)
@st.composite
def _pylist(draw, value_type, size, nullable=True):
arr = draw(arrays(value_type, size=size, nullable=False))
return arr.to_pylist()
@st.composite
def _pymap(draw, key_type, value_type, size, nullable=True):
length = draw(size)
keys = draw(_pylist(key_type, size=length, nullable=False))
values = draw(_pylist(value_type, size=length, nullable=nullable))
return list(zip(keys, values))
@st.composite
def arrays(draw, type, size=None, nullable=True):
if isinstance(type, st.SearchStrategy):
ty = draw(type)
elif isinstance(type, pa.DataType):
ty = type
else:
raise TypeError('Type must be a pyarrow DataType')
if isinstance(size, st.SearchStrategy):
size = draw(size)
elif size is None:
size = draw(_default_array_sizes)
elif not isinstance(size, int):
raise TypeError('Size must be an integer')
if pa.types.is_null(ty):
h.assume(nullable)
value = st.none()
elif pa.types.is_boolean(ty):
value = st.booleans()
elif pa.types.is_integer(ty):
values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,)))
return pa.array(values, type=ty)
elif pa.types.is_floating(ty):
values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,)))
# Workaround ARROW-4952: no easy way to assert array equality
# in a NaN-tolerant way.
values[np.isnan(values)] = -42.0
return pa.array(values, type=ty)
elif pa.types.is_decimal(ty):
# TODO(kszucs): properly limit the precision
# value = st.decimals(places=type.scale, allow_infinity=False)
h.reject()
elif pa.types.is_time(ty):
value = st.times()
elif pa.types.is_date(ty):
value = st.dates()
elif pa.types.is_timestamp(ty):
if zoneinfo is None:
pytest.skip('no module named zoneinfo (or tzdata on Windows)')
if ty.tz is None:
pytest.skip('requires timezone not None')
min_int64 = -(2**63)
max_int64 = 2**63 - 1
min_datetime = datetime.datetime.fromtimestamp(
min_int64 // 10**9) + datetime.timedelta(hours=12)
max_datetime = datetime.datetime.fromtimestamp(
max_int64 // 10**9) - datetime.timedelta(hours=12)
try:
offset = ty.tz.split(":")
offset_hours = int(offset[0])
offset_min = int(offset[1])
tz = datetime.timedelta(hours=offset_hours, minutes=offset_min)
except ValueError:
tz = zoneinfo.ZoneInfo(ty.tz)
value = st.datetimes(timezones=st.just(tz), min_value=min_datetime,
max_value=max_datetime)
elif pa.types.is_duration(ty):
value = st.timedeltas()
elif pa.types.is_interval(ty):
value = st.timedeltas()
elif pa.types.is_binary(ty) or pa.types.is_large_binary(ty):
value = st.binary()
elif pa.types.is_string(ty) or pa.types.is_large_string(ty):
value = st.text()
elif pa.types.is_fixed_size_binary(ty):
value = st.binary(min_size=ty.byte_width, max_size=ty.byte_width)
elif pa.types.is_list(ty):
value = _pylist(ty.value_type, size=size, nullable=nullable)
elif pa.types.is_large_list(ty):
value = _pylist(ty.value_type, size=size, nullable=nullable)
elif pa.types.is_fixed_size_list(ty):
value = _pylist(ty.value_type, size=ty.list_size, nullable=nullable)
elif pa.types.is_dictionary(ty):
values = _pylist(ty.value_type, size=size, nullable=nullable)
return pa.array(draw(values), type=ty)
elif pa.types.is_map(ty):
value = _pymap(ty.key_type, ty.item_type, size=_default_array_sizes,
nullable=nullable)
elif pa.types.is_struct(ty):
h.assume(len(ty) > 0)
fields, child_arrays = [], []
for field in ty:
fields.append(field)
child_arrays.append(draw(arrays(field.type, size=size)))
return pa.StructArray.from_arrays(child_arrays, fields=fields)
else:
raise NotImplementedError(ty)
if nullable:
value = st.one_of(st.none(), value)
values = st.lists(value, min_size=size, max_size=size)
return pa.array(draw(values), type=ty)
@st.composite
def chunked_arrays(draw, type, min_chunks=0, max_chunks=None, chunk_size=None):
if isinstance(type, st.SearchStrategy):
type = draw(type)
# TODO(kszucs): remove it, field metadata is not kept
h.assume(not pa.types.is_struct(type))
chunk = arrays(type, size=chunk_size)
chunks = st.lists(chunk, min_size=min_chunks, max_size=max_chunks)
return pa.chunked_array(draw(chunks), type=type)
@st.composite
def record_batches(draw, type, rows=None, max_fields=None):
if isinstance(rows, st.SearchStrategy):
rows = draw(rows)
elif rows is None:
rows = draw(_default_array_sizes)
elif not isinstance(rows, int):
raise TypeError('Rows must be an integer')
schema = draw(schemas(type, max_fields=max_fields))
children = [draw(arrays(field.type, size=rows)) for field in schema]
# TODO(kszucs): the names and schema arguments are not consistent with
# Table.from_array's arguments
return pa.RecordBatch.from_arrays(children, names=schema)
@st.composite
def tables(draw, type, rows=None, max_fields=None):
if isinstance(rows, st.SearchStrategy):
rows = draw(rows)
elif rows is None:
rows = draw(_default_array_sizes)
elif not isinstance(rows, int):
raise TypeError('Rows must be an integer')
schema = draw(schemas(type, max_fields=max_fields))
children = [draw(arrays(field.type, size=rows)) for field in schema]
return pa.Table.from_arrays(children, schema=schema)
all_arrays = arrays(all_types)
all_chunked_arrays = chunked_arrays(all_types)
all_record_batches = record_batches(all_types)
all_tables = tables(all_types)
# Define the same rules as above for pandas tests by excluding certain types
# from the generation because of known issues.
pandas_compatible_primitive_types = st.one_of(
null_type,
bool_type,
integer_types,
st.sampled_from([pa.float32(), pa.float64()]),
decimal128_type,
date_types,
time_types,
# Need to exclude timestamp and duration types otherwise hypothesis
# discovers ARROW-10210
# timestamp_types,
# duration_types
interval_types,
binary_type,
string_type,
large_binary_type,
large_string_type,
)
# Need to exclude floating point types otherwise hypothesis discovers
# ARROW-10211
pandas_compatible_dictionary_value_types = st.one_of(
bool_type,
integer_types,
binary_type,
string_type,
fixed_size_binary_type,
)
def pandas_compatible_list_types(
item_strategy=pandas_compatible_primitive_types
):
# Need to exclude fixed size list type otherwise hypothesis discovers
# ARROW-10194
return (
st.builds(pa.list_, item_strategy) |
st.builds(pa.large_list, item_strategy)
)
pandas_compatible_types = st.deferred(
lambda: st.one_of(
pandas_compatible_primitive_types,
pandas_compatible_list_types(pandas_compatible_primitive_types),
struct_types(pandas_compatible_primitive_types),
dictionary_types(
value_strategy=pandas_compatible_dictionary_value_types
),
pandas_compatible_list_types(pandas_compatible_types),
struct_types(pandas_compatible_types)
)
)
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import numpy as np
import pyarrow as pa
import pyarrow.tests.util as test_util
try:
import pandas as pd
except ImportError:
pass
@pytest.mark.memory_leak
@pytest.mark.pandas
def test_deserialize_pandas_arrow_7956():
df = pd.DataFrame({'a': np.arange(10000),
'b': [test_util.rands(5) for _ in range(10000)]})
def action():
df_bytes = pa.ipc.serialize_pandas(df).to_pybytes()
buf = pa.py_buffer(df_bytes)
pa.ipc.deserialize_pandas(buf)
# Abort at 128MB threshold
test_util.memory_leak_check(action, threshold=1 << 27, iterations=100)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import weakref
import numpy as np
import pyarrow as pa
from pyarrow.lib import StringBuilder
def test_weakref():
sbuilder = StringBuilder()
wr = weakref.ref(sbuilder)
assert wr() is not None
del sbuilder
assert wr() is None
def test_string_builder_append():
sbuilder = StringBuilder()
sbuilder.append(b"a byte string")
sbuilder.append("a string")
sbuilder.append(np.nan)
sbuilder.append(None)
assert len(sbuilder) == 4
assert sbuilder.null_count == 2
arr = sbuilder.finish()
assert len(sbuilder) == 0
assert isinstance(arr, pa.Array)
assert arr.null_count == 2
assert arr.type == 'str'
expected = ["a byte string", "a string", None, None]
assert arr.to_pylist() == expected
def test_string_builder_append_values():
sbuilder = StringBuilder()
sbuilder.append_values([np.nan, None, "text", None, "other text"])
assert sbuilder.null_count == 3
arr = sbuilder.finish()
assert arr.null_count == 3
expected = [None, None, "text", None, "other text"]
assert arr.to_pylist() == expected
def test_string_builder_append_after_finish():
sbuilder = StringBuilder()
sbuilder.append_values([np.nan, None, "text", None, "other text"])
arr = sbuilder.finish()
sbuilder.append("No effect")
expected = [None, None, "text", None, "other text"]
assert arr.to_pylist() == expected
@@ -0,0 +1,413 @@
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import gc
import pyarrow as pa
try:
from pyarrow.cffi import ffi
except ImportError:
ffi = None
import pytest
try:
import pandas as pd
import pandas.testing as tm
except ImportError:
pd = tm = None
needs_cffi = pytest.mark.skipif(ffi is None,
reason="test needs cffi package installed")
assert_schema_released = pytest.raises(
ValueError, match="Cannot import released ArrowSchema")
assert_array_released = pytest.raises(
ValueError, match="Cannot import released ArrowArray")
assert_stream_released = pytest.raises(
ValueError, match="Cannot import released ArrowArrayStream")
class ParamExtType(pa.PyExtensionType):
def __init__(self, width):
self._width = width
pa.PyExtensionType.__init__(self, pa.binary(width))
@property
def width(self):
return self._width
def __reduce__(self):
return ParamExtType, (self.width,)
def make_schema():
return pa.schema([('ints', pa.list_(pa.int32()))],
metadata={b'key1': b'value1'})
def make_extension_schema():
return pa.schema([('ext', ParamExtType(3))],
metadata={b'key1': b'value1'})
def make_batch():
return pa.record_batch([[[1], [2, 42]]], make_schema())
def make_extension_batch():
schema = make_extension_schema()
ext_col = schema[0].type.wrap_array(pa.array([b"foo", b"bar"],
type=pa.binary(3)))
return pa.record_batch([ext_col], schema)
def make_batches():
schema = make_schema()
return [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
def make_serialized(schema, batches):
with pa.BufferOutputStream() as sink:
with pa.ipc.new_stream(sink, schema) as out:
for batch in batches:
out.write(batch)
return sink.getvalue()
@needs_cffi
def test_export_import_type():
c_schema = ffi.new("struct ArrowSchema*")
ptr_schema = int(ffi.cast("uintptr_t", c_schema))
gc.collect() # Make sure no Arrow data dangles in a ref cycle
old_allocated = pa.total_allocated_bytes()
typ = pa.list_(pa.int32())
typ._export_to_c(ptr_schema)
assert pa.total_allocated_bytes() > old_allocated
# Delete and recreate C++ object from exported pointer
del typ
assert pa.total_allocated_bytes() > old_allocated
typ_new = pa.DataType._import_from_c(ptr_schema)
assert typ_new == pa.list_(pa.int32())
assert pa.total_allocated_bytes() == old_allocated
# Now released
with assert_schema_released:
pa.DataType._import_from_c(ptr_schema)
# Invalid format string
pa.int32()._export_to_c(ptr_schema)
bad_format = ffi.new("char[]", b"zzz")
c_schema.format = bad_format
with pytest.raises(ValueError,
match="Invalid or unsupported format string"):
pa.DataType._import_from_c(ptr_schema)
# Now released
with assert_schema_released:
pa.DataType._import_from_c(ptr_schema)
@needs_cffi
def test_export_import_field():
c_schema = ffi.new("struct ArrowSchema*")
ptr_schema = int(ffi.cast("uintptr_t", c_schema))
gc.collect() # Make sure no Arrow data dangles in a ref cycle
old_allocated = pa.total_allocated_bytes()
field = pa.field("test", pa.list_(pa.int32()), nullable=True)
field._export_to_c(ptr_schema)
assert pa.total_allocated_bytes() > old_allocated
# Delete and recreate C++ object from exported pointer
del field
assert pa.total_allocated_bytes() > old_allocated
field_new = pa.Field._import_from_c(ptr_schema)
assert field_new == pa.field("test", pa.list_(pa.int32()), nullable=True)
assert pa.total_allocated_bytes() == old_allocated
# Now released
with assert_schema_released:
pa.Field._import_from_c(ptr_schema)
@needs_cffi
def test_export_import_array():
c_schema = ffi.new("struct ArrowSchema*")
ptr_schema = int(ffi.cast("uintptr_t", c_schema))
c_array = ffi.new("struct ArrowArray*")
ptr_array = int(ffi.cast("uintptr_t", c_array))
gc.collect() # Make sure no Arrow data dangles in a ref cycle
old_allocated = pa.total_allocated_bytes()
# Type is known up front
typ = pa.list_(pa.int32())
arr = pa.array([[1], [2, 42]], type=typ)
py_value = arr.to_pylist()
arr._export_to_c(ptr_array)
assert pa.total_allocated_bytes() > old_allocated
# Delete recreate C++ object from exported pointer
del arr
arr_new = pa.Array._import_from_c(ptr_array, typ)
assert arr_new.to_pylist() == py_value
assert arr_new.type == pa.list_(pa.int32())
assert pa.total_allocated_bytes() > old_allocated
del arr_new, typ
assert pa.total_allocated_bytes() == old_allocated
# Now released
with assert_array_released:
pa.Array._import_from_c(ptr_array, pa.list_(pa.int32()))
# Type is exported and imported at the same time
arr = pa.array([[1], [2, 42]], type=pa.list_(pa.int32()))
py_value = arr.to_pylist()
arr._export_to_c(ptr_array, ptr_schema)
# Delete and recreate C++ objects from exported pointers
del arr
arr_new = pa.Array._import_from_c(ptr_array, ptr_schema)
assert arr_new.to_pylist() == py_value
assert arr_new.type == pa.list_(pa.int32())
assert pa.total_allocated_bytes() > old_allocated
del arr_new
assert pa.total_allocated_bytes() == old_allocated
# Now released
with assert_schema_released:
pa.Array._import_from_c(ptr_array, ptr_schema)
def check_export_import_schema(schema_factory):
c_schema = ffi.new("struct ArrowSchema*")
ptr_schema = int(ffi.cast("uintptr_t", c_schema))
gc.collect() # Make sure no Arrow data dangles in a ref cycle
old_allocated = pa.total_allocated_bytes()
schema_factory()._export_to_c(ptr_schema)
assert pa.total_allocated_bytes() > old_allocated
# Delete and recreate C++ object from exported pointer
schema_new = pa.Schema._import_from_c(ptr_schema)
assert schema_new == schema_factory()
assert pa.total_allocated_bytes() == old_allocated
del schema_new
assert pa.total_allocated_bytes() == old_allocated
# Now released
with assert_schema_released:
pa.Schema._import_from_c(ptr_schema)
# Not a struct type
pa.int32()._export_to_c(ptr_schema)
with pytest.raises(ValueError,
match="ArrowSchema describes non-struct type"):
pa.Schema._import_from_c(ptr_schema)
# Now released
with assert_schema_released:
pa.Schema._import_from_c(ptr_schema)
@needs_cffi
def test_export_import_schema():
check_export_import_schema(make_schema)
@needs_cffi
def test_export_import_schema_with_extension():
check_export_import_schema(make_extension_schema)
@needs_cffi
def test_export_import_schema_float_pointer():
# Previous versions of the R Arrow library used to pass pointer
# values as a double.
c_schema = ffi.new("struct ArrowSchema*")
ptr_schema = int(ffi.cast("uintptr_t", c_schema))
match = "Passing a pointer value as a float is unsafe"
with pytest.warns(UserWarning, match=match):
make_schema()._export_to_c(float(ptr_schema))
with pytest.warns(UserWarning, match=match):
schema_new = pa.Schema._import_from_c(float(ptr_schema))
assert schema_new == make_schema()
def check_export_import_batch(batch_factory):
c_schema = ffi.new("struct ArrowSchema*")
ptr_schema = int(ffi.cast("uintptr_t", c_schema))
c_array = ffi.new("struct ArrowArray*")
ptr_array = int(ffi.cast("uintptr_t", c_array))
gc.collect() # Make sure no Arrow data dangles in a ref cycle
old_allocated = pa.total_allocated_bytes()
# Schema is known up front
batch = batch_factory()
schema = batch.schema
py_value = batch.to_pydict()
batch._export_to_c(ptr_array)
assert pa.total_allocated_bytes() > old_allocated
# Delete and recreate C++ object from exported pointer
del batch
batch_new = pa.RecordBatch._import_from_c(ptr_array, schema)
assert batch_new.to_pydict() == py_value
assert batch_new.schema == schema
assert pa.total_allocated_bytes() > old_allocated
del batch_new, schema
assert pa.total_allocated_bytes() == old_allocated
# Now released
with assert_array_released:
pa.RecordBatch._import_from_c(ptr_array, make_schema())
# Type is exported and imported at the same time
batch = batch_factory()
py_value = batch.to_pydict()
batch._export_to_c(ptr_array, ptr_schema)
# Delete and recreate C++ objects from exported pointers
del batch
batch_new = pa.RecordBatch._import_from_c(ptr_array, ptr_schema)
assert batch_new.to_pydict() == py_value
assert batch_new.schema == batch_factory().schema
assert pa.total_allocated_bytes() > old_allocated
del batch_new
assert pa.total_allocated_bytes() == old_allocated
# Now released
with assert_schema_released:
pa.RecordBatch._import_from_c(ptr_array, ptr_schema)
# Not a struct type
pa.int32()._export_to_c(ptr_schema)
batch_factory()._export_to_c(ptr_array)
with pytest.raises(ValueError,
match="ArrowSchema describes non-struct type"):
pa.RecordBatch._import_from_c(ptr_array, ptr_schema)
# Now released
with assert_schema_released:
pa.RecordBatch._import_from_c(ptr_array, ptr_schema)
@needs_cffi
def test_export_import_batch():
check_export_import_batch(make_batch)
@needs_cffi
def test_export_import_batch_with_extension():
check_export_import_batch(make_extension_batch)
def _export_import_batch_reader(ptr_stream, reader_factory):
# Prepare input
batches = make_batches()
schema = batches[0].schema
reader = reader_factory(schema, batches)
reader._export_to_c(ptr_stream)
# Delete and recreate C++ object from exported pointer
del reader, batches
reader_new = pa.RecordBatchReader._import_from_c(ptr_stream)
assert reader_new.schema == schema
got_batches = list(reader_new)
del reader_new
assert got_batches == make_batches()
# Test read_pandas()
if pd is not None:
batches = make_batches()
schema = batches[0].schema
expected_df = pa.Table.from_batches(batches).to_pandas()
reader = reader_factory(schema, batches)
reader._export_to_c(ptr_stream)
del reader, batches
reader_new = pa.RecordBatchReader._import_from_c(ptr_stream)
got_df = reader_new.read_pandas()
del reader_new
tm.assert_frame_equal(expected_df, got_df)
def make_ipc_stream_reader(schema, batches):
return pa.ipc.open_stream(make_serialized(schema, batches))
def make_py_record_batch_reader(schema, batches):
return pa.RecordBatchReader.from_batches(schema, batches)
@needs_cffi
@pytest.mark.parametrize('reader_factory',
[make_ipc_stream_reader,
make_py_record_batch_reader])
def test_export_import_batch_reader(reader_factory):
c_stream = ffi.new("struct ArrowArrayStream*")
ptr_stream = int(ffi.cast("uintptr_t", c_stream))
gc.collect() # Make sure no Arrow data dangles in a ref cycle
old_allocated = pa.total_allocated_bytes()
_export_import_batch_reader(ptr_stream, reader_factory)
assert pa.total_allocated_bytes() == old_allocated
# Now released
with assert_stream_released:
pa.RecordBatchReader._import_from_c(ptr_stream)
@needs_cffi
def test_imported_batch_reader_error():
c_stream = ffi.new("struct ArrowArrayStream*")
ptr_stream = int(ffi.cast("uintptr_t", c_stream))
schema = pa.schema([('foo', pa.int32())])
batches = [pa.record_batch([[1, 2, 3]], schema=schema),
pa.record_batch([[4, 5, 6]], schema=schema)]
buf = make_serialized(schema, batches)
# Open a corrupt/incomplete stream and export it
reader = pa.ipc.open_stream(buf[:-16])
reader._export_to_c(ptr_stream)
del reader
reader_new = pa.RecordBatchReader._import_from_c(ptr_stream)
batch = reader_new.read_next_batch()
assert batch == batches[0]
with pytest.raises(OSError,
match="Expected to be able to read 16 bytes "
"for message body, got 8"):
reader_new.read_next_batch()
# Again, but call read_all()
reader = pa.ipc.open_stream(buf[:-16])
reader._export_to_c(ptr_stream)
del reader
reader_new = pa.RecordBatchReader._import_from_c(ptr_stream)
with pytest.raises(OSError,
match="Expected to be able to read 16 bytes "
"for message body, got 8"):
reader_new.read_all()
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,792 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
UNTESTED:
read_message
"""
import sys
import sysconfig
import pytest
import pyarrow as pa
import numpy as np
cuda = pytest.importorskip("pyarrow.cuda")
platform = sysconfig.get_platform()
# TODO: enable ppc64 when Arrow C++ supports IPC in ppc64 systems:
has_ipc_support = platform == 'linux-x86_64' # or 'ppc64' in platform
cuda_ipc = pytest.mark.skipif(
not has_ipc_support,
reason='CUDA IPC not supported in platform `%s`' % (platform))
global_context = None # for flake8
global_context1 = None # for flake8
def setup_module(module):
module.global_context = cuda.Context(0)
module.global_context1 = cuda.Context(cuda.Context.get_num_devices() - 1)
def teardown_module(module):
del module.global_context
def test_Context():
assert cuda.Context.get_num_devices() > 0
assert global_context.device_number == 0
assert global_context1.device_number == cuda.Context.get_num_devices() - 1
with pytest.raises(ValueError,
match=("device_number argument must "
"be non-negative less than")):
cuda.Context(cuda.Context.get_num_devices())
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_manage_allocate_free_host(size):
buf = cuda.new_host_buffer(size)
arr = np.frombuffer(buf, dtype=np.uint8)
arr[size//4:3*size//4] = 1
arr_cp = arr.copy()
arr2 = np.frombuffer(buf, dtype=np.uint8)
np.testing.assert_equal(arr2, arr_cp)
assert buf.size == size
def test_context_allocate_del():
bytes_allocated = global_context.bytes_allocated
cudabuf = global_context.new_buffer(128)
assert global_context.bytes_allocated == bytes_allocated + 128
del cudabuf
assert global_context.bytes_allocated == bytes_allocated
def make_random_buffer(size, target='host'):
"""Return a host or device buffer with random data.
"""
if target == 'host':
assert size >= 0
buf = pa.allocate_buffer(size)
assert buf.size == size
arr = np.frombuffer(buf, dtype=np.uint8)
assert arr.size == size
arr[:] = np.random.randint(low=1, high=255, size=size, dtype=np.uint8)
assert arr.sum() > 0 or size == 0
arr_ = np.frombuffer(buf, dtype=np.uint8)
np.testing.assert_equal(arr, arr_)
return arr, buf
elif target == 'device':
arr, buf = make_random_buffer(size, target='host')
dbuf = global_context.new_buffer(size)
assert dbuf.size == size
dbuf.copy_from_host(buf, position=0, nbytes=size)
return arr, dbuf
raise ValueError('invalid target value')
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_context_device_buffer(size):
# Creating device buffer from host buffer;
arr, buf = make_random_buffer(size)
cudabuf = global_context.buffer_from_data(buf)
assert cudabuf.size == size
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
# CudaBuffer does not support buffer protocol
with pytest.raises(BufferError):
memoryview(cudabuf)
# Creating device buffer from array:
cudabuf = global_context.buffer_from_data(arr)
assert cudabuf.size == size
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
# Creating device buffer from bytes:
cudabuf = global_context.buffer_from_data(arr.tobytes())
assert cudabuf.size == size
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
# Creating a device buffer from another device buffer, view:
cudabuf2 = cudabuf.slice(0, cudabuf.size)
assert cudabuf2.size == size
arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
if size > 1:
cudabuf2.copy_from_host(arr[size//2:])
arr3 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(np.concatenate((arr[size//2:], arr[size//2:])),
arr3)
cudabuf2.copy_from_host(arr[:size//2]) # restoring arr
# Creating a device buffer from another device buffer, copy:
cudabuf2 = global_context.buffer_from_data(cudabuf)
assert cudabuf2.size == size
arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
cudabuf2.copy_from_host(arr[size//2:])
arr3 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr3)
# Slice of a device buffer
cudabuf2 = cudabuf.slice(0, cudabuf.size+10)
assert cudabuf2.size == size
arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
cudabuf2 = cudabuf.slice(size//4, size+10)
assert cudabuf2.size == size - size//4
arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr[size//4:], arr2)
# Creating a device buffer from a slice of host buffer
soffset = size//4
ssize = 2*size//4
cudabuf = global_context.buffer_from_data(buf, offset=soffset,
size=ssize)
assert cudabuf.size == ssize
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
cudabuf = global_context.buffer_from_data(buf.slice(offset=soffset,
length=ssize))
assert cudabuf.size == ssize
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
# Creating a device buffer from a slice of an array
cudabuf = global_context.buffer_from_data(arr, offset=soffset, size=ssize)
assert cudabuf.size == ssize
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
cudabuf = global_context.buffer_from_data(arr[soffset:soffset+ssize])
assert cudabuf.size == ssize
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
# Creating a device buffer from a slice of bytes
cudabuf = global_context.buffer_from_data(arr.tobytes(),
offset=soffset,
size=ssize)
assert cudabuf.size == ssize
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
# Creating a device buffer from size
cudabuf = global_context.new_buffer(size)
assert cudabuf.size == size
# Creating device buffer from a slice of another device buffer:
cudabuf = global_context.buffer_from_data(arr)
cudabuf2 = cudabuf.slice(soffset, ssize)
assert cudabuf2.size == ssize
arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr[soffset:soffset+ssize], arr2)
# Creating device buffer from HostBuffer
buf = cuda.new_host_buffer(size)
arr_ = np.frombuffer(buf, dtype=np.uint8)
arr_[:] = arr
cudabuf = global_context.buffer_from_data(buf)
assert cudabuf.size == size
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
# Creating device buffer from HostBuffer slice
cudabuf = global_context.buffer_from_data(buf, offset=soffset, size=ssize)
assert cudabuf.size == ssize
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr[soffset:soffset+ssize], arr2)
cudabuf = global_context.buffer_from_data(
buf.slice(offset=soffset, length=ssize))
assert cudabuf.size == ssize
arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr[soffset:soffset+ssize], arr2)
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_context_from_object(size):
ctx = global_context
arr, cbuf = make_random_buffer(size, target='device')
dtype = arr.dtype
# Creating device buffer from a CUDA host buffer
hbuf = cuda.new_host_buffer(size * arr.dtype.itemsize)
np.frombuffer(hbuf, dtype=dtype)[:] = arr
cbuf2 = ctx.buffer_from_object(hbuf)
assert cbuf2.size == cbuf.size
arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
np.testing.assert_equal(arr, arr2)
# Creating device buffer from a device buffer
cbuf2 = ctx.buffer_from_object(cbuf2)
assert cbuf2.size == cbuf.size
arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
np.testing.assert_equal(arr, arr2)
# Trying to create a device buffer from a Buffer
with pytest.raises(pa.ArrowTypeError,
match=('buffer is not backed by a CudaBuffer')):
ctx.buffer_from_object(pa.py_buffer(b"123"))
# Trying to create a device buffer from numpy.array
with pytest.raises(pa.ArrowTypeError,
match=("cannot create device buffer view from "
".* \'numpy.ndarray\'")):
ctx.buffer_from_object(np.array([1, 2, 3]))
def test_foreign_buffer():
ctx = global_context
dtype = np.dtype(np.uint8)
size = 10
hbuf = cuda.new_host_buffer(size * dtype.itemsize)
# test host buffer memory reference counting
rc = sys.getrefcount(hbuf)
fbuf = ctx.foreign_buffer(hbuf.address, hbuf.size, hbuf)
assert sys.getrefcount(hbuf) == rc + 1
del fbuf
assert sys.getrefcount(hbuf) == rc
# test postponed deallocation of host buffer memory
fbuf = ctx.foreign_buffer(hbuf.address, hbuf.size, hbuf)
del hbuf
fbuf.copy_to_host()
# test deallocating the host buffer memory making it inaccessible
hbuf = cuda.new_host_buffer(size * dtype.itemsize)
fbuf = ctx.foreign_buffer(hbuf.address, hbuf.size)
del hbuf
with pytest.raises(pa.ArrowIOError,
match=('Cuda error ')):
fbuf.copy_to_host()
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_CudaBuffer(size):
arr, buf = make_random_buffer(size)
assert arr.tobytes() == buf.to_pybytes()
cbuf = global_context.buffer_from_data(buf)
assert cbuf.size == size
assert not cbuf.is_cpu
assert arr.tobytes() == cbuf.to_pybytes()
if size > 0:
assert cbuf.address > 0
for i in range(size):
assert cbuf[i] == arr[i]
for s in [
slice(None),
slice(size//4, size//2),
]:
assert cbuf[s].to_pybytes() == arr[s].tobytes()
sbuf = cbuf.slice(size//4, size//2)
assert sbuf.parent == cbuf
with pytest.raises(TypeError,
match="Do not call CudaBuffer's constructor directly"):
cuda.CudaBuffer()
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_HostBuffer(size):
arr, buf = make_random_buffer(size)
assert arr.tobytes() == buf.to_pybytes()
hbuf = cuda.new_host_buffer(size)
np.frombuffer(hbuf, dtype=np.uint8)[:] = arr
assert hbuf.size == size
assert hbuf.is_cpu
assert arr.tobytes() == hbuf.to_pybytes()
for i in range(size):
assert hbuf[i] == arr[i]
for s in [
slice(None),
slice(size//4, size//2),
]:
assert hbuf[s].to_pybytes() == arr[s].tobytes()
sbuf = hbuf.slice(size//4, size//2)
assert sbuf.parent == hbuf
del hbuf
with pytest.raises(TypeError,
match="Do not call HostBuffer's constructor directly"):
cuda.HostBuffer()
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_copy_from_to_host(size):
# Create a buffer in host containing range(size)
buf = pa.allocate_buffer(size, resizable=True) # in host
assert isinstance(buf, pa.Buffer)
assert not isinstance(buf, cuda.CudaBuffer)
arr = np.frombuffer(buf, dtype=np.uint8)
assert arr.size == size
arr[:] = range(size)
arr_ = np.frombuffer(buf, dtype=np.uint8)
np.testing.assert_equal(arr, arr_)
device_buffer = global_context.new_buffer(size)
assert isinstance(device_buffer, cuda.CudaBuffer)
assert isinstance(device_buffer, pa.Buffer)
assert device_buffer.size == size
assert not device_buffer.is_cpu
device_buffer.copy_from_host(buf, position=0, nbytes=size)
buf2 = device_buffer.copy_to_host(position=0, nbytes=size)
arr2 = np.frombuffer(buf2, dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_copy_to_host(size):
arr, dbuf = make_random_buffer(size, target='device')
buf = dbuf.copy_to_host()
assert buf.is_cpu
np.testing.assert_equal(arr, np.frombuffer(buf, dtype=np.uint8))
buf = dbuf.copy_to_host(position=size//4)
assert buf.is_cpu
np.testing.assert_equal(arr[size//4:], np.frombuffer(buf, dtype=np.uint8))
buf = dbuf.copy_to_host(position=size//4, nbytes=size//8)
assert buf.is_cpu
np.testing.assert_equal(arr[size//4:size//4+size//8],
np.frombuffer(buf, dtype=np.uint8))
buf = dbuf.copy_to_host(position=size//4, nbytes=0)
assert buf.is_cpu
assert buf.size == 0
for (position, nbytes) in [
(size+2, -1), (-2, -1), (size+1, 0), (-3, 0),
]:
with pytest.raises(ValueError,
match='position argument is out-of-range'):
dbuf.copy_to_host(position=position, nbytes=nbytes)
for (position, nbytes) in [
(0, size+1), (size//2, (size+1)//2+1), (size, 1)
]:
with pytest.raises(ValueError,
match=('requested more to copy than'
' available from device buffer')):
dbuf.copy_to_host(position=position, nbytes=nbytes)
buf = pa.allocate_buffer(size//4)
dbuf.copy_to_host(buf=buf)
np.testing.assert_equal(arr[:size//4], np.frombuffer(buf, dtype=np.uint8))
if size < 12:
return
dbuf.copy_to_host(buf=buf, position=12)
np.testing.assert_equal(arr[12:12+size//4],
np.frombuffer(buf, dtype=np.uint8))
dbuf.copy_to_host(buf=buf, nbytes=12)
np.testing.assert_equal(arr[:12], np.frombuffer(buf, dtype=np.uint8)[:12])
dbuf.copy_to_host(buf=buf, nbytes=12, position=6)
np.testing.assert_equal(arr[6:6+12],
np.frombuffer(buf, dtype=np.uint8)[:12])
for (position, nbytes) in [
(0, size+10), (10, size-5),
(0, size//2), (size//4, size//4+1)
]:
with pytest.raises(ValueError,
match=('requested copy does not '
'fit into host buffer')):
dbuf.copy_to_host(buf=buf, position=position, nbytes=nbytes)
@pytest.mark.parametrize("dest_ctx", ['same', 'another'])
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_copy_from_device(dest_ctx, size):
arr, buf = make_random_buffer(size=size, target='device')
lst = arr.tolist()
if dest_ctx == 'another':
dest_ctx = global_context1
if buf.context.device_number == dest_ctx.device_number:
pytest.skip("not a multi-GPU system")
else:
dest_ctx = buf.context
dbuf = dest_ctx.new_buffer(size)
def put(*args, **kwargs):
dbuf.copy_from_device(buf, *args, **kwargs)
rbuf = dbuf.copy_to_host()
return np.frombuffer(rbuf, dtype=np.uint8).tolist()
assert put() == lst
if size > 4:
assert put(position=size//4) == lst[:size//4]+lst[:-size//4]
assert put() == lst
assert put(position=1, nbytes=size//2) == \
lst[:1] + lst[:size//2] + lst[-(size-size//2-1):]
for (position, nbytes) in [
(size+2, -1), (-2, -1), (size+1, 0), (-3, 0),
]:
with pytest.raises(ValueError,
match='position argument is out-of-range'):
put(position=position, nbytes=nbytes)
for (position, nbytes) in [
(0, size+1),
]:
with pytest.raises(ValueError,
match=('requested more to copy than'
' available from device buffer')):
put(position=position, nbytes=nbytes)
if size < 4:
return
for (position, nbytes) in [
(size//2, (size+1)//2+1)
]:
with pytest.raises(ValueError,
match=('requested more to copy than'
' available in device buffer')):
put(position=position, nbytes=nbytes)
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_copy_from_host(size):
arr, buf = make_random_buffer(size=size, target='host')
lst = arr.tolist()
dbuf = global_context.new_buffer(size)
def put(*args, **kwargs):
dbuf.copy_from_host(buf, *args, **kwargs)
rbuf = dbuf.copy_to_host()
return np.frombuffer(rbuf, dtype=np.uint8).tolist()
assert put() == lst
if size > 4:
assert put(position=size//4) == lst[:size//4]+lst[:-size//4]
assert put() == lst
assert put(position=1, nbytes=size//2) == \
lst[:1] + lst[:size//2] + lst[-(size-size//2-1):]
for (position, nbytes) in [
(size+2, -1), (-2, -1), (size+1, 0), (-3, 0),
]:
with pytest.raises(ValueError,
match='position argument is out-of-range'):
put(position=position, nbytes=nbytes)
for (position, nbytes) in [
(0, size+1),
]:
with pytest.raises(ValueError,
match=('requested more to copy than'
' available from host buffer')):
put(position=position, nbytes=nbytes)
if size < 4:
return
for (position, nbytes) in [
(size//2, (size+1)//2+1)
]:
with pytest.raises(ValueError,
match=('requested more to copy than'
' available in device buffer')):
put(position=position, nbytes=nbytes)
def test_BufferWriter():
def allocate(size):
cbuf = global_context.new_buffer(size)
writer = cuda.BufferWriter(cbuf)
return cbuf, writer
def test_writes(total_size, chunksize, buffer_size=0):
cbuf, writer = allocate(total_size)
arr, buf = make_random_buffer(size=total_size, target='host')
if buffer_size > 0:
writer.buffer_size = buffer_size
position = writer.tell()
assert position == 0
writer.write(buf.slice(length=chunksize))
assert writer.tell() == chunksize
writer.seek(0)
position = writer.tell()
assert position == 0
while position < total_size:
bytes_to_write = min(chunksize, total_size - position)
writer.write(buf.slice(offset=position, length=bytes_to_write))
position += bytes_to_write
writer.flush()
assert cbuf.size == total_size
cbuf.context.synchronize()
buf2 = cbuf.copy_to_host()
cbuf.context.synchronize()
assert buf2.size == total_size
arr2 = np.frombuffer(buf2, dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
total_size, chunk_size = 1 << 16, 1000
test_writes(total_size, chunk_size)
test_writes(total_size, chunk_size, total_size // 16)
cbuf, writer = allocate(100)
writer.write(np.arange(100, dtype=np.uint8))
writer.writeat(50, np.arange(25, dtype=np.uint8))
writer.write(np.arange(25, dtype=np.uint8))
writer.flush()
arr = np.frombuffer(cbuf.copy_to_host(), np.uint8)
np.testing.assert_equal(arr[:50], np.arange(50, dtype=np.uint8))
np.testing.assert_equal(arr[50:75], np.arange(25, dtype=np.uint8))
np.testing.assert_equal(arr[75:], np.arange(25, dtype=np.uint8))
def test_BufferWriter_edge_cases():
# edge cases, see cuda-test.cc for more information:
size = 1000
cbuf = global_context.new_buffer(size)
writer = cuda.BufferWriter(cbuf)
arr, buf = make_random_buffer(size=size, target='host')
assert writer.buffer_size == 0
writer.buffer_size = 100
assert writer.buffer_size == 100
writer.write(buf.slice(length=0))
assert writer.tell() == 0
writer.write(buf.slice(length=10))
writer.buffer_size = 200
assert writer.buffer_size == 200
assert writer.num_bytes_buffered == 0
writer.write(buf.slice(offset=10, length=300))
assert writer.num_bytes_buffered == 0
writer.write(buf.slice(offset=310, length=200))
assert writer.num_bytes_buffered == 0
writer.write(buf.slice(offset=510, length=390))
writer.write(buf.slice(offset=900, length=100))
writer.flush()
buf2 = cbuf.copy_to_host()
assert buf2.size == size
arr2 = np.frombuffer(buf2, dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
def test_BufferReader():
size = 1000
arr, cbuf = make_random_buffer(size=size, target='device')
reader = cuda.BufferReader(cbuf)
reader.seek(950)
assert reader.tell() == 950
data = reader.read(100)
assert len(data) == 50
assert reader.tell() == 1000
reader.seek(925)
arr2 = np.zeros(100, dtype=np.uint8)
n = reader.readinto(arr2)
assert n == 75
assert reader.tell() == 1000
np.testing.assert_equal(arr[925:], arr2[:75])
reader.seek(0)
assert reader.tell() == 0
buf2 = reader.read_buffer()
arr2 = np.frombuffer(buf2.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
def test_BufferReader_zero_size():
arr, cbuf = make_random_buffer(size=0, target='device')
reader = cuda.BufferReader(cbuf)
reader.seek(0)
data = reader.read()
assert len(data) == 0
assert reader.tell() == 0
buf2 = reader.read_buffer()
arr2 = np.frombuffer(buf2.copy_to_host(), dtype=np.uint8)
np.testing.assert_equal(arr, arr2)
def make_recordbatch(length):
schema = pa.schema([pa.field('f0', pa.int16()),
pa.field('f1', pa.int16())])
a0 = pa.array(np.random.randint(0, 255, size=length, dtype=np.int16))
a1 = pa.array(np.random.randint(0, 255, size=length, dtype=np.int16))
batch = pa.record_batch([a0, a1], schema=schema)
return batch
def test_batch_serialize():
batch = make_recordbatch(10)
hbuf = batch.serialize()
cbuf = cuda.serialize_record_batch(batch, global_context)
# Test that read_record_batch works properly
cbatch = cuda.read_record_batch(cbuf, batch.schema)
assert isinstance(cbatch, pa.RecordBatch)
assert batch.schema == cbatch.schema
assert batch.num_columns == cbatch.num_columns
assert batch.num_rows == cbatch.num_rows
# Deserialize CUDA-serialized batch on host
buf = cbuf.copy_to_host()
assert hbuf.equals(buf)
batch2 = pa.ipc.read_record_batch(buf, batch.schema)
assert hbuf.equals(batch2.serialize())
assert batch.num_columns == batch2.num_columns
assert batch.num_rows == batch2.num_rows
assert batch.column(0).equals(batch2.column(0))
assert batch.equals(batch2)
def make_table():
a0 = pa.array([0, 1, 42, None], type=pa.int16())
a1 = pa.array([[0, 1], [2], [], None], type=pa.list_(pa.int32()))
a2 = pa.array([("ab", True), ("cde", False), (None, None), None],
type=pa.struct([("strs", pa.utf8()),
("bools", pa.bool_())]))
# Dictionaries are validated on the IPC read path, but that can produce
# issues for GPU-located dictionaries. Check that they work fine.
a3 = pa.DictionaryArray.from_arrays(
indices=[0, 1, 1, None],
dictionary=pa.array(['foo', 'bar']))
a4 = pa.DictionaryArray.from_arrays(
indices=[2, 1, 2, None],
dictionary=a1)
a5 = pa.DictionaryArray.from_arrays(
indices=[2, 1, 0, None],
dictionary=a2)
arrays = [a0, a1, a2, a3, a4, a5]
schema = pa.schema([('f{}'.format(i), arr.type)
for i, arr in enumerate(arrays)])
batch = pa.record_batch(arrays, schema=schema)
table = pa.Table.from_batches([batch])
return table
def make_table_cuda():
htable = make_table()
# Serialize the host table to bytes
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, htable.schema) as out:
out.write_table(htable)
hbuf = pa.py_buffer(sink.getvalue().to_pybytes())
# Copy the host bytes to a device buffer
dbuf = global_context.new_buffer(len(hbuf))
dbuf.copy_from_host(hbuf, nbytes=len(hbuf))
# Deserialize the device buffer into a Table
dtable = pa.ipc.open_stream(cuda.BufferReader(dbuf)).read_all()
return hbuf, htable, dbuf, dtable
def test_table_deserialize():
# ARROW-9659: make sure that we can deserialize a GPU-located table
# without crashing when initializing or validating the underlying arrays.
hbuf, htable, dbuf, dtable = make_table_cuda()
# Assert basic fields the same between host and device tables
assert htable.schema == dtable.schema
assert htable.num_rows == dtable.num_rows
assert htable.num_columns == dtable.num_columns
# Assert byte-level equality
assert hbuf.equals(dbuf.copy_to_host())
# Copy DtoH and assert the tables are still equivalent
assert htable.equals(pa.ipc.open_stream(
dbuf.copy_to_host()
).read_all())
def test_create_table_with_device_buffers():
# ARROW-11872: make sure that we can create an Arrow Table from
# GPU-located Arrays without crashing.
hbuf, htable, dbuf, dtable = make_table_cuda()
# Construct a new Table from the device Table
dtable2 = pa.Table.from_arrays(dtable.columns, dtable.column_names)
# Assert basic fields the same between host and device tables
assert htable.schema == dtable2.schema
assert htable.num_rows == dtable2.num_rows
assert htable.num_columns == dtable2.num_columns
# Assert byte-level equality
assert hbuf.equals(dbuf.copy_to_host())
# Copy DtoH and assert the tables are still equivalent
assert htable.equals(pa.ipc.open_stream(
dbuf.copy_to_host()
).read_all())
def other_process_for_test_IPC(handle_buffer, expected_arr):
other_context = pa.cuda.Context(0)
ipc_handle = pa.cuda.IpcMemHandle.from_buffer(handle_buffer)
ipc_buf = other_context.open_ipc_buffer(ipc_handle)
ipc_buf.context.synchronize()
buf = ipc_buf.copy_to_host()
assert buf.size == expected_arr.size, repr((buf.size, expected_arr.size))
arr = np.frombuffer(buf, dtype=expected_arr.dtype)
np.testing.assert_equal(arr, expected_arr)
@cuda_ipc
@pytest.mark.parametrize("size", [0, 1, 1000])
def test_IPC(size):
import multiprocessing
ctx = multiprocessing.get_context('spawn')
arr, cbuf = make_random_buffer(size=size, target='device')
ipc_handle = cbuf.export_for_ipc()
handle_buffer = ipc_handle.serialize()
p = ctx.Process(target=other_process_for_test_IPC,
args=(handle_buffer, arr))
p.start()
p.join()
assert p.exitcode == 0
@@ -0,0 +1,235 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import pyarrow as pa
import numpy as np
dtypes = ['uint8', 'int16', 'float32']
cuda = pytest.importorskip("pyarrow.cuda")
nb_cuda = pytest.importorskip("numba.cuda")
from numba.cuda.cudadrv.devicearray import DeviceNDArray # noqa: E402
context_choices = None
context_choice_ids = ['pyarrow.cuda', 'numba.cuda']
def setup_module(module):
np.random.seed(1234)
ctx1 = cuda.Context()
nb_ctx1 = ctx1.to_numba()
nb_ctx2 = nb_cuda.current_context()
ctx2 = cuda.Context.from_numba(nb_ctx2)
module.context_choices = [(ctx1, nb_ctx1), (ctx2, nb_ctx2)]
def teardown_module(module):
del module.context_choices
@pytest.mark.parametrize("c", range(len(context_choice_ids)),
ids=context_choice_ids)
def test_context(c):
ctx, nb_ctx = context_choices[c]
assert ctx.handle == nb_ctx.handle.value
assert ctx.handle == ctx.to_numba().handle.value
ctx2 = cuda.Context.from_numba(nb_ctx)
assert ctx.handle == ctx2.handle
size = 10
buf = ctx.new_buffer(size)
assert ctx.handle == buf.context.handle
def make_random_buffer(size, target='host', dtype='uint8', ctx=None):
"""Return a host or device buffer with random data.
"""
dtype = np.dtype(dtype)
if target == 'host':
assert size >= 0
buf = pa.allocate_buffer(size*dtype.itemsize)
arr = np.frombuffer(buf, dtype=dtype)
arr[:] = np.random.randint(low=0, high=255, size=size,
dtype=np.uint8)
return arr, buf
elif target == 'device':
arr, buf = make_random_buffer(size, target='host', dtype=dtype)
dbuf = ctx.new_buffer(size * dtype.itemsize)
dbuf.copy_from_host(buf, position=0, nbytes=buf.size)
return arr, dbuf
raise ValueError('invalid target value')
@pytest.mark.parametrize("c", range(len(context_choice_ids)),
ids=context_choice_ids)
@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
@pytest.mark.parametrize("size", [0, 1, 8, 1000])
def test_from_object(c, dtype, size):
ctx, nb_ctx = context_choices[c]
arr, cbuf = make_random_buffer(size, target='device', dtype=dtype, ctx=ctx)
# Creating device buffer from numba DeviceNDArray:
darr = nb_cuda.to_device(arr)
cbuf2 = ctx.buffer_from_object(darr)
assert cbuf2.size == cbuf.size
arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
np.testing.assert_equal(arr, arr2)
# Creating device buffer from a slice of numba DeviceNDArray:
if size >= 8:
# 1-D arrays
for s in [slice(size//4, None, None),
slice(size//4, -(size//4), None)]:
cbuf2 = ctx.buffer_from_object(darr[s])
arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
np.testing.assert_equal(arr[s], arr2)
# cannot test negative strides due to numba bug, see its issue 3705
if 0:
rdarr = darr[::-1]
cbuf2 = ctx.buffer_from_object(rdarr)
assert cbuf2.size == cbuf.size
arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
np.testing.assert_equal(arr, arr2)
with pytest.raises(ValueError,
match=('array data is non-contiguous')):
ctx.buffer_from_object(darr[::2])
# a rectangular 2-D array
s1 = size//4
s2 = size//s1
assert s1 * s2 == size
cbuf2 = ctx.buffer_from_object(darr.reshape(s1, s2))
assert cbuf2.size == cbuf.size
arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
np.testing.assert_equal(arr, arr2)
with pytest.raises(ValueError,
match=('array data is non-contiguous')):
ctx.buffer_from_object(darr.reshape(s1, s2)[:, ::2])
# a 3-D array
s1 = 4
s2 = size//8
s3 = size//(s1*s2)
assert s1 * s2 * s3 == size
cbuf2 = ctx.buffer_from_object(darr.reshape(s1, s2, s3))
assert cbuf2.size == cbuf.size
arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
np.testing.assert_equal(arr, arr2)
with pytest.raises(ValueError,
match=('array data is non-contiguous')):
ctx.buffer_from_object(darr.reshape(s1, s2, s3)[::2])
# Creating device buffer from am object implementing cuda array
# interface:
class MyObj:
def __init__(self, darr):
self.darr = darr
@property
def __cuda_array_interface__(self):
return self.darr.__cuda_array_interface__
cbuf2 = ctx.buffer_from_object(MyObj(darr))
assert cbuf2.size == cbuf.size
arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
np.testing.assert_equal(arr, arr2)
@pytest.mark.parametrize("c", range(len(context_choice_ids)),
ids=context_choice_ids)
@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
def test_numba_memalloc(c, dtype):
ctx, nb_ctx = context_choices[c]
dtype = np.dtype(dtype)
# Allocate memory using numba context
# Warning: this will not be reflected in pyarrow context manager
# (e.g bytes_allocated does not change)
size = 10
mem = nb_ctx.memalloc(size * dtype.itemsize)
darr = DeviceNDArray((size,), (dtype.itemsize,), dtype, gpu_data=mem)
darr[:5] = 99
darr[5:] = 88
np.testing.assert_equal(darr.copy_to_host()[:5], 99)
np.testing.assert_equal(darr.copy_to_host()[5:], 88)
# wrap numba allocated memory with CudaBuffer
cbuf = cuda.CudaBuffer.from_numba(mem)
arr2 = np.frombuffer(cbuf.copy_to_host(), dtype=dtype)
np.testing.assert_equal(arr2, darr.copy_to_host())
@pytest.mark.parametrize("c", range(len(context_choice_ids)),
ids=context_choice_ids)
@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
def test_pyarrow_memalloc(c, dtype):
ctx, nb_ctx = context_choices[c]
size = 10
arr, cbuf = make_random_buffer(size, target='device', dtype=dtype, ctx=ctx)
# wrap CudaBuffer with numba device array
mem = cbuf.to_numba()
darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem)
np.testing.assert_equal(darr.copy_to_host(), arr)
@pytest.mark.parametrize("c", range(len(context_choice_ids)),
ids=context_choice_ids)
@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
def test_numba_context(c, dtype):
ctx, nb_ctx = context_choices[c]
size = 10
with nb_cuda.gpus[0]:
arr, cbuf = make_random_buffer(size, target='device',
dtype=dtype, ctx=ctx)
assert cbuf.context.handle == nb_ctx.handle.value
mem = cbuf.to_numba()
darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem)
np.testing.assert_equal(darr.copy_to_host(), arr)
darr[0] = 99
cbuf.context.synchronize()
arr2 = np.frombuffer(cbuf.copy_to_host(), dtype=dtype)
assert arr2[0] == 99
@pytest.mark.parametrize("c", range(len(context_choice_ids)),
ids=context_choice_ids)
@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
def test_pyarrow_jit(c, dtype):
ctx, nb_ctx = context_choices[c]
@nb_cuda.jit
def increment_by_one(an_array):
pos = nb_cuda.grid(1)
if pos < an_array.size:
an_array[pos] += 1
# applying numba.cuda kernel to memory hold by CudaBuffer
size = 10
arr, cbuf = make_random_buffer(size, target='device', dtype=dtype, ctx=ctx)
threadsperblock = 32
blockspergrid = (arr.size + (threadsperblock - 1)) // threadsperblock
mem = cbuf.to_numba()
darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem)
increment_by_one[blockspergrid, threadsperblock](darr)
cbuf.context.synchronize()
arr1 = np.frombuffer(cbuf.copy_to_host(), dtype=arr.dtype)
np.testing.assert_equal(arr1, arr + 1)
@@ -0,0 +1,180 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import shutil
import subprocess
import sys
import pytest
import pyarrow as pa
import pyarrow.tests.util as test_util
here = os.path.dirname(os.path.abspath(__file__))
test_ld_path = os.environ.get('PYARROW_TEST_LD_PATH', '')
if os.name == 'posix':
compiler_opts = ['-std=c++11']
else:
compiler_opts = []
setup_template = """if 1:
from setuptools import setup
from Cython.Build import cythonize
import numpy as np
import pyarrow as pa
ext_modules = cythonize({pyx_file!r})
compiler_opts = {compiler_opts!r}
custom_ld_path = {test_ld_path!r}
for ext in ext_modules:
# XXX required for numpy/numpyconfig.h,
# included from arrow/python/api.h
ext.include_dirs.append(np.get_include())
ext.include_dirs.append(pa.get_include())
ext.libraries.extend(pa.get_libraries())
ext.library_dirs.extend(pa.get_library_dirs())
if custom_ld_path:
ext.library_dirs.append(custom_ld_path)
ext.extra_compile_args.extend(compiler_opts)
print("Extension module:",
ext, ext.include_dirs, ext.libraries, ext.library_dirs)
setup(
ext_modules=ext_modules,
)
"""
def check_cython_example_module(mod):
arr = pa.array([1, 2, 3])
assert mod.get_array_length(arr) == 3
with pytest.raises(TypeError, match="not an array"):
mod.get_array_length(None)
scal = pa.scalar(123)
cast_scal = mod.cast_scalar(scal, pa.utf8())
assert cast_scal == pa.scalar("123")
with pytest.raises(NotImplementedError,
match="casting scalars of type int64 to type list"):
mod.cast_scalar(scal, pa.list_(pa.int64()))
@pytest.mark.cython
def test_cython_api(tmpdir):
"""
Basic test for the Cython API.
"""
# Fail early if cython is not found
import cython # noqa
with tmpdir.as_cwd():
# Set up temporary workspace
pyx_file = 'pyarrow_cython_example.pyx'
shutil.copyfile(os.path.join(here, pyx_file),
os.path.join(str(tmpdir), pyx_file))
# Create setup.py file
setup_code = setup_template.format(pyx_file=pyx_file,
compiler_opts=compiler_opts,
test_ld_path=test_ld_path)
with open('setup.py', 'w') as f:
f.write(setup_code)
# ARROW-2263: Make environment with this pyarrow/ package first on the
# PYTHONPATH, for local dev environments
subprocess_env = test_util.get_modified_env_with_pythonpath()
# Compile extension module
subprocess.check_call([sys.executable, 'setup.py',
'build_ext', '--inplace'],
env=subprocess_env)
# Check basic functionality
orig_path = sys.path[:]
sys.path.insert(0, str(tmpdir))
try:
mod = __import__('pyarrow_cython_example')
check_cython_example_module(mod)
finally:
sys.path = orig_path
# Check the extension module is loadable from a subprocess without
# pyarrow imported first.
code = """if 1:
import sys
mod = __import__({mod_name!r})
arr = mod.make_null_array(5)
assert mod.get_array_length(arr) == 5
assert arr.null_count == 5
""".format(mod_name='pyarrow_cython_example')
if sys.platform == 'win32':
delim, var = ';', 'PATH'
else:
delim, var = ':', 'LD_LIBRARY_PATH'
subprocess_env[var] = delim.join(
pa.get_library_dirs() + [subprocess_env.get(var, '')]
)
subprocess.check_call([sys.executable, '-c', code],
stdout=subprocess.PIPE,
env=subprocess_env)
@pytest.mark.cython
def test_visit_strings(tmpdir):
with tmpdir.as_cwd():
# Set up temporary workspace
pyx_file = 'bound_function_visit_strings.pyx'
shutil.copyfile(os.path.join(here, pyx_file),
os.path.join(str(tmpdir), pyx_file))
# Create setup.py file
setup_code = setup_template.format(pyx_file=pyx_file,
compiler_opts=compiler_opts,
test_ld_path=test_ld_path)
with open('setup.py', 'w') as f:
f.write(setup_code)
subprocess_env = test_util.get_modified_env_with_pythonpath()
# Compile extension module
subprocess.check_call([sys.executable, 'setup.py',
'build_ext', '--inplace'],
env=subprocess_env)
sys.path.insert(0, str(tmpdir))
mod = __import__('bound_function_visit_strings')
strings = ['a', 'b', 'c']
visited = []
mod._visit_strings(strings, visited.append)
assert visited == strings
with pytest.raises(ValueError, match="wtf"):
def raise_on_b(s):
if s == 'b':
raise ValueError('wtf')
mod._visit_strings(strings, raise_on_b)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Check that various deprecation warnings are raised
# flake8: noqa
import pyarrow as pa
import pytest
@@ -0,0 +1,192 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import pyarrow as pa
try:
import pyarrow.dataset as ds
import pyarrow._exec_plan as ep
except ImportError:
pass
pytestmark = pytest.mark.dataset
def test_joins_corner_cases():
t1 = pa.Table.from_pydict({
"colA": [1, 2, 3, 4, 5, 6],
"col2": ["a", "b", "c", "d", "e", "f"]
})
t2 = pa.Table.from_pydict({
"colB": [1, 2, 3, 4, 5],
"col3": ["A", "B", "C", "D", "E"]
})
with pytest.raises(pa.ArrowInvalid):
ep._perform_join("left outer", t1, "", t2, "")
with pytest.raises(TypeError):
ep._perform_join("left outer", None, "colA", t2, "colB")
with pytest.raises(ValueError):
ep._perform_join("super mario join", t1, "colA", t2, "colB")
@pytest.mark.parametrize("jointype,expected", [
("left semi", {
"colA": [1, 2],
"col2": ["a", "b"]
}),
("right semi", {
"colB": [1, 2],
"col3": ["A", "B"]
}),
("left anti", {
"colA": [6],
"col2": ["f"]
}),
("right anti", {
"colB": [99],
"col3": ["Z"]
}),
("inner", {
"colA": [1, 2],
"col2": ["a", "b"],
"col3": ["A", "B"]
}),
("left outer", {
"colA": [1, 2, 6],
"col2": ["a", "b", "f"],
"col3": ["A", "B", None]
}),
("right outer", {
"col2": ["a", "b", None],
"colB": [1, 2, 99],
"col3": ["A", "B", "Z"]
}),
("full outer", {
"colA": [1, 2, 6, 99],
"col2": ["a", "b", "f", None],
"col3": ["A", "B", None, "Z"]
})
])
@pytest.mark.parametrize("use_threads", [True, False])
@pytest.mark.parametrize("use_datasets", [False, True])
def test_joins(jointype, expected, use_threads, use_datasets):
# Allocate table here instead of using parametrize
# this prevents having arrow allocated memory forever around.
expected = pa.table(expected)
t1 = pa.Table.from_pydict({
"colA": [1, 2, 6],
"col2": ["a", "b", "f"]
})
t2 = pa.Table.from_pydict({
"colB": [99, 2, 1],
"col3": ["Z", "B", "A"]
})
if use_datasets:
t1 = ds.dataset([t1])
t2 = ds.dataset([t2])
r = ep._perform_join(jointype, t1, "colA", t2, "colB",
use_threads=use_threads, coalesce_keys=True)
r = r.combine_chunks()
if "right" in jointype:
r = r.sort_by("colB")
else:
r = r.sort_by("colA")
assert r == expected
def test_table_join_collisions():
t1 = pa.table({
"colA": [1, 2, 6],
"colB": [10, 20, 60],
"colVals": ["a", "b", "f"]
})
t2 = pa.table({
"colB": [99, 20, 10],
"colVals": ["Z", "B", "A"],
"colUniq": [100, 200, 300],
"colA": [99, 2, 1],
})
result = ep._perform_join(
"full outer", t1, ["colA", "colB"], t2, ["colA", "colB"])
assert result.combine_chunks() == pa.table([
[1, 2, 6, None],
[10, 20, 60, None],
["a", "b", "f", None],
[10, 20, None, 99],
["A", "B", None, "Z"],
[300, 200, None, 100],
[1, 2, None, 99],
], names=["colA", "colB", "colVals", "colB", "colVals", "colUniq", "colA"])
result = ep._perform_join("full outer", t1, "colA",
t2, "colA", right_suffix="_r",
coalesce_keys=False)
assert result.combine_chunks() == pa.table({
"colA": [1, 2, 6, None],
"colB": [10, 20, 60, None],
"colVals": ["a", "b", "f", None],
"colB_r": [10, 20, None, 99],
"colVals_r": ["A", "B", None, "Z"],
"colUniq": [300, 200, None, 100],
"colA_r": [1, 2, None, 99],
})
result = ep._perform_join("full outer", t1, "colA",
t2, "colA", right_suffix="_r",
coalesce_keys=True)
assert result.combine_chunks() == pa.table({
"colA": [1, 2, 6, 99],
"colB": [10, 20, 60, None],
"colVals": ["a", "b", "f", None],
"colB_r": [10, 20, None, 99],
"colVals_r": ["A", "B", None, "Z"],
"colUniq": [300, 200, None, 100]
})
def test_table_join_keys_order():
t1 = pa.table({
"colB": [10, 20, 60],
"colA": [1, 2, 6],
"colVals": ["a", "b", "f"]
})
t2 = pa.table({
"colVals": ["Z", "B", "A"],
"colX": [99, 2, 1],
})
result = ep._perform_join("full outer", t1, "colA", t2, "colX",
left_suffix="_l", right_suffix="_r",
coalesce_keys=True)
assert result.combine_chunks() == pa.table({
"colB": [10, 20, 60, None],
"colA": [1, 2, 6, 99],
"colVals_l": ["a", "b", "f", None],
"colVals_r": ["A", "B", None, "Z"],
})
@@ -0,0 +1,852 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pickle
import weakref
import numpy as np
import pyarrow as pa
import pytest
class IntegerType(pa.PyExtensionType):
def __init__(self):
pa.PyExtensionType.__init__(self, pa.int64())
def __reduce__(self):
return IntegerType, ()
class UuidType(pa.PyExtensionType):
def __init__(self):
pa.PyExtensionType.__init__(self, pa.binary(16))
def __reduce__(self):
return UuidType, ()
class ParamExtType(pa.PyExtensionType):
def __init__(self, width):
self._width = width
pa.PyExtensionType.__init__(self, pa.binary(width))
@property
def width(self):
return self._width
def __reduce__(self):
return ParamExtType, (self.width,)
class MyStructType(pa.PyExtensionType):
storage_type = pa.struct([('left', pa.int64()),
('right', pa.int64())])
def __init__(self):
pa.PyExtensionType.__init__(self, self.storage_type)
def __reduce__(self):
return MyStructType, ()
class MyListType(pa.PyExtensionType):
def __init__(self, storage_type):
pa.PyExtensionType.__init__(self, storage_type)
def __reduce__(self):
return MyListType, (self.storage_type,)
class AnnotatedType(pa.PyExtensionType):
"""
Generic extension type that can store any storage type.
"""
def __init__(self, storage_type, annotation):
self.annotation = annotation
super().__init__(storage_type)
def __reduce__(self):
return AnnotatedType, (self.storage_type, self.annotation)
def ipc_write_batch(batch):
stream = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
writer.close()
return stream.getvalue()
def ipc_read_batch(buf):
reader = pa.RecordBatchStreamReader(buf)
return reader.read_next_batch()
def test_ext_type_basics():
ty = UuidType()
assert ty.extension_name == "arrow.py_extension_type"
def test_ext_type_str():
ty = IntegerType()
expected = "extension<arrow.py_extension_type<IntegerType>>"
assert str(ty) == expected
assert pa.DataType.__str__(ty) == expected
def test_ext_type_repr():
ty = IntegerType()
assert repr(ty) == "IntegerType(DataType(int64))"
def test_ext_type__lifetime():
ty = UuidType()
wr = weakref.ref(ty)
del ty
assert wr() is None
def test_ext_type__storage_type():
ty = UuidType()
assert ty.storage_type == pa.binary(16)
assert ty.__class__ is UuidType
ty = ParamExtType(5)
assert ty.storage_type == pa.binary(5)
assert ty.__class__ is ParamExtType
def test_uuid_type_pickle():
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
ty = UuidType()
ser = pickle.dumps(ty, protocol=proto)
del ty
ty = pickle.loads(ser)
wr = weakref.ref(ty)
assert ty.extension_name == "arrow.py_extension_type"
del ty
assert wr() is None
def test_ext_type_equality():
a = ParamExtType(5)
b = ParamExtType(6)
c = ParamExtType(6)
assert a != b
assert b == c
d = UuidType()
e = UuidType()
assert a != d
assert d == e
def test_ext_array_basics():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
arr.validate()
assert arr.type is ty
assert arr.storage.equals(storage)
def test_ext_array_lifetime():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
refs = [weakref.ref(ty), weakref.ref(arr), weakref.ref(storage)]
del ty, storage, arr
for ref in refs:
assert ref() is None
def test_ext_array_to_pylist():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar", None], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
assert arr.to_pylist() == [b"foo", b"bar", None]
def test_ext_array_errors():
ty = ParamExtType(4)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
with pytest.raises(TypeError, match="Incompatible storage type"):
pa.ExtensionArray.from_storage(ty, storage)
def test_ext_array_equality():
storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
storage3 = pa.array([], type=pa.binary(16))
ty1 = UuidType()
ty2 = ParamExtType(16)
a = pa.ExtensionArray.from_storage(ty1, storage1)
b = pa.ExtensionArray.from_storage(ty1, storage2)
assert a.equals(b)
c = pa.ExtensionArray.from_storage(ty1, storage3)
assert not a.equals(c)
d = pa.ExtensionArray.from_storage(ty2, storage1)
assert not a.equals(d)
e = pa.ExtensionArray.from_storage(ty2, storage2)
assert d.equals(e)
f = pa.ExtensionArray.from_storage(ty2, storage3)
assert not d.equals(f)
def test_ext_array_wrap_array():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar", None], type=pa.binary(3))
arr = ty.wrap_array(storage)
arr.validate(full=True)
assert isinstance(arr, pa.ExtensionArray)
assert arr.type == ty
assert arr.storage == storage
storage = pa.chunked_array([[b"abc", b"def"], [b"ghi"]],
type=pa.binary(3))
arr = ty.wrap_array(storage)
arr.validate(full=True)
assert isinstance(arr, pa.ChunkedArray)
assert arr.type == ty
assert arr.chunk(0).storage == storage.chunk(0)
assert arr.chunk(1).storage == storage.chunk(1)
# Wrong storage type
storage = pa.array([b"foo", b"bar", None])
with pytest.raises(TypeError, match="Incompatible storage type"):
ty.wrap_array(storage)
# Not an array or chunked array
with pytest.raises(TypeError, match="Expected array or chunked array"):
ty.wrap_array(None)
def test_ext_scalar_from_array():
data = [b"0123456789abcdef", b"0123456789abcdef",
b"zyxwvutsrqponmlk", None]
storage = pa.array(data, type=pa.binary(16))
ty1 = UuidType()
ty2 = ParamExtType(16)
a = pa.ExtensionArray.from_storage(ty1, storage)
b = pa.ExtensionArray.from_storage(ty2, storage)
scalars_a = list(a)
assert len(scalars_a) == 4
for s, val in zip(scalars_a, data):
assert isinstance(s, pa.ExtensionScalar)
assert s.is_valid == (val is not None)
assert s.type == ty1
if val is not None:
assert s.value == pa.scalar(val, storage.type)
else:
assert s.value is None
assert s.as_py() == val
scalars_b = list(b)
assert len(scalars_b) == 4
for sa, sb in zip(scalars_a, scalars_b):
assert sa.is_valid == sb.is_valid
assert sa.as_py() == sb.as_py()
assert sa != sb
def test_ext_scalar_from_storage():
ty = UuidType()
s = pa.ExtensionScalar.from_storage(ty, None)
assert isinstance(s, pa.ExtensionScalar)
assert s.type == ty
assert s.is_valid is False
assert s.value is None
s = pa.ExtensionScalar.from_storage(ty, b"0123456789abcdef")
assert isinstance(s, pa.ExtensionScalar)
assert s.type == ty
assert s.is_valid is True
assert s.value == pa.scalar(b"0123456789abcdef", ty.storage_type)
s = pa.ExtensionScalar.from_storage(ty, pa.scalar(None, ty.storage_type))
assert isinstance(s, pa.ExtensionScalar)
assert s.type == ty
assert s.is_valid is False
assert s.value is None
s = pa.ExtensionScalar.from_storage(
ty, pa.scalar(b"0123456789abcdef", ty.storage_type))
assert isinstance(s, pa.ExtensionScalar)
assert s.type == ty
assert s.is_valid is True
assert s.value == pa.scalar(b"0123456789abcdef", ty.storage_type)
def test_ext_array_pickling():
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
ser = pickle.dumps(arr, protocol=proto)
del ty, storage, arr
arr = pickle.loads(ser)
arr.validate()
assert isinstance(arr, pa.ExtensionArray)
assert arr.type == ParamExtType(3)
assert arr.type.storage_type == pa.binary(3)
assert arr.storage.type == pa.binary(3)
assert arr.storage.to_pylist() == [b"foo", b"bar"]
def test_ext_array_conversion_to_numpy():
storage1 = pa.array([1, 2, 3], type=pa.int64())
storage2 = pa.array([b"123", b"456", b"789"], type=pa.binary(3))
ty1 = IntegerType()
ty2 = ParamExtType(3)
arr1 = pa.ExtensionArray.from_storage(ty1, storage1)
arr2 = pa.ExtensionArray.from_storage(ty2, storage2)
result = arr1.to_numpy()
expected = np.array([1, 2, 3], dtype="int64")
np.testing.assert_array_equal(result, expected)
with pytest.raises(ValueError, match="zero_copy_only was True"):
arr2.to_numpy()
result = arr2.to_numpy(zero_copy_only=False)
expected = np.array([b"123", b"456", b"789"])
np.testing.assert_array_equal(result, expected)
@pytest.mark.pandas
def test_ext_array_conversion_to_pandas():
import pandas as pd
storage1 = pa.array([1, 2, 3], type=pa.int64())
storage2 = pa.array([b"123", b"456", b"789"], type=pa.binary(3))
ty1 = IntegerType()
ty2 = ParamExtType(3)
arr1 = pa.ExtensionArray.from_storage(ty1, storage1)
arr2 = pa.ExtensionArray.from_storage(ty2, storage2)
result = arr1.to_pandas()
expected = pd.Series([1, 2, 3], dtype="int64")
pd.testing.assert_series_equal(result, expected)
result = arr2.to_pandas()
expected = pd.Series([b"123", b"456", b"789"], dtype=object)
pd.testing.assert_series_equal(result, expected)
@pytest.fixture
def struct_w_ext_data():
storage1 = pa.array([1, 2, 3], type=pa.int64())
storage2 = pa.array([b"123", b"456", b"789"], type=pa.binary(3))
ty1 = IntegerType()
ty2 = ParamExtType(3)
arr1 = pa.ExtensionArray.from_storage(ty1, storage1)
arr2 = pa.ExtensionArray.from_storage(ty2, storage2)
sarr1 = pa.StructArray.from_arrays([arr1], ["f0"])
sarr2 = pa.StructArray.from_arrays([arr2], ["f1"])
return [sarr1, sarr2]
def test_struct_w_ext_array_to_numpy(struct_w_ext_data):
# ARROW-15291
# Check that we don't segfault when trying to build
# a numpy array from a StructArray with a field being
# an ExtensionArray
result = struct_w_ext_data[0].to_numpy(zero_copy_only=False)
expected = np.array([{'f0': 1}, {'f0': 2},
{'f0': 3}], dtype=object)
np.testing.assert_array_equal(result, expected)
result = struct_w_ext_data[1].to_numpy(zero_copy_only=False)
expected = np.array([{'f1': b'123'}, {'f1': b'456'},
{'f1': b'789'}], dtype=object)
np.testing.assert_array_equal(result, expected)
@pytest.mark.pandas
def test_struct_w_ext_array_to_pandas(struct_w_ext_data):
# ARROW-15291
# Check that we don't segfault when trying to build
# a Pandas dataframe from a StructArray with a field
# being an ExtensionArray
import pandas as pd
result = struct_w_ext_data[0].to_pandas()
expected = pd.Series([{'f0': 1}, {'f0': 2},
{'f0': 3}], dtype=object)
pd.testing.assert_series_equal(result, expected)
result = struct_w_ext_data[1].to_pandas()
expected = pd.Series([{'f1': b'123'}, {'f1': b'456'},
{'f1': b'789'}], dtype=object)
pd.testing.assert_series_equal(result, expected)
def test_cast_kernel_on_extension_arrays():
# test array casting
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(IntegerType(), storage)
# test that no allocation happens during identity cast
allocated_before_cast = pa.total_allocated_bytes()
casted = arr.cast(pa.int64())
assert pa.total_allocated_bytes() == allocated_before_cast
cases = [
(pa.int64(), pa.Int64Array),
(pa.int32(), pa.Int32Array),
(pa.int16(), pa.Int16Array),
(pa.uint64(), pa.UInt64Array),
(pa.uint32(), pa.UInt32Array),
(pa.uint16(), pa.UInt16Array)
]
for typ, klass in cases:
casted = arr.cast(typ)
assert casted.type == typ
assert isinstance(casted, klass)
# test chunked array casting
arr = pa.chunked_array([arr, arr])
casted = arr.cast(pa.int16())
assert casted.type == pa.int16()
assert isinstance(casted, pa.ChunkedArray)
def test_casting_to_extension_type_raises():
arr = pa.array([1, 2, 3, 4], pa.int64())
with pytest.raises(pa.ArrowNotImplementedError):
arr.cast(IntegerType())
def test_null_storage_type():
ext_type = AnnotatedType(pa.null(), {"key": "value"})
storage = pa.array([None] * 10, pa.null())
arr = pa.ExtensionArray.from_storage(ext_type, storage)
assert arr.null_count == 10
arr.validate(full=True)
def example_batch():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
return pa.RecordBatch.from_arrays([arr], ["exts"])
def check_example_batch(batch):
arr = batch.column(0)
assert isinstance(arr, pa.ExtensionArray)
assert arr.type.storage_type == pa.binary(3)
assert arr.storage.to_pylist() == [b"foo", b"bar"]
return arr
def test_ipc():
batch = example_batch()
buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)
arr = check_example_batch(batch)
assert arr.type == ParamExtType(3)
def test_ipc_unknown_type():
batch = example_batch()
buf = ipc_write_batch(batch)
del batch
orig_type = ParamExtType
try:
# Simulate the original Python type being unavailable.
# Deserialization should not fail but return a placeholder type.
del globals()['ParamExtType']
batch = ipc_read_batch(buf)
arr = check_example_batch(batch)
assert isinstance(arr.type, pa.UnknownExtensionType)
# Can be serialized again
buf2 = ipc_write_batch(batch)
del batch, arr
batch = ipc_read_batch(buf2)
arr = check_example_batch(batch)
assert isinstance(arr.type, pa.UnknownExtensionType)
finally:
globals()['ParamExtType'] = orig_type
# Deserialize again with the type restored
batch = ipc_read_batch(buf2)
arr = check_example_batch(batch)
assert arr.type == ParamExtType(3)
class PeriodArray(pa.ExtensionArray):
pass
class PeriodType(pa.ExtensionType):
def __init__(self, freq):
# attributes need to be set first before calling
# super init (as that calls serialize)
self._freq = freq
pa.ExtensionType.__init__(self, pa.int64(), 'test.period')
@property
def freq(self):
return self._freq
def __arrow_ext_serialize__(self):
return "freq={}".format(self.freq).encode()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
serialized = serialized.decode()
assert serialized.startswith("freq=")
freq = serialized.split('=')[1]
return PeriodType(freq)
def __eq__(self, other):
if isinstance(other, pa.BaseExtensionType):
return (type(self) == type(other) and
self.freq == other.freq)
else:
return NotImplemented
class PeriodTypeWithClass(PeriodType):
def __init__(self, freq):
PeriodType.__init__(self, freq)
def __arrow_ext_class__(self):
return PeriodArray
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
freq = PeriodType.__arrow_ext_deserialize__(
storage_type, serialized).freq
return PeriodTypeWithClass(freq)
@pytest.fixture(params=[PeriodType('D'), PeriodTypeWithClass('D')])
def registered_period_type(request):
# setup
period_type = request.param
period_class = period_type.__arrow_ext_class__()
pa.register_extension_type(period_type)
yield period_type, period_class
# teardown
try:
pa.unregister_extension_type('test.period')
except KeyError:
pass
def test_generic_ext_type():
period_type = PeriodType('D')
assert period_type.extension_name == "test.period"
assert period_type.storage_type == pa.int64()
# default ext_class expected.
assert period_type.__arrow_ext_class__() == pa.ExtensionArray
def test_generic_ext_type_ipc(registered_period_type):
period_type, period_class = registered_period_type
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(period_type, storage)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])
# check the built array has exactly the expected clss
assert type(arr) == period_class
buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)
result = batch.column(0)
# check the deserialized array class is the expected one
assert type(result) == period_class
assert result.type.extension_name == "test.period"
assert arr.storage.to_pylist() == [1, 2, 3, 4]
# we get back an actual PeriodType
assert isinstance(result.type, PeriodType)
assert result.type.freq == 'D'
assert result.type == period_type
# using different parametrization as how it was registered
period_type_H = period_type.__class__('H')
assert period_type_H.extension_name == "test.period"
assert period_type_H.freq == 'H'
arr = pa.ExtensionArray.from_storage(period_type_H, storage)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])
buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)
result = batch.column(0)
assert isinstance(result.type, PeriodType)
assert result.type.freq == 'H'
assert type(result) == period_class
def test_generic_ext_type_ipc_unknown(registered_period_type):
period_type, _ = registered_period_type
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(period_type, storage)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])
buf = ipc_write_batch(batch)
del batch
# unregister type before loading again => reading unknown extension type
# as plain array (but metadata in schema's field are preserved)
pa.unregister_extension_type('test.period')
batch = ipc_read_batch(buf)
result = batch.column(0)
assert isinstance(result, pa.Int64Array)
ext_field = batch.schema.field('ext')
assert ext_field.metadata == {
b'ARROW:extension:metadata': b'freq=D',
b'ARROW:extension:name': b'test.period'
}
def test_generic_ext_type_equality():
period_type = PeriodType('D')
assert period_type.extension_name == "test.period"
period_type2 = PeriodType('D')
period_type3 = PeriodType('H')
assert period_type == period_type2
assert not period_type == period_type3
def test_generic_ext_type_register(registered_period_type):
# test that trying to register other type does not segfault
with pytest.raises(TypeError):
pa.register_extension_type(pa.string())
# register second time raises KeyError
period_type = PeriodType('D')
with pytest.raises(KeyError):
pa.register_extension_type(period_type)
@pytest.mark.parquet
def test_parquet_period(tmpdir, registered_period_type):
# Parquet support for primitive extension types
period_type, period_class = registered_period_type
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(period_type, storage)
table = pa.table([arr], names=["ext"])
import pyarrow.parquet as pq
filename = tmpdir / 'period_extension_type.parquet'
pq.write_table(table, filename)
# Stored in parquet as storage type but with extension metadata saved
# in the serialized arrow schema
meta = pq.read_metadata(filename)
assert meta.schema.column(0).physical_type == "INT64"
assert b"ARROW:schema" in meta.metadata
import base64
decoded_schema = base64.b64decode(meta.metadata[b"ARROW:schema"])
schema = pa.ipc.read_schema(pa.BufferReader(decoded_schema))
# Since the type could be reconstructed, the extension type metadata is
# absent.
assert schema.field("ext").metadata == {}
# When reading in, properly create extension type if it is registered
result = pq.read_table(filename)
assert result.schema.field("ext").type == period_type
assert result.schema.field("ext").metadata == {}
# Get the exact array class defined by the registered type.
result_array = result.column("ext").chunk(0)
assert type(result_array) is period_class
# When the type is not registered, read in as storage type
pa.unregister_extension_type(period_type.extension_name)
result = pq.read_table(filename)
assert result.schema.field("ext").type == pa.int64()
# The extension metadata is present for roundtripping.
assert result.schema.field("ext").metadata == {
b'ARROW:extension:metadata': b'freq=D',
b'ARROW:extension:name': b'test.period'
}
@pytest.mark.parquet
def test_parquet_extension_with_nested_storage(tmpdir):
# Parquet support for extension types with nested storage type
import pyarrow.parquet as pq
struct_array = pa.StructArray.from_arrays(
[pa.array([0, 1], type="int64"), pa.array([4, 5], type="int64")],
names=["left", "right"])
list_array = pa.array([[1, 2, 3], [4, 5]], type=pa.list_(pa.int32()))
mystruct_array = pa.ExtensionArray.from_storage(MyStructType(),
struct_array)
mylist_array = pa.ExtensionArray.from_storage(
MyListType(list_array.type), list_array)
orig_table = pa.table({'structs': mystruct_array,
'lists': mylist_array})
filename = tmpdir / 'nested_extension_storage.parquet'
pq.write_table(orig_table, filename)
table = pq.read_table(filename)
assert table.column('structs').type == mystruct_array.type
assert table.column('lists').type == mylist_array.type
assert table == orig_table
@pytest.mark.parquet
def test_parquet_nested_extension(tmpdir):
# Parquet support for extension types nested in struct or list
import pyarrow.parquet as pq
ext_type = IntegerType()
storage = pa.array([4, 5, 6, 7], type=pa.int64())
ext_array = pa.ExtensionArray.from_storage(ext_type, storage)
# Struct of extensions
struct_array = pa.StructArray.from_arrays(
[storage, ext_array],
names=['ints', 'exts'])
orig_table = pa.table({'structs': struct_array})
filename = tmpdir / 'struct_of_ext.parquet'
pq.write_table(orig_table, filename)
table = pq.read_table(filename)
assert table.column(0).type == struct_array.type
assert table == orig_table
# List of extensions
list_array = pa.ListArray.from_arrays([0, 1, None, 3], ext_array)
orig_table = pa.table({'lists': list_array})
filename = tmpdir / 'list_of_ext.parquet'
pq.write_table(orig_table, filename)
table = pq.read_table(filename)
assert table.column(0).type == list_array.type
assert table == orig_table
# Large list of extensions
list_array = pa.LargeListArray.from_arrays([0, 1, None, 3], ext_array)
orig_table = pa.table({'lists': list_array})
filename = tmpdir / 'list_of_ext.parquet'
pq.write_table(orig_table, filename)
table = pq.read_table(filename)
assert table.column(0).type == list_array.type
assert table == orig_table
@pytest.mark.parquet
def test_parquet_extension_nested_in_extension(tmpdir):
# Parquet support for extension<list<extension>>
import pyarrow.parquet as pq
inner_ext_type = IntegerType()
inner_storage = pa.array([4, 5, 6, 7], type=pa.int64())
inner_ext_array = pa.ExtensionArray.from_storage(inner_ext_type,
inner_storage)
list_array = pa.ListArray.from_arrays([0, 1, None, 3], inner_ext_array)
mylist_array = pa.ExtensionArray.from_storage(
MyListType(list_array.type), list_array)
orig_table = pa.table({'lists': mylist_array})
filename = tmpdir / 'ext_of_list_of_ext.parquet'
pq.write_table(orig_table, filename)
table = pq.read_table(filename)
assert table.column(0).type == mylist_array.type
assert table == orig_table
def test_to_numpy():
period_type = PeriodType('D')
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(period_type, storage)
expected = storage.to_numpy()
result = arr.to_numpy()
np.testing.assert_array_equal(result, expected)
result = np.asarray(arr)
np.testing.assert_array_equal(result, expected)
# chunked array
a1 = pa.chunked_array([arr, arr])
a2 = pa.chunked_array([arr, arr], type=period_type)
expected = np.hstack([expected, expected])
for charr in [a1, a2]:
assert charr.type == period_type
for result in [np.asarray(charr), charr.to_numpy()]:
assert result.dtype == np.int64
np.testing.assert_array_equal(result, expected)
# zero chunks
charr = pa.chunked_array([], type=period_type)
assert charr.type == period_type
for result in [np.asarray(charr), charr.to_numpy()]:
assert result.dtype == np.int64
np.testing.assert_array_equal(result, np.array([], dtype='int64'))
def test_empty_take():
# https://issues.apache.org/jira/browse/ARROW-13474
ext_type = IntegerType()
storage = pa.array([], type=pa.int64())
empty_arr = pa.ExtensionArray.from_storage(ext_type, storage)
result = empty_arr.filter(pa.array([], pa.bool_()))
assert len(result) == 0
assert result.equals(empty_arr)
result = empty_arr.take(pa.array([], pa.int32()))
assert len(result) == 0
assert result.equals(empty_arr)
@@ -0,0 +1,840 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import os
import sys
import tempfile
import pytest
import hypothesis as h
import hypothesis.strategies as st
import numpy as np
import pyarrow as pa
import pyarrow.tests.strategies as past
from pyarrow.feather import (read_feather, write_feather, read_table,
FeatherDataset)
try:
from pandas.testing import assert_frame_equal
import pandas as pd
import pyarrow.pandas_compat
except ImportError:
pass
@pytest.fixture(scope='module')
def datadir(base_datadir):
return base_datadir / 'feather'
def random_path(prefix='feather_'):
return tempfile.mktemp(prefix=prefix)
@pytest.fixture(scope="module", params=[1, 2])
def version(request):
yield request.param
@pytest.fixture(scope="module", params=[None, "uncompressed", "lz4", "zstd"])
def compression(request):
if request.param in ['lz4', 'zstd'] and not pa.Codec.is_available(
request.param):
pytest.skip(f'{request.param} is not available')
yield request.param
TEST_FILES = None
def setup_module(module):
global TEST_FILES
TEST_FILES = []
def teardown_module(module):
for path in TEST_FILES:
try:
os.remove(path)
except os.error:
pass
@pytest.mark.pandas
def test_file_not_exist():
with pytest.raises(pa.ArrowIOError):
read_feather('test_invalid_file')
def _check_pandas_roundtrip(df, expected=None, path=None,
columns=None, use_threads=False,
version=None, compression=None,
compression_level=None):
if path is None:
path = random_path()
if version is None:
version = 2
TEST_FILES.append(path)
write_feather(df, path, compression=compression,
compression_level=compression_level, version=version)
if not os.path.exists(path):
raise Exception('file not written')
result = read_feather(path, columns, use_threads=use_threads)
if expected is None:
expected = df
assert_frame_equal(result, expected)
def _check_arrow_roundtrip(table, path=None, compression=None):
if path is None:
path = random_path()
TEST_FILES.append(path)
write_feather(table, path, compression=compression)
if not os.path.exists(path):
raise Exception('file not written')
result = read_table(path)
assert result.equals(table)
def _assert_error_on_write(df, exc, path=None, version=2):
# check that we are raising the exception
# on writing
if path is None:
path = random_path()
TEST_FILES.append(path)
def f():
write_feather(df, path, version=version)
pytest.raises(exc, f)
def test_dataset(version):
num_values = (100, 100)
num_files = 5
paths = [random_path() for i in range(num_files)]
data = {
"col_" + str(i): np.random.randn(num_values[0])
for i in range(num_values[1])
}
table = pa.table(data)
TEST_FILES.extend(paths)
for index, path in enumerate(paths):
rows = (
index * (num_values[0] // num_files),
(index + 1) * (num_values[0] // num_files),
)
write_feather(table[rows[0]: rows[1]], path, version=version)
data = FeatherDataset(paths).read_table()
assert data.equals(table)
@pytest.mark.pandas
def test_float_no_nulls(version):
data = {}
numpy_dtypes = ['f4', 'f8']
num_values = 100
for dtype in numpy_dtypes:
values = np.random.randn(num_values)
data[dtype] = values.astype(dtype)
df = pd.DataFrame(data)
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_read_table(version):
num_values = (100, 100)
path = random_path()
TEST_FILES.append(path)
values = np.random.randint(0, 100, size=num_values)
columns = ['col_' + str(i) for i in range(100)]
table = pa.Table.from_arrays(values, columns)
write_feather(table, path, version=version)
result = read_table(path)
assert result.equals(table)
# Test without memory mapping
result = read_table(path, memory_map=False)
assert result.equals(table)
result = read_feather(path, memory_map=False)
assert_frame_equal(table.to_pandas(), result)
@pytest.mark.pandas
def test_use_threads(version):
# ARROW-14470
num_values = (10, 10)
path = random_path()
TEST_FILES.append(path)
values = np.random.randint(0, 10, size=num_values)
columns = ['col_' + str(i) for i in range(10)]
table = pa.Table.from_arrays(values, columns)
write_feather(table, path, version=version)
result = read_feather(path)
assert_frame_equal(table.to_pandas(), result)
# Test read_feather with use_threads=False
result = read_feather(path, use_threads=False)
assert_frame_equal(table.to_pandas(), result)
# Test read_table with use_threads=False
result = read_table(path, use_threads=False)
assert result.equals(table)
@pytest.mark.pandas
def test_float_nulls(version):
num_values = 100
path = random_path()
TEST_FILES.append(path)
null_mask = np.random.randint(0, 10, size=num_values) < 3
dtypes = ['f4', 'f8']
expected_cols = []
arrays = []
for name in dtypes:
values = np.random.randn(num_values).astype(name)
arrays.append(pa.array(values, mask=null_mask))
values[null_mask] = np.nan
expected_cols.append(values)
table = pa.table(arrays, names=dtypes)
_check_arrow_roundtrip(table)
df = table.to_pandas()
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_integer_no_nulls(version):
data, arr = {}, []
numpy_dtypes = ['i1', 'i2', 'i4', 'i8',
'u1', 'u2', 'u4', 'u8']
num_values = 100
for dtype in numpy_dtypes:
values = np.random.randint(0, 100, size=num_values)
data[dtype] = values.astype(dtype)
arr.append(values.astype(dtype))
df = pd.DataFrame(data)
_check_pandas_roundtrip(df, version=version)
table = pa.table(arr, names=numpy_dtypes)
_check_arrow_roundtrip(table)
@pytest.mark.pandas
def test_platform_numpy_integers(version):
data = {}
numpy_dtypes = ['longlong']
num_values = 100
for dtype in numpy_dtypes:
values = np.random.randint(0, 100, size=num_values)
data[dtype] = values.astype(dtype)
df = pd.DataFrame(data)
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_integer_with_nulls(version):
# pandas requires upcast to float dtype
path = random_path()
TEST_FILES.append(path)
int_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8']
num_values = 100
arrays = []
null_mask = np.random.randint(0, 10, size=num_values) < 3
expected_cols = []
for name in int_dtypes:
values = np.random.randint(0, 100, size=num_values)
arrays.append(pa.array(values, mask=null_mask))
expected = values.astype('f8')
expected[null_mask] = np.nan
expected_cols.append(expected)
table = pa.table(arrays, names=int_dtypes)
_check_arrow_roundtrip(table)
df = table.to_pandas()
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_boolean_no_nulls(version):
num_values = 100
np.random.seed(0)
df = pd.DataFrame({'bools': np.random.randn(num_values) > 0})
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_boolean_nulls(version):
# pandas requires upcast to object dtype
path = random_path()
TEST_FILES.append(path)
num_values = 100
np.random.seed(0)
mask = np.random.randint(0, 10, size=num_values) < 3
values = np.random.randint(0, 10, size=num_values) < 5
table = pa.table([pa.array(values, mask=mask)], names=['bools'])
_check_arrow_roundtrip(table)
df = table.to_pandas()
_check_pandas_roundtrip(df, version=version)
def test_buffer_bounds_error(version):
# ARROW-1676
path = random_path()
TEST_FILES.append(path)
for i in range(16, 256):
table = pa.Table.from_arrays(
[pa.array([None] + list(range(i)), type=pa.float64())],
names=["arr"]
)
_check_arrow_roundtrip(table)
def test_boolean_object_nulls(version):
repeats = 100
table = pa.Table.from_arrays(
[np.array([False, None, True] * repeats, dtype=object)],
names=["arr"]
)
_check_arrow_roundtrip(table)
@pytest.mark.pandas
def test_delete_partial_file_on_error(version):
if sys.platform == 'win32':
pytest.skip('Windows hangs on to file handle for some reason')
class CustomClass:
pass
# strings will fail
df = pd.DataFrame(
{
'numbers': range(5),
'strings': [b'foo', None, 'bar', CustomClass(), np.nan]},
columns=['numbers', 'strings'])
path = random_path()
try:
write_feather(df, path, version=version)
except Exception:
pass
assert not os.path.exists(path)
@pytest.mark.pandas
def test_strings(version):
repeats = 1000
# Mixed bytes, unicode, strings coerced to binary
values = [b'foo', None, 'bar', 'qux', np.nan]
df = pd.DataFrame({'strings': values * repeats})
ex_values = [b'foo', None, b'bar', b'qux', np.nan]
expected = pd.DataFrame({'strings': ex_values * repeats})
_check_pandas_roundtrip(df, expected, version=version)
# embedded nulls are ok
values = ['foo', None, 'bar', 'qux', None]
df = pd.DataFrame({'strings': values * repeats})
expected = pd.DataFrame({'strings': values * repeats})
_check_pandas_roundtrip(df, expected, version=version)
values = ['foo', None, 'bar', 'qux', np.nan]
df = pd.DataFrame({'strings': values * repeats})
expected = pd.DataFrame({'strings': values * repeats})
_check_pandas_roundtrip(df, expected, version=version)
@pytest.mark.pandas
def test_empty_strings(version):
df = pd.DataFrame({'strings': [''] * 10})
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_all_none(version):
df = pd.DataFrame({'all_none': [None] * 10})
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_all_null_category(version):
# ARROW-1188
df = pd.DataFrame({"A": (1, 2, 3), "B": (None, None, None)})
df = df.assign(B=df.B.astype("category"))
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_multithreaded_read(version):
data = {'c{}'.format(i): [''] * 10
for i in range(100)}
df = pd.DataFrame(data)
_check_pandas_roundtrip(df, use_threads=True, version=version)
@pytest.mark.pandas
def test_nan_as_null(version):
# Create a nan that is not numpy.nan
values = np.array(['foo', np.nan, np.nan * 2, 'bar'] * 10)
df = pd.DataFrame({'strings': values})
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_category(version):
repeats = 1000
values = ['foo', None, 'bar', 'qux', np.nan]
df = pd.DataFrame({'strings': values * repeats})
df['strings'] = df['strings'].astype('category')
values = ['foo', None, 'bar', 'qux', None]
expected = pd.DataFrame({'strings': pd.Categorical(values * repeats)})
_check_pandas_roundtrip(df, expected, version=version)
@pytest.mark.pandas
def test_timestamp(version):
df = pd.DataFrame({'naive': pd.date_range('2016-03-28', periods=10)})
df['with_tz'] = (df.naive.dt.tz_localize('utc')
.dt.tz_convert('America/Los_Angeles'))
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_timestamp_with_nulls(version):
df = pd.DataFrame({'test': [pd.Timestamp(2016, 1, 1),
None,
pd.Timestamp(2016, 1, 3)]})
df['with_tz'] = df.test.dt.tz_localize('utc')
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
@pytest.mark.xfail(reason="not supported", raises=TypeError)
def test_timedelta_with_nulls_v1():
df = pd.DataFrame({'test': [pd.Timedelta('1 day'),
None,
pd.Timedelta('3 day')]})
_check_pandas_roundtrip(df, version=1)
@pytest.mark.pandas
def test_timedelta_with_nulls():
df = pd.DataFrame({'test': [pd.Timedelta('1 day'),
None,
pd.Timedelta('3 day')]})
_check_pandas_roundtrip(df, version=2)
@pytest.mark.pandas
def test_out_of_float64_timestamp_with_nulls(version):
df = pd.DataFrame(
{'test': pd.DatetimeIndex([1451606400000000001,
None, 14516064000030405])})
df['with_tz'] = df.test.dt.tz_localize('utc')
_check_pandas_roundtrip(df, version=version)
@pytest.mark.pandas
def test_non_string_columns(version):
df = pd.DataFrame({0: [1, 2, 3, 4],
1: [True, False, True, False]})
expected = df
if version == 1:
expected = df.rename(columns=str)
_check_pandas_roundtrip(df, expected, version=version)
@pytest.mark.pandas
@pytest.mark.skipif(not os.path.supports_unicode_filenames,
reason='unicode filenames not supported')
def test_unicode_filename(version):
# GH #209
name = (b'Besa_Kavaj\xc3\xab.feather').decode('utf-8')
df = pd.DataFrame({'foo': [1, 2, 3, 4]})
_check_pandas_roundtrip(df, path=random_path(prefix=name),
version=version)
@pytest.mark.pandas
def test_read_columns(version):
df = pd.DataFrame({
'foo': [1, 2, 3, 4],
'boo': [5, 6, 7, 8],
'woo': [1, 3, 5, 7]
})
expected = df[['boo', 'woo']]
_check_pandas_roundtrip(df, expected, version=version,
columns=['boo', 'woo'])
def test_overwritten_file(version):
path = random_path()
TEST_FILES.append(path)
num_values = 100
np.random.seed(0)
values = np.random.randint(0, 10, size=num_values)
table = pa.table({'ints': values})
write_feather(table, path)
table = pa.table({'more_ints': values[0:num_values//2]})
_check_arrow_roundtrip(table, path=path)
@pytest.mark.pandas
def test_filelike_objects(version):
buf = io.BytesIO()
# the copy makes it non-strided
df = pd.DataFrame(np.arange(12).reshape(4, 3),
columns=['a', 'b', 'c']).copy()
write_feather(df, buf, version=version)
buf.seek(0)
result = read_feather(buf)
assert_frame_equal(result, df)
@pytest.mark.pandas
@pytest.mark.filterwarnings("ignore:Sparse:FutureWarning")
@pytest.mark.filterwarnings("ignore:DataFrame.to_sparse:FutureWarning")
def test_sparse_dataframe(version):
if not pa.pandas_compat._pandas_api.has_sparse:
pytest.skip("version of pandas does not support SparseDataFrame")
# GH #221
data = {'A': [0, 1, 2],
'B': [1, 0, 1]}
df = pd.DataFrame(data).to_sparse(fill_value=1)
expected = df.to_dense()
_check_pandas_roundtrip(df, expected, version=version)
@pytest.mark.pandas
def test_duplicate_columns_pandas():
# https://github.com/wesm/feather/issues/53
# not currently able to handle duplicate columns
df = pd.DataFrame(np.arange(12).reshape(4, 3),
columns=list('aaa')).copy()
_assert_error_on_write(df, ValueError)
def test_duplicate_columns():
# only works for version 2
table = pa.table([[1, 2, 3], [4, 5, 6], [7, 8, 9]], names=['a', 'a', 'b'])
_check_arrow_roundtrip(table)
_assert_error_on_write(table, ValueError, version=1)
@pytest.mark.pandas
def test_unsupported():
# https://github.com/wesm/feather/issues/240
# serializing actual python objects
# custom python objects
class A:
pass
df = pd.DataFrame({'a': [A(), A()]})
_assert_error_on_write(df, ValueError)
# non-strings
df = pd.DataFrame({'a': ['a', 1, 2.0]})
_assert_error_on_write(df, TypeError)
@pytest.mark.pandas
def test_v2_set_chunksize():
df = pd.DataFrame({'A': np.arange(1000)})
table = pa.table(df)
buf = io.BytesIO()
write_feather(table, buf, chunksize=250, version=2)
result = buf.getvalue()
ipc_file = pa.ipc.open_file(pa.BufferReader(result))
assert ipc_file.num_record_batches == 4
assert len(ipc_file.get_batch(0)) == 250
@pytest.mark.pandas
@pytest.mark.lz4
@pytest.mark.snappy
@pytest.mark.zstd
def test_v2_compression_options():
df = pd.DataFrame({'A': np.arange(1000)})
cases = [
# compression, compression_level
('uncompressed', None),
('lz4', None),
('lz4', 1),
('lz4', 12),
('zstd', 1),
('zstd', 10)
]
for compression, compression_level in cases:
_check_pandas_roundtrip(df, compression=compression,
compression_level=compression_level)
buf = io.BytesIO()
# Trying to compress with V1
with pytest.raises(
ValueError,
match="Feather V1 files do not support compression option"):
write_feather(df, buf, compression='lz4', version=1)
# Trying to set chunksize with V1
with pytest.raises(
ValueError,
match="Feather V1 files do not support chunksize option"):
write_feather(df, buf, chunksize=4096, version=1)
# Unsupported compressor
with pytest.raises(ValueError,
match='compression="snappy" not supported'):
write_feather(df, buf, compression='snappy')
def test_v2_lz4_default_compression():
# ARROW-8750: Make sure that the compression=None option selects lz4 if
# it's available
if not pa.Codec.is_available('lz4_frame'):
pytest.skip("LZ4 compression support is not built in C++")
# some highly compressible data
t = pa.table([np.repeat(0, 100000)], names=['f0'])
buf = io.BytesIO()
write_feather(t, buf)
default_result = buf.getvalue()
buf = io.BytesIO()
write_feather(t, buf, compression='uncompressed')
uncompressed_result = buf.getvalue()
assert len(default_result) < len(uncompressed_result)
def test_v1_unsupported_types():
table = pa.table([pa.array([[1, 2, 3], [], None])], names=['f0'])
buf = io.BytesIO()
with pytest.raises(TypeError,
match=("Unsupported Feather V1 type: "
"list<item: int64>. "
"Use V2 format to serialize all Arrow types.")):
write_feather(table, buf, version=1)
@pytest.mark.slow
@pytest.mark.pandas
def test_large_dataframe(version):
df = pd.DataFrame({'A': np.arange(400000000)})
_check_pandas_roundtrip(df, version=version)
@pytest.mark.large_memory
@pytest.mark.pandas
def test_chunked_binary_error_message():
# ARROW-3058: As Feather does not yet support chunked columns, we at least
# make sure it's clear to the user what is going on
# 2^31 + 1 bytes
values = [b'x'] + [
b'x' * (1 << 20)
] * 2 * (1 << 10)
df = pd.DataFrame({'byte_col': values})
# Works fine with version 2
buf = io.BytesIO()
write_feather(df, buf, version=2)
result = read_feather(pa.BufferReader(buf.getvalue()))
assert_frame_equal(result, df)
with pytest.raises(ValueError, match="'byte_col' exceeds 2GB maximum "
"capacity of a Feather binary column. This restriction "
"may be lifted in the future"):
write_feather(df, io.BytesIO(), version=1)
def test_feather_without_pandas(tempdir, version):
# ARROW-8345
table = pa.table([pa.array([1, 2, 3])], names=['f0'])
path = str(tempdir / "data.feather")
_check_arrow_roundtrip(table, path)
@pytest.mark.pandas
def test_read_column_selection(version):
# ARROW-8641
df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=['a', 'b', 'c'])
# select columns as string names or integer indices
_check_pandas_roundtrip(
df, columns=['a', 'c'], expected=df[['a', 'c']], version=version)
_check_pandas_roundtrip(
df, columns=[0, 2], expected=df[['a', 'c']], version=version)
# different order is followed
_check_pandas_roundtrip(
df, columns=['b', 'a'], expected=df[['b', 'a']], version=version)
_check_pandas_roundtrip(
df, columns=[1, 0], expected=df[['b', 'a']], version=version)
def test_read_column_duplicated_selection(tempdir, version):
# duplicated columns in the column selection
table = pa.table([[1, 2, 3], [4, 5, 6], [7, 8, 9]], names=['a', 'b', 'c'])
path = str(tempdir / "data.feather")
write_feather(table, path, version=version)
expected = pa.table([[1, 2, 3], [4, 5, 6], [1, 2, 3]],
names=['a', 'b', 'a'])
for col_selection in [['a', 'b', 'a'], [0, 1, 0]]:
result = read_table(path, columns=col_selection)
assert result.equals(expected)
def test_read_column_duplicated_in_file(tempdir):
# duplicated columns in feather file (only works for feather v2)
table = pa.table([[1, 2, 3], [4, 5, 6], [7, 8, 9]], names=['a', 'b', 'a'])
path = str(tempdir / "data.feather")
write_feather(table, path, version=2)
# no selection works fine
result = read_table(path)
assert result.equals(table)
# selection with indices works
result = read_table(path, columns=[0, 2])
assert result.column_names == ['a', 'a']
# selection with column names errors
with pytest.raises(ValueError):
read_table(path, columns=['a', 'b'])
def test_nested_types(compression):
# https://issues.apache.org/jira/browse/ARROW-8860
table = pa.table({'col': pa.StructArray.from_arrays(
[[0, 1, 2], [1, 2, 3]], names=["f1", "f2"])})
_check_arrow_roundtrip(table, compression=compression)
table = pa.table({'col': pa.array([[1, 2], [3, 4]])})
_check_arrow_roundtrip(table, compression=compression)
table = pa.table({'col': pa.array([[[1, 2], [3, 4]], [[5, 6], None]])})
_check_arrow_roundtrip(table, compression=compression)
@h.given(past.all_tables, st.sampled_from(["uncompressed", "lz4", "zstd"]))
def test_roundtrip(table, compression):
_check_arrow_roundtrip(table, compression=compression)
@pytest.mark.lz4
def test_feather_v017_experimental_compression_backward_compatibility(datadir):
# ARROW-11163 - ensure newer pyarrow versions can read the old feather
# files from version 0.17.0 with experimental compression support (before
# it was officially added to IPC format in 1.0.0)
# file generated with:
# table = pa.table({'a': range(5)})
# from pyarrow import feather
# feather.write_feather(
# table, "v0.17.0.version.2-compression.lz4.feather",
# compression="lz4", version=2)
expected = pa.table({'a': range(5)})
result = read_table(datadir / "v0.17.0.version.2-compression.lz4.feather")
assert result.equals(expected)
@pytest.mark.pandas
def test_preserve_index_pandas(version):
df = pd.DataFrame({'a': [1, 2, 3]}, index=['a', 'b', 'c'])
if version == 1:
expected = df.reset_index(drop=True).rename(columns=str)
else:
expected = df
_check_pandas_roundtrip(df, expected, version=version)
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pyarrow as pa
from pyarrow import filesystem
import os
import pytest
def test_filesystem_deprecated():
with pytest.warns(FutureWarning):
filesystem.LocalFileSystem()
with pytest.warns(FutureWarning):
filesystem.LocalFileSystem.get_instance()
def test_filesystem_deprecated_toplevel():
with pytest.warns(FutureWarning):
pa.localfs
with pytest.warns(FutureWarning):
pa.FileSystem
with pytest.warns(FutureWarning):
pa.LocalFileSystem
with pytest.warns(FutureWarning):
pa.HadoopFileSystem
def test_resolve_uri():
uri = "file:///home/user/myfile.parquet"
fs, path = filesystem.resolve_filesystem_and_path(uri)
assert isinstance(fs, filesystem.LocalFileSystem)
assert path == "/home/user/myfile.parquet"
def test_resolve_local_path():
for uri in ['/home/user/myfile.parquet',
'myfile.parquet',
'my # file ? parquet',
'C:/Windows/myfile.parquet',
r'C:\\Windows\\myfile.parquet',
]:
fs, path = filesystem.resolve_filesystem_and_path(uri)
assert isinstance(fs, filesystem.LocalFileSystem)
assert path == uri
def test_resolve_home_directory():
uri = '~/myfile.parquet'
fs, path = filesystem.resolve_filesystem_and_path(uri)
assert isinstance(fs, filesystem.LocalFileSystem)
assert path == os.path.expanduser(uri)
local_fs = filesystem.LocalFileSystem()
fs, path = filesystem.resolve_filesystem_and_path(uri, local_fs)
assert path == os.path.expanduser(uri)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,391 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import pytest
import pyarrow as pa
@pytest.mark.gandiva
def test_tree_exp_builder():
import pyarrow.gandiva as gandiva
builder = gandiva.TreeExprBuilder()
field_a = pa.field('a', pa.int32())
field_b = pa.field('b', pa.int32())
schema = pa.schema([field_a, field_b])
field_result = pa.field('res', pa.int32())
node_a = builder.make_field(field_a)
node_b = builder.make_field(field_b)
assert node_a.return_type() == field_a.type
condition = builder.make_function("greater_than", [node_a, node_b],
pa.bool_())
if_node = builder.make_if(condition, node_a, node_b, pa.int32())
expr = builder.make_expression(if_node, field_result)
assert expr.result().type == pa.int32()
projector = gandiva.make_projector(
schema, [expr], pa.default_memory_pool())
# Gandiva generates compute kernel function named `@expr_X`
assert projector.llvm_ir.find("@expr_") != -1
a = pa.array([10, 12, -20, 5], type=pa.int32())
b = pa.array([5, 15, 15, 17], type=pa.int32())
e = pa.array([10, 15, 15, 17], type=pa.int32())
input_batch = pa.RecordBatch.from_arrays([a, b], names=['a', 'b'])
r, = projector.evaluate(input_batch)
assert r.equals(e)
@pytest.mark.gandiva
def test_table():
import pyarrow.gandiva as gandiva
table = pa.Table.from_arrays([pa.array([1.0, 2.0]), pa.array([3.0, 4.0])],
['a', 'b'])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
node_b = builder.make_field(table.schema.field("b"))
sum = builder.make_function("add", [node_a, node_b], pa.float64())
field_result = pa.field("c", pa.float64())
expr = builder.make_expression(sum, field_result)
projector = gandiva.make_projector(
table.schema, [expr], pa.default_memory_pool())
# TODO: Add .evaluate function which can take Tables instead of
# RecordBatches
r, = projector.evaluate(table.to_batches()[0])
e = pa.array([4.0, 6.0])
assert r.equals(e)
@pytest.mark.gandiva
def test_filter():
import pyarrow.gandiva as gandiva
table = pa.Table.from_arrays([pa.array([1.0 * i for i in range(10000)])],
['a'])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
thousand = builder.make_literal(1000.0, pa.float64())
cond = builder.make_function("less_than", [node_a, thousand], pa.bool_())
condition = builder.make_condition(cond)
assert condition.result().type == pa.bool_()
filter = gandiva.make_filter(table.schema, condition)
# Gandiva generates compute kernel function named `@expr_X`
assert filter.llvm_ir.find("@expr_") != -1
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert result.to_array().equals(pa.array(range(1000), type=pa.uint32()))
@pytest.mark.gandiva
def test_in_expr():
import pyarrow.gandiva as gandiva
arr = pa.array(["ga", "an", "nd", "di", "iv", "va"])
table = pa.Table.from_arrays([arr], ["a"])
# string
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
cond = builder.make_in_expression(node_a, ["an", "nd"], pa.string())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert result.to_array().equals(pa.array([1, 2], type=pa.uint32()))
# int32
arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
table = pa.Table.from_arrays([arr.cast(pa.int32())], ["a"])
node_a = builder.make_field(table.schema.field("a"))
cond = builder.make_in_expression(node_a, [1, 5], pa.int32())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32()))
# int64
arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
table = pa.Table.from_arrays([arr], ["a"])
node_a = builder.make_field(table.schema.field("a"))
cond = builder.make_in_expression(node_a, [1, 5], pa.int64())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32()))
@pytest.mark.skip(reason="Gandiva C++ did not have *real* binary, "
"time and date support.")
def test_in_expr_todo():
import pyarrow.gandiva as gandiva
# TODO: Implement reasonable support for timestamp, time & date.
# Current exceptions:
# pyarrow.lib.ArrowException: ExpressionValidationError:
# Evaluation expression for IN clause returns XXXX values are of typeXXXX
# binary
arr = pa.array([b"ga", b"an", b"nd", b"di", b"iv", b"va"])
table = pa.Table.from_arrays([arr], ["a"])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
cond = builder.make_in_expression(node_a, [b'an', b'nd'], pa.binary())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert result.to_array().equals(pa.array([1, 2], type=pa.uint32()))
# timestamp
datetime_1 = datetime.datetime.utcfromtimestamp(1542238951.621877)
datetime_2 = datetime.datetime.utcfromtimestamp(1542238911.621877)
datetime_3 = datetime.datetime.utcfromtimestamp(1542238051.621877)
arr = pa.array([datetime_1, datetime_2, datetime_3])
table = pa.Table.from_arrays([arr], ["a"])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
cond = builder.make_in_expression(node_a, [datetime_2], pa.timestamp('ms'))
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1]
# time
time_1 = datetime_1.time()
time_2 = datetime_2.time()
time_3 = datetime_3.time()
arr = pa.array([time_1, time_2, time_3])
table = pa.Table.from_arrays([arr], ["a"])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
cond = builder.make_in_expression(node_a, [time_2], pa.time64('ms'))
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1]
# date
date_1 = datetime_1.date()
date_2 = datetime_2.date()
date_3 = datetime_3.date()
arr = pa.array([date_1, date_2, date_3])
table = pa.Table.from_arrays([arr], ["a"])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
cond = builder.make_in_expression(node_a, [date_2], pa.date32())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1]
@pytest.mark.gandiva
def test_boolean():
import pyarrow.gandiva as gandiva
table = pa.Table.from_arrays([
pa.array([1., 31., 46., 3., 57., 44., 22.]),
pa.array([5., 45., 36., 73., 83., 23., 76.])],
['a', 'b'])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
node_b = builder.make_field(table.schema.field("b"))
fifty = builder.make_literal(50.0, pa.float64())
eleven = builder.make_literal(11.0, pa.float64())
cond_1 = builder.make_function("less_than", [node_a, fifty], pa.bool_())
cond_2 = builder.make_function("greater_than", [node_a, node_b],
pa.bool_())
cond_3 = builder.make_function("less_than", [node_b, eleven], pa.bool_())
cond = builder.make_or([builder.make_and([cond_1, cond_2]), cond_3])
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert result.to_array().equals(pa.array([0, 2, 5], type=pa.uint32()))
@pytest.mark.gandiva
def test_literals():
import pyarrow.gandiva as gandiva
builder = gandiva.TreeExprBuilder()
builder.make_literal(True, pa.bool_())
builder.make_literal(0, pa.uint8())
builder.make_literal(1, pa.uint16())
builder.make_literal(2, pa.uint32())
builder.make_literal(3, pa.uint64())
builder.make_literal(4, pa.int8())
builder.make_literal(5, pa.int16())
builder.make_literal(6, pa.int32())
builder.make_literal(7, pa.int64())
builder.make_literal(8.0, pa.float32())
builder.make_literal(9.0, pa.float64())
builder.make_literal("hello", pa.string())
builder.make_literal(b"world", pa.binary())
builder.make_literal(True, "bool")
builder.make_literal(0, "uint8")
builder.make_literal(1, "uint16")
builder.make_literal(2, "uint32")
builder.make_literal(3, "uint64")
builder.make_literal(4, "int8")
builder.make_literal(5, "int16")
builder.make_literal(6, "int32")
builder.make_literal(7, "int64")
builder.make_literal(8.0, "float32")
builder.make_literal(9.0, "float64")
builder.make_literal("hello", "string")
builder.make_literal(b"world", "binary")
with pytest.raises(TypeError):
builder.make_literal("hello", pa.int64())
with pytest.raises(TypeError):
builder.make_literal(True, None)
@pytest.mark.gandiva
def test_regex():
import pyarrow.gandiva as gandiva
elements = ["park", "sparkle", "bright spark and fire", "spark"]
data = pa.array(elements, type=pa.string())
table = pa.Table.from_arrays([data], names=['a'])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
regex = builder.make_literal("%spark%", pa.string())
like = builder.make_function("like", [node_a, regex], pa.bool_())
field_result = pa.field("b", pa.bool_())
expr = builder.make_expression(like, field_result)
projector = gandiva.make_projector(
table.schema, [expr], pa.default_memory_pool())
r, = projector.evaluate(table.to_batches()[0])
b = pa.array([False, True, True, True], type=pa.bool_())
assert r.equals(b)
@pytest.mark.gandiva
def test_get_registered_function_signatures():
import pyarrow.gandiva as gandiva
signatures = gandiva.get_registered_function_signatures()
assert type(signatures[0].return_type()) is pa.DataType
assert type(signatures[0].param_types()) is list
assert hasattr(signatures[0], "name")
@pytest.mark.gandiva
def test_filter_project():
import pyarrow.gandiva as gandiva
mpool = pa.default_memory_pool()
# Create a table with some sample data
array0 = pa.array([10, 12, -20, 5, 21, 29], pa.int32())
array1 = pa.array([5, 15, 15, 17, 12, 3], pa.int32())
array2 = pa.array([1, 25, 11, 30, -21, None], pa.int32())
table = pa.Table.from_arrays([array0, array1, array2], ['a', 'b', 'c'])
field_result = pa.field("res", pa.int32())
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field("a"))
node_b = builder.make_field(table.schema.field("b"))
node_c = builder.make_field(table.schema.field("c"))
greater_than_function = builder.make_function("greater_than",
[node_a, node_b], pa.bool_())
filter_condition = builder.make_condition(
greater_than_function)
project_condition = builder.make_function("less_than",
[node_b, node_c], pa.bool_())
if_node = builder.make_if(project_condition,
node_b, node_c, pa.int32())
expr = builder.make_expression(if_node, field_result)
# Build a filter for the expressions.
filter = gandiva.make_filter(table.schema, filter_condition)
# Build a projector for the expressions.
projector = gandiva.make_projector(
table.schema, [expr], mpool, "UINT32")
# Evaluate filter
selection_vector = filter.evaluate(table.to_batches()[0], mpool)
# Evaluate project
r, = projector.evaluate(
table.to_batches()[0], selection_vector)
exp = pa.array([1, -21, None], pa.int32())
assert r.equals(exp)
@pytest.mark.gandiva
def test_to_string():
import pyarrow.gandiva as gandiva
builder = gandiva.TreeExprBuilder()
assert str(builder.make_literal(2.0, pa.float64())
).startswith('(const double) 2 raw(')
assert str(builder.make_literal(2, pa.int64())) == '(const int64) 2'
assert str(builder.make_field(pa.field('x', pa.float64()))) == '(double) x'
assert str(builder.make_field(pa.field('y', pa.string()))) == '(string) y'
field_z = builder.make_field(pa.field('z', pa.bool_()))
func_node = builder.make_function('not', [field_z], pa.bool_())
assert str(func_node) == 'bool not((bool) z)'
field_y = builder.make_field(pa.field('y', pa.bool_()))
and_node = builder.make_and([func_node, field_y])
assert str(and_node) == 'bool not((bool) z) && (bool) y'
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,447 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import pickle
import random
import unittest
from io import BytesIO
from os.path import join as pjoin
import numpy as np
import pytest
import pyarrow as pa
from pyarrow.pandas_compat import _pandas_api
from pyarrow.tests import util
from pyarrow.tests.parquet.common import _test_dataframe
from pyarrow.tests.parquet.test_dataset import (
_test_read_common_metadata_files, _test_write_to_dataset_with_partitions,
_test_write_to_dataset_no_partitions
)
from pyarrow.util import guid
# ----------------------------------------------------------------------
# HDFS tests
def check_libhdfs_present():
if not pa.have_libhdfs():
message = 'No libhdfs available on system'
if os.environ.get('PYARROW_HDFS_TEST_LIBHDFS_REQUIRE'):
pytest.fail(message)
else:
pytest.skip(message)
def hdfs_test_client():
host = os.environ.get('ARROW_HDFS_TEST_HOST', 'default')
user = os.environ.get('ARROW_HDFS_TEST_USER', None)
try:
port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0))
except ValueError:
raise ValueError('Env variable ARROW_HDFS_TEST_PORT was not '
'an integer')
with pytest.warns(FutureWarning):
return pa.hdfs.connect(host, port, user)
@pytest.mark.hdfs
class HdfsTestCases:
def _make_test_file(self, hdfs, test_name, test_path, test_data):
base_path = pjoin(self.tmp_path, test_name)
hdfs.mkdir(base_path)
full_path = pjoin(base_path, test_path)
with hdfs.open(full_path, 'wb') as f:
f.write(test_data)
return full_path
@classmethod
def setUpClass(cls):
cls.check_driver()
cls.hdfs = hdfs_test_client()
cls.tmp_path = '/tmp/pyarrow-test-{}'.format(random.randint(0, 1000))
cls.hdfs.mkdir(cls.tmp_path)
@classmethod
def tearDownClass(cls):
cls.hdfs.delete(cls.tmp_path, recursive=True)
cls.hdfs.close()
def test_pickle(self):
s = pickle.dumps(self.hdfs)
h2 = pickle.loads(s)
assert h2.is_open
assert h2.host == self.hdfs.host
assert h2.port == self.hdfs.port
assert h2.user == self.hdfs.user
assert h2.kerb_ticket == self.hdfs.kerb_ticket
# smoketest unpickled client works
h2.ls(self.tmp_path)
def test_cat(self):
path = pjoin(self.tmp_path, 'cat-test')
data = b'foobarbaz'
with self.hdfs.open(path, 'wb') as f:
f.write(data)
contents = self.hdfs.cat(path)
assert contents == data
def test_capacity_space(self):
capacity = self.hdfs.get_capacity()
space_used = self.hdfs.get_space_used()
disk_free = self.hdfs.df()
assert capacity > 0
assert capacity > space_used
assert disk_free == (capacity - space_used)
def test_close(self):
client = hdfs_test_client()
assert client.is_open
client.close()
assert not client.is_open
with pytest.raises(Exception):
client.ls('/')
def test_mkdir(self):
path = pjoin(self.tmp_path, 'test-dir/test-dir')
parent_path = pjoin(self.tmp_path, 'test-dir')
self.hdfs.mkdir(path)
assert self.hdfs.exists(path)
self.hdfs.delete(parent_path, recursive=True)
assert not self.hdfs.exists(path)
def test_mv_rename(self):
path = pjoin(self.tmp_path, 'mv-test')
new_path = pjoin(self.tmp_path, 'mv-new-test')
data = b'foobarbaz'
with self.hdfs.open(path, 'wb') as f:
f.write(data)
assert self.hdfs.exists(path)
self.hdfs.mv(path, new_path)
assert not self.hdfs.exists(path)
assert self.hdfs.exists(new_path)
assert self.hdfs.cat(new_path) == data
self.hdfs.rename(new_path, path)
assert self.hdfs.cat(path) == data
def test_info(self):
path = pjoin(self.tmp_path, 'info-base')
file_path = pjoin(path, 'ex')
self.hdfs.mkdir(path)
data = b'foobarbaz'
with self.hdfs.open(file_path, 'wb') as f:
f.write(data)
path_info = self.hdfs.info(path)
file_path_info = self.hdfs.info(file_path)
assert path_info['kind'] == 'directory'
assert file_path_info['kind'] == 'file'
assert file_path_info['size'] == len(data)
def test_exists_isdir_isfile(self):
dir_path = pjoin(self.tmp_path, 'info-base')
file_path = pjoin(dir_path, 'ex')
missing_path = pjoin(dir_path, 'this-path-is-missing')
self.hdfs.mkdir(dir_path)
with self.hdfs.open(file_path, 'wb') as f:
f.write(b'foobarbaz')
assert self.hdfs.exists(dir_path)
assert self.hdfs.exists(file_path)
assert not self.hdfs.exists(missing_path)
assert self.hdfs.isdir(dir_path)
assert not self.hdfs.isdir(file_path)
assert not self.hdfs.isdir(missing_path)
assert not self.hdfs.isfile(dir_path)
assert self.hdfs.isfile(file_path)
assert not self.hdfs.isfile(missing_path)
def test_disk_usage(self):
path = pjoin(self.tmp_path, 'disk-usage-base')
p1 = pjoin(path, 'p1')
p2 = pjoin(path, 'p2')
subdir = pjoin(path, 'subdir')
p3 = pjoin(subdir, 'p3')
if self.hdfs.exists(path):
self.hdfs.delete(path, True)
self.hdfs.mkdir(path)
self.hdfs.mkdir(subdir)
data = b'foobarbaz'
for file_path in [p1, p2, p3]:
with self.hdfs.open(file_path, 'wb') as f:
f.write(data)
assert self.hdfs.disk_usage(path) == len(data) * 3
def test_ls(self):
base_path = pjoin(self.tmp_path, 'ls-test')
self.hdfs.mkdir(base_path)
dir_path = pjoin(base_path, 'a-dir')
f1_path = pjoin(base_path, 'a-file-1')
self.hdfs.mkdir(dir_path)
f = self.hdfs.open(f1_path, 'wb')
f.write(b'a' * 10)
contents = sorted(self.hdfs.ls(base_path, False))
assert contents == [dir_path, f1_path]
def test_chmod_chown(self):
path = pjoin(self.tmp_path, 'chmod-test')
with self.hdfs.open(path, 'wb') as f:
f.write(b'a' * 10)
def test_download_upload(self):
base_path = pjoin(self.tmp_path, 'upload-test')
data = b'foobarbaz'
buf = BytesIO(data)
buf.seek(0)
self.hdfs.upload(base_path, buf)
out_buf = BytesIO()
self.hdfs.download(base_path, out_buf)
out_buf.seek(0)
assert out_buf.getvalue() == data
def test_file_context_manager(self):
path = pjoin(self.tmp_path, 'ctx-manager')
data = b'foo'
with self.hdfs.open(path, 'wb') as f:
f.write(data)
with self.hdfs.open(path, 'rb') as f:
assert f.size() == 3
result = f.read(10)
assert result == data
def test_open_not_exist(self):
path = pjoin(self.tmp_path, 'does-not-exist-123')
with pytest.raises(FileNotFoundError):
self.hdfs.open(path)
def test_open_write_error(self):
with pytest.raises((FileExistsError, IsADirectoryError)):
self.hdfs.open('/', 'wb')
def test_read_whole_file(self):
path = pjoin(self.tmp_path, 'read-whole-file')
data = b'foo' * 1000
with self.hdfs.open(path, 'wb') as f:
f.write(data)
with self.hdfs.open(path, 'rb') as f:
result = f.read()
assert result == data
def _write_multiple_hdfs_pq_files(self, tmpdir):
import pyarrow.parquet as pq
nfiles = 10
size = 5
test_data = []
for i in range(nfiles):
df = _test_dataframe(size, seed=i)
df['index'] = np.arange(i * size, (i + 1) * size)
# Hack so that we don't have a dtype cast in v1 files
df['uint32'] = df['uint32'].astype(np.int64)
path = pjoin(tmpdir, '{}.parquet'.format(i))
table = pa.Table.from_pandas(df, preserve_index=False)
with self.hdfs.open(path, 'wb') as f:
pq.write_table(table, f)
test_data.append(table)
expected = pa.concat_tables(test_data)
return expected
@pytest.mark.pandas
@pytest.mark.parquet
def test_read_multiple_parquet_files(self):
tmpdir = pjoin(self.tmp_path, 'multi-parquet-' + guid())
self.hdfs.mkdir(tmpdir)
expected = self._write_multiple_hdfs_pq_files(tmpdir)
result = self.hdfs.read_parquet(tmpdir)
_pandas_api.assert_frame_equal(result.to_pandas()
.sort_values(by='index')
.reset_index(drop=True),
expected.to_pandas())
@pytest.mark.pandas
@pytest.mark.parquet
def test_read_multiple_parquet_files_with_uri(self):
import pyarrow.parquet as pq
tmpdir = pjoin(self.tmp_path, 'multi-parquet-uri-' + guid())
self.hdfs.mkdir(tmpdir)
expected = self._write_multiple_hdfs_pq_files(tmpdir)
path = _get_hdfs_uri(tmpdir)
result = pq.read_table(path)
_pandas_api.assert_frame_equal(result.to_pandas()
.sort_values(by='index')
.reset_index(drop=True),
expected.to_pandas())
@pytest.mark.pandas
@pytest.mark.parquet
def test_read_write_parquet_files_with_uri(self):
import pyarrow.parquet as pq
tmpdir = pjoin(self.tmp_path, 'uri-parquet-' + guid())
self.hdfs.mkdir(tmpdir)
path = _get_hdfs_uri(pjoin(tmpdir, 'test.parquet'))
size = 5
df = _test_dataframe(size, seed=0)
# Hack so that we don't have a dtype cast in v1 files
df['uint32'] = df['uint32'].astype(np.int64)
table = pa.Table.from_pandas(df, preserve_index=False)
pq.write_table(table, path, filesystem=self.hdfs)
result = pq.read_table(
path, filesystem=self.hdfs, use_legacy_dataset=True
).to_pandas()
_pandas_api.assert_frame_equal(result, df)
@pytest.mark.parquet
@pytest.mark.pandas
def test_read_common_metadata_files(self):
tmpdir = pjoin(self.tmp_path, 'common-metadata-' + guid())
self.hdfs.mkdir(tmpdir)
_test_read_common_metadata_files(self.hdfs, tmpdir)
@pytest.mark.parquet
@pytest.mark.pandas
def test_write_to_dataset_with_partitions(self):
tmpdir = pjoin(self.tmp_path, 'write-partitions-' + guid())
self.hdfs.mkdir(tmpdir)
_test_write_to_dataset_with_partitions(
tmpdir, filesystem=self.hdfs)
@pytest.mark.parquet
@pytest.mark.pandas
def test_write_to_dataset_no_partitions(self):
tmpdir = pjoin(self.tmp_path, 'write-no_partitions-' + guid())
self.hdfs.mkdir(tmpdir)
_test_write_to_dataset_no_partitions(
tmpdir, filesystem=self.hdfs)
class TestLibHdfs(HdfsTestCases, unittest.TestCase):
@classmethod
def check_driver(cls):
check_libhdfs_present()
def test_orphaned_file(self):
hdfs = hdfs_test_client()
file_path = self._make_test_file(hdfs, 'orphaned_file_test', 'fname',
b'foobarbaz')
f = hdfs.open(file_path)
hdfs = None
f = None # noqa
def _get_hdfs_uri(path):
host = os.environ.get('ARROW_HDFS_TEST_HOST', 'localhost')
try:
port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0))
except ValueError:
raise ValueError('Env variable ARROW_HDFS_TEST_PORT was not '
'an integer')
uri = "hdfs://{}:{}{}".format(host, port, path)
return uri
@pytest.mark.hdfs
@pytest.mark.pandas
@pytest.mark.parquet
@pytest.mark.fastparquet
def test_fastparquet_read_with_hdfs():
from pandas.testing import assert_frame_equal
check_libhdfs_present()
try:
import snappy # noqa
except ImportError:
pytest.skip('fastparquet test requires snappy')
import pyarrow.parquet as pq
fastparquet = pytest.importorskip('fastparquet')
fs = hdfs_test_client()
df = util.make_dataframe()
table = pa.Table.from_pandas(df)
path = '/tmp/testing.parquet'
with fs.open(path, 'wb') as f:
pq.write_table(table, f)
parquet_file = fastparquet.ParquetFile(path, open_with=fs.open)
result = parquet_file.to_pandas()
assert_frame_equal(result, df)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,326 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import OrderedDict
import io
import itertools
import json
import pickle
import string
import unittest
import numpy as np
import pytest
import pyarrow as pa
from pyarrow.json import read_json, ReadOptions, ParseOptions
def generate_col_names():
# 'a', 'b'... 'z', then 'aa', 'ab'...
letters = string.ascii_lowercase
yield from letters
for first in letters:
for second in letters:
yield first + second
def make_random_json(num_cols=2, num_rows=10, linesep='\r\n'):
arr = np.random.RandomState(42).randint(0, 1000, size=(num_cols, num_rows))
col_names = list(itertools.islice(generate_col_names(), num_cols))
lines = []
for row in arr.T:
json_obj = OrderedDict([(k, int(v)) for (k, v) in zip(col_names, row)])
lines.append(json.dumps(json_obj))
data = linesep.join(lines).encode()
columns = [pa.array(col, type=pa.int64()) for col in arr]
expected = pa.Table.from_arrays(columns, col_names)
return data, expected
def check_options_class_pickling(cls, **attr_values):
opts = cls(**attr_values)
new_opts = pickle.loads(pickle.dumps(opts,
protocol=pickle.HIGHEST_PROTOCOL))
for name, value in attr_values.items():
assert getattr(new_opts, name) == value
def test_read_options():
cls = ReadOptions
opts = cls()
assert opts.block_size > 0
opts.block_size = 12345
assert opts.block_size == 12345
assert opts.use_threads is True
opts.use_threads = False
assert opts.use_threads is False
opts = cls(block_size=1234, use_threads=False)
assert opts.block_size == 1234
assert opts.use_threads is False
check_options_class_pickling(cls, block_size=1234,
use_threads=False)
def test_parse_options():
cls = ParseOptions
opts = cls()
assert opts.newlines_in_values is False
assert opts.explicit_schema is None
opts.newlines_in_values = True
assert opts.newlines_in_values is True
schema = pa.schema([pa.field('foo', pa.int32())])
opts.explicit_schema = schema
assert opts.explicit_schema == schema
assert opts.unexpected_field_behavior == "infer"
for value in ["ignore", "error", "infer"]:
opts.unexpected_field_behavior = value
assert opts.unexpected_field_behavior == value
with pytest.raises(ValueError):
opts.unexpected_field_behavior = "invalid-value"
check_options_class_pickling(cls, explicit_schema=schema,
newlines_in_values=False,
unexpected_field_behavior="ignore")
class BaseTestJSONRead:
def read_bytes(self, b, **kwargs):
return self.read_json(pa.py_buffer(b), **kwargs)
def check_names(self, table, names):
assert table.num_columns == len(names)
assert [c.name for c in table.columns] == names
def test_file_object(self):
data = b'{"a": 1, "b": 2}\n'
expected_data = {'a': [1], 'b': [2]}
bio = io.BytesIO(data)
table = self.read_json(bio)
assert table.to_pydict() == expected_data
# Text files not allowed
sio = io.StringIO(data.decode())
with pytest.raises(TypeError):
self.read_json(sio)
def test_block_sizes(self):
rows = b'{"a": 1}\n{"a": 2}\n{"a": 3}'
read_options = ReadOptions()
parse_options = ParseOptions()
for data in [rows, rows + b'\n']:
for newlines_in_values in [False, True]:
parse_options.newlines_in_values = newlines_in_values
read_options.block_size = 4
with pytest.raises(ValueError,
match="try to increase block size"):
self.read_bytes(data, read_options=read_options,
parse_options=parse_options)
# Validate reader behavior with various block sizes.
# There used to be bugs in this area.
for block_size in range(9, 20):
read_options.block_size = block_size
table = self.read_bytes(data, read_options=read_options,
parse_options=parse_options)
assert table.to_pydict() == {'a': [1, 2, 3]}
def test_no_newline_at_end(self):
rows = b'{"a": 1,"b": 2, "c": 3}\n{"a": 4,"b": 5, "c": 6}'
table = self.read_bytes(rows)
assert table.to_pydict() == {
'a': [1, 4],
'b': [2, 5],
'c': [3, 6],
}
def test_simple_ints(self):
# Infer integer columns
rows = b'{"a": 1,"b": 2, "c": 3}\n{"a": 4,"b": 5, "c": 6}\n'
table = self.read_bytes(rows)
schema = pa.schema([('a', pa.int64()),
('b', pa.int64()),
('c', pa.int64())])
assert table.schema == schema
assert table.to_pydict() == {
'a': [1, 4],
'b': [2, 5],
'c': [3, 6],
}
def test_simple_varied(self):
# Infer various kinds of data
rows = (b'{"a": 1,"b": 2, "c": "3", "d": false}\n'
b'{"a": 4.0, "b": -5, "c": "foo", "d": true}\n')
table = self.read_bytes(rows)
schema = pa.schema([('a', pa.float64()),
('b', pa.int64()),
('c', pa.string()),
('d', pa.bool_())])
assert table.schema == schema
assert table.to_pydict() == {
'a': [1.0, 4.0],
'b': [2, -5],
'c': ["3", "foo"],
'd': [False, True],
}
def test_simple_nulls(self):
# Infer various kinds of data, with nulls
rows = (b'{"a": 1, "b": 2, "c": null, "d": null, "e": null}\n'
b'{"a": null, "b": -5, "c": "foo", "d": null, "e": true}\n'
b'{"a": 4.5, "b": null, "c": "nan", "d": null,"e": false}\n')
table = self.read_bytes(rows)
schema = pa.schema([('a', pa.float64()),
('b', pa.int64()),
('c', pa.string()),
('d', pa.null()),
('e', pa.bool_())])
assert table.schema == schema
assert table.to_pydict() == {
'a': [1.0, None, 4.5],
'b': [2, -5, None],
'c': [None, "foo", "nan"],
'd': [None, None, None],
'e': [None, True, False],
}
def test_empty_lists(self):
# ARROW-10955: Infer list(null)
rows = b'{"a": []}'
table = self.read_bytes(rows)
schema = pa.schema([('a', pa.list_(pa.null()))])
assert table.schema == schema
assert table.to_pydict() == {'a': [[]]}
def test_empty_rows(self):
rows = b'{}\n{}\n'
table = self.read_bytes(rows)
schema = pa.schema([])
assert table.schema == schema
assert table.num_columns == 0
assert table.num_rows == 2
def test_reconcile_accross_blocks(self):
# ARROW-12065: reconciling inferred types across blocks
first_row = b'{ }\n'
read_options = ReadOptions(block_size=len(first_row))
for next_rows, expected_pylist in [
(b'{"a": 0}', [None, 0]),
(b'{"a": []}', [None, []]),
(b'{"a": []}\n{"a": [[1]]}', [None, [], [[1]]]),
(b'{"a": {}}', [None, {}]),
(b'{"a": {}}\n{"a": {"b": {"c": 1}}}',
[None, {"b": None}, {"b": {"c": 1}}]),
]:
table = self.read_bytes(first_row + next_rows,
read_options=read_options)
expected = {"a": expected_pylist}
assert table.to_pydict() == expected
# Check that the issue was exercised
assert table.column("a").num_chunks > 1
def test_explicit_schema_with_unexpected_behaviour(self):
# infer by default
rows = (b'{"foo": "bar", "num": 0}\n'
b'{"foo": "baz", "num": 1}\n')
schema = pa.schema([
('foo', pa.binary())
])
opts = ParseOptions(explicit_schema=schema)
table = self.read_bytes(rows, parse_options=opts)
assert table.schema == pa.schema([
('foo', pa.binary()),
('num', pa.int64())
])
assert table.to_pydict() == {
'foo': [b'bar', b'baz'],
'num': [0, 1],
}
# ignore the unexpected fields
opts = ParseOptions(explicit_schema=schema,
unexpected_field_behavior="ignore")
table = self.read_bytes(rows, parse_options=opts)
assert table.schema == pa.schema([
('foo', pa.binary()),
])
assert table.to_pydict() == {
'foo': [b'bar', b'baz'],
}
# raise error
opts = ParseOptions(explicit_schema=schema,
unexpected_field_behavior="error")
with pytest.raises(pa.ArrowInvalid,
match="JSON parse error: unexpected field"):
self.read_bytes(rows, parse_options=opts)
def test_small_random_json(self):
data, expected = make_random_json(num_cols=2, num_rows=10)
table = self.read_bytes(data)
assert table.schema == expected.schema
assert table.equals(expected)
assert table.to_pydict() == expected.to_pydict()
def test_stress_block_sizes(self):
# Test a number of small block sizes to stress block stitching
data_base, expected = make_random_json(num_cols=2, num_rows=100)
read_options = ReadOptions()
parse_options = ParseOptions()
for data in [data_base, data_base.rstrip(b'\r\n')]:
for newlines_in_values in [False, True]:
parse_options.newlines_in_values = newlines_in_values
for block_size in [22, 23, 37]:
read_options.block_size = block_size
table = self.read_bytes(data, read_options=read_options,
parse_options=parse_options)
assert table.schema == expected.schema
if not table.equals(expected):
# Better error output
assert table.to_pydict() == expected.to_pydict()
class TestSerialJSONRead(BaseTestJSONRead, unittest.TestCase):
def read_json(self, *args, **kwargs):
read_options = kwargs.setdefault('read_options', ReadOptions())
read_options.use_threads = False
table = read_json(*args, **kwargs)
table.validate(full=True)
return table
class TestParallelJSONRead(BaseTestJSONRead, unittest.TestCase):
def read_json(self, *args, **kwargs):
read_options = kwargs.setdefault('read_options', ReadOptions())
read_options.use_threads = True
table = read_json(*args, **kwargs)
table.validate(full=True)
return table
@@ -0,0 +1,433 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
import pyarrow as pa
import pyarrow.jvm as pa_jvm
import pytest
import sys
import xml.etree.ElementTree as ET
jpype = pytest.importorskip("jpype")
@pytest.fixture(scope="session")
def root_allocator():
# This test requires Arrow Java to be built in the same source tree
try:
arrow_dir = os.environ["ARROW_SOURCE_DIR"]
except KeyError:
arrow_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..')
pom_path = os.path.join(arrow_dir, 'java', 'pom.xml')
tree = ET.parse(pom_path)
version = tree.getroot().find(
'POM:version',
namespaces={
'POM': 'http://maven.apache.org/POM/4.0.0'
}).text
jar_path = os.path.join(
arrow_dir, 'java', 'tools', 'target',
'arrow-tools-{}-jar-with-dependencies.jar'.format(version))
jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path)
kwargs = {}
# This will be the default behaviour in jpype 0.8+
kwargs['convertStrings'] = False
jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path,
**kwargs)
return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize)
def test_jvm_buffer(root_allocator):
# Create a Java buffer
jvm_buffer = root_allocator.buffer(8)
for i in range(8):
jvm_buffer.setByte(i, 8 - i)
orig_refcnt = jvm_buffer.refCnt()
# Convert to Python
buf = pa_jvm.jvm_buffer(jvm_buffer)
# Check its content
assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01'
# Check Java buffer lifetime is tied to PyArrow buffer lifetime
assert jvm_buffer.refCnt() == orig_refcnt + 1
del buf
assert jvm_buffer.refCnt() == orig_refcnt
def test_jvm_buffer_released(root_allocator):
import jpype.imports # noqa
from java.lang import IllegalArgumentException
jvm_buffer = root_allocator.buffer(8)
jvm_buffer.release()
with pytest.raises(IllegalArgumentException):
pa_jvm.jvm_buffer(jvm_buffer)
def _jvm_field(jvm_spec):
om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')()
pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field')
return om.readValue(jvm_spec, pojo_Field)
def _jvm_schema(jvm_spec, metadata=None):
field = _jvm_field(jvm_spec)
schema_cls = jpype.JClass('org.apache.arrow.vector.types.pojo.Schema')
fields = jpype.JClass('java.util.ArrayList')()
fields.add(field)
if metadata:
dct = jpype.JClass('java.util.HashMap')()
for k, v in metadata.items():
dct.put(k, v)
return schema_cls(fields, dct)
else:
return schema_cls(fields)
# In the following, we use the JSON serialization of the Field objects in Java.
# This ensures that we neither rely on the exact mechanics on how to construct
# them using Java code as well as enables us to define them as parameters
# without to invoke the JVM.
#
# The specifications were created using:
#
# om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')()
# field = … # Code to instantiate the field
# jvm_spec = om.writeValueAsString(field)
@pytest.mark.parametrize('pa_type,jvm_spec', [
(pa.null(), '{"name":"null"}'),
(pa.bool_(), '{"name":"bool"}'),
(pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'),
(pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'),
(pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'),
(pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'),
(pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'),
(pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'),
(pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'),
(pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'),
(pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'),
(pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'),
(pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'),
(pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'),
(pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'),
(pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'),
(pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'),
(pa.timestamp('s'), '{"name":"timestamp","unit":"SECOND",'
'"timezone":null}'),
(pa.timestamp('ms'), '{"name":"timestamp","unit":"MILLISECOND",'
'"timezone":null}'),
(pa.timestamp('us'), '{"name":"timestamp","unit":"MICROSECOND",'
'"timezone":null}'),
(pa.timestamp('ns'), '{"name":"timestamp","unit":"NANOSECOND",'
'"timezone":null}'),
(pa.timestamp('ns', tz='UTC'), '{"name":"timestamp","unit":"NANOSECOND"'
',"timezone":"UTC"}'),
(pa.timestamp('ns', tz='Europe/Paris'), '{"name":"timestamp",'
'"unit":"NANOSECOND","timezone":"Europe/Paris"}'),
(pa.date32(), '{"name":"date","unit":"DAY"}'),
(pa.date64(), '{"name":"date","unit":"MILLISECOND"}'),
(pa.decimal128(19, 4), '{"name":"decimal","precision":19,"scale":4}'),
(pa.string(), '{"name":"utf8"}'),
(pa.binary(), '{"name":"binary"}'),
(pa.binary(10), '{"name":"fixedsizebinary","byteWidth":10}'),
# TODO(ARROW-2609): complex types that have children
# pa.list_(pa.int32()),
# pa.struct([pa.field('a', pa.int32()),
# pa.field('b', pa.int8()),
# pa.field('c', pa.string())]),
# pa.union([pa.field('a', pa.binary(10)),
# pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE),
# pa.union([pa.field('a', pa.binary(10)),
# pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE),
# TODO: DictionaryType requires a vector in the type
# pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])),
])
@pytest.mark.parametrize('nullable', [True, False])
def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable):
if pa_type == pa.null() and not nullable:
return
spec = {
'name': 'field_name',
'nullable': nullable,
'type': json.loads(jvm_spec),
# TODO: This needs to be set for complex types
'children': []
}
jvm_field = _jvm_field(json.dumps(spec))
result = pa_jvm.field(jvm_field)
expected_field = pa.field('field_name', pa_type, nullable=nullable)
assert result == expected_field
jvm_schema = _jvm_schema(json.dumps(spec))
result = pa_jvm.schema(jvm_schema)
assert result == pa.schema([expected_field])
# Schema with custom metadata
jvm_schema = _jvm_schema(json.dumps(spec), {'meta': 'data'})
result = pa_jvm.schema(jvm_schema)
assert result == pa.schema([expected_field], {'meta': 'data'})
# Schema with custom field metadata
spec['metadata'] = [{'key': 'field meta', 'value': 'field data'}]
jvm_schema = _jvm_schema(json.dumps(spec))
result = pa_jvm.schema(jvm_schema)
expected_field = expected_field.with_metadata(
{'field meta': 'field data'})
assert result == pa.schema([expected_field])
# These test parameters mostly use an integer range as an input as this is
# often the only type that is understood by both Python and Java
# implementations of Arrow.
@pytest.mark.parametrize('pa_type,py_data,jvm_type', [
(pa.bool_(), [True, False, True, True], 'BitVector'),
(pa.uint8(), list(range(128)), 'UInt1Vector'),
(pa.uint16(), list(range(128)), 'UInt2Vector'),
(pa.int32(), list(range(128)), 'IntVector'),
(pa.int64(), list(range(128)), 'BigIntVector'),
(pa.float32(), list(range(128)), 'Float4Vector'),
(pa.float64(), list(range(128)), 'Float8Vector'),
(pa.timestamp('s'), list(range(128)), 'TimeStampSecVector'),
(pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector'),
(pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector'),
(pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector'),
# TODO(ARROW-2605): These types miss a conversion from pure Python objects
# * pa.time32('s')
# * pa.time32('ms')
# * pa.time64('us')
# * pa.time64('ns')
(pa.date32(), list(range(128)), 'DateDayVector'),
(pa.date64(), list(range(128)), 'DateMilliVector'),
# TODO(ARROW-2606): pa.decimal128(19, 4)
])
def test_jvm_array(root_allocator, pa_type, py_data, jvm_type):
# Create vector
cls = "org.apache.arrow.vector.{}".format(jvm_type)
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
jvm_vector.allocateNew(len(py_data))
for i, val in enumerate(py_data):
# char and int are ambiguous overloads for these two setSafe calls
if jvm_type in {'UInt1Vector', 'UInt2Vector'}:
val = jpype.JInt(val)
jvm_vector.setSafe(i, val)
jvm_vector.setValueCount(len(py_data))
py_array = pa.array(py_data, type=pa_type)
jvm_array = pa_jvm.array(jvm_vector)
assert py_array.equals(jvm_array)
def test_jvm_array_empty(root_allocator):
cls = "org.apache.arrow.vector.{}".format('IntVector')
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
jvm_vector.allocateNew()
jvm_array = pa_jvm.array(jvm_vector)
assert len(jvm_array) == 0
assert jvm_array.type == pa.int32()
# These test parameters mostly use an integer range as an input as this is
# often the only type that is understood by both Python and Java
# implementations of Arrow.
@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [
# TODO: null
(pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'),
(
pa.uint8(),
list(range(128)),
'UInt1Vector',
'{"name":"int","bitWidth":8,"isSigned":false}'
),
(
pa.uint16(),
list(range(128)),
'UInt2Vector',
'{"name":"int","bitWidth":16,"isSigned":false}'
),
(
pa.uint32(),
list(range(128)),
'UInt4Vector',
'{"name":"int","bitWidth":32,"isSigned":false}'
),
(
pa.uint64(),
list(range(128)),
'UInt8Vector',
'{"name":"int","bitWidth":64,"isSigned":false}'
),
(
pa.int8(),
list(range(128)),
'TinyIntVector',
'{"name":"int","bitWidth":8,"isSigned":true}'
),
(
pa.int16(),
list(range(128)),
'SmallIntVector',
'{"name":"int","bitWidth":16,"isSigned":true}'
),
(
pa.int32(),
list(range(128)),
'IntVector',
'{"name":"int","bitWidth":32,"isSigned":true}'
),
(
pa.int64(),
list(range(128)),
'BigIntVector',
'{"name":"int","bitWidth":64,"isSigned":true}'
),
# TODO: float16
(
pa.float32(),
list(range(128)),
'Float4Vector',
'{"name":"floatingpoint","precision":"SINGLE"}'
),
(
pa.float64(),
list(range(128)),
'Float8Vector',
'{"name":"floatingpoint","precision":"DOUBLE"}'
),
(
pa.timestamp('s'),
list(range(128)),
'TimeStampSecVector',
'{"name":"timestamp","unit":"SECOND","timezone":null}'
),
(
pa.timestamp('ms'),
list(range(128)),
'TimeStampMilliVector',
'{"name":"timestamp","unit":"MILLISECOND","timezone":null}'
),
(
pa.timestamp('us'),
list(range(128)),
'TimeStampMicroVector',
'{"name":"timestamp","unit":"MICROSECOND","timezone":null}'
),
(
pa.timestamp('ns'),
list(range(128)),
'TimeStampNanoVector',
'{"name":"timestamp","unit":"NANOSECOND","timezone":null}'
),
# TODO(ARROW-2605): These types miss a conversion from pure Python objects
# * pa.time32('s')
# * pa.time32('ms')
# * pa.time64('us')
# * pa.time64('ns')
(
pa.date32(),
list(range(128)),
'DateDayVector',
'{"name":"date","unit":"DAY"}'
),
(
pa.date64(),
list(range(128)),
'DateMilliVector',
'{"name":"date","unit":"MILLISECOND"}'
),
# TODO(ARROW-2606): pa.decimal128(19, 4)
])
def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type,
jvm_spec):
# Create vector
cls = "org.apache.arrow.vector.{}".format(jvm_type)
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
jvm_vector.allocateNew(len(py_data))
for i, val in enumerate(py_data):
if jvm_type in {'UInt1Vector', 'UInt2Vector'}:
val = jpype.JInt(val)
jvm_vector.setSafe(i, val)
jvm_vector.setValueCount(len(py_data))
# Create field
spec = {
'name': 'field_name',
'nullable': False,
'type': json.loads(jvm_spec),
# TODO: This needs to be set for complex types
'children': []
}
jvm_field = _jvm_field(json.dumps(spec))
# Create VectorSchemaRoot
jvm_fields = jpype.JClass('java.util.ArrayList')()
jvm_fields.add(jvm_field)
jvm_vectors = jpype.JClass('java.util.ArrayList')()
jvm_vectors.add(jvm_vector)
jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot')
jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data))
py_record_batch = pa.RecordBatch.from_arrays(
[pa.array(py_data, type=pa_type)],
['col']
)
jvm_record_batch = pa_jvm.record_batch(jvm_vsr)
assert py_record_batch.equals(jvm_record_batch)
def _string_to_varchar_holder(ra, string):
nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder"
holder = jpype.JClass(nvch_cls)()
if string is None:
holder.isSet = 0
else:
holder.isSet = 1
value = jpype.JClass("java.lang.String")("string")
std_charsets = jpype.JClass("java.nio.charset.StandardCharsets")
bytes_ = value.getBytes(std_charsets.UTF_8)
holder.buffer = ra.buffer(len(bytes_))
holder.buffer.setBytes(0, bytes_, 0, len(bytes_))
holder.start = 0
holder.end = len(bytes_)
return holder
# TODO(ARROW-2607)
@pytest.mark.xfail(reason="from_buffers is only supported for "
"primitive arrays yet")
def test_jvm_string_array(root_allocator):
data = ["string", None, "töst"]
cls = "org.apache.arrow.vector.VarCharVector"
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
jvm_vector.allocateNew()
for i, string in enumerate(data):
holder = _string_to_varchar_holder(root_allocator, "string")
jvm_vector.setSafe(i, holder)
jvm_vector.setValueCount(i + 1)
py_array = pa.array(data, type=pa.string())
jvm_array = pa_jvm.array(jvm_vector)
assert py_array.equals(jvm_array)
@@ -0,0 +1,245 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import os
import signal
import subprocess
import sys
import weakref
import pyarrow as pa
import pytest
possible_backends = ["system", "jemalloc", "mimalloc"]
should_have_jemalloc = sys.platform == "linux"
should_have_mimalloc = sys.platform == "win32"
def supported_factories():
yield pa.default_memory_pool
for backend in pa.supported_memory_backends():
yield getattr(pa, f"{backend}_memory_pool")
@contextlib.contextmanager
def allocate_bytes(pool, nbytes):
"""
Temporarily allocate *nbytes* from the given *pool*.
"""
arr = pa.array([b"x" * nbytes], type=pa.binary(), memory_pool=pool)
# Fetch the values buffer from the varbinary array and release the rest,
# to get the desired allocation amount
buf = arr.buffers()[2]
arr = None
assert len(buf) == nbytes
try:
yield
finally:
buf = None
def check_allocated_bytes(pool):
"""
Check allocation stats on *pool*.
"""
allocated_before = pool.bytes_allocated()
max_mem_before = pool.max_memory()
with allocate_bytes(pool, 512):
assert pool.bytes_allocated() == allocated_before + 512
new_max_memory = pool.max_memory()
assert pool.max_memory() >= max_mem_before
assert pool.bytes_allocated() == allocated_before
assert pool.max_memory() == new_max_memory
def test_default_allocated_bytes():
pool = pa.default_memory_pool()
with allocate_bytes(pool, 1024):
check_allocated_bytes(pool)
assert pool.bytes_allocated() == pa.total_allocated_bytes()
def test_proxy_memory_pool():
pool = pa.proxy_memory_pool(pa.default_memory_pool())
check_allocated_bytes(pool)
wr = weakref.ref(pool)
assert wr() is not None
del pool
assert wr() is None
def test_logging_memory_pool(capfd):
pool = pa.logging_memory_pool(pa.default_memory_pool())
check_allocated_bytes(pool)
out, err = capfd.readouterr()
assert err == ""
assert out.count("Allocate:") > 0
assert out.count("Allocate:") == out.count("Free:")
def test_set_memory_pool():
old_pool = pa.default_memory_pool()
pool = pa.proxy_memory_pool(old_pool)
pa.set_memory_pool(pool)
try:
allocated_before = pool.bytes_allocated()
with allocate_bytes(None, 512):
assert pool.bytes_allocated() == allocated_before + 512
assert pool.bytes_allocated() == allocated_before
finally:
pa.set_memory_pool(old_pool)
def test_default_backend_name():
pool = pa.default_memory_pool()
assert pool.backend_name in possible_backends
def test_release_unused():
pool = pa.default_memory_pool()
pool.release_unused()
def check_env_var(name, expected, *, expect_warning=False):
code = f"""if 1:
import pyarrow as pa
pool = pa.default_memory_pool()
assert pool.backend_name in {expected!r}, pool.backend_name
"""
env = dict(os.environ)
env['ARROW_DEFAULT_MEMORY_POOL'] = name
res = subprocess.run([sys.executable, "-c", code], env=env,
universal_newlines=True, stderr=subprocess.PIPE)
if res.returncode != 0:
print(res.stderr, file=sys.stderr)
res.check_returncode() # fail
errlines = res.stderr.splitlines()
if expect_warning:
assert len(errlines) == 1
assert f"Unsupported backend '{name}'" in errlines[0]
else:
assert len(errlines) == 0
def test_env_var():
check_env_var("system", ["system"])
if should_have_jemalloc:
check_env_var("jemalloc", ["jemalloc"])
if should_have_mimalloc:
check_env_var("mimalloc", ["mimalloc"])
check_env_var("nonexistent", possible_backends, expect_warning=True)
def test_specific_memory_pools():
specific_pools = set()
def check(factory, name, *, can_fail=False):
if can_fail:
try:
pool = factory()
except NotImplementedError:
return
else:
pool = factory()
assert pool.backend_name == name
specific_pools.add(pool)
check(pa.system_memory_pool, "system")
check(pa.jemalloc_memory_pool, "jemalloc",
can_fail=not should_have_jemalloc)
check(pa.mimalloc_memory_pool, "mimalloc",
can_fail=not should_have_mimalloc)
def test_supported_memory_backends():
backends = pa.supported_memory_backends()
assert "system" in backends
if should_have_jemalloc:
assert "jemalloc" in backends
if should_have_mimalloc:
assert "mimalloc" in backends
def run_debug_memory_pool(pool_factory, env_value):
"""
Run a piece of code making an invalid memory write with the
ARROW_DEBUG_MEMORY_POOL environment variable set to a specific value.
"""
code = f"""if 1:
import ctypes
import pyarrow as pa
pool = pa.{pool_factory}()
buf = pa.allocate_buffer(64, memory_pool=pool)
# Write memory out of bounds
ptr = ctypes.cast(buf.address, ctypes.POINTER(ctypes.c_ubyte))
ptr[64] = 0
del buf
"""
env = dict(os.environ)
env['ARROW_DEBUG_MEMORY_POOL'] = env_value
res = subprocess.run([sys.executable, "-c", code], env=env,
universal_newlines=True, stderr=subprocess.PIPE)
print(res.stderr, file=sys.stderr)
return res
@pytest.mark.parametrize('pool_factory', supported_factories())
def test_debug_memory_pool_abort(pool_factory):
res = run_debug_memory_pool(pool_factory.__name__, "abort")
if os.name == "posix":
assert res.returncode == -signal.SIGABRT
else:
assert res.returncode != 0
assert "Wrong size on deallocation" in res.stderr
@pytest.mark.parametrize('pool_factory', supported_factories())
def test_debug_memory_pool_trap(pool_factory):
res = run_debug_memory_pool(pool_factory.__name__, "trap")
if os.name == "posix":
assert res.returncode == -signal.SIGTRAP
else:
assert res.returncode != 0
assert "Wrong size on deallocation" in res.stderr
@pytest.mark.parametrize('pool_factory', supported_factories())
def test_debug_memory_pool_warn(pool_factory):
res = run_debug_memory_pool(pool_factory.__name__, "warn")
res.check_returncode()
assert "Wrong size on deallocation" in res.stderr
@pytest.mark.parametrize('pool_factory', supported_factories())
def test_debug_memory_pool_disabled(pool_factory):
res = run_debug_memory_pool(pool_factory.__name__, "")
# The subprocess either returned successfully or was killed by a signal
# (due to writing out of bounds), depending on the underlying allocator.
if os.name == "posix":
assert res.returncode <= 0
else:
res.check_returncode()
assert res.stderr == ""
@@ -0,0 +1,215 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import subprocess
import sys
import pytest
import pyarrow as pa
def test_get_include():
include_dir = pa.get_include()
assert os.path.exists(os.path.join(include_dir, 'arrow', 'api.h'))
@pytest.mark.skipif('sys.platform != "win32"')
def test_get_library_dirs_win32():
assert any(os.path.exists(os.path.join(directory, 'arrow.lib'))
for directory in pa.get_library_dirs())
def test_cpu_count():
n = pa.cpu_count()
assert n > 0
try:
pa.set_cpu_count(n + 5)
assert pa.cpu_count() == n + 5
finally:
pa.set_cpu_count(n)
def test_io_thread_count():
n = pa.io_thread_count()
assert n > 0
try:
pa.set_io_thread_count(n + 5)
assert pa.io_thread_count() == n + 5
finally:
pa.set_io_thread_count(n)
def test_env_var_io_thread_count():
# Test that the number of IO threads can be overriden with the
# ARROW_IO_THREADS environment variable.
code = """if 1:
import pyarrow as pa
print(pa.io_thread_count())
"""
def run_with_env_var(env_var):
env = os.environ.copy()
env['ARROW_IO_THREADS'] = env_var
res = subprocess.run([sys.executable, "-c", code], env=env,
capture_output=True)
res.check_returncode()
return res.stdout.decode(), res.stderr.decode()
out, err = run_with_env_var('17')
assert out.strip() == '17'
assert err == ''
for v in ('-1', 'z'):
out, err = run_with_env_var(v)
assert out.strip() == '8' # default value
assert ("ARROW_IO_THREADS does not contain a valid number of threads"
in err.strip())
def test_build_info():
assert isinstance(pa.cpp_build_info, pa.BuildInfo)
assert isinstance(pa.cpp_version_info, pa.VersionInfo)
assert isinstance(pa.cpp_version, str)
assert isinstance(pa.__version__, str)
assert pa.cpp_build_info.version_info == pa.cpp_version_info
assert pa.cpp_build_info.build_type in (
'debug', 'release', 'minsizerel', 'relwithdebinfo')
# assert pa.version == pa.__version__ # XXX currently false
def test_runtime_info():
info = pa.runtime_info()
assert isinstance(info, pa.RuntimeInfo)
possible_simd_levels = ('none', 'sse4_2', 'avx', 'avx2', 'avx512')
assert info.simd_level in possible_simd_levels
assert info.detected_simd_level in possible_simd_levels
if info.simd_level != 'none':
env = os.environ.copy()
env['ARROW_USER_SIMD_LEVEL'] = 'none'
code = f"""if 1:
import pyarrow as pa
info = pa.runtime_info()
assert info.simd_level == 'none', info.simd_level
assert info.detected_simd_level == {info.detected_simd_level!r},\
info.detected_simd_level
"""
subprocess.check_call([sys.executable, "-c", code], env=env)
@pytest.mark.parametrize('klass', [
pa.Field,
pa.Schema,
pa.ChunkedArray,
pa.RecordBatch,
pa.Table,
pa.Buffer,
pa.Array,
pa.Tensor,
pa.DataType,
pa.ListType,
pa.LargeListType,
pa.FixedSizeListType,
pa.UnionType,
pa.SparseUnionType,
pa.DenseUnionType,
pa.StructType,
pa.Time32Type,
pa.Time64Type,
pa.TimestampType,
pa.Decimal128Type,
pa.Decimal256Type,
pa.DictionaryType,
pa.FixedSizeBinaryType,
pa.NullArray,
pa.NumericArray,
pa.IntegerArray,
pa.FloatingPointArray,
pa.BooleanArray,
pa.Int8Array,
pa.Int16Array,
pa.Int32Array,
pa.Int64Array,
pa.UInt8Array,
pa.UInt16Array,
pa.UInt32Array,
pa.UInt64Array,
pa.ListArray,
pa.LargeListArray,
pa.MapArray,
pa.FixedSizeListArray,
pa.UnionArray,
pa.BinaryArray,
pa.StringArray,
pa.FixedSizeBinaryArray,
pa.DictionaryArray,
pa.Date32Array,
pa.Date64Array,
pa.TimestampArray,
pa.Time32Array,
pa.Time64Array,
pa.DurationArray,
pa.Decimal128Array,
pa.Decimal256Array,
pa.StructArray,
pa.Scalar,
pa.BooleanScalar,
pa.Int8Scalar,
pa.Int16Scalar,
pa.Int32Scalar,
pa.Int64Scalar,
pa.UInt8Scalar,
pa.UInt16Scalar,
pa.UInt32Scalar,
pa.UInt64Scalar,
pa.HalfFloatScalar,
pa.FloatScalar,
pa.DoubleScalar,
pa.Decimal128Scalar,
pa.Decimal256Scalar,
pa.Date32Scalar,
pa.Date64Scalar,
pa.Time32Scalar,
pa.Time64Scalar,
pa.TimestampScalar,
pa.DurationScalar,
pa.StringScalar,
pa.BinaryScalar,
pa.FixedSizeBinaryScalar,
pa.ListScalar,
pa.LargeListScalar,
pa.MapScalar,
pa.FixedSizeListScalar,
pa.UnionScalar,
pa.StructScalar,
pa.DictionaryScalar,
pa.ipc.Message,
pa.ipc.MessageReader,
pa.MemoryPool,
pa.LoggingMemoryPool,
pa.ProxyMemoryPool,
])
def test_extension_type_constructor_errors(klass):
# ARROW-2638: prevent calling extension class constructors directly
msg = "Do not call {cls}'s constructor directly, use .* instead."
with pytest.raises(TypeError, match=msg.format(cls=klass.__name__)):
klass()
@@ -0,0 +1,636 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import decimal
import datetime
import pyarrow as pa
from pyarrow import fs
from pyarrow.tests import util
# Marks all of the tests in this module
# Ignore these with pytest ... -m 'not orc'
pytestmark = pytest.mark.orc
try:
from pandas.testing import assert_frame_equal
import pandas as pd
except ImportError:
pass
@pytest.fixture(scope="module")
def datadir(base_datadir):
return base_datadir / "orc"
def fix_example_values(actual_cols, expected_cols):
"""
Fix type of expected values (as read from JSON) according to
actual ORC datatype.
"""
for name in expected_cols:
expected = expected_cols[name]
actual = actual_cols[name]
if (name == "map" and
[d.keys() == {'key', 'value'} for m in expected for d in m]):
# convert [{'key': k, 'value': v}, ...] to [(k, v), ...]
for i, m in enumerate(expected):
expected_cols[name][i] = [(d['key'], d['value']) for d in m]
continue
typ = actual[0].__class__
if issubclass(typ, datetime.datetime):
# timestamp fields are represented as strings in JSON files
expected = pd.to_datetime(expected)
elif issubclass(typ, datetime.date):
# date fields are represented as strings in JSON files
expected = expected.dt.date
elif typ is decimal.Decimal:
converted_decimals = [None] * len(expected)
# decimal fields are represented as reals in JSON files
for i, (d, v) in enumerate(zip(actual, expected)):
if not pd.isnull(v):
exp = d.as_tuple().exponent
factor = 10 ** -exp
converted_decimals[i] = (
decimal.Decimal(round(v * factor)).scaleb(exp))
expected = pd.Series(converted_decimals)
expected_cols[name] = expected
def check_example_values(orc_df, expected_df, start=None, stop=None):
if start is not None or stop is not None:
expected_df = expected_df[start:stop].reset_index(drop=True)
assert_frame_equal(orc_df, expected_df, check_dtype=False)
def check_example_file(orc_path, expected_df, need_fix=False):
"""
Check a ORC file against the expected columns dictionary.
"""
from pyarrow import orc
orc_file = orc.ORCFile(orc_path)
# Exercise ORCFile.read()
table = orc_file.read()
assert isinstance(table, pa.Table)
table.validate()
# This workaround needed because of ARROW-3080
orc_df = pd.DataFrame(table.to_pydict())
assert set(expected_df.columns) == set(orc_df.columns)
# reorder columns if necessary
if not orc_df.columns.equals(expected_df.columns):
expected_df = expected_df.reindex(columns=orc_df.columns)
if need_fix:
fix_example_values(orc_df, expected_df)
check_example_values(orc_df, expected_df)
# Exercise ORCFile.read_stripe()
json_pos = 0
for i in range(orc_file.nstripes):
batch = orc_file.read_stripe(i)
check_example_values(pd.DataFrame(batch.to_pydict()),
expected_df,
start=json_pos,
stop=json_pos + len(batch))
json_pos += len(batch)
assert json_pos == orc_file.nrows
@pytest.mark.pandas
@pytest.mark.parametrize('filename', [
'TestOrcFile.test1.orc',
'TestOrcFile.testDate1900.orc',
'decimal.orc'
])
def test_example_using_json(filename, datadir):
"""
Check a ORC file example against the equivalent JSON file, as given
in the Apache ORC repository (the JSON file has one JSON object per
line, corresponding to one row in the ORC file).
"""
# Read JSON file
path = datadir / filename
table = pd.read_json(str(path.with_suffix('.jsn.gz')), lines=True)
check_example_file(path, table, need_fix=True)
def test_orcfile_empty(datadir):
from pyarrow import orc
table = orc.ORCFile(datadir / "TestOrcFile.emptyFile.orc").read()
assert table.num_rows == 0
expected_schema = pa.schema([
("boolean1", pa.bool_()),
("byte1", pa.int8()),
("short1", pa.int16()),
("int1", pa.int32()),
("long1", pa.int64()),
("float1", pa.float32()),
("double1", pa.float64()),
("bytes1", pa.binary()),
("string1", pa.string()),
("middle", pa.struct(
[("list", pa.list_(
pa.struct([("int1", pa.int32()),
("string1", pa.string())])))
])),
("list", pa.list_(
pa.struct([("int1", pa.int32()),
("string1", pa.string())])
)),
("map", pa.map_(pa.string(),
pa.struct([("int1", pa.int32()),
("string1", pa.string())])
)),
])
assert table.schema == expected_schema
def test_filesystem_uri(tmpdir):
from pyarrow import orc
table = pa.table({"a": [1, 2, 3]})
directory = tmpdir / "data_dir"
directory.mkdir()
path = directory / "data.orc"
orc.write_table(table, str(path))
# filesystem object
result = orc.read_table(path, filesystem=fs.LocalFileSystem())
assert result.equals(table)
# filesystem URI
result = orc.read_table(
"data_dir/data.orc", filesystem=util._filesystem_uri(tmpdir))
assert result.equals(table)
# use the path only
result = orc.read_table(
util._filesystem_uri(path))
assert result.equals(table)
def test_orcfile_readwrite(tmpdir):
from pyarrow import orc
a = pa.array([1, None, 3, None])
b = pa.array([None, "Arrow", None, "ORC"])
table = pa.table({"int64": a, "utf8": b})
file = tmpdir.join("test.orc")
orc.write_table(table, file)
output_table = orc.read_table(file)
assert table.equals(output_table)
output_table = orc.read_table(file, [])
assert 4 == output_table.num_rows
assert 0 == output_table.num_columns
output_table = orc.read_table(file, columns=["int64"])
assert 4 == output_table.num_rows
assert 1 == output_table.num_columns
def test_bytesio_readwrite():
from pyarrow import orc
from io import BytesIO
buf = BytesIO()
a = pa.array([1, None, 3, None])
b = pa.array([None, "Arrow", None, "ORC"])
table = pa.table({"int64": a, "utf8": b})
orc.write_table(table, buf)
buf.seek(0)
orc_file = orc.ORCFile(buf)
output_table = orc_file.read()
assert table.equals(output_table)
def test_buffer_readwrite():
from pyarrow import orc
buffer_output_stream = pa.BufferOutputStream()
a = pa.array([1, None, 3, None])
b = pa.array([None, "Arrow", None, "ORC"])
table = pa.table({"int64": a, "utf8": b})
orc.write_table(table, buffer_output_stream)
buffer_reader = pa.BufferReader(buffer_output_stream.getvalue())
orc_file = orc.ORCFile(buffer_reader)
output_table = orc_file.read()
assert table.equals(output_table)
# Check for default WriteOptions
assert orc_file.compression == 'UNCOMPRESSED'
assert orc_file.file_version == '0.12'
assert orc_file.row_index_stride == 10000
assert orc_file.compression_size == 65536
# deprecated keyword order
buffer_output_stream = pa.BufferOutputStream()
with pytest.warns(FutureWarning):
orc.write_table(buffer_output_stream, table)
buffer_reader = pa.BufferReader(buffer_output_stream.getvalue())
orc_file = orc.ORCFile(buffer_reader)
output_table = orc_file.read()
assert table.equals(output_table)
# Check for default WriteOptions
assert orc_file.compression == 'UNCOMPRESSED'
assert orc_file.file_version == '0.12'
assert orc_file.row_index_stride == 10000
assert orc_file.compression_size == 65536
@pytest.mark.snappy
def test_buffer_readwrite_with_writeoptions():
from pyarrow import orc
buffer_output_stream = pa.BufferOutputStream()
a = pa.array([1, None, 3, None])
b = pa.array([None, "Arrow", None, "ORC"])
table = pa.table({"int64": a, "utf8": b})
orc.write_table(
table,
buffer_output_stream,
compression='snappy',
file_version='0.11',
row_index_stride=5000,
compression_block_size=32768,
)
buffer_reader = pa.BufferReader(buffer_output_stream.getvalue())
orc_file = orc.ORCFile(buffer_reader)
output_table = orc_file.read()
assert table.equals(output_table)
# Check for modified WriteOptions
assert orc_file.compression == 'SNAPPY'
assert orc_file.file_version == '0.11'
assert orc_file.row_index_stride == 5000
assert orc_file.compression_size == 32768
# deprecated keyword order
buffer_output_stream = pa.BufferOutputStream()
with pytest.warns(FutureWarning):
orc.write_table(
buffer_output_stream,
table,
compression='uncompressed',
file_version='0.11',
row_index_stride=20000,
compression_block_size=16384,
)
buffer_reader = pa.BufferReader(buffer_output_stream.getvalue())
orc_file = orc.ORCFile(buffer_reader)
output_table = orc_file.read()
assert table.equals(output_table)
# Check for default WriteOptions
assert orc_file.compression == 'UNCOMPRESSED'
assert orc_file.file_version == '0.11'
assert orc_file.row_index_stride == 20000
assert orc_file.compression_size == 16384
def test_buffer_readwrite_with_bad_writeoptions():
from pyarrow import orc
buffer_output_stream = pa.BufferOutputStream()
a = pa.array([1, None, 3, None])
table = pa.table({"int64": a})
# batch_size must be a positive integer
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
batch_size=0,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
batch_size=-100,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
batch_size=1024.23,
)
# file_version must be 0.11 or 0.12
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
file_version=0.13,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
file_version='1.1',
)
# stripe_size must be a positive integer
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
stripe_size=0,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
stripe_size=-400,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
stripe_size=4096.73,
)
# compression must be among the given options
with pytest.raises(TypeError):
orc.write_table(
table,
buffer_output_stream,
compression=0,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
compression='none',
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
compression='zlid',
)
# compression_block_size must be a positive integer
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
compression_block_size=0,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
compression_block_size=-200,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
compression_block_size=1096.73,
)
# compression_strategy must be among the given options
with pytest.raises(TypeError):
orc.write_table(
table,
buffer_output_stream,
compression_strategy=0,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
compression_strategy='no',
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
compression_strategy='large',
)
# row_index_stride must be a positive integer
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
row_index_stride=0,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
row_index_stride=-800,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
row_index_stride=3096.29,
)
# padding_tolerance must be possible to cast to float
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
padding_tolerance='cat',
)
# dictionary_key_size_threshold must be possible to cast to
# float between 0.0 and 1.0
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
dictionary_key_size_threshold='arrow',
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
dictionary_key_size_threshold=1.2,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
dictionary_key_size_threshold=-3.2,
)
# bloom_filter_columns must be convertible to a list containing
# nonnegative integers
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
bloom_filter_columns="string",
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
bloom_filter_columns=[0, 1.4],
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
bloom_filter_columns={0, 2, -1},
)
# bloom_filter_fpp must be convertible to a float between 0.0 and 1.0
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
bloom_filter_fpp='arrow',
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
bloom_filter_fpp=1.1,
)
with pytest.raises(ValueError):
orc.write_table(
table,
buffer_output_stream,
bloom_filter_fpp=-0.1,
)
def test_column_selection(tempdir):
from pyarrow import orc
# create a table with nested types
inner = pa.field('inner', pa.int64())
middle = pa.field('middle', pa.struct([inner]))
fields = [
pa.field('basic', pa.int32()),
pa.field(
'list', pa.list_(pa.field('item', pa.int32()))
),
pa.field(
'struct', pa.struct([middle, pa.field('inner2', pa.int64())])
),
pa.field(
'list-struct', pa.list_(pa.field(
'item', pa.struct([
pa.field('inner1', pa.int64()),
pa.field('inner2', pa.int64())
])
))
),
pa.field('basic2', pa.int64()),
]
arrs = [
[0], [[1, 2]], [{"middle": {"inner": 3}, "inner2": 4}],
[[{"inner1": 5, "inner2": 6}, {"inner1": 7, "inner2": 8}]], [9]]
table = pa.table(arrs, schema=pa.schema(fields))
path = str(tempdir / 'test.orc')
orc.write_table(table, path)
orc_file = orc.ORCFile(path)
# default selecting all columns
result1 = orc_file.read()
assert result1.equals(table)
# selecting with columns names
result2 = orc_file.read(columns=["basic", "basic2"])
assert result2.equals(table.select(["basic", "basic2"]))
result3 = orc_file.read(columns=["list", "struct", "basic2"])
assert result3.equals(table.select(["list", "struct", "basic2"]))
# using dotted paths
result4 = orc_file.read(columns=["struct.middle.inner"])
expected4 = pa.table({"struct": [{"middle": {"inner": 3}}]})
assert result4.equals(expected4)
result5 = orc_file.read(columns=["struct.inner2"])
expected5 = pa.table({"struct": [{"inner2": 4}]})
assert result5.equals(expected5)
result6 = orc_file.read(
columns=["list", "struct.middle.inner", "struct.inner2"]
)
assert result6.equals(table.select(["list", "struct"]))
result7 = orc_file.read(columns=["list-struct.inner1"])
expected7 = pa.table({"list-struct": [[{"inner1": 5}, {"inner1": 7}]]})
assert result7.equals(expected7)
# selecting with (Arrow-based) field indices
result2 = orc_file.read(columns=[0, 4])
assert result2.equals(table.select(["basic", "basic2"]))
result3 = orc_file.read(columns=[1, 2, 3])
assert result3.equals(table.select(["list", "struct", "list-struct"]))
# error on non-existing name or index
with pytest.raises(IOError):
# liborc returns ParseError, which gets translated into IOError
# instead of ValueError
orc_file.read(columns=["wrong"])
with pytest.raises(ValueError):
orc_file.read(columns=[5])
def test_wrong_usage_orc_writer(tempdir):
from pyarrow import orc
path = str(tempdir / 'test.orc')
with orc.ORCWriter(path) as writer:
with pytest.raises(AttributeError):
writer.test()
def test_orc_writer_with_null_arrays(tempdir):
from pyarrow import orc
import pyarrow as pa
path = str(tempdir / 'test.orc')
a = pa.array([1, None, 3, None])
b = pa.array([None, None, None, None])
table = pa.table({"int64": a, "utf8": b})
with pytest.raises(pa.ArrowNotImplementedError):
orc.write_table(table, path)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,104 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
def run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
client, use_gpu, dtype):
FORCE_DEVICE = '/gpu' if use_gpu else '/cpu'
object_id = np.random.bytes(20)
data = np.random.randn(3, 244, 244).astype(dtype)
ones = np.ones((3, 244, 244)).astype(dtype)
sess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True))
def ToPlasma():
data_tensor = tf.constant(data)
ones_tensor = tf.constant(ones)
return plasma.tf_plasma_op.tensor_to_plasma(
[data_tensor, ones_tensor],
object_id,
plasma_store_socket_name=plasma_store_name)
def FromPlasma():
return plasma.tf_plasma_op.plasma_to_tensor(
object_id,
dtype=tf.as_dtype(dtype),
plasma_store_socket_name=plasma_store_name)
with tf.device(FORCE_DEVICE):
to_plasma = ToPlasma()
from_plasma = FromPlasma()
z = from_plasma + 1
sess.run(to_plasma)
# NOTE(zongheng): currently it returns a flat 1D tensor.
# So reshape manually.
out = sess.run(from_plasma)
out = np.split(out, 2)
out0 = out[0].reshape(3, 244, 244)
out1 = out[1].reshape(3, 244, 244)
sess.run(z)
assert np.array_equal(data, out0), "Data not equal!"
assert np.array_equal(ones, out1), "Data not equal!"
# Try getting the data from Python
plasma_object_id = plasma.ObjectID(object_id)
obj = client.get(plasma_object_id)
# Deserialized Tensor should be 64-byte aligned.
assert obj.ctypes.data % 64 == 0
result = np.split(obj, 2)
result0 = result[0].reshape(3, 244, 244)
result1 = result[1].reshape(3, 244, 244)
assert np.array_equal(data, result0), "Data not equal!"
assert np.array_equal(ones, result1), "Data not equal!"
@pytest.mark.plasma
@pytest.mark.tensorflow
@pytest.mark.skip(reason='Until ARROW-4259 is resolved')
def test_plasma_tf_op(use_gpu=False):
import pyarrow.plasma as plasma
import tensorflow as tf
plasma.build_plasma_tensorflow_op()
if plasma.tf_plasma_op is None:
pytest.skip("TensorFlow Op not found")
with plasma.start_plasma_store(10**8) as (plasma_store_name, p):
client = plasma.connect(plasma_store_name)
for dtype in [np.float32, np.float64,
np.int8, np.int16, np.int32, np.int64]:
run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
client, use_gpu, dtype)
# Make sure the objects have been released.
for _, info in client.list().items():
assert info['ref_count'] == 0
@@ -0,0 +1,696 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import decimal
import pickle
import pytest
import weakref
import numpy as np
import pyarrow as pa
@pytest.mark.parametrize(['value', 'ty', 'klass', 'deprecated'], [
(False, None, pa.BooleanScalar, pa.BooleanValue),
(True, None, pa.BooleanScalar, pa.BooleanValue),
(1, None, pa.Int64Scalar, pa.Int64Value),
(-1, None, pa.Int64Scalar, pa.Int64Value),
(1, pa.int8(), pa.Int8Scalar, pa.Int8Value),
(1, pa.uint8(), pa.UInt8Scalar, pa.UInt8Value),
(1, pa.int16(), pa.Int16Scalar, pa.Int16Value),
(1, pa.uint16(), pa.UInt16Scalar, pa.UInt16Value),
(1, pa.int32(), pa.Int32Scalar, pa.Int32Value),
(1, pa.uint32(), pa.UInt32Scalar, pa.UInt32Value),
(1, pa.int64(), pa.Int64Scalar, pa.Int64Value),
(1, pa.uint64(), pa.UInt64Scalar, pa.UInt64Value),
(1.0, None, pa.DoubleScalar, pa.DoubleValue),
(np.float16(1.0), pa.float16(), pa.HalfFloatScalar, pa.HalfFloatValue),
(1.0, pa.float32(), pa.FloatScalar, pa.FloatValue),
(decimal.Decimal("1.123"), None, pa.Decimal128Scalar, pa.Decimal128Value),
(decimal.Decimal("1.1234567890123456789012345678901234567890"),
None, pa.Decimal256Scalar, pa.Decimal256Value),
("string", None, pa.StringScalar, pa.StringValue),
(b"bytes", None, pa.BinaryScalar, pa.BinaryValue),
("largestring", pa.large_string(), pa.LargeStringScalar,
pa.LargeStringValue),
(b"largebytes", pa.large_binary(), pa.LargeBinaryScalar,
pa.LargeBinaryValue),
(b"abc", pa.binary(3), pa.FixedSizeBinaryScalar, pa.FixedSizeBinaryValue),
([1, 2, 3], None, pa.ListScalar, pa.ListValue),
([1, 2, 3, 4], pa.large_list(pa.int8()), pa.LargeListScalar,
pa.LargeListValue),
([1, 2, 3, 4, 5], pa.list_(pa.int8(), 5), pa.FixedSizeListScalar,
pa.FixedSizeListValue),
(datetime.date.today(), None, pa.Date32Scalar, pa.Date32Value),
(datetime.date.today(), pa.date64(), pa.Date64Scalar, pa.Date64Value),
(datetime.datetime.now(), None, pa.TimestampScalar, pa.TimestampValue),
(datetime.datetime.now().time().replace(microsecond=0), pa.time32('s'),
pa.Time32Scalar, pa.Time32Value),
(datetime.datetime.now().time(), None, pa.Time64Scalar, pa.Time64Value),
(datetime.timedelta(days=1), None, pa.DurationScalar, pa.DurationValue),
(pa.MonthDayNano([1, -1, -10100]), None,
pa.MonthDayNanoIntervalScalar, None),
({'a': 1, 'b': [1, 2]}, None, pa.StructScalar, pa.StructValue),
([('a', 1), ('b', 2)], pa.map_(pa.string(), pa.int8()), pa.MapScalar,
pa.MapValue),
])
def test_basics(value, ty, klass, deprecated):
s = pa.scalar(value, type=ty)
assert isinstance(s, klass)
assert s.as_py() == value
assert s == pa.scalar(value, type=ty)
assert s != value
assert s != "else"
assert hash(s) == hash(s)
assert s.is_valid is True
assert s != None # noqa: E711
if deprecated is not None:
with pytest.warns(FutureWarning):
assert isinstance(s, deprecated)
s = pa.scalar(None, type=s.type)
assert s.is_valid is False
assert s.as_py() is None
assert s != pa.scalar(value, type=ty)
# test pickle roundtrip
restored = pickle.loads(pickle.dumps(s))
assert s.equals(restored)
# test that scalars are weak-referenceable
wr = weakref.ref(s)
assert wr() is not None
del s
assert wr() is None
def test_null_singleton():
with pytest.raises(RuntimeError):
pa.NullScalar()
def test_nulls():
null = pa.scalar(None)
assert null is pa.NA
assert null.as_py() is None
assert null != "something"
assert (null == pa.scalar(None)) is True
assert (null == 0) is False
assert pa.NA == pa.NA
assert pa.NA not in [5]
arr = pa.array([None, None])
for v in arr:
assert v is pa.NA
assert v.as_py() is None
# test pickle roundtrip
restored = pickle.loads(pickle.dumps(null))
assert restored.equals(null)
# test that scalars are weak-referenceable
wr = weakref.ref(null)
assert wr() is not None
del null
assert wr() is not None # singleton
def test_hashing():
# ARROW-640
values = list(range(500))
arr = pa.array(values + values)
set_from_array = set(arr)
assert isinstance(set_from_array, set)
assert len(set_from_array) == 500
def test_bool():
false = pa.scalar(False)
true = pa.scalar(True)
assert isinstance(false, pa.BooleanScalar)
assert isinstance(true, pa.BooleanScalar)
assert repr(true) == "<pyarrow.BooleanScalar: True>"
assert str(true) == "True"
assert repr(false) == "<pyarrow.BooleanScalar: False>"
assert str(false) == "False"
assert true.as_py() is True
assert false.as_py() is False
def test_numerics():
# int64
s = pa.scalar(1)
assert isinstance(s, pa.Int64Scalar)
assert repr(s) == "<pyarrow.Int64Scalar: 1>"
assert str(s) == "1"
assert s.as_py() == 1
with pytest.raises(OverflowError):
pa.scalar(-1, type='uint8')
# float64
s = pa.scalar(1.5)
assert isinstance(s, pa.DoubleScalar)
assert repr(s) == "<pyarrow.DoubleScalar: 1.5>"
assert str(s) == "1.5"
assert s.as_py() == 1.5
# float16
s = pa.scalar(np.float16(0.5), type='float16')
assert isinstance(s, pa.HalfFloatScalar)
assert repr(s) == "<pyarrow.HalfFloatScalar: 0.5>"
assert str(s) == "0.5"
assert s.as_py() == 0.5
def test_decimal128():
v = decimal.Decimal("1.123")
s = pa.scalar(v)
assert isinstance(s, pa.Decimal128Scalar)
assert s.as_py() == v
assert s.type == pa.decimal128(4, 3)
v = decimal.Decimal("1.1234")
with pytest.raises(pa.ArrowInvalid):
pa.scalar(v, type=pa.decimal128(4, scale=3))
with pytest.raises(pa.ArrowInvalid):
pa.scalar(v, type=pa.decimal128(5, scale=3))
s = pa.scalar(v, type=pa.decimal128(5, scale=4))
assert isinstance(s, pa.Decimal128Scalar)
assert s.as_py() == v
def test_decimal256():
v = decimal.Decimal("1234567890123456789012345678901234567890.123")
s = pa.scalar(v)
assert isinstance(s, pa.Decimal256Scalar)
assert s.as_py() == v
assert s.type == pa.decimal256(43, 3)
v = decimal.Decimal("1.1234")
with pytest.raises(pa.ArrowInvalid):
pa.scalar(v, type=pa.decimal256(4, scale=3))
with pytest.raises(pa.ArrowInvalid):
pa.scalar(v, type=pa.decimal256(5, scale=3))
s = pa.scalar(v, type=pa.decimal256(5, scale=4))
assert isinstance(s, pa.Decimal256Scalar)
assert s.as_py() == v
def test_date():
# ARROW-5125
d1 = datetime.date(3200, 1, 1)
d2 = datetime.date(1960, 1, 1)
for ty in [pa.date32(), pa.date64()]:
for d in [d1, d2]:
s = pa.scalar(d, type=ty)
assert s.as_py() == d
def test_date_cast():
# ARROW-10472 - casting fo scalars doesn't segfault
scalar = pa.scalar(datetime.datetime(2012, 1, 1), type=pa.timestamp("us"))
expected = datetime.date(2012, 1, 1)
for ty in [pa.date32(), pa.date64()]:
result = scalar.cast(ty)
assert result.as_py() == expected
def test_time():
t1 = datetime.time(18, 0)
t2 = datetime.time(21, 0)
types = [pa.time32('s'), pa.time32('ms'), pa.time64('us'), pa.time64('ns')]
for ty in types:
for t in [t1, t2]:
s = pa.scalar(t, type=ty)
assert s.as_py() == t
def test_cast():
val = pa.scalar(5, type='int8')
assert val.cast('int64') == pa.scalar(5, type='int64')
assert val.cast('uint32') == pa.scalar(5, type='uint32')
assert val.cast('string') == pa.scalar('5', type='string')
with pytest.raises(ValueError):
pa.scalar('foo').cast('int32')
@pytest.mark.pandas
def test_timestamp():
import pandas as pd
arr = pd.date_range('2000-01-01 12:34:56', periods=10).values
units = ['ns', 'us', 'ms', 's']
for i, unit in enumerate(units):
dtype = 'datetime64[{}]'.format(unit)
arrow_arr = pa.Array.from_pandas(arr.astype(dtype))
expected = pd.Timestamp('2000-01-01 12:34:56')
assert arrow_arr[0].as_py() == expected
assert arrow_arr[0].value * 1000**i == expected.value
tz = 'America/New_York'
arrow_type = pa.timestamp(unit, tz=tz)
dtype = 'datetime64[{}]'.format(unit)
arrow_arr = pa.Array.from_pandas(arr.astype(dtype), type=arrow_type)
expected = (pd.Timestamp('2000-01-01 12:34:56')
.tz_localize('utc')
.tz_convert(tz))
assert arrow_arr[0].as_py() == expected
assert arrow_arr[0].value * 1000**i == expected.value
@pytest.mark.nopandas
def test_timestamp_nanos_nopandas():
# ARROW-5450
pytest.importorskip("pytz")
import pytz
tz = 'America/New_York'
ty = pa.timestamp('ns', tz=tz)
# 2000-01-01 00:00:00 + 1 microsecond
s = pa.scalar(946684800000000000 + 1000, type=ty)
tzinfo = pytz.timezone(tz)
expected = datetime.datetime(2000, 1, 1, microsecond=1, tzinfo=tzinfo)
expected = tzinfo.fromutc(expected)
result = s.as_py()
assert result == expected
assert result.year == 1999
assert result.hour == 19
# Non-zero nanos yields ValueError
s = pa.scalar(946684800000000001, type=ty)
with pytest.raises(ValueError):
s.as_py()
def test_timestamp_no_overflow():
# ARROW-5450
pytest.importorskip("pytz")
import pytz
timestamps = [
datetime.datetime(1, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=pytz.utc),
datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
]
for ts in timestamps:
s = pa.scalar(ts, type=pa.timestamp("us", tz="UTC"))
assert s.as_py() == ts
def test_timestamp_fixed_offset_print():
# ARROW-13896
pytest.importorskip("pytz")
arr = pa.array([0], pa.timestamp('s', tz='+02:00'))
assert str(arr[0]) == "1970-01-01 02:00:00+02:00"
def test_duration():
arr = np.array([0, 3600000000000], dtype='timedelta64[ns]')
units = ['us', 'ms', 's']
for i, unit in enumerate(units):
dtype = 'timedelta64[{}]'.format(unit)
arrow_arr = pa.array(arr.astype(dtype))
expected = datetime.timedelta(seconds=60*60)
assert isinstance(arrow_arr[1].as_py(), datetime.timedelta)
assert arrow_arr[1].as_py() == expected
assert (arrow_arr[1].value * 1000**(i+1) ==
expected.total_seconds() * 1e9)
@pytest.mark.pandas
def test_duration_nanos_pandas():
import pandas as pd
arr = pa.array([0, 3600000000000], type=pa.duration('ns'))
expected = pd.Timedelta('1 hour')
assert isinstance(arr[1].as_py(), pd.Timedelta)
assert arr[1].as_py() == expected
assert arr[1].value == expected.value
# Non-zero nanos work fine
arr = pa.array([946684800000000001], type=pa.duration('ns'))
assert arr[0].as_py() == pd.Timedelta(946684800000000001, unit='ns')
@pytest.mark.nopandas
def test_duration_nanos_nopandas():
arr = pa.array([0, 3600000000000], pa.duration('ns'))
expected = datetime.timedelta(seconds=60*60)
assert isinstance(arr[1].as_py(), datetime.timedelta)
assert arr[1].as_py() == expected
assert arr[1].value == expected.total_seconds() * 1e9
# Non-zero nanos yields ValueError
arr = pa.array([946684800000000001], type=pa.duration('ns'))
with pytest.raises(ValueError):
arr[0].as_py()
def test_month_day_nano_interval():
triple = pa.MonthDayNano([-3600, 1800, -50])
arr = pa.array([triple])
assert isinstance(arr[0].as_py(), pa.MonthDayNano)
assert arr[0].as_py() == triple
assert arr[0].value == triple
@pytest.mark.parametrize('value', ['foo', 'mañana'])
@pytest.mark.parametrize(('ty', 'scalar_typ'), [
(pa.string(), pa.StringScalar),
(pa.large_string(), pa.LargeStringScalar)
])
def test_string(value, ty, scalar_typ):
s = pa.scalar(value, type=ty)
assert isinstance(s, scalar_typ)
assert s.as_py() == value
assert s.as_py() != 'something'
assert repr(value) in repr(s)
assert str(s) == str(value)
buf = s.as_buffer()
assert isinstance(buf, pa.Buffer)
assert buf.to_pybytes() == value.encode()
@pytest.mark.parametrize('value', [b'foo', b'bar'])
@pytest.mark.parametrize(('ty', 'scalar_typ'), [
(pa.binary(), pa.BinaryScalar),
(pa.large_binary(), pa.LargeBinaryScalar)
])
def test_binary(value, ty, scalar_typ):
s = pa.scalar(value, type=ty)
assert isinstance(s, scalar_typ)
assert s.as_py() == value
assert str(s) == str(value)
assert repr(value) in repr(s)
assert s.as_py() == value
assert s != b'xxxxx'
buf = s.as_buffer()
assert isinstance(buf, pa.Buffer)
assert buf.to_pybytes() == value
def test_fixed_size_binary():
s = pa.scalar(b'foof', type=pa.binary(4))
assert isinstance(s, pa.FixedSizeBinaryScalar)
assert s.as_py() == b'foof'
with pytest.raises(pa.ArrowInvalid):
pa.scalar(b'foof5', type=pa.binary(4))
@pytest.mark.parametrize(('ty', 'klass'), [
(pa.list_(pa.string()), pa.ListScalar),
(pa.large_list(pa.string()), pa.LargeListScalar)
])
def test_list(ty, klass):
v = ['foo', None]
s = pa.scalar(v, type=ty)
assert s.type == ty
assert len(s) == 2
assert isinstance(s.values, pa.Array)
assert s.values.to_pylist() == v
assert isinstance(s, klass)
assert repr(v) in repr(s)
assert s.as_py() == v
assert s[0].as_py() == 'foo'
assert s[1].as_py() is None
assert s[-1] == s[1]
assert s[-2] == s[0]
with pytest.raises(IndexError):
s[-3]
with pytest.raises(IndexError):
s[2]
def test_list_from_numpy():
s = pa.scalar(np.array([1, 2, 3], dtype=np.int64()))
assert s.type == pa.list_(pa.int64())
assert s.as_py() == [1, 2, 3]
@pytest.mark.pandas
def test_list_from_pandas():
import pandas as pd
s = pa.scalar(pd.Series([1, 2, 3]))
assert s.as_py() == [1, 2, 3]
cases = [
(np.nan, 'null'),
(['string', np.nan], pa.list_(pa.binary())),
(['string', np.nan], pa.list_(pa.utf8())),
([b'string', np.nan], pa.list_(pa.binary(6))),
([True, np.nan], pa.list_(pa.bool_())),
([decimal.Decimal('0'), np.nan], pa.list_(pa.decimal128(12, 2))),
]
for case, ty in cases:
# Both types of exceptions are raised. May want to clean that up
with pytest.raises((ValueError, TypeError)):
pa.scalar(case, type=ty)
# from_pandas option suppresses failure
s = pa.scalar(case, type=ty, from_pandas=True)
def test_fixed_size_list():
s = pa.scalar([1, None, 3], type=pa.list_(pa.int64(), 3))
assert len(s) == 3
assert isinstance(s, pa.FixedSizeListScalar)
assert repr(s) == "<pyarrow.FixedSizeListScalar: [1, None, 3]>"
assert s.as_py() == [1, None, 3]
assert s[0].as_py() == 1
assert s[1].as_py() is None
assert s[-1] == s[2]
with pytest.raises(IndexError):
s[-4]
with pytest.raises(IndexError):
s[3]
def test_struct():
ty = pa.struct([
pa.field('x', pa.int16()),
pa.field('y', pa.float32())
])
v = {'x': 2, 'y': 3.5}
s = pa.scalar(v, type=ty)
assert list(s) == list(s.keys()) == ['x', 'y']
assert list(s.values()) == [
pa.scalar(2, type=pa.int16()),
pa.scalar(3.5, type=pa.float32())
]
assert list(s.items()) == [
('x', pa.scalar(2, type=pa.int16())),
('y', pa.scalar(3.5, type=pa.float32()))
]
assert 'x' in s
assert 'y' in s
assert 'z' not in s
assert 0 not in s
assert s.as_py() == v
assert repr(s) != repr(v)
assert repr(s.as_py()) == repr(v)
assert len(s) == 2
assert isinstance(s['x'], pa.Int16Scalar)
assert isinstance(s['y'], pa.FloatScalar)
assert s['x'].as_py() == 2
assert s['y'].as_py() == 3.5
with pytest.raises(KeyError):
s['non-existent']
s = pa.scalar(None, type=ty)
assert list(s) == list(s.keys()) == ['x', 'y']
assert s.as_py() is None
assert 'x' in s
assert 'y' in s
assert isinstance(s['x'], pa.Int16Scalar)
assert isinstance(s['y'], pa.FloatScalar)
assert s['x'].is_valid is False
assert s['y'].is_valid is False
assert s['x'].as_py() is None
assert s['y'].as_py() is None
def test_struct_duplicate_fields():
ty = pa.struct([
pa.field('x', pa.int16()),
pa.field('y', pa.float32()),
pa.field('x', pa.int64()),
])
s = pa.scalar([('x', 1), ('y', 2.0), ('x', 3)], type=ty)
assert list(s) == list(s.keys()) == ['x', 'y', 'x']
assert len(s) == 3
assert s == s
assert list(s.items()) == [
('x', pa.scalar(1, pa.int16())),
('y', pa.scalar(2.0, pa.float32())),
('x', pa.scalar(3, pa.int64()))
]
assert 'x' in s
assert 'y' in s
assert 'z' not in s
assert 0 not in s
# getitem with field names fails for duplicate fields, works for others
with pytest.raises(KeyError):
s['x']
assert isinstance(s['y'], pa.FloatScalar)
assert s['y'].as_py() == 2.0
# getitem with integer index works for all fields
assert isinstance(s[0], pa.Int16Scalar)
assert s[0].as_py() == 1
assert isinstance(s[1], pa.FloatScalar)
assert s[1].as_py() == 2.0
assert isinstance(s[2], pa.Int64Scalar)
assert s[2].as_py() == 3
assert "pyarrow.StructScalar" in repr(s)
with pytest.raises(ValueError, match="duplicate field names"):
s.as_py()
def test_map():
ty = pa.map_(pa.string(), pa.int8())
v = [('a', 1), ('b', 2)]
s = pa.scalar(v, type=ty)
assert len(s) == 2
assert isinstance(s, pa.MapScalar)
assert isinstance(s.values, pa.Array)
assert repr(s) == "<pyarrow.MapScalar: [('a', 1), ('b', 2)]>"
assert s.values.to_pylist() == [
{'key': 'a', 'value': 1},
{'key': 'b', 'value': 2}
]
# test iteration
for i, j in zip(s, v):
assert i == j
assert s.as_py() == v
assert s[1] == (
pa.scalar('b', type=pa.string()),
pa.scalar(2, type=pa.int8())
)
assert s[-1] == s[1]
assert s[-2] == s[0]
with pytest.raises(IndexError):
s[-3]
with pytest.raises(IndexError):
s[2]
restored = pickle.loads(pickle.dumps(s))
assert restored.equals(s)
def test_dictionary():
indices = pa.array([2, None, 1, 2, 0, None])
dictionary = pa.array(['foo', 'bar', 'baz'])
arr = pa.DictionaryArray.from_arrays(indices, dictionary)
expected = ['baz', None, 'bar', 'baz', 'foo', None]
assert arr.to_pylist() == expected
for j, (i, v) in enumerate(zip(indices, expected)):
s = arr[j]
assert s.as_py() == v
assert s.value.as_py() == v
assert s.index.equals(i)
assert s.dictionary.equals(dictionary)
with pytest.warns(FutureWarning):
assert s.index_value.equals(i)
with pytest.warns(FutureWarning):
assert s.dictionary_value.as_py() == v
restored = pickle.loads(pickle.dumps(s))
assert restored.equals(s)
def test_union():
# sparse
arr = pa.UnionArray.from_sparse(
pa.array([0, 0, 1, 1], type=pa.int8()),
[
pa.array(["a", "b", "c", "d"]),
pa.array([1, 2, 3, 4])
]
)
for s in arr:
assert isinstance(s, pa.UnionScalar)
assert s.type.equals(arr.type)
assert s.is_valid is True
with pytest.raises(pa.ArrowNotImplementedError):
pickle.loads(pickle.dumps(s))
assert arr[0].type_code == 0
assert arr[0].as_py() == "a"
assert arr[1].type_code == 0
assert arr[1].as_py() == "b"
assert arr[2].type_code == 1
assert arr[2].as_py() == 3
assert arr[3].type_code == 1
assert arr[3].as_py() == 4
# dense
arr = pa.UnionArray.from_dense(
types=pa.array([0, 1, 0, 0, 1, 1, 0], type='int8'),
value_offsets=pa.array([0, 0, 2, 1, 1, 2, 3], type='int32'),
children=[
pa.array([b'a', b'b', b'c', b'd'], type='binary'),
pa.array([1, 2, 3], type='int64')
]
)
for s in arr:
assert isinstance(s, pa.UnionScalar)
assert s.type.equals(arr.type)
assert s.is_valid is True
with pytest.raises(pa.ArrowNotImplementedError):
pickle.loads(pickle.dumps(s))
assert arr[0].type_code == 0
assert arr[0].as_py() == b'a'
assert arr[5].type_code == 1
assert arr[5].as_py() == 3
@@ -0,0 +1,730 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import OrderedDict
import pickle
import sys
import weakref
import pytest
import numpy as np
import pyarrow as pa
import pyarrow.tests.util as test_util
from pyarrow.vendored.version import Version
def test_schema_constructor_errors():
msg = ("Do not call Schema's constructor directly, use `pyarrow.schema` "
"instead")
with pytest.raises(TypeError, match=msg):
pa.Schema()
def test_type_integers():
dtypes = ['int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64']
for name in dtypes:
factory = getattr(pa, name)
t = factory()
assert str(t) == name
def test_type_to_pandas_dtype():
M8_ns = np.dtype('datetime64[ns]')
cases = [
(pa.null(), np.object_),
(pa.bool_(), np.bool_),
(pa.int8(), np.int8),
(pa.int16(), np.int16),
(pa.int32(), np.int32),
(pa.int64(), np.int64),
(pa.uint8(), np.uint8),
(pa.uint16(), np.uint16),
(pa.uint32(), np.uint32),
(pa.uint64(), np.uint64),
(pa.float16(), np.float16),
(pa.float32(), np.float32),
(pa.float64(), np.float64),
(pa.date32(), M8_ns),
(pa.date64(), M8_ns),
(pa.timestamp('ms'), M8_ns),
(pa.binary(), np.object_),
(pa.binary(12), np.object_),
(pa.string(), np.object_),
(pa.list_(pa.int8()), np.object_),
# (pa.list_(pa.int8(), 2), np.object_), # TODO needs pandas conversion
(pa.map_(pa.int64(), pa.float64()), np.object_),
]
for arrow_type, numpy_type in cases:
assert arrow_type.to_pandas_dtype() == numpy_type
@pytest.mark.pandas
def test_type_to_pandas_dtype_check_import():
# ARROW-7980
test_util.invoke_script('arrow_7980.py')
def test_type_list():
value_type = pa.int32()
list_type = pa.list_(value_type)
assert str(list_type) == 'list<item: int32>'
field = pa.field('my_item', pa.string())
l2 = pa.list_(field)
assert str(l2) == 'list<my_item: string>'
def test_type_comparisons():
val = pa.int32()
assert val == pa.int32()
assert val == 'int32'
assert val != 5
def test_type_for_alias():
cases = [
('i1', pa.int8()),
('int8', pa.int8()),
('i2', pa.int16()),
('int16', pa.int16()),
('i4', pa.int32()),
('int32', pa.int32()),
('i8', pa.int64()),
('int64', pa.int64()),
('u1', pa.uint8()),
('uint8', pa.uint8()),
('u2', pa.uint16()),
('uint16', pa.uint16()),
('u4', pa.uint32()),
('uint32', pa.uint32()),
('u8', pa.uint64()),
('uint64', pa.uint64()),
('f4', pa.float32()),
('float32', pa.float32()),
('f8', pa.float64()),
('float64', pa.float64()),
('date32', pa.date32()),
('date64', pa.date64()),
('string', pa.string()),
('str', pa.string()),
('binary', pa.binary()),
('time32[s]', pa.time32('s')),
('time32[ms]', pa.time32('ms')),
('time64[us]', pa.time64('us')),
('time64[ns]', pa.time64('ns')),
('timestamp[s]', pa.timestamp('s')),
('timestamp[ms]', pa.timestamp('ms')),
('timestamp[us]', pa.timestamp('us')),
('timestamp[ns]', pa.timestamp('ns')),
('duration[s]', pa.duration('s')),
('duration[ms]', pa.duration('ms')),
('duration[us]', pa.duration('us')),
('duration[ns]', pa.duration('ns')),
('month_day_nano_interval', pa.month_day_nano_interval()),
]
for val, expected in cases:
assert pa.type_for_alias(val) == expected
def test_type_string():
t = pa.string()
assert str(t) == 'string'
def test_type_timestamp_with_tz():
tz = 'America/Los_Angeles'
t = pa.timestamp('ns', tz=tz)
assert t.unit == 'ns'
assert t.tz == tz
def test_time_types():
t1 = pa.time32('s')
t2 = pa.time32('ms')
t3 = pa.time64('us')
t4 = pa.time64('ns')
assert t1.unit == 's'
assert t2.unit == 'ms'
assert t3.unit == 'us'
assert t4.unit == 'ns'
assert str(t1) == 'time32[s]'
assert str(t4) == 'time64[ns]'
with pytest.raises(ValueError):
pa.time32('us')
with pytest.raises(ValueError):
pa.time64('s')
def test_from_numpy_dtype():
cases = [
(np.dtype('bool'), pa.bool_()),
(np.dtype('int8'), pa.int8()),
(np.dtype('int16'), pa.int16()),
(np.dtype('int32'), pa.int32()),
(np.dtype('int64'), pa.int64()),
(np.dtype('uint8'), pa.uint8()),
(np.dtype('uint16'), pa.uint16()),
(np.dtype('uint32'), pa.uint32()),
(np.dtype('float16'), pa.float16()),
(np.dtype('float32'), pa.float32()),
(np.dtype('float64'), pa.float64()),
(np.dtype('U'), pa.string()),
(np.dtype('S'), pa.binary()),
(np.dtype('datetime64[s]'), pa.timestamp('s')),
(np.dtype('datetime64[ms]'), pa.timestamp('ms')),
(np.dtype('datetime64[us]'), pa.timestamp('us')),
(np.dtype('datetime64[ns]'), pa.timestamp('ns')),
(np.dtype('timedelta64[s]'), pa.duration('s')),
(np.dtype('timedelta64[ms]'), pa.duration('ms')),
(np.dtype('timedelta64[us]'), pa.duration('us')),
(np.dtype('timedelta64[ns]'), pa.duration('ns')),
]
for dt, pt in cases:
result = pa.from_numpy_dtype(dt)
assert result == pt
# Things convertible to numpy dtypes work
assert pa.from_numpy_dtype('U') == pa.string()
assert pa.from_numpy_dtype(np.str_) == pa.string()
assert pa.from_numpy_dtype('int32') == pa.int32()
assert pa.from_numpy_dtype(bool) == pa.bool_()
with pytest.raises(NotImplementedError):
pa.from_numpy_dtype(np.dtype('O'))
with pytest.raises(TypeError):
pa.from_numpy_dtype('not_convertible_to_dtype')
def test_schema():
fields = [
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8()))
]
sch = pa.schema(fields)
assert sch.names == ['foo', 'bar', 'baz']
assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
assert len(sch) == 3
assert sch[0].name == 'foo'
assert sch[0].type == fields[0].type
assert sch.field('foo').name == 'foo'
assert sch.field('foo').type == fields[0].type
assert repr(sch) == """\
foo: int32
bar: string
baz: list<item: int8>
child 0, item: int8"""
with pytest.raises(TypeError):
pa.schema([None])
def test_schema_weakref():
fields = [
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8()))
]
schema = pa.schema(fields)
wr = weakref.ref(schema)
assert wr() is not None
del schema
assert wr() is None
def test_schema_to_string_with_metadata():
lorem = """\
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nulla accumsan vel
turpis et mollis. Aliquam tincidunt arcu id tortor blandit blandit. Donec
eget leo quis lectus scelerisque varius. Class aptent taciti sociosqu ad
litora torquent per conubia nostra, per inceptos himenaeos. Praesent
faucibus, diam eu volutpat iaculis, tellus est porta ligula, a efficitur
turpis nulla facilisis quam. Aliquam vitae lorem erat. Proin a dolor ac libero
dignissim mollis vitae eu mauris. Quisque posuere tellus vitae massa
pellentesque sagittis. Aenean feugiat, diam ac dignissim fermentum, lorem
sapien commodo massa, vel volutpat orci nisi eu justo. Nulla non blandit
sapien. Quisque pretium vestibulum urna eu vehicula."""
# ARROW-7063
my_schema = pa.schema([pa.field("foo", "int32", False,
metadata={"key1": "value1"}),
pa.field("bar", "string", True,
metadata={"key3": "value3"})],
metadata={"lorem": lorem})
assert my_schema.to_string() == """\
foo: int32 not null
-- field metadata --
key1: 'value1'
bar: string
-- field metadata --
key3: 'value3'
-- schema metadata --
lorem: '""" + lorem[:65] + "' + " + str(len(lorem) - 65)
# Metadata that exactly fits
result = pa.schema([('f0', 'int32')],
metadata={'key': 'value' + 'x' * 62}).to_string()
assert result == """\
f0: int32
-- schema metadata --
key: 'valuexxxxxxxxxxxxxxxxxxxxxxxxxxxxx\
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'"""
assert my_schema.to_string(truncate_metadata=False) == """\
foo: int32 not null
-- field metadata --
key1: 'value1'
bar: string
-- field metadata --
key3: 'value3'
-- schema metadata --
lorem: '{}'""".format(lorem)
assert my_schema.to_string(truncate_metadata=False,
show_field_metadata=False) == """\
foo: int32 not null
bar: string
-- schema metadata --
lorem: '{}'""".format(lorem)
assert my_schema.to_string(truncate_metadata=False,
show_schema_metadata=False) == """\
foo: int32 not null
-- field metadata --
key1: 'value1'
bar: string
-- field metadata --
key3: 'value3'"""
assert my_schema.to_string(truncate_metadata=False,
show_field_metadata=False,
show_schema_metadata=False) == """\
foo: int32 not null
bar: string"""
def test_schema_from_tuples():
fields = [
('foo', pa.int32()),
('bar', pa.string()),
('baz', pa.list_(pa.int8())),
]
sch = pa.schema(fields)
assert sch.names == ['foo', 'bar', 'baz']
assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
assert len(sch) == 3
assert repr(sch) == """\
foo: int32
bar: string
baz: list<item: int8>
child 0, item: int8"""
with pytest.raises(TypeError):
pa.schema([('foo', None)])
def test_schema_from_mapping():
fields = OrderedDict([
('foo', pa.int32()),
('bar', pa.string()),
('baz', pa.list_(pa.int8())),
])
sch = pa.schema(fields)
assert sch.names == ['foo', 'bar', 'baz']
assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
assert len(sch) == 3
assert repr(sch) == """\
foo: int32
bar: string
baz: list<item: int8>
child 0, item: int8"""
fields = OrderedDict([('foo', None)])
with pytest.raises(TypeError):
pa.schema(fields)
def test_schema_duplicate_fields():
fields = [
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('foo', pa.list_(pa.int8())),
]
sch = pa.schema(fields)
assert sch.names == ['foo', 'bar', 'foo']
assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
assert len(sch) == 3
assert repr(sch) == """\
foo: int32
bar: string
foo: list<item: int8>
child 0, item: int8"""
assert sch[0].name == 'foo'
assert sch[0].type == fields[0].type
with pytest.warns(FutureWarning):
assert sch.field_by_name('bar') == fields[1]
with pytest.warns(FutureWarning):
assert sch.field_by_name('xxx') is None
with pytest.warns((UserWarning, FutureWarning)):
assert sch.field_by_name('foo') is None
# Schema::GetFieldIndex
assert sch.get_field_index('foo') == -1
# Schema::GetAllFieldIndices
assert sch.get_all_field_indices('foo') == [0, 2]
def test_field_flatten():
f0 = pa.field('foo', pa.int32()).with_metadata({b'foo': b'bar'})
assert f0.flatten() == [f0]
f1 = pa.field('bar', pa.float64(), nullable=False)
ff = pa.field('ff', pa.struct([f0, f1]), nullable=False)
assert ff.flatten() == [
pa.field('ff.foo', pa.int32()).with_metadata({b'foo': b'bar'}),
pa.field('ff.bar', pa.float64(), nullable=False)] # XXX
# Nullable parent makes flattened child nullable
ff = pa.field('ff', pa.struct([f0, f1]))
assert ff.flatten() == [
pa.field('ff.foo', pa.int32()).with_metadata({b'foo': b'bar'}),
pa.field('ff.bar', pa.float64())]
fff = pa.field('fff', pa.struct([ff]))
assert fff.flatten() == [pa.field('fff.ff', pa.struct([f0, f1]))]
def test_schema_add_remove_metadata():
fields = [
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8()))
]
s1 = pa.schema(fields)
assert s1.metadata is None
metadata = {b'foo': b'bar', b'pandas': b'badger'}
s2 = s1.with_metadata(metadata)
assert s2.metadata == metadata
s3 = s2.remove_metadata()
assert s3.metadata is None
# idempotent
s4 = s3.remove_metadata()
assert s4.metadata is None
def test_schema_equals():
fields = [
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8()))
]
metadata = {b'foo': b'bar', b'pandas': b'badger'}
sch1 = pa.schema(fields)
sch2 = pa.schema(fields)
sch3 = pa.schema(fields, metadata=metadata)
sch4 = pa.schema(fields, metadata=metadata)
assert sch1.equals(sch2, check_metadata=True)
assert sch3.equals(sch4, check_metadata=True)
assert sch1.equals(sch3)
assert not sch1.equals(sch3, check_metadata=True)
assert not sch1.equals(sch3, check_metadata=True)
del fields[-1]
sch3 = pa.schema(fields)
assert not sch1.equals(sch3)
def test_schema_equals_propagates_check_metadata():
# ARROW-4088
schema1 = pa.schema([
pa.field('foo', pa.int32()),
pa.field('bar', pa.string())
])
schema2 = pa.schema([
pa.field('foo', pa.int32()),
pa.field('bar', pa.string(), metadata={'a': 'alpha'}),
])
assert not schema1.equals(schema2, check_metadata=True)
assert schema1.equals(schema2)
def test_schema_equals_invalid_type():
# ARROW-5873
schema = pa.schema([pa.field("a", pa.int64())])
for val in [None, 'string', pa.array([1, 2])]:
with pytest.raises(TypeError):
schema.equals(val)
def test_schema_equality_operators():
fields = [
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8()))
]
metadata = {b'foo': b'bar', b'pandas': b'badger'}
sch1 = pa.schema(fields)
sch2 = pa.schema(fields)
sch3 = pa.schema(fields, metadata=metadata)
sch4 = pa.schema(fields, metadata=metadata)
assert sch1 == sch2
assert sch3 == sch4
# __eq__ and __ne__ do not check metadata
assert sch1 == sch3
assert not sch1 != sch3
assert sch2 == sch4
# comparison with other types doesn't raise
assert sch1 != []
assert sch3 != 'foo'
def test_schema_get_fields():
fields = [
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8()))
]
schema = pa.schema(fields)
assert schema.field('foo').name == 'foo'
assert schema.field(0).name == 'foo'
assert schema.field(-1).name == 'baz'
with pytest.raises(KeyError):
schema.field('other')
with pytest.raises(TypeError):
schema.field(0.0)
with pytest.raises(IndexError):
schema.field(4)
def test_schema_negative_indexing():
fields = [
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8()))
]
schema = pa.schema(fields)
assert schema[-1].equals(schema[2])
assert schema[-2].equals(schema[1])
assert schema[-3].equals(schema[0])
with pytest.raises(IndexError):
schema[-4]
with pytest.raises(IndexError):
schema[3]
def test_schema_repr_with_dictionaries():
fields = [
pa.field('one', pa.dictionary(pa.int16(), pa.string())),
pa.field('two', pa.int32())
]
sch = pa.schema(fields)
expected = (
"""\
one: dictionary<values=string, indices=int16, ordered=0>
two: int32""")
assert repr(sch) == expected
def test_type_schema_pickling():
cases = [
pa.int8(),
pa.string(),
pa.binary(),
pa.binary(10),
pa.list_(pa.string()),
pa.map_(pa.string(), pa.int8()),
pa.struct([
pa.field('a', 'int8'),
pa.field('b', 'string')
]),
pa.union([
pa.field('a', pa.int8()),
pa.field('b', pa.int16())
], pa.lib.UnionMode_SPARSE),
pa.union([
pa.field('a', pa.int8()),
pa.field('b', pa.int16())
], pa.lib.UnionMode_DENSE),
pa.time32('s'),
pa.time64('us'),
pa.date32(),
pa.date64(),
pa.timestamp('ms'),
pa.timestamp('ns'),
pa.decimal128(12, 2),
pa.decimal256(76, 38),
pa.field('a', 'string', metadata={b'foo': b'bar'}),
pa.list_(pa.field("element", pa.int64())),
pa.large_list(pa.field("element", pa.int64())),
pa.map_(pa.field("key", pa.string(), nullable=False),
pa.field("value", pa.int8()))
]
for val in cases:
roundtripped = pickle.loads(pickle.dumps(val))
assert val == roundtripped
fields = []
for i, f in enumerate(cases):
if isinstance(f, pa.Field):
fields.append(f)
else:
fields.append(pa.field('_f{}'.format(i), f))
schema = pa.schema(fields, metadata={b'foo': b'bar'})
roundtripped = pickle.loads(pickle.dumps(schema))
assert schema == roundtripped
def test_empty_table():
schema1 = pa.schema([
pa.field('f0', pa.int64()),
pa.field('f1', pa.dictionary(pa.int32(), pa.string())),
pa.field('f2', pa.list_(pa.list_(pa.int64()))),
])
# test it preserves field nullability
schema2 = pa.schema([
pa.field('a', pa.int64(), nullable=False),
pa.field('b', pa.int64())
])
for schema in [schema1, schema2]:
table = schema.empty_table()
assert isinstance(table, pa.Table)
assert table.num_rows == 0
assert table.schema == schema
@pytest.mark.pandas
def test_schema_from_pandas():
import pandas as pd
inputs = [
list(range(10)),
pd.Categorical(list(range(10))),
['foo', 'bar', None, 'baz', 'qux'],
np.array([
'2007-07-13T01:23:34.123456789',
'2006-01-13T12:34:56.432539784',
'2010-08-13T05:46:57.437699912'
], dtype='datetime64[ns]'),
]
if Version(pd.__version__) >= Version('1.0.0'):
inputs.append(pd.array([1, 2, None], dtype=pd.Int32Dtype()))
for data in inputs:
df = pd.DataFrame({'a': data})
schema = pa.Schema.from_pandas(df)
expected = pa.Table.from_pandas(df).schema
assert schema == expected
def test_schema_sizeof():
schema = pa.schema([
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
])
assert sys.getsizeof(schema) > 30
schema2 = schema.with_metadata({"key": "some metadata"})
assert sys.getsizeof(schema2) > sys.getsizeof(schema)
schema3 = schema.with_metadata({"key": "some more metadata"})
assert sys.getsizeof(schema3) > sys.getsizeof(schema2)
def test_schema_merge():
a = pa.schema([
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8()))
])
b = pa.schema([
pa.field('foo', pa.int32()),
pa.field('qux', pa.bool_())
])
c = pa.schema([
pa.field('quux', pa.dictionary(pa.int32(), pa.string()))
])
d = pa.schema([
pa.field('foo', pa.int64()),
pa.field('qux', pa.bool_())
])
result = pa.unify_schemas([a, b, c])
expected = pa.schema([
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8())),
pa.field('qux', pa.bool_()),
pa.field('quux', pa.dictionary(pa.int32(), pa.string()))
])
assert result.equals(expected)
with pytest.raises(pa.ArrowInvalid):
pa.unify_schemas([b, d])
# ARROW-14002: Try with tuple instead of list
result = pa.unify_schemas((a, b, c))
assert result.equals(expected)
def test_undecodable_metadata():
# ARROW-10214: undecodable metadata shouldn't fail repr()
data1 = b'abcdef\xff\x00'
data2 = b'ghijkl\xff\x00'
schema = pa.schema(
[pa.field('ints', pa.int16(), metadata={'key': data1})],
metadata={'key': data2})
assert 'abcdef' in str(schema)
assert 'ghijkl' in str(schema)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import pyarrow as pa
def test_serialization_deprecated():
with pytest.warns(FutureWarning):
ser = pa.serialize(1)
with pytest.warns(FutureWarning):
pa.deserialize(ser.to_buffer())
f = pa.BufferOutputStream()
with pytest.warns(FutureWarning):
pa.serialize_to(12, f)
buf = f.getvalue()
f = pa.BufferReader(buf)
with pytest.warns(FutureWarning):
pa.read_serialized(f).deserialize()
with pytest.warns(FutureWarning):
pa.default_serialization_context()
context = pa.lib.SerializationContext()
with pytest.warns(FutureWarning):
pa.register_default_serialization_handlers(context)
def test_serialization_deprecated_toplevel():
with pytest.warns(FutureWarning):
pa.SerializedPyObject()
with pytest.warns(FutureWarning):
pa.SerializationContext()
@@ -0,0 +1,491 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import sys
import weakref
import numpy as np
import pyarrow as pa
try:
from scipy.sparse import csr_matrix, coo_matrix
except ImportError:
coo_matrix = None
csr_matrix = None
try:
import sparse
except ImportError:
sparse = None
tensor_type_pairs = [
('i1', pa.int8()),
('i2', pa.int16()),
('i4', pa.int32()),
('i8', pa.int64()),
('u1', pa.uint8()),
('u2', pa.uint16()),
('u4', pa.uint32()),
('u8', pa.uint64()),
('f2', pa.float16()),
('f4', pa.float32()),
('f8', pa.float64())
]
@pytest.mark.parametrize('sparse_tensor_type', [
pa.SparseCSRMatrix,
pa.SparseCSCMatrix,
pa.SparseCOOTensor,
pa.SparseCSFTensor,
])
def test_sparse_tensor_attrs(sparse_tensor_type):
data = np.array([
[8, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 5],
[3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 6],
])
dim_names = ('x', 'y')
sparse_tensor = sparse_tensor_type.from_dense_numpy(data, dim_names)
assert sparse_tensor.ndim == 2
assert sparse_tensor.size == 24
assert sparse_tensor.shape == data.shape
assert sparse_tensor.is_mutable
assert sparse_tensor.dim_name(0) == dim_names[0]
assert sparse_tensor.dim_names == dim_names
assert sparse_tensor.non_zero_length == 6
wr = weakref.ref(sparse_tensor)
assert wr() is not None
del sparse_tensor
assert wr() is None
def test_sparse_coo_tensor_base_object():
expected_data = np.array([[8, 2, 5, 3, 4, 6]]).T
expected_coords = np.array([
[0, 0, 1, 2, 3, 3],
[0, 2, 5, 0, 4, 5],
]).T
array = np.array([
[8, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 5],
[3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 6],
])
sparse_tensor = pa.SparseCOOTensor.from_dense_numpy(array)
n = sys.getrefcount(sparse_tensor)
result_data, result_coords = sparse_tensor.to_numpy()
assert sparse_tensor.has_canonical_format
assert sys.getrefcount(sparse_tensor) == n + 2
sparse_tensor = None
assert np.array_equal(expected_data, result_data)
assert np.array_equal(expected_coords, result_coords)
assert result_coords.flags.c_contiguous # row-major
def test_sparse_csr_matrix_base_object():
data = np.array([[8, 2, 5, 3, 4, 6]]).T
indptr = np.array([0, 2, 3, 4, 6])
indices = np.array([0, 2, 5, 0, 4, 5])
array = np.array([
[8, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 5],
[3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 6],
])
sparse_tensor = pa.SparseCSRMatrix.from_dense_numpy(array)
n = sys.getrefcount(sparse_tensor)
result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
assert sys.getrefcount(sparse_tensor) == n + 3
sparse_tensor = None
assert np.array_equal(data, result_data)
assert np.array_equal(indptr, result_indptr)
assert np.array_equal(indices, result_indices)
def test_sparse_csf_tensor_base_object():
data = np.array([[8, 2, 5, 3, 4, 6]]).T
indptr = [np.array([0, 2, 3, 4, 6])]
indices = [
np.array([0, 1, 2, 3]),
np.array([0, 2, 5, 0, 4, 5])
]
array = np.array([
[8, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 5],
[3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 6],
])
sparse_tensor = pa.SparseCSFTensor.from_dense_numpy(array)
n = sys.getrefcount(sparse_tensor)
result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
assert sys.getrefcount(sparse_tensor) == n + 4
sparse_tensor = None
assert np.array_equal(data, result_data)
assert np.array_equal(indptr[0], result_indptr[0])
assert np.array_equal(indices[0], result_indices[0])
assert np.array_equal(indices[1], result_indices[1])
@pytest.mark.parametrize('sparse_tensor_type', [
pa.SparseCSRMatrix,
pa.SparseCSCMatrix,
pa.SparseCOOTensor,
pa.SparseCSFTensor,
])
def test_sparse_tensor_equals(sparse_tensor_type):
def eq(a, b):
assert a.equals(b)
assert a == b
assert not (a != b)
def ne(a, b):
assert not a.equals(b)
assert not (a == b)
assert a != b
data = np.random.randn(10, 6)[::, ::2]
sparse_tensor1 = sparse_tensor_type.from_dense_numpy(data)
sparse_tensor2 = sparse_tensor_type.from_dense_numpy(
np.ascontiguousarray(data))
eq(sparse_tensor1, sparse_tensor2)
data = data.copy()
data[9, 0] = 1.0
sparse_tensor2 = sparse_tensor_type.from_dense_numpy(
np.ascontiguousarray(data))
ne(sparse_tensor1, sparse_tensor2)
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_sparse_coo_tensor_from_dense(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
expected_data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
expected_coords = np.array([
[0, 0, 1, 2, 3, 3],
[0, 2, 5, 0, 4, 5],
]).T
array = np.array([
[8, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 5],
[3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 6],
]).astype(dtype)
tensor = pa.Tensor.from_numpy(array)
# Test from numpy array
sparse_tensor = pa.SparseCOOTensor.from_dense_numpy(array)
repr(sparse_tensor)
result_data, result_coords = sparse_tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert np.array_equal(expected_data, result_data)
assert np.array_equal(expected_coords, result_coords)
# Test from Tensor
sparse_tensor = pa.SparseCOOTensor.from_tensor(tensor)
repr(sparse_tensor)
result_data, result_coords = sparse_tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert np.array_equal(expected_data, result_data)
assert np.array_equal(expected_coords, result_coords)
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_sparse_csr_matrix_from_dense(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
indptr = np.array([0, 2, 3, 4, 6])
indices = np.array([0, 2, 5, 0, 4, 5])
array = np.array([
[8, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 5],
[3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 6],
]).astype(dtype)
tensor = pa.Tensor.from_numpy(array)
# Test from numpy array
sparse_tensor = pa.SparseCSRMatrix.from_dense_numpy(array)
repr(sparse_tensor)
result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert np.array_equal(data, result_data)
assert np.array_equal(indptr, result_indptr)
assert np.array_equal(indices, result_indices)
# Test from Tensor
sparse_tensor = pa.SparseCSRMatrix.from_tensor(tensor)
repr(sparse_tensor)
result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert np.array_equal(data, result_data)
assert np.array_equal(indptr, result_indptr)
assert np.array_equal(indices, result_indices)
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_sparse_csf_tensor_from_dense_numpy(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
indptr = [np.array([0, 2, 3, 4, 6])]
indices = [
np.array([0, 1, 2, 3]),
np.array([0, 2, 5, 0, 4, 5])
]
array = np.array([
[8, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 5],
[3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 6],
]).astype(dtype)
# Test from numpy array
sparse_tensor = pa.SparseCSFTensor.from_dense_numpy(array)
repr(sparse_tensor)
result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert np.array_equal(data, result_data)
assert np.array_equal(indptr[0], result_indptr[0])
assert np.array_equal(indices[0], result_indices[0])
assert np.array_equal(indices[1], result_indices[1])
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_sparse_csf_tensor_from_dense_tensor(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
indptr = [np.array([0, 2, 3, 4, 6])]
indices = [
np.array([0, 1, 2, 3]),
np.array([0, 2, 5, 0, 4, 5])
]
array = np.array([
[8, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 5],
[3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 6],
]).astype(dtype)
tensor = pa.Tensor.from_numpy(array)
# Test from Tensor
sparse_tensor = pa.SparseCSFTensor.from_tensor(tensor)
repr(sparse_tensor)
result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert np.array_equal(data, result_data)
assert np.array_equal(indptr[0], result_indptr[0])
assert np.array_equal(indices[0], result_indices[0])
assert np.array_equal(indices[1], result_indices[1])
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_sparse_coo_tensor_numpy_roundtrip(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = np.array([[1, 2, 3, 4, 5, 6]]).T.astype(dtype)
coords = np.array([
[0, 0, 2, 3, 1, 3],
[0, 2, 0, 4, 5, 5],
]).T
shape = (4, 6)
dim_names = ('x', 'y')
sparse_tensor = pa.SparseCOOTensor.from_numpy(data, coords, shape,
dim_names)
repr(sparse_tensor)
result_data, result_coords = sparse_tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert np.array_equal(data, result_data)
assert np.array_equal(coords, result_coords)
assert sparse_tensor.dim_names == dim_names
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_sparse_csr_matrix_numpy_roundtrip(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
indptr = np.array([0, 2, 3, 4, 6])
indices = np.array([0, 2, 5, 0, 4, 5])
shape = (4, 6)
dim_names = ('x', 'y')
sparse_tensor = pa.SparseCSRMatrix.from_numpy(data, indptr, indices,
shape, dim_names)
repr(sparse_tensor)
result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert np.array_equal(data, result_data)
assert np.array_equal(indptr, result_indptr)
assert np.array_equal(indices, result_indices)
assert sparse_tensor.dim_names == dim_names
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_sparse_csf_tensor_numpy_roundtrip(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
indptr = [np.array([0, 2, 3, 4, 6])]
indices = [
np.array([0, 1, 2, 3]),
np.array([0, 2, 5, 0, 4, 5])
]
axis_order = (0, 1)
shape = (4, 6)
dim_names = ('x', 'y')
sparse_tensor = pa.SparseCSFTensor.from_numpy(data, indptr, indices,
shape, axis_order,
dim_names)
repr(sparse_tensor)
result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert np.array_equal(data, result_data)
assert np.array_equal(indptr[0], result_indptr[0])
assert np.array_equal(indices[0], result_indices[0])
assert np.array_equal(indices[1], result_indices[1])
assert sparse_tensor.dim_names == dim_names
@pytest.mark.parametrize('sparse_tensor_type', [
pa.SparseCSRMatrix,
pa.SparseCSCMatrix,
pa.SparseCOOTensor,
pa.SparseCSFTensor,
])
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_dense_to_sparse_tensor(dtype_str, arrow_type, sparse_tensor_type):
dtype = np.dtype(dtype_str)
array = np.array([[4, 0, 9, 0],
[0, 7, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 5]]).astype(dtype)
dim_names = ('x', 'y')
sparse_tensor = sparse_tensor_type.from_dense_numpy(array, dim_names)
tensor = sparse_tensor.to_tensor()
result_array = tensor.to_numpy()
assert sparse_tensor.type == arrow_type
assert tensor.type == arrow_type
assert sparse_tensor.dim_names == dim_names
assert np.array_equal(array, result_array)
@pytest.mark.skipif(not coo_matrix, reason="requires scipy")
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_sparse_coo_tensor_scipy_roundtrip(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = np.array([1, 2, 3, 4, 5, 6]).astype(dtype)
row = np.array([0, 0, 2, 3, 1, 3])
col = np.array([0, 2, 0, 4, 5, 5])
shape = (4, 6)
dim_names = ('x', 'y')
# non-canonical sparse coo matrix
scipy_matrix = coo_matrix((data, (row, col)), shape=shape)
sparse_tensor = pa.SparseCOOTensor.from_scipy(scipy_matrix,
dim_names=dim_names)
out_scipy_matrix = sparse_tensor.to_scipy()
assert not scipy_matrix.has_canonical_format
assert not sparse_tensor.has_canonical_format
assert not out_scipy_matrix.has_canonical_format
assert sparse_tensor.type == arrow_type
assert sparse_tensor.dim_names == dim_names
assert scipy_matrix.dtype == out_scipy_matrix.dtype
assert np.array_equal(scipy_matrix.data, out_scipy_matrix.data)
assert np.array_equal(scipy_matrix.row, out_scipy_matrix.row)
assert np.array_equal(scipy_matrix.col, out_scipy_matrix.col)
if dtype_str == 'f2':
dense_array = \
scipy_matrix.astype(np.float32).toarray().astype(np.float16)
else:
dense_array = scipy_matrix.toarray()
assert np.array_equal(dense_array, sparse_tensor.to_tensor().to_numpy())
# canonical sparse coo matrix
scipy_matrix.sum_duplicates()
sparse_tensor = pa.SparseCOOTensor.from_scipy(scipy_matrix,
dim_names=dim_names)
out_scipy_matrix = sparse_tensor.to_scipy()
assert scipy_matrix.has_canonical_format
assert sparse_tensor.has_canonical_format
assert out_scipy_matrix.has_canonical_format
@pytest.mark.skipif(not csr_matrix, reason="requires scipy")
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_sparse_csr_matrix_scipy_roundtrip(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = np.array([8, 2, 5, 3, 4, 6]).astype(dtype)
indptr = np.array([0, 2, 3, 4, 6])
indices = np.array([0, 2, 5, 0, 4, 5])
shape = (4, 6)
dim_names = ('x', 'y')
sparse_array = csr_matrix((data, indices, indptr), shape=shape)
sparse_tensor = pa.SparseCSRMatrix.from_scipy(sparse_array,
dim_names=dim_names)
out_sparse_array = sparse_tensor.to_scipy()
assert sparse_tensor.type == arrow_type
assert sparse_tensor.dim_names == dim_names
assert sparse_array.dtype == out_sparse_array.dtype
assert np.array_equal(sparse_array.data, out_sparse_array.data)
assert np.array_equal(sparse_array.indptr, out_sparse_array.indptr)
assert np.array_equal(sparse_array.indices, out_sparse_array.indices)
if dtype_str == 'f2':
dense_array = \
sparse_array.astype(np.float32).toarray().astype(np.float16)
else:
dense_array = sparse_array.toarray()
assert np.array_equal(dense_array, sparse_tensor.to_tensor().to_numpy())
@pytest.mark.skipif(not sparse, reason="requires pydata/sparse")
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_pydata_sparse_sparse_coo_tensor_roundtrip(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = np.array([1, 2, 3, 4, 5, 6]).astype(dtype)
coords = np.array([
[0, 0, 2, 3, 1, 3],
[0, 2, 0, 4, 5, 5],
])
shape = (4, 6)
dim_names = ("x", "y")
sparse_array = sparse.COO(data=data, coords=coords, shape=shape)
sparse_tensor = pa.SparseCOOTensor.from_pydata_sparse(sparse_array,
dim_names=dim_names)
out_sparse_array = sparse_tensor.to_pydata_sparse()
assert sparse_tensor.type == arrow_type
assert sparse_tensor.dim_names == dim_names
assert sparse_array.dtype == out_sparse_array.dtype
assert np.array_equal(sparse_array.data, out_sparse_array.data)
assert np.array_equal(sparse_array.coords, out_sparse_array.coords)
assert np.array_equal(sparse_array.todense(),
sparse_tensor.to_tensor().to_numpy())
@@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import hypothesis as h
import pyarrow as pa
import pyarrow.tests.strategies as past
@h.given(past.all_types)
def test_types(ty):
assert isinstance(ty, pa.lib.DataType)
@h.given(past.all_fields)
def test_fields(field):
assert isinstance(field, pa.lib.Field)
@h.given(past.all_schemas)
def test_schemas(schema):
assert isinstance(schema, pa.lib.Schema)
@h.given(past.all_arrays)
def test_arrays(array):
assert isinstance(array, pa.lib.Array)
@h.given(past.arrays(past.primitive_types, nullable=False))
def test_array_nullability(array):
assert array.null_count == 0
@h.given(past.all_chunked_arrays)
def test_chunked_arrays(chunked_array):
assert isinstance(chunked_array, pa.lib.ChunkedArray)
@h.given(past.all_record_batches)
def test_record_batches(record_bath):
assert isinstance(record_bath, pa.lib.RecordBatch)
@h.given(past.all_tables)
def test_tables(table):
assert isinstance(table, pa.lib.Table)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,216 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import sys
import pytest
import weakref
import numpy as np
import pyarrow as pa
tensor_type_pairs = [
('i1', pa.int8()),
('i2', pa.int16()),
('i4', pa.int32()),
('i8', pa.int64()),
('u1', pa.uint8()),
('u2', pa.uint16()),
('u4', pa.uint32()),
('u8', pa.uint64()),
('f2', pa.float16()),
('f4', pa.float32()),
('f8', pa.float64())
]
def test_tensor_attrs():
data = np.random.randn(10, 4)
tensor = pa.Tensor.from_numpy(data)
assert tensor.ndim == 2
assert tensor.dim_names == []
assert tensor.size == 40
assert tensor.shape == data.shape
assert tensor.strides == data.strides
assert tensor.is_contiguous
assert tensor.is_mutable
# not writeable
data2 = data.copy()
data2.flags.writeable = False
tensor = pa.Tensor.from_numpy(data2)
assert not tensor.is_mutable
# With dim_names
tensor = pa.Tensor.from_numpy(data, dim_names=('x', 'y'))
assert tensor.ndim == 2
assert tensor.dim_names == ['x', 'y']
assert tensor.dim_name(0) == 'x'
assert tensor.dim_name(1) == 'y'
wr = weakref.ref(tensor)
assert wr() is not None
del tensor
assert wr() is None
def test_tensor_base_object():
tensor = pa.Tensor.from_numpy(np.random.randn(10, 4))
n = sys.getrefcount(tensor)
array = tensor.to_numpy() # noqa
assert sys.getrefcount(tensor) == n + 1
@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
def test_tensor_numpy_roundtrip(dtype_str, arrow_type):
dtype = np.dtype(dtype_str)
data = (100 * np.random.randn(10, 4)).astype(dtype)
tensor = pa.Tensor.from_numpy(data)
assert tensor.type == arrow_type
repr(tensor)
result = tensor.to_numpy()
assert (data == result).all()
def test_tensor_ipc_roundtrip(tmpdir):
data = np.random.randn(10, 4)
tensor = pa.Tensor.from_numpy(data)
path = os.path.join(str(tmpdir), 'pyarrow-tensor-ipc-roundtrip')
mmap = pa.create_memory_map(path, 1024)
pa.ipc.write_tensor(tensor, mmap)
mmap.seek(0)
result = pa.ipc.read_tensor(mmap)
assert result.equals(tensor)
@pytest.mark.gzip
def test_tensor_ipc_read_from_compressed(tempdir):
# ARROW-5910
data = np.random.randn(10, 4)
tensor = pa.Tensor.from_numpy(data)
path = tempdir / 'tensor-compressed-file'
out_stream = pa.output_stream(path, compression='gzip')
pa.ipc.write_tensor(tensor, out_stream)
out_stream.close()
result = pa.ipc.read_tensor(pa.input_stream(path, compression='gzip'))
assert result.equals(tensor)
def test_tensor_ipc_strided(tmpdir):
data1 = np.random.randn(10, 4)
tensor1 = pa.Tensor.from_numpy(data1[::2])
data2 = np.random.randn(10, 6, 4)
tensor2 = pa.Tensor.from_numpy(data2[::, ::2, ::])
path = os.path.join(str(tmpdir), 'pyarrow-tensor-ipc-strided')
mmap = pa.create_memory_map(path, 2048)
for tensor in [tensor1, tensor2]:
mmap.seek(0)
pa.ipc.write_tensor(tensor, mmap)
mmap.seek(0)
result = pa.ipc.read_tensor(mmap)
assert result.equals(tensor)
def test_tensor_equals():
def eq(a, b):
assert a.equals(b)
assert a == b
assert not (a != b)
def ne(a, b):
assert not a.equals(b)
assert not (a == b)
assert a != b
data = np.random.randn(10, 6, 4)[::, ::2, ::]
tensor1 = pa.Tensor.from_numpy(data)
tensor2 = pa.Tensor.from_numpy(np.ascontiguousarray(data))
eq(tensor1, tensor2)
data = data.copy()
data[9, 0, 0] = 1.0
tensor2 = pa.Tensor.from_numpy(np.ascontiguousarray(data))
ne(tensor1, tensor2)
def test_tensor_hashing():
# Tensors are unhashable
with pytest.raises(TypeError, match="unhashable"):
hash(pa.Tensor.from_numpy(np.arange(10)))
def test_tensor_size():
data = np.random.randn(10, 4)
tensor = pa.Tensor.from_numpy(data)
assert pa.ipc.get_tensor_size(tensor) > (data.size * 8)
def test_read_tensor(tmpdir):
# Create and write tensor tensor
data = np.random.randn(10, 4)
tensor = pa.Tensor.from_numpy(data)
data_size = pa.ipc.get_tensor_size(tensor)
path = os.path.join(str(tmpdir), 'pyarrow-tensor-ipc-read-tensor')
write_mmap = pa.create_memory_map(path, data_size)
pa.ipc.write_tensor(tensor, write_mmap)
# Try to read tensor
read_mmap = pa.memory_map(path, mode='r')
array = pa.ipc.read_tensor(read_mmap).to_numpy()
np.testing.assert_equal(data, array)
def test_tensor_memoryview():
# Tensors support the PEP 3118 buffer protocol
for dtype, expected_format in [(np.int8, '=b'),
(np.int64, '=q'),
(np.uint64, '=Q'),
(np.float16, 'e'),
(np.float64, 'd'),
]:
data = np.arange(10, dtype=dtype)
dtype = data.dtype
lst = data.tolist()
tensor = pa.Tensor.from_numpy(data)
m = memoryview(tensor)
assert m.format == expected_format
assert m.shape == data.shape
assert m.strides == data.strides
assert m.ndim == 1
assert m.nbytes == data.nbytes
assert m.itemsize == data.itemsize
assert m.itemsize * 8 == tensor.type.bit_width
assert np.frombuffer(m, dtype).tolist() == lst
del tensor, data
assert np.frombuffer(m, dtype).tolist() == lst
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import gc
import signal
import sys
import weakref
import pytest
from pyarrow import util
from pyarrow.tests.util import disabled_gc
def exhibit_signal_refcycle():
# Put an object in the frame locals and return a weakref to it.
# If `signal.getsignal` has a bug where it creates a reference cycle
# keeping alive the current execution frames, `obj` will not be
# destroyed immediately when this function returns.
obj = set()
signal.getsignal(signal.SIGINT)
return weakref.ref(obj)
def test_signal_refcycle():
# Test possible workaround for https://bugs.python.org/issue42248
with disabled_gc():
wr = exhibit_signal_refcycle()
if wr() is None:
pytest.skip(
"Python version does not have the bug we're testing for")
gc.collect()
with disabled_gc():
wr = exhibit_signal_refcycle()
assert wr() is not None
util._break_traceback_cycle_from_frame(sys._getframe(0))
assert wr() is None
@@ -0,0 +1,449 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Utility functions for testing
"""
import contextlib
import decimal
import gc
import numpy as np
import os
import random
import re
import shutil
import signal
import socket
import string
import subprocess
import sys
import time
import pytest
import pyarrow as pa
import pyarrow.fs
def randsign():
"""Randomly choose either 1 or -1.
Returns
-------
sign : int
"""
return random.choice((-1, 1))
@contextlib.contextmanager
def random_seed(seed):
"""Set the random seed inside of a context manager.
Parameters
----------
seed : int
The seed to set
Notes
-----
This function is useful when you want to set a random seed but not affect
the random state of other functions using the random module.
"""
original_state = random.getstate()
random.seed(seed)
try:
yield
finally:
random.setstate(original_state)
def randdecimal(precision, scale):
"""Generate a random decimal value with specified precision and scale.
Parameters
----------
precision : int
The maximum number of digits to generate. Must be an integer between 1
and 38 inclusive.
scale : int
The maximum number of digits following the decimal point. Must be an
integer greater than or equal to 0.
Returns
-------
decimal_value : decimal.Decimal
A random decimal.Decimal object with the specified precision and scale.
"""
assert 1 <= precision <= 38, 'precision must be between 1 and 38 inclusive'
if scale < 0:
raise ValueError(
'randdecimal does not yet support generating decimals with '
'negative scale'
)
max_whole_value = 10 ** (precision - scale) - 1
whole = random.randint(-max_whole_value, max_whole_value)
if not scale:
return decimal.Decimal(whole)
max_fractional_value = 10 ** scale - 1
fractional = random.randint(0, max_fractional_value)
return decimal.Decimal(
'{}.{}'.format(whole, str(fractional).rjust(scale, '0'))
)
def random_ascii(length):
return bytes(np.random.randint(65, 123, size=length, dtype='i1'))
def rands(nchars):
"""
Generate one random string.
"""
RANDS_CHARS = np.array(
list(string.ascii_letters + string.digits), dtype=(np.str_, 1))
return "".join(np.random.choice(RANDS_CHARS, nchars))
def make_dataframe():
import pandas as pd
N = 30
df = pd.DataFrame(
{col: np.random.randn(N) for col in string.ascii_uppercase[:4]},
index=pd.Index([rands(10) for _ in range(N)])
)
return df
def memory_leak_check(f, metric='rss', threshold=1 << 17, iterations=10,
check_interval=1):
"""
Execute the function and try to detect a clear memory leak either internal
to Arrow or caused by a reference counting problem in the Python binding
implementation. Raises exception if a leak detected
Parameters
----------
f : callable
Function to invoke on each iteration
metric : {'rss', 'vms', 'shared'}, default 'rss'
Attribute of psutil.Process.memory_info to use for determining current
memory use
threshold : int, default 128K
Threshold in number of bytes to consider a leak
iterations : int, default 10
Total number of invocations of f
check_interval : int, default 1
Number of invocations of f in between each memory use check
"""
import psutil
proc = psutil.Process()
def _get_use():
gc.collect()
return getattr(proc.memory_info(), metric)
baseline_use = _get_use()
def _leak_check():
current_use = _get_use()
if current_use - baseline_use > threshold:
raise Exception("Memory leak detected. "
"Departure from baseline {} after {} iterations"
.format(current_use - baseline_use, i))
for i in range(iterations):
f()
if i % check_interval == 0:
_leak_check()
def get_modified_env_with_pythonpath():
# Prepend pyarrow root directory to PYTHONPATH
env = os.environ.copy()
existing_pythonpath = env.get('PYTHONPATH', '')
module_path = os.path.abspath(
os.path.dirname(os.path.dirname(pa.__file__)))
if existing_pythonpath:
new_pythonpath = os.pathsep.join((module_path, existing_pythonpath))
else:
new_pythonpath = module_path
env['PYTHONPATH'] = new_pythonpath
return env
def invoke_script(script_name, *args):
subprocess_env = get_modified_env_with_pythonpath()
dir_path = os.path.dirname(os.path.realpath(__file__))
python_file = os.path.join(dir_path, script_name)
cmd = [sys.executable, python_file]
cmd.extend(args)
subprocess.check_call(cmd, env=subprocess_env)
@contextlib.contextmanager
def changed_environ(name, value):
"""
Temporarily set environment variable *name* to *value*.
"""
orig_value = os.environ.get(name)
os.environ[name] = value
try:
yield
finally:
if orig_value is None:
del os.environ[name]
else:
os.environ[name] = orig_value
@contextlib.contextmanager
def change_cwd(path):
curdir = os.getcwd()
os.chdir(str(path))
try:
yield
finally:
os.chdir(curdir)
@contextlib.contextmanager
def disabled_gc():
gc.disable()
try:
yield
finally:
gc.enable()
def _filesystem_uri(path):
# URIs on Windows must follow 'file:///C:...' or 'file:/C:...' patterns.
if os.name == 'nt':
uri = 'file:///{}'.format(path)
else:
uri = 'file://{}'.format(path)
return uri
class FSProtocolClass:
def __init__(self, path):
self._path = path
def __fspath__(self):
return str(self._path)
class ProxyHandler(pyarrow.fs.FileSystemHandler):
"""
A dataset handler that proxies to an underlying filesystem. Useful
to partially wrap an existing filesystem with partial changes.
"""
def __init__(self, fs):
self._fs = fs
def __eq__(self, other):
if isinstance(other, ProxyHandler):
return self._fs == other._fs
return NotImplemented
def __ne__(self, other):
if isinstance(other, ProxyHandler):
return self._fs != other._fs
return NotImplemented
def get_type_name(self):
return "proxy::" + self._fs.type_name
def normalize_path(self, path):
return self._fs.normalize_path(path)
def get_file_info(self, paths):
return self._fs.get_file_info(paths)
def get_file_info_selector(self, selector):
return self._fs.get_file_info(selector)
def create_dir(self, path, recursive):
return self._fs.create_dir(path, recursive=recursive)
def delete_dir(self, path):
return self._fs.delete_dir(path)
def delete_dir_contents(self, path, missing_dir_ok):
return self._fs.delete_dir_contents(path,
missing_dir_ok=missing_dir_ok)
def delete_root_dir_contents(self):
return self._fs.delete_dir_contents("", accept_root_dir=True)
def delete_file(self, path):
return self._fs.delete_file(path)
def move(self, src, dest):
return self._fs.move(src, dest)
def copy_file(self, src, dest):
return self._fs.copy_file(src, dest)
def open_input_stream(self, path):
return self._fs.open_input_stream(path)
def open_input_file(self, path):
return self._fs.open_input_file(path)
def open_output_stream(self, path, metadata):
return self._fs.open_output_stream(path, metadata=metadata)
def open_append_stream(self, path, metadata):
return self._fs.open_append_stream(path, metadata=metadata)
def get_raise_signal():
if sys.version_info >= (3, 8):
return signal.raise_signal
elif os.name == 'nt':
# On Windows, os.kill() doesn't actually send a signal,
# it just terminates the process with the given exit code.
pytest.skip("test requires Python 3.8+ on Windows")
else:
# On Unix, emulate raise_signal() with os.kill().
def raise_signal(signum):
os.kill(os.getpid(), signum)
return raise_signal
@contextlib.contextmanager
def signal_wakeup_fd(*, warn_on_full_buffer=False):
# Use a socket pair, rather a self-pipe, so that select() can be used
# on Windows.
r, w = socket.socketpair()
old_fd = None
try:
r.setblocking(False)
w.setblocking(False)
old_fd = signal.set_wakeup_fd(
w.fileno(), warn_on_full_buffer=warn_on_full_buffer)
yield r
finally:
if old_fd is not None:
signal.set_wakeup_fd(old_fd)
r.close()
w.close()
def _ensure_minio_component_version(component, minimum_year):
full_args = [component, '--version']
proc = subprocess.Popen(full_args, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, encoding='utf-8')
if proc.wait(10) != 0:
return False
stdout = proc.stdout.read()
pattern = component + r' version RELEASE\.(\d+)-.*'
version_match = re.search(pattern, stdout)
if version_match:
version_year = version_match.group(1)
return int(version_year) >= minimum_year
else:
raise FileNotFoundError("minio component older than the minimum year")
def _wait_for_minio_startup(mcdir, address, access_key, secret_key):
start = time.time()
while time.time() - start < 10:
try:
_run_mc_command(mcdir, 'alias', 'set', 'myminio',
f'http://{address}', access_key, secret_key)
return
except ChildProcessError:
time.sleep(1)
raise Exception("mc command could not connect to local minio")
def _run_mc_command(mcdir, *args):
full_args = ['mc', '-C', mcdir] + list(args)
proc = subprocess.Popen(full_args, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, encoding='utf-8')
retval = proc.wait(10)
cmd_str = ' '.join(full_args)
print(f'Cmd: {cmd_str}')
print(f' Return: {retval}')
print(f' Stdout: {proc.stdout.read()}')
print(f' Stderr: {proc.stderr.read()}')
if retval != 0:
raise ChildProcessError("Could not run mc")
def _configure_s3_limited_user(s3_server, policy):
"""
Attempts to use the mc command to configure the minio server
with a special user limited:limited123 which does not have
permission to create buckets. This mirrors some real life S3
configurations where users are given strict permissions.
Arrow S3 operations should still work in such a configuration
(e.g. see ARROW-13685)
"""
if sys.platform == 'win32':
# Can't rely on FileNotFound check because
# there is sometimes an mc command on Windows
# which is unrelated to the minio mc
pytest.skip('The mc command is not installed on Windows')
try:
# ensuring version of mc and minio for the capabilities we need
_ensure_minio_component_version('mc', 2021)
_ensure_minio_component_version('minio', 2021)
tempdir = s3_server['tempdir']
host, port, access_key, secret_key = s3_server['connection']
address = '{}:{}'.format(host, port)
mcdir = os.path.join(tempdir, 'mc')
if os.path.exists(mcdir):
shutil.rmtree(mcdir)
os.mkdir(mcdir)
policy_path = os.path.join(tempdir, 'limited-buckets-policy.json')
with open(policy_path, mode='w') as policy_file:
policy_file.write(policy)
# The s3_server fixture starts the minio process but
# it takes a few moments for the process to become available
_wait_for_minio_startup(mcdir, address, access_key, secret_key)
# These commands create a limited user with a specific
# policy and creates a sample bucket for that user to
# write to
_run_mc_command(mcdir, 'admin', 'policy', 'add',
'myminio/', 'no-create-buckets', policy_path)
_run_mc_command(mcdir, 'admin', 'user', 'add',
'myminio/', 'limited', 'limited123')
_run_mc_command(mcdir, 'admin', 'policy', 'set',
'myminio', 'no-create-buckets', 'user=limited')
_run_mc_command(mcdir, 'mb', 'myminio/existing-bucket',
'--ignore-existing')
except FileNotFoundError:
pytest.skip("Configuring limited s3 user failed")