mirror of
https://github.com/aykhans/AzSuicideDataVisualization.git
synced 2025-04-21 18:23:35 +00:00
105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
|
|
def run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
|
|
client, use_gpu, dtype):
|
|
FORCE_DEVICE = '/gpu' if use_gpu else '/cpu'
|
|
|
|
object_id = np.random.bytes(20)
|
|
|
|
data = np.random.randn(3, 244, 244).astype(dtype)
|
|
ones = np.ones((3, 244, 244)).astype(dtype)
|
|
|
|
sess = tf.Session(config=tf.ConfigProto(
|
|
allow_soft_placement=True, log_device_placement=True))
|
|
|
|
def ToPlasma():
|
|
data_tensor = tf.constant(data)
|
|
ones_tensor = tf.constant(ones)
|
|
return plasma.tf_plasma_op.tensor_to_plasma(
|
|
[data_tensor, ones_tensor],
|
|
object_id,
|
|
plasma_store_socket_name=plasma_store_name)
|
|
|
|
def FromPlasma():
|
|
return plasma.tf_plasma_op.plasma_to_tensor(
|
|
object_id,
|
|
dtype=tf.as_dtype(dtype),
|
|
plasma_store_socket_name=plasma_store_name)
|
|
|
|
with tf.device(FORCE_DEVICE):
|
|
to_plasma = ToPlasma()
|
|
from_plasma = FromPlasma()
|
|
|
|
z = from_plasma + 1
|
|
|
|
sess.run(to_plasma)
|
|
# NOTE(zongheng): currently it returns a flat 1D tensor.
|
|
# So reshape manually.
|
|
out = sess.run(from_plasma)
|
|
|
|
out = np.split(out, 2)
|
|
out0 = out[0].reshape(3, 244, 244)
|
|
out1 = out[1].reshape(3, 244, 244)
|
|
|
|
sess.run(z)
|
|
|
|
assert np.array_equal(data, out0), "Data not equal!"
|
|
assert np.array_equal(ones, out1), "Data not equal!"
|
|
|
|
# Try getting the data from Python
|
|
plasma_object_id = plasma.ObjectID(object_id)
|
|
obj = client.get(plasma_object_id)
|
|
|
|
# Deserialized Tensor should be 64-byte aligned.
|
|
assert obj.ctypes.data % 64 == 0
|
|
|
|
result = np.split(obj, 2)
|
|
result0 = result[0].reshape(3, 244, 244)
|
|
result1 = result[1].reshape(3, 244, 244)
|
|
|
|
assert np.array_equal(data, result0), "Data not equal!"
|
|
assert np.array_equal(ones, result1), "Data not equal!"
|
|
|
|
|
|
@pytest.mark.plasma
|
|
@pytest.mark.tensorflow
|
|
@pytest.mark.skip(reason='Until ARROW-4259 is resolved')
|
|
def test_plasma_tf_op(use_gpu=False):
|
|
import pyarrow.plasma as plasma
|
|
import tensorflow as tf
|
|
|
|
plasma.build_plasma_tensorflow_op()
|
|
|
|
if plasma.tf_plasma_op is None:
|
|
pytest.skip("TensorFlow Op not found")
|
|
|
|
with plasma.start_plasma_store(10**8) as (plasma_store_name, p):
|
|
client = plasma.connect(plasma_store_name)
|
|
for dtype in [np.float32, np.float64,
|
|
np.int8, np.int16, np.int32, np.int64]:
|
|
run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
|
|
client, use_gpu, dtype)
|
|
|
|
# Make sure the objects have been released.
|
|
for _, info in client.list().items():
|
|
assert info['ref_count'] == 0
|