mirror of
				https://github.com/aykhans/AzSuicideDataVisualization.git
				synced 2025-11-04 11:59:58 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			375 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			375 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2018-2022 Streamlit Inc.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
"""Image marshalling."""
 | 
						|
 | 
						|
import imghdr
 | 
						|
import io
 | 
						|
import mimetypes
 | 
						|
from typing import cast
 | 
						|
from urllib.parse import urlparse
 | 
						|
import re
 | 
						|
 | 
						|
import numpy as np
 | 
						|
from PIL import Image, ImageFile
 | 
						|
 | 
						|
import streamlit
 | 
						|
from streamlit.errors import StreamlitAPIException
 | 
						|
from streamlit.logger import get_logger
 | 
						|
from streamlit.in_memory_file_manager import in_memory_file_manager
 | 
						|
from streamlit.proto.Image_pb2 import ImageList as ImageListProto
 | 
						|
 | 
						|
LOGGER = get_logger(__name__)
 | 
						|
 | 
						|
# This constant is related to the frontend maximum content width specified
 | 
						|
# in App.jsx main container
 | 
						|
# 730 is the max width of element-container in the frontend, and 2x is for high
 | 
						|
# DPI.
 | 
						|
MAXIMUM_CONTENT_WIDTH = 2 * 730
 | 
						|
 | 
						|
 | 
						|
class ImageMixin:
 | 
						|
    def image(
 | 
						|
        self,
 | 
						|
        image,
 | 
						|
        caption=None,
 | 
						|
        width=None,
 | 
						|
        use_column_width=None,
 | 
						|
        clamp=False,
 | 
						|
        channels="RGB",
 | 
						|
        output_format="auto",
 | 
						|
    ):
 | 
						|
        """Display an image or list of images.
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        image : numpy.ndarray, [numpy.ndarray], BytesIO, str, or [str]
 | 
						|
            Monochrome image of shape (w,h) or (w,h,1)
 | 
						|
            OR a color image of shape (w,h,3)
 | 
						|
            OR an RGBA image of shape (w,h,4)
 | 
						|
            OR a URL to fetch the image from
 | 
						|
            OR a path of a local image file
 | 
						|
            OR an SVG XML string like `<svg xmlns=...</svg>`
 | 
						|
            OR a list of one of the above, to display multiple images.
 | 
						|
        caption : str or list of str
 | 
						|
            Image caption. If displaying multiple images, caption should be a
 | 
						|
            list of captions (one for each image).
 | 
						|
        width : int or None
 | 
						|
            Image width. None means use the image width,
 | 
						|
            but do not exceed the width of the column.
 | 
						|
            Should be set for SVG images, as they have no default image width.
 | 
						|
        use_column_width : 'auto' or 'always' or 'never' or bool
 | 
						|
            If 'auto', set the image's width to its natural size,
 | 
						|
            but do not exceed the width of the column.
 | 
						|
            If 'always' or True, set the image's width to the column width.
 | 
						|
            If 'never' or False, set the image's width to its natural size.
 | 
						|
            Note: if set, `use_column_width` takes precedence over the `width` parameter.
 | 
						|
        clamp : bool
 | 
						|
            Clamp image pixel values to a valid range ([0-255] per channel).
 | 
						|
            This is only meaningful for byte array images; the parameter is
 | 
						|
            ignored for image URLs. If this is not set, and an image has an
 | 
						|
            out-of-range value, an error will be thrown.
 | 
						|
        channels : 'RGB' or 'BGR'
 | 
						|
            If image is an nd.array, this parameter denotes the format used to
 | 
						|
            represent color information. Defaults to 'RGB', meaning
 | 
						|
            `image[:, :, 0]` is the red channel, `image[:, :, 1]` is green, and
 | 
						|
            `image[:, :, 2]` is blue. For images coming from libraries like
 | 
						|
            OpenCV you should set this to 'BGR', instead.
 | 
						|
        output_format : 'JPEG', 'PNG', or 'auto'
 | 
						|
            This parameter specifies the format to use when transferring the
 | 
						|
            image data. Photos should use the JPEG format for lossy compression
 | 
						|
            while diagrams should use the PNG format for lossless compression.
 | 
						|
            Defaults to 'auto' which identifies the compression type based
 | 
						|
            on the type and format of the image argument.
 | 
						|
 | 
						|
        Example
 | 
						|
        -------
 | 
						|
        >>> from PIL import Image
 | 
						|
        >>> image = Image.open('sunrise.jpg')
 | 
						|
        >>>
 | 
						|
        >>> st.image(image, caption='Sunrise by the mountains')
 | 
						|
 | 
						|
        .. output::
 | 
						|
           https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.image.py
 | 
						|
           height: 710px
 | 
						|
 | 
						|
        """
 | 
						|
 | 
						|
        if use_column_width == "auto" or (use_column_width is None and width is None):
 | 
						|
            width = -3
 | 
						|
        elif use_column_width == "always" or use_column_width == True:
 | 
						|
            width = -2
 | 
						|
        elif width is None:
 | 
						|
            width = -1
 | 
						|
        elif width <= 0:
 | 
						|
            raise StreamlitAPIException("Image width must be positive.")
 | 
						|
 | 
						|
        image_list_proto = ImageListProto()
 | 
						|
        marshall_images(
 | 
						|
            self.dg._get_delta_path_str(),
 | 
						|
            image,
 | 
						|
            caption,
 | 
						|
            width,
 | 
						|
            image_list_proto,
 | 
						|
            clamp,
 | 
						|
            channels,
 | 
						|
            output_format,
 | 
						|
        )
 | 
						|
        return self.dg._enqueue("imgs", image_list_proto)
 | 
						|
 | 
						|
    @property
 | 
						|
    def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
 | 
						|
        """Get our DeltaGenerator."""
 | 
						|
        return cast("streamlit.delta_generator.DeltaGenerator", self)
 | 
						|
 | 
						|
 | 
						|
