mirror of
https://github.com/aykhans/AzSuicideDataVisualization.git
synced 2025-04-22 02:23:48 +00:00
505 lines
18 KiB
Python
505 lines
18 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import collections
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
import pyarrow as pa
|
|
from pyarrow.lib import SerializationContext, py_buffer, builtin_pickle
|
|
|
|
try:
|
|
import cloudpickle
|
|
except ImportError:
|
|
cloudpickle = builtin_pickle
|
|
|
|
|
|
try:
|
|
# This function is available after numpy-0.16.0.
|
|
# See also: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py
|
|
from numpy.lib.format import descr_to_dtype
|
|
except ImportError:
|
|
def descr_to_dtype(descr):
|
|
'''
|
|
descr may be stored as dtype.descr, which is a list of (name, format,
|
|
[shape]) tuples where format may be a str or a tuple. Offsets are not
|
|
explicitly saved, rather empty fields with name, format == '', '|Vn'
|
|
are added as padding. This function reverses the process, eliminating
|
|
the empty padding fields.
|
|
'''
|
|
if isinstance(descr, str):
|
|
# No padding removal needed
|
|
return np.dtype(descr)
|
|
elif isinstance(descr, tuple):
|
|
# subtype, will always have a shape descr[1]
|
|
dt = descr_to_dtype(descr[0])
|
|
return np.dtype((dt, descr[1]))
|
|
fields = []
|
|
offset = 0
|
|
for field in descr:
|
|
if len(field) == 2:
|
|
name, descr_str = field
|
|
dt = descr_to_dtype(descr_str)
|
|
else:
|
|
name, descr_str, shape = field
|
|
dt = np.dtype((descr_to_dtype(descr_str), shape))
|
|
|
|
# Ignore padding bytes, which will be void bytes with '' as name
|
|
# Once support for blank names is removed, only "if name == ''"
|
|
# needed)
|
|
is_pad = (name == '' and dt.type is np.void and dt.names is None)
|
|
if not is_pad:
|
|
fields.append((name, dt, offset))
|
|
|
|
offset += dt.itemsize
|
|
|
|
names, formats, offsets = zip(*fields)
|
|
# names may be (title, names) tuples
|
|
nametups = (n if isinstance(n, tuple) else (None, n) for n in names)
|
|
titles, names = zip(*nametups)
|
|
return np.dtype({'names': names, 'formats': formats, 'titles': titles,
|
|
'offsets': offsets, 'itemsize': offset})
|
|
|
|
|
|
def _deprecate_serialization(name):
|
|
msg = (
|
|
"'pyarrow.{}' is deprecated as of 2.0.0 and will be removed in a "
|
|
"future version. Use pickle or the pyarrow IPC functionality instead."
|
|
).format(name)
|
|
warnings.warn(msg, FutureWarning, stacklevel=3)
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Set up serialization for numpy with dtype object (primitive types are
|
|
# handled efficiently with Arrow's Tensor facilities, see
|
|
# python_to_arrow.cc)
|
|
|
|
def _serialize_numpy_array_list(obj):
|
|
if obj.dtype.str != '|O':
|
|
# Make the array c_contiguous if necessary so that we can call change
|
|
# the view.
|
|
if not obj.flags.c_contiguous:
|
|
obj = np.ascontiguousarray(obj)
|
|
return obj.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype)
|
|
else:
|
|
return obj.tolist(), np.lib.format.dtype_to_descr(obj.dtype)
|
|
|
|
|
|
def _deserialize_numpy_array_list(data):
|
|
if data[1] != '|O':
|
|
assert data[0].dtype == np.uint8
|
|
return data[0].view(descr_to_dtype(data[1]))
|
|
else:
|
|
return np.array(data[0], dtype=np.dtype(data[1]))
|
|
|
|
|
|
def _serialize_numpy_matrix(obj):
|
|
if obj.dtype.str != '|O':
|
|
# Make the array c_contiguous if necessary so that we can call change
|
|
# the view.
|
|
if not obj.flags.c_contiguous:
|
|
obj = np.ascontiguousarray(obj.A)
|
|
return obj.A.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype)
|
|
else:
|
|
return obj.A.tolist(), np.lib.format.dtype_to_descr(obj.dtype)
|
|
|
|
|
|
def _deserialize_numpy_matrix(data):
|
|
if data[1] != '|O':
|
|
assert data[0].dtype == np.uint8
|
|
return np.matrix(data[0].view(descr_to_dtype(data[1])),
|
|
copy=False)
|
|
else:
|
|
return np.matrix(data[0], dtype=np.dtype(data[1]), copy=False)
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# pyarrow.RecordBatch-specific serialization matters
|
|
|
|
def _serialize_pyarrow_recordbatch(batch):
|
|
output_stream = pa.BufferOutputStream()
|
|
with pa.RecordBatchStreamWriter(output_stream, schema=batch.schema) as wr:
|
|
wr.write_batch(batch)
|
|
return output_stream.getvalue() # This will also close the stream.
|
|
|
|
|
|
def _deserialize_pyarrow_recordbatch(buf):
|
|
with pa.RecordBatchStreamReader(buf) as reader:
|
|
return reader.read_next_batch()
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# pyarrow.Array-specific serialization matters
|
|
|
|
def _serialize_pyarrow_array(array):
|
|
# TODO(suquark): implement more effcient array serialization.
|
|
batch = pa.RecordBatch.from_arrays([array], [''])
|
|
return _serialize_pyarrow_recordbatch(batch)
|
|
|
|
|
|
def _deserialize_pyarrow_array(buf):
|
|
# TODO(suquark): implement more effcient array deserialization.
|
|
batch = _deserialize_pyarrow_recordbatch(buf)
|
|
return batch.columns[0]
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# pyarrow.Table-specific serialization matters
|
|
|
|
def _serialize_pyarrow_table(table):
|
|
output_stream = pa.BufferOutputStream()
|
|
with pa.RecordBatchStreamWriter(output_stream, schema=table.schema) as wr:
|
|
wr.write_table(table)
|
|
return output_stream.getvalue() # This will also close the stream.
|
|
|
|
|
|
def _deserialize_pyarrow_table(buf):
|
|
with pa.RecordBatchStreamReader(buf) as reader:
|
|
return reader.read_all()
|
|
|
|
|
|
def _pickle_to_buffer(x):
|
|
pickled = builtin_pickle.dumps(x, protocol=builtin_pickle.HIGHEST_PROTOCOL)
|
|
return py_buffer(pickled)
|
|
|
|
|
|
def _load_pickle_from_buffer(data):
|
|
as_memoryview = memoryview(data)
|
|
return builtin_pickle.loads(as_memoryview)
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# pandas-specific serialization matters
|
|
|
|
def _register_custom_pandas_handlers(context):
|
|
# ARROW-1784, faster path for pandas-only visibility
|
|
|
|
try:
|
|
import pandas as pd
|
|
except ImportError:
|
|
return
|
|
|
|
import pyarrow.pandas_compat as pdcompat
|
|
|
|
sparse_type_error_msg = (
|
|
'{0} serialization is not supported.\n'
|
|
'Note that {0} is planned to be deprecated '
|
|
'in pandas future releases.\n'
|
|
'See https://github.com/pandas-dev/pandas/issues/19239 '
|
|
'for more information.'
|
|
)
|
|
|
|
def _serialize_pandas_dataframe(obj):
|
|
if (pdcompat._pandas_api.has_sparse and
|
|
isinstance(obj, pd.SparseDataFrame)):
|
|
raise NotImplementedError(
|
|
sparse_type_error_msg.format('SparseDataFrame')
|
|
)
|
|
|
|
return pdcompat.dataframe_to_serialized_dict(obj)
|
|
|
|
def _deserialize_pandas_dataframe(data):
|
|
return pdcompat.serialized_dict_to_dataframe(data)
|
|
|
|
def _serialize_pandas_series(obj):
|
|
if (pdcompat._pandas_api.has_sparse and
|
|
isinstance(obj, pd.SparseSeries)):
|
|
raise NotImplementedError(
|
|
sparse_type_error_msg.format('SparseSeries')
|
|
)
|
|
|
|
return _serialize_pandas_dataframe(pd.DataFrame({obj.name: obj}))
|
|
|
|
def _deserialize_pandas_series(data):
|
|
deserialized = _deserialize_pandas_dataframe(data)
|
|
return deserialized[deserialized.columns[0]]
|
|
|
|
context.register_type(
|
|
pd.Series, 'pd.Series',
|
|
custom_serializer=_serialize_pandas_series,
|
|
custom_deserializer=_deserialize_pandas_series)
|
|
|
|
context.register_type(
|
|
pd.Index, 'pd.Index',
|
|
custom_serializer=_pickle_to_buffer,
|
|
custom_deserializer=_load_pickle_from_buffer)
|
|
|
|
if hasattr(pd.core, 'arrays'):
|
|
if hasattr(pd.core.arrays, 'interval'):
|
|
context.register_type(
|
|
pd.core.arrays.interval.IntervalArray,
|
|
'pd.core.arrays.interval.IntervalArray',
|
|
custom_serializer=_pickle_to_buffer,
|
|
custom_deserializer=_load_pickle_from_buffer)
|
|
|
|
if hasattr(pd.core.arrays, 'period'):
|
|
context.register_type(
|
|
pd.core.arrays.period.PeriodArray,
|
|
'pd.core.arrays.period.PeriodArray',
|
|
custom_serializer=_pickle_to_buffer,
|
|
custom_deserializer=_load_pickle_from_buffer)
|
|
|
|
if hasattr(pd.core.arrays, 'datetimes'):
|
|
context.register_type(
|
|
pd.core.arrays.datetimes.DatetimeArray,
|
|
'pd.core.arrays.datetimes.DatetimeArray',
|
|
custom_serializer=_pickle_to_buffer,
|
|
custom_deserializer=_load_pickle_from_buffer)
|
|
|
|
context.register_type(
|
|
pd.DataFrame, 'pd.DataFrame',
|
|
custom_serializer=_serialize_pandas_dataframe,
|
|
custom_deserializer=_deserialize_pandas_dataframe)
|
|
|
|
|
|
def register_torch_serialization_handlers(serialization_context):
|
|
# ----------------------------------------------------------------------
|
|
# Set up serialization for pytorch tensors
|
|
_deprecate_serialization("register_torch_serialization_handlers")
|
|
|
|
try:
|
|
import torch
|
|
|
|
def _serialize_torch_tensor(obj):
|
|
if obj.is_sparse:
|
|
return pa.SparseCOOTensor.from_numpy(
|
|
obj._values().detach().numpy(),
|
|
obj._indices().detach().numpy().T,
|
|
shape=list(obj.shape))
|
|
else:
|
|
return obj.detach().numpy()
|
|
|
|
def _deserialize_torch_tensor(data):
|
|
if isinstance(data, pa.SparseCOOTensor):
|
|
return torch.sparse_coo_tensor(
|
|
indices=data.to_numpy()[1].T,
|
|
values=data.to_numpy()[0][:, 0],
|
|
size=data.shape)
|
|
else:
|
|
return torch.from_numpy(data)
|
|
|
|
for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
|
|
torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
|
|
torch.IntTensor, torch.LongTensor, torch.Tensor]:
|
|
serialization_context.register_type(
|
|
t, "torch." + t.__name__,
|
|
custom_serializer=_serialize_torch_tensor,
|
|
custom_deserializer=_deserialize_torch_tensor)
|
|
except ImportError:
|
|
# no torch
|
|
pass
|
|
|
|
|
|
def _register_collections_serialization_handlers(serialization_context):
|
|
def _serialize_deque(obj):
|
|
return list(obj)
|
|
|
|
def _deserialize_deque(data):
|
|
return collections.deque(data)
|
|
|
|
serialization_context.register_type(
|
|
collections.deque, "collections.deque",
|
|
custom_serializer=_serialize_deque,
|
|
custom_deserializer=_deserialize_deque)
|
|
|
|
def _serialize_ordered_dict(obj):
|
|
return list(obj.keys()), list(obj.values())
|
|
|
|
def _deserialize_ordered_dict(data):
|
|
return collections.OrderedDict(zip(data[0], data[1]))
|
|
|
|
serialization_context.register_type(
|
|
collections.OrderedDict, "collections.OrderedDict",
|
|
custom_serializer=_serialize_ordered_dict,
|
|
custom_deserializer=_deserialize_ordered_dict)
|
|
|
|
def _serialize_default_dict(obj):
|
|
return list(obj.keys()), list(obj.values()), obj.default_factory
|
|
|
|
def _deserialize_default_dict(data):
|
|
return collections.defaultdict(data[2], zip(data[0], data[1]))
|
|
|
|
serialization_context.register_type(
|
|
collections.defaultdict, "collections.defaultdict",
|
|
custom_serializer=_serialize_default_dict,
|
|
custom_deserializer=_deserialize_default_dict)
|
|
|
|
def _serialize_counter(obj):
|
|
return list(obj.keys()), list(obj.values())
|
|
|
|
def _deserialize_counter(data):
|
|
return collections.Counter(dict(zip(data[0], data[1])))
|
|
|
|
serialization_context.register_type(
|
|
collections.Counter, "collections.Counter",
|
|
custom_serializer=_serialize_counter,
|
|
custom_deserializer=_deserialize_counter)
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Set up serialization for scipy sparse matrices. Primitive types are handled
|
|
# efficiently with Arrow's SparseTensor facilities, see numpy_convert.cc)
|
|
|
|
def _register_scipy_handlers(serialization_context):
|
|
try:
|
|
from scipy.sparse import (csr_matrix, csc_matrix, coo_matrix,
|
|
isspmatrix_coo, isspmatrix_csr,
|
|
isspmatrix_csc, isspmatrix)
|
|
|
|
def _serialize_scipy_sparse(obj):
|
|
if isspmatrix_coo(obj):
|
|
return 'coo', pa.SparseCOOTensor.from_scipy(obj)
|
|
|
|
elif isspmatrix_csr(obj):
|
|
return 'csr', pa.SparseCSRMatrix.from_scipy(obj)
|
|
|
|
elif isspmatrix_csc(obj):
|
|
return 'csc', pa.SparseCSCMatrix.from_scipy(obj)
|
|
|
|
elif isspmatrix(obj):
|
|
return 'csr', pa.SparseCOOTensor.from_scipy(obj.to_coo())
|
|
|
|
else:
|
|
raise NotImplementedError(
|
|
"Serialization of {} is not supported.".format(obj[0]))
|
|
|
|
def _deserialize_scipy_sparse(data):
|
|
if data[0] == 'coo':
|
|
return data[1].to_scipy()
|
|
|
|
elif data[0] == 'csr':
|
|
return data[1].to_scipy()
|
|
|
|
elif data[0] == 'csc':
|
|
return data[1].to_scipy()
|
|
|
|
else:
|
|
return data[1].to_scipy()
|
|
|
|
serialization_context.register_type(
|
|
coo_matrix, 'scipy.sparse.coo.coo_matrix',
|
|
custom_serializer=_serialize_scipy_sparse,
|
|
custom_deserializer=_deserialize_scipy_sparse)
|
|
|
|
serialization_context.register_type(
|
|
csr_matrix, 'scipy.sparse.csr.csr_matrix',
|
|
custom_serializer=_serialize_scipy_sparse,
|
|
custom_deserializer=_deserialize_scipy_sparse)
|
|
|
|
serialization_context.register_type(
|
|
csc_matrix, 'scipy.sparse.csc.csc_matrix',
|
|
custom_serializer=_serialize_scipy_sparse,
|
|
custom_deserializer=_deserialize_scipy_sparse)
|
|
|
|
except ImportError:
|
|
# no scipy
|
|
pass
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Set up serialization for pydata/sparse tensors.
|
|
|
|
def _register_pydata_sparse_handlers(serialization_context):
|
|
try:
|
|
import sparse
|
|
|
|
def _serialize_pydata_sparse(obj):
|
|
if isinstance(obj, sparse.COO):
|
|
return 'coo', pa.SparseCOOTensor.from_pydata_sparse(obj)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Serialization of {} is not supported.".format(sparse.COO))
|
|
|
|
def _deserialize_pydata_sparse(data):
|
|
if data[0] == 'coo':
|
|
data_array, coords = data[1].to_numpy()
|
|
return sparse.COO(
|
|
data=data_array[:, 0],
|
|
coords=coords.T, shape=data[1].shape)
|
|
|
|
serialization_context.register_type(
|
|
sparse.COO, 'sparse.COO',
|
|
custom_serializer=_serialize_pydata_sparse,
|
|
custom_deserializer=_deserialize_pydata_sparse)
|
|
|
|
except ImportError:
|
|
# no pydata/sparse
|
|
pass
|
|
|
|
|
|
def _register_default_serialization_handlers(serialization_context):
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Set up serialization for primitive datatypes
|
|
|
|
# TODO(pcm): This is currently a workaround until arrow supports
|
|
# arbitrary precision integers. This is only called on long integers,
|
|
# see the associated case in the append method in python_to_arrow.cc
|
|
serialization_context.register_type(
|
|
int, "int",
|
|
custom_serializer=lambda obj: str(obj),
|
|
custom_deserializer=lambda data: int(data))
|
|
|
|
serialization_context.register_type(
|
|
type(lambda: 0), "function",
|
|
pickle=True)
|
|
|
|
serialization_context.register_type(type, "type", pickle=True)
|
|
|
|
serialization_context.register_type(
|
|
np.matrix, 'np.matrix',
|
|
custom_serializer=_serialize_numpy_matrix,
|
|
custom_deserializer=_deserialize_numpy_matrix)
|
|
|
|
serialization_context.register_type(
|
|
np.ndarray, 'np.array',
|
|
custom_serializer=_serialize_numpy_array_list,
|
|
custom_deserializer=_deserialize_numpy_array_list)
|
|
|
|
serialization_context.register_type(
|
|
pa.Array, 'pyarrow.Array',
|
|
custom_serializer=_serialize_pyarrow_array,
|
|
custom_deserializer=_deserialize_pyarrow_array)
|
|
|
|
serialization_context.register_type(
|
|
pa.RecordBatch, 'pyarrow.RecordBatch',
|
|
custom_serializer=_serialize_pyarrow_recordbatch,
|
|
custom_deserializer=_deserialize_pyarrow_recordbatch)
|
|
|
|
serialization_context.register_type(
|
|
pa.Table, 'pyarrow.Table',
|
|
custom_serializer=_serialize_pyarrow_table,
|
|
custom_deserializer=_deserialize_pyarrow_table)
|
|
|
|
_register_collections_serialization_handlers(serialization_context)
|
|
_register_custom_pandas_handlers(serialization_context)
|
|
_register_scipy_handlers(serialization_context)
|
|
_register_pydata_sparse_handlers(serialization_context)
|
|
|
|
|
|
def register_default_serialization_handlers(serialization_context):
|
|
_deprecate_serialization("register_default_serialization_handlers")
|
|
_register_default_serialization_handlers(serialization_context)
|
|
|
|
|
|
def default_serialization_context():
|
|
_deprecate_serialization("default_serialization_context")
|
|
context = SerializationContext()
|
|
_register_default_serialization_handlers(context)
|
|
return context
|