mirror of
https://github.com/aykhans/AzSuicideDataVisualization.git
synced 2025-04-22 10:28:02 +00:00
434 lines
15 KiB
Python
434 lines
15 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import json
|
|
import os
|
|
import pyarrow as pa
|
|
import pyarrow.jvm as pa_jvm
|
|
import pytest
|
|
import sys
|
|
import xml.etree.ElementTree as ET
|
|
|
|
|
|
jpype = pytest.importorskip("jpype")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def root_allocator():
|
|
# This test requires Arrow Java to be built in the same source tree
|
|
try:
|
|
arrow_dir = os.environ["ARROW_SOURCE_DIR"]
|
|
except KeyError:
|
|
arrow_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..')
|
|
pom_path = os.path.join(arrow_dir, 'java', 'pom.xml')
|
|
tree = ET.parse(pom_path)
|
|
version = tree.getroot().find(
|
|
'POM:version',
|
|
namespaces={
|
|
'POM': 'http://maven.apache.org/POM/4.0.0'
|
|
}).text
|
|
jar_path = os.path.join(
|
|
arrow_dir, 'java', 'tools', 'target',
|
|
'arrow-tools-{}-jar-with-dependencies.jar'.format(version))
|
|
jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path)
|
|
kwargs = {}
|
|
# This will be the default behaviour in jpype 0.8+
|
|
kwargs['convertStrings'] = False
|
|
jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path,
|
|
**kwargs)
|
|
return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize)
|
|
|
|
|
|
def test_jvm_buffer(root_allocator):
|
|
# Create a Java buffer
|
|
jvm_buffer = root_allocator.buffer(8)
|
|
for i in range(8):
|
|
jvm_buffer.setByte(i, 8 - i)
|
|
|
|
orig_refcnt = jvm_buffer.refCnt()
|
|
|
|
# Convert to Python
|
|
buf = pa_jvm.jvm_buffer(jvm_buffer)
|
|
|
|
# Check its content
|
|
assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01'
|
|
|
|
# Check Java buffer lifetime is tied to PyArrow buffer lifetime
|
|
assert jvm_buffer.refCnt() == orig_refcnt + 1
|
|
del buf
|
|
assert jvm_buffer.refCnt() == orig_refcnt
|
|
|
|
|
|
def test_jvm_buffer_released(root_allocator):
|
|
import jpype.imports # noqa
|
|
from java.lang import IllegalArgumentException
|
|
|
|
jvm_buffer = root_allocator.buffer(8)
|
|
jvm_buffer.release()
|
|
|
|
with pytest.raises(IllegalArgumentException):
|
|
pa_jvm.jvm_buffer(jvm_buffer)
|
|
|
|
|
|
def _jvm_field(jvm_spec):
|
|
om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')()
|
|
pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field')
|
|
return om.readValue(jvm_spec, pojo_Field)
|
|
|
|
|
|
def _jvm_schema(jvm_spec, metadata=None):
|
|
field = _jvm_field(jvm_spec)
|
|
schema_cls = jpype.JClass('org.apache.arrow.vector.types.pojo.Schema')
|
|
fields = jpype.JClass('java.util.ArrayList')()
|
|
fields.add(field)
|
|
if metadata:
|
|
dct = jpype.JClass('java.util.HashMap')()
|
|
for k, v in metadata.items():
|
|
dct.put(k, v)
|
|
return schema_cls(fields, dct)
|
|
else:
|
|
return schema_cls(fields)
|
|
|
|
|
|
# In the following, we use the JSON serialization of the Field objects in Java.
|
|
# This ensures that we neither rely on the exact mechanics on how to construct
|
|
# them using Java code as well as enables us to define them as parameters
|
|
# without to invoke the JVM.
|
|
#
|
|
# The specifications were created using:
|
|
#
|
|
# om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')()
|
|
# field = … # Code to instantiate the field
|
|
# jvm_spec = om.writeValueAsString(field)
|
|
@pytest.mark.parametrize('pa_type,jvm_spec', [
|
|
(pa.null(), '{"name":"null"}'),
|
|
(pa.bool_(), '{"name":"bool"}'),
|
|
(pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'),
|
|
(pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'),
|
|
(pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'),
|
|
(pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'),
|
|
(pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'),
|
|
(pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'),
|
|
(pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'),
|
|
(pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'),
|
|
(pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'),
|
|
(pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'),
|
|
(pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'),
|
|
(pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'),
|
|
(pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'),
|
|
(pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'),
|
|
(pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'),
|
|
(pa.timestamp('s'), '{"name":"timestamp","unit":"SECOND",'
|
|
'"timezone":null}'),
|
|
(pa.timestamp('ms'), '{"name":"timestamp","unit":"MILLISECOND",'
|
|
'"timezone":null}'),
|
|
(pa.timestamp('us'), '{"name":"timestamp","unit":"MICROSECOND",'
|
|
'"timezone":null}'),
|
|
(pa.timestamp('ns'), '{"name":"timestamp","unit":"NANOSECOND",'
|
|
'"timezone":null}'),
|
|
(pa.timestamp('ns', tz='UTC'), '{"name":"timestamp","unit":"NANOSECOND"'
|
|
',"timezone":"UTC"}'),
|
|
(pa.timestamp('ns', tz='Europe/Paris'), '{"name":"timestamp",'
|
|
'"unit":"NANOSECOND","timezone":"Europe/Paris"}'),
|
|
(pa.date32(), '{"name":"date","unit":"DAY"}'),
|
|
(pa.date64(), '{"name":"date","unit":"MILLISECOND"}'),
|
|
(pa.decimal128(19, 4), '{"name":"decimal","precision":19,"scale":4}'),
|
|
(pa.string(), '{"name":"utf8"}'),
|
|
(pa.binary(), '{"name":"binary"}'),
|
|
(pa.binary(10), '{"name":"fixedsizebinary","byteWidth":10}'),
|
|
# TODO(ARROW-2609): complex types that have children
|
|
# pa.list_(pa.int32()),
|
|
# pa.struct([pa.field('a', pa.int32()),
|
|
# pa.field('b', pa.int8()),
|
|
# pa.field('c', pa.string())]),
|
|
# pa.union([pa.field('a', pa.binary(10)),
|
|
# pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE),
|
|
# pa.union([pa.field('a', pa.binary(10)),
|
|
# pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE),
|
|
# TODO: DictionaryType requires a vector in the type
|
|
# pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])),
|
|
])
|
|
@pytest.mark.parametrize('nullable', [True, False])
|
|
def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable):
|
|
if pa_type == pa.null() and not nullable:
|
|
return
|
|
spec = {
|
|
'name': 'field_name',
|
|
'nullable': nullable,
|
|
'type': json.loads(jvm_spec),
|
|
# TODO: This needs to be set for complex types
|
|
'children': []
|
|
}
|
|
jvm_field = _jvm_field(json.dumps(spec))
|
|
result = pa_jvm.field(jvm_field)
|
|
expected_field = pa.field('field_name', pa_type, nullable=nullable)
|
|
assert result == expected_field
|
|
|
|
jvm_schema = _jvm_schema(json.dumps(spec))
|
|
result = pa_jvm.schema(jvm_schema)
|
|
assert result == pa.schema([expected_field])
|
|
|
|
# Schema with custom metadata
|
|
jvm_schema = _jvm_schema(json.dumps(spec), {'meta': 'data'})
|
|
result = pa_jvm.schema(jvm_schema)
|
|
assert result == pa.schema([expected_field], {'meta': 'data'})
|
|
|
|
# Schema with custom field metadata
|
|
spec['metadata'] = [{'key': 'field meta', 'value': 'field data'}]
|
|
jvm_schema = _jvm_schema(json.dumps(spec))
|
|
result = pa_jvm.schema(jvm_schema)
|
|
expected_field = expected_field.with_metadata(
|
|
{'field meta': 'field data'})
|
|
assert result == pa.schema([expected_field])
|
|
|
|
|
|
# These test parameters mostly use an integer range as an input as this is
|
|
# often the only type that is understood by both Python and Java
|
|
# implementations of Arrow.
|
|
@pytest.mark.parametrize('pa_type,py_data,jvm_type', [
|
|
(pa.bool_(), [True, False, True, True], 'BitVector'),
|
|
(pa.uint8(), list(range(128)), 'UInt1Vector'),
|
|
(pa.uint16(), list(range(128)), 'UInt2Vector'),
|
|
(pa.int32(), list(range(128)), 'IntVector'),
|
|
(pa.int64(), list(range(128)), 'BigIntVector'),
|
|
(pa.float32(), list(range(128)), 'Float4Vector'),
|
|
(pa.float64(), list(range(128)), 'Float8Vector'),
|
|
(pa.timestamp('s'), list(range(128)), 'TimeStampSecVector'),
|
|
(pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector'),
|
|
(pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector'),
|
|
(pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector'),
|
|
# TODO(ARROW-2605): These types miss a conversion from pure Python objects
|
|
# * pa.time32('s')
|
|
# * pa.time32('ms')
|
|
# * pa.time64('us')
|
|
# * pa.time64('ns')
|
|
(pa.date32(), list(range(128)), 'DateDayVector'),
|
|
(pa.date64(), list(range(128)), 'DateMilliVector'),
|
|
# TODO(ARROW-2606): pa.decimal128(19, 4)
|
|
])
|
|
def test_jvm_array(root_allocator, pa_type, py_data, jvm_type):
|
|
# Create vector
|
|
cls = "org.apache.arrow.vector.{}".format(jvm_type)
|
|
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
|
|
jvm_vector.allocateNew(len(py_data))
|
|
for i, val in enumerate(py_data):
|
|
# char and int are ambiguous overloads for these two setSafe calls
|
|
if jvm_type in {'UInt1Vector', 'UInt2Vector'}:
|
|
val = jpype.JInt(val)
|
|
jvm_vector.setSafe(i, val)
|
|
jvm_vector.setValueCount(len(py_data))
|
|
|
|
py_array = pa.array(py_data, type=pa_type)
|
|
jvm_array = pa_jvm.array(jvm_vector)
|
|
|
|
assert py_array.equals(jvm_array)
|
|
|
|
|
|
def test_jvm_array_empty(root_allocator):
|
|
cls = "org.apache.arrow.vector.{}".format('IntVector')
|
|
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
|
|
jvm_vector.allocateNew()
|
|
jvm_array = pa_jvm.array(jvm_vector)
|
|
assert len(jvm_array) == 0
|
|
assert jvm_array.type == pa.int32()
|
|
|
|
|
|
# These test parameters mostly use an integer range as an input as this is
|
|
# often the only type that is understood by both Python and Java
|
|
# implementations of Arrow.
|
|
@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [
|
|
# TODO: null
|
|
(pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'),
|
|
(
|
|
pa.uint8(),
|
|
list(range(128)),
|
|
'UInt1Vector',
|
|
'{"name":"int","bitWidth":8,"isSigned":false}'
|
|
),
|
|
(
|
|
pa.uint16(),
|
|
list(range(128)),
|
|
'UInt2Vector',
|
|
'{"name":"int","bitWidth":16,"isSigned":false}'
|
|
),
|
|
(
|
|
pa.uint32(),
|
|
list(range(128)),
|
|
'UInt4Vector',
|
|
'{"name":"int","bitWidth":32,"isSigned":false}'
|
|
),
|
|
(
|
|
pa.uint64(),
|
|
list(range(128)),
|
|
'UInt8Vector',
|
|
'{"name":"int","bitWidth":64,"isSigned":false}'
|
|
),
|
|
(
|
|
pa.int8(),
|
|
list(range(128)),
|
|
'TinyIntVector',
|
|
'{"name":"int","bitWidth":8,"isSigned":true}'
|
|
),
|
|
(
|
|
pa.int16(),
|
|
list(range(128)),
|
|
'SmallIntVector',
|
|
'{"name":"int","bitWidth":16,"isSigned":true}'
|
|
),
|
|
(
|
|
pa.int32(),
|
|
list(range(128)),
|
|
'IntVector',
|
|
'{"name":"int","bitWidth":32,"isSigned":true}'
|
|
),
|
|
(
|
|
pa.int64(),
|
|
list(range(128)),
|
|
'BigIntVector',
|
|
'{"name":"int","bitWidth":64,"isSigned":true}'
|
|
),
|
|
# TODO: float16
|
|
(
|
|
pa.float32(),
|
|
list(range(128)),
|
|
'Float4Vector',
|
|
'{"name":"floatingpoint","precision":"SINGLE"}'
|
|
),
|
|
(
|
|
pa.float64(),
|
|
list(range(128)),
|
|
'Float8Vector',
|
|
'{"name":"floatingpoint","precision":"DOUBLE"}'
|
|
),
|
|
(
|
|
pa.timestamp('s'),
|
|
list(range(128)),
|
|
'TimeStampSecVector',
|
|
'{"name":"timestamp","unit":"SECOND","timezone":null}'
|
|
),
|
|
(
|
|
pa.timestamp('ms'),
|
|
list(range(128)),
|
|
'TimeStampMilliVector',
|
|
'{"name":"timestamp","unit":"MILLISECOND","timezone":null}'
|
|
),
|
|
(
|
|
pa.timestamp('us'),
|
|
list(range(128)),
|
|
'TimeStampMicroVector',
|
|
'{"name":"timestamp","unit":"MICROSECOND","timezone":null}'
|
|
),
|
|
(
|
|
pa.timestamp('ns'),
|
|
list(range(128)),
|
|
'TimeStampNanoVector',
|
|
'{"name":"timestamp","unit":"NANOSECOND","timezone":null}'
|
|
),
|
|
# TODO(ARROW-2605): These types miss a conversion from pure Python objects
|
|
# * pa.time32('s')
|
|
# * pa.time32('ms')
|
|
# * pa.time64('us')
|
|
# * pa.time64('ns')
|
|
(
|
|
pa.date32(),
|
|
list(range(128)),
|
|
'DateDayVector',
|
|
'{"name":"date","unit":"DAY"}'
|
|
),
|
|
(
|
|
pa.date64(),
|
|
list(range(128)),
|
|
'DateMilliVector',
|
|
'{"name":"date","unit":"MILLISECOND"}'
|
|
),
|
|
# TODO(ARROW-2606): pa.decimal128(19, 4)
|
|
])
|
|
def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type,
|
|
jvm_spec):
|
|
# Create vector
|
|
cls = "org.apache.arrow.vector.{}".format(jvm_type)
|
|
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
|
|
jvm_vector.allocateNew(len(py_data))
|
|
for i, val in enumerate(py_data):
|
|
if jvm_type in {'UInt1Vector', 'UInt2Vector'}:
|
|
val = jpype.JInt(val)
|
|
jvm_vector.setSafe(i, val)
|
|
jvm_vector.setValueCount(len(py_data))
|
|
|
|
# Create field
|
|
spec = {
|
|
'name': 'field_name',
|
|
'nullable': False,
|
|
'type': json.loads(jvm_spec),
|
|
# TODO: This needs to be set for complex types
|
|
'children': []
|
|
}
|
|
jvm_field = _jvm_field(json.dumps(spec))
|
|
|
|
# Create VectorSchemaRoot
|
|
jvm_fields = jpype.JClass('java.util.ArrayList')()
|
|
jvm_fields.add(jvm_field)
|
|
jvm_vectors = jpype.JClass('java.util.ArrayList')()
|
|
jvm_vectors.add(jvm_vector)
|
|
jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot')
|
|
jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data))
|
|
|
|
py_record_batch = pa.RecordBatch.from_arrays(
|
|
[pa.array(py_data, type=pa_type)],
|
|
['col']
|
|
)
|
|
jvm_record_batch = pa_jvm.record_batch(jvm_vsr)
|
|
|
|
assert py_record_batch.equals(jvm_record_batch)
|
|
|
|
|
|
def _string_to_varchar_holder(ra, string):
|
|
nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder"
|
|
holder = jpype.JClass(nvch_cls)()
|
|
if string is None:
|
|
holder.isSet = 0
|
|
else:
|
|
holder.isSet = 1
|
|
value = jpype.JClass("java.lang.String")("string")
|
|
std_charsets = jpype.JClass("java.nio.charset.StandardCharsets")
|
|
bytes_ = value.getBytes(std_charsets.UTF_8)
|
|
holder.buffer = ra.buffer(len(bytes_))
|
|
holder.buffer.setBytes(0, bytes_, 0, len(bytes_))
|
|
holder.start = 0
|
|
holder.end = len(bytes_)
|
|
return holder
|
|
|
|
|
|
# TODO(ARROW-2607)
|
|
@pytest.mark.xfail(reason="from_buffers is only supported for "
|
|
"primitive arrays yet")
|
|
def test_jvm_string_array(root_allocator):
|
|
data = ["string", None, "töst"]
|
|
cls = "org.apache.arrow.vector.VarCharVector"
|
|
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
|
|
jvm_vector.allocateNew()
|
|
|
|
for i, string in enumerate(data):
|
|
holder = _string_to_varchar_holder(root_allocator, "string")
|
|
jvm_vector.setSafe(i, holder)
|
|
jvm_vector.setValueCount(i + 1)
|
|
|
|
py_array = pa.array(data, type=pa.string())
|
|
jvm_array = pa_jvm.array(jvm_vector)
|
|
|
|
assert py_array.equals(jvm_array)
|