def _image_may_have_alpha_channel(image):
 | 
						|
    if image.mode in ("RGBA", "LA", "P"):
 | 
						|
        return True
 | 
						|
    else:
 | 
						|
        return False
 | 
						|
 | 
						|
 | 
						|
def _format_from_image_type(image, output_format):
 | 
						|
    output_format = output_format.upper()
 | 
						|
    if output_format == "JPEG" or output_format == "PNG":
 | 
						|
        return output_format
 | 
						|
 | 
						|
    # We are forgiving on the spelling of JPEG
 | 
						|
    if output_format == "JPG":
 | 
						|
        return "JPEG"
 | 
						|
 | 
						|
    if _image_may_have_alpha_channel(image):
 | 
						|
        return "PNG"
 | 
						|
 | 
						|
    return "JPEG"
 | 
						|
 | 
						|
 | 
						|
def _PIL_to_bytes(image, format="JPEG", quality=100):
 | 
						|
    tmp = io.BytesIO()
 | 
						|
 | 
						|
    # User must have specified JPEG, so we must convert it
 | 
						|
    if format == "JPEG" and _image_may_have_alpha_channel(image):
 | 
						|
        image = image.convert("RGB")
 | 
						|
 | 
						|
    image.save(tmp, format=format, quality=quality)
 | 
						|
 | 
						|
    return tmp.getvalue()
 | 
						|
 | 
						|
 | 
						|
def _BytesIO_to_bytes(data):
 | 
						|
    data.seek(0)
 | 
						|
    return data.getvalue()
 | 
						|
 | 
						|
 | 
						|
def _np_array_to_bytes(array, output_format="JPEG"):
 | 
						|
    img = Image.fromarray(array.astype(np.uint8))
 | 
						|
    format = _format_from_image_type(img, output_format)
 | 
						|
 | 
						|
    return _PIL_to_bytes(img, format)
 | 
						|
 | 
						|
 | 
						|
def _4d_to_list_3d(array):
 | 
						|
    return [array[i, :, :, :] for i in range(0, array.shape[0])]
 | 
						|
 | 
						|
 | 
						|
def _verify_np_shape(array):
 | 
						|
    if len(array.shape) not in (2, 3):
 | 
						|
        raise StreamlitAPIException("Numpy shape has to be of length 2 or 3.")
 | 
						|
    if len(array.shape) == 3 and array.shape[-1] not in (1, 3, 4):
 | 
						|
        raise StreamlitAPIException(
 | 
						|
            "Channel can only be 1, 3, or 4 got %d. Shape is %s"
 | 
						|
            % (array.shape[-1], str(array.shape))
 | 
						|
        )
 | 
						|
 | 
						|
    # If there's only one channel, convert is to x, y
 | 
						|
    if len(array.shape) == 3 and array.shape[-1] == 1:
 | 
						|
        array = array[:, :, 0]
 | 
						|
 | 
						|
    return array
 | 
						|
 | 
						|
 | 
						|
def _normalize_to_bytes(data, width, output_format):
 | 
						|
    image = Image.open(io.BytesIO(data))
 | 
						|
    actual_width, actual_height = image.size
 | 
						|
    format = _format_from_image_type(image, output_format)
 | 
						|
    if output_format.lower() == "auto":
 | 
						|
        ext = imghdr.what(None, data)
 | 
						|
        mimetype = mimetypes.guess_type("image.%s" % ext)[0]
 | 
						|
        # if no other options, attempt to convert
 | 
						|
        if mimetype is None:
 | 
						|
            mimetype = "image/" + format.lower()
 | 
						|
    else:
 | 
						|
        mimetype = "image/" + format.lower()
 | 
						|
 | 
						|
    if width < 0 and actual_width > MAXIMUM_CONTENT_WIDTH:
 | 
						|
        width = MAXIMUM_CONTENT_WIDTH
 | 
						|
 | 
						|
    if width > 0 and actual_width > width:
 | 
						|
        new_height = int(1.0 * actual_height * width / actual_width)
 | 
						|
        image = image.resize((width, new_height), resample=Image.BILINEAR)
 | 
						|
        data = _PIL_to_bytes(image, format=format, quality=90)
 | 
						|
        mimetype = "image/" + format.lower()
 | 
						|
 | 
						|
    return data, mimetype
 | 
						|
 | 
						|
 | 
						|
def _clip_image(image, clamp):
 | 
						|
    data = image
 | 
						|
    if issubclass(image.dtype.type, np.floating):
 | 
						|
        if clamp:
 | 
						|
            data = np.clip(image, 0, 1.0)
 | 
						|
        else:
 | 
						|
            if np.amin(image) < 0.0 or np.amax(image) > 1.0:
 | 
						|
                raise RuntimeError("Data is outside [0.0, 1.0] and clamp is not set.")
 | 
						|
        data = data * 255
 | 
						|
    else:
 | 
						|
        if clamp:
 | 
						|
            data = np.clip(image, 0, 255)
 | 
						|
        else:
 | 
						|
            if np.amin(image) < 0 or np.amax(image) > 255:
 | 
						|
                raise RuntimeError("Data is outside [0, 255] and clamp is not set.")
 | 
						|
    return data
 | 
						|
 | 
						|
 | 
						|
def image_to_url(
 | 
						|
    image, width, clamp, channels, output_format, image_id, allow_emoji=False
 | 
						|
):
 | 
						|
    # PIL Images
 | 
						|
    if isinstance(image, ImageFile.ImageFile) or isinstance(image, Image.Image):
 | 
						|
        format = _format_from_image_type(image, output_format)
 | 
						|
        data = _PIL_to_bytes(image, format)
 | 
						|
 | 
						|
    # BytesIO
 | 
						|
    # Note: This doesn't support SVG. We could convert to png (cairosvg.svg2png)
 | 
						|
    # or just decode BytesIO to string and handle that way.
 | 
						|
    elif isinstance(image, io.BytesIO):
 | 
						|
        data = _BytesIO_to_bytes(image)
 | 
						|
 | 
						|
    # Numpy Arrays (ie opencv)
 | 
						|
    elif type(image) is np.ndarray:
 | 
						|
        data = _verify_np_shape(image)
 | 
						|
        data = _clip_image(data, clamp)
 | 
						|
 | 
						|
        if channels == "BGR":
 | 
						|
            if len(data.shape) == 3:
 | 
						|
                data = data[:, :, [2, 1, 0]]
 | 
						|
            else:
 | 
						|
                raise StreamlitAPIException(
 | 
						|
                    'When using `channels="BGR"`, the input image should '
 | 
						|
                    "have exactly 3 color channels"
 | 
						|
                )
 | 
						|
 | 
						|
        data = _np_array_to_bytes(data, output_format=output_format)
 | 
						|
 | 
						|
    # Strings
 | 
						|
    elif isinstance(image, str):
 | 
						|
        # If it's a url, then set the protobuf and continue
 | 
						|
        try:
 | 
						|
            p = urlparse(image)
 | 
						|
            if p.scheme:
 | 
						|
                return image
 | 
						|
        except UnicodeDecodeError:
 | 
						|
            pass
 | 
						|
 | 
						|
        # Finally, see if it's a file.
 | 
						|
        try:
 | 
						|
            with open(image, "rb") as f:
 | 
						|
                data = f.read()
 | 
						|
        except:
 | 
						|
            if allow_emoji:
 | 
						|
                # This might be an emoji string, so just pass it to the frontend
 | 
						|
                return image
 | 
						|
            else:
 | 
						|
                # Allow OS filesystem errors to raise
 | 
						|
                raise
 | 
						|
 | 
						|
    # Assume input in bytes.
 | 
						|
    else:
 | 
						|
        data = image
 | 
						|
 | 
						|
    (data, mimetype) = _normalize_to_bytes(data, width, output_format)
 | 
						|
    this_file = in_memory_file_manager.add(data, mimetype, image_id)
 | 
						|
    return this_file.url
 | 
						|
 | 
						|
 | 
						|
def marshall_images(
 | 
						|
    coordinates,
 | 
						|
    image,
 | 
						|
    caption,
 | 
						|
    width,
 | 
						|
    proto_imgs,
 | 
						|
    clamp,
 | 
						|
    channels="RGB",
 | 
						|
    output_format="auto",
 | 
						|
):
 | 
						|
    channels = channels.upper()
 | 
						|
 | 
						|
    # Turn single image and caption into one element list.
 | 
						|
    if type(image) is list:
 | 
						|
        images = image
 | 
						|
    else:
 | 
						|
        if type(image) == np.ndarray and len(image.shape) == 4:
 | 
						|
            images = _4d_to_list_3d(image)
 | 
						|
        else:
 | 
						|
            images = [image]
 | 
						|
 | 
						|
    if type(caption) is list:
 | 
						|
        captions = caption
 | 
						|
    else:
 | 
						|
        if isinstance(caption, str):
 | 
						|
            captions = [caption]
 | 
						|
        # You can pass in a 1-D Numpy array as captions.
 | 
						|
        elif type(caption) == np.ndarray and len(caption.shape) == 1:
 | 
						|
            captions = caption.tolist()
 | 
						|
        # If there are no captions then make the captions list the same size
 | 
						|
        # as the images list.
 | 
						|
        elif caption is None:
 | 
						|
            captions = [None] * len(images)
 | 
						|
        else:
 | 
						|
            captions = [str(caption)]
 | 
						|
 | 
						|
    assert type(captions) == list, "If image is a list then caption should be as well"
 | 
						|
    assert len(captions) == len(images), "Cannot pair %d captions with %d images." % (
 | 
						|
        len(captions),
 | 
						|
        len(images),
 | 
						|
    )
 | 
						|
 | 
						|
    proto_imgs.width = width
 | 
						|
    # Each image in an image list needs to be kept track of at its own coordinates.
 | 
						|
    for coord_suffix, (image, caption) in enumerate(zip(images, captions)):
 | 
						|
        proto_img = proto_imgs.imgs.add()
 | 
						|
        if caption is not None:
 | 
						|
            proto_img.caption = str(caption)
 | 
						|
 | 
						|
        # We use the index of the image in the input image list to identify this image inside
 | 
						|
        # InMemoryFileManager. For this, we just add the index to the image's "coordinates".
 | 
						|
        image_id = "%s-%i" % (coordinates, coord_suffix)
 | 
						|
 | 
						|
        is_svg = False
 | 
						|
        if isinstance(image, str):
 | 
						|
            # Unpack local SVG image file to an SVG string
 | 
						|
            if image.endswith(".svg") and not image.startswith(("http://", "https://")):
 | 
						|
                with open(image) as textfile:
 | 
						|
                    image = textfile.read()
 | 
						|
 | 
						|
            # Following regex allows svg image files to start either via a "<?xml...>" tag eventually followed by a "<svg...>" tag or directly starting with a "<svg>" tag
 | 
						|
            if re.search(r"(^\s?(<\?xml[\s\S]*<svg )|^\s?<svg )", image):
 | 
						|
                proto_img.markup = f"data:image/svg+xml,{image}"
 | 
						|
                is_svg = True
 | 
						|
        if not is_svg:
 | 
						|
            proto_img.url = image_to_url(
 | 
						|
                image, width, clamp, channels, output_format, image_id
 | 
						|
            )
 |