mirror of
https://github.com/aykhans/AzSuicideDataVisualization.git
synced 2025-04-22 18:32:15 +00:00
311 lines
11 KiB
Python
311 lines
11 KiB
Python
# Copyright 2018-2022 Streamlit Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import io
|
|
import threading
|
|
from typing import Dict, NamedTuple, List, Tuple
|
|
from blinker import Signal
|
|
from streamlit import util
|
|
from streamlit.logger import get_logger
|
|
from streamlit.stats import CacheStatsProvider, CacheStat
|
|
|
|
LOGGER = get_logger(__name__)
|
|
|
|
|
|
class UploadedFileRec(NamedTuple):
|
|
"""Metadata and raw bytes for an uploaded file. Immutable."""
|
|
|
|
id: int
|
|
name: str
|
|
type: str
|
|
data: bytes
|
|
|
|
|
|
class UploadedFile(io.BytesIO):
|
|
"""A mutable uploaded file.
|
|
|
|
This class extends BytesIO, which has copy-on-write semantics when
|
|
initialized with `bytes`.
|
|
"""
|
|
|
|
def __init__(self, record: UploadedFileRec):
|
|
# BytesIO's copy-on-write semantics doesn't seem to be mentioned in
|
|
# the Python docs - possibly because it's a CPython-only optimization
|
|
# and not guaranteed to be in other Python runtimes. But it's detailed
|
|
# here: https://hg.python.org/cpython/rev/79a5fbe2c78f
|
|
super(UploadedFile, self).__init__(record.data)
|
|
self.id = record.id
|
|
self.name = record.name
|
|
self.type = record.type
|
|
self.size = len(record.data)
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, UploadedFile):
|
|
return NotImplemented
|
|
return self.id == other.id
|
|
|
|
def __repr__(self) -> str:
|
|
return util.repr_(self)
|
|
|
|
|
|
class UploadedFileManager(CacheStatsProvider):
|
|
"""Holds files uploaded by users of the running Streamlit app,
|
|
and emits an event signal when a file is added.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# List of files for a given widget in a given session.
|
|
self._files_by_id: Dict[Tuple[str, str], List[UploadedFileRec]] = {}
|
|
|
|
# A counter that generates unique file IDs. Each file ID is greater
|
|
# than the previous ID, which means we can use IDs to compare files
|
|
# by age.
|
|
self._file_id_counter = 1
|
|
self._file_id_lock = threading.Lock()
|
|
|
|
# Prevents concurrent access to the _files_by_id dict.
|
|
# In remove_session_files(), we iterate over the dict's keys. It's
|
|
# an error to mutate a dict while iterating; this lock prevents that.
|
|
self._files_lock = threading.Lock()
|
|
self.on_files_updated = Signal(
|
|
doc="""Emitted when a file list is added to the manager or updated.
|
|
|
|
Parameters
|
|
----------
|
|
session_id : str
|
|
The session_id for the session whose files were updated.
|
|
"""
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return util.repr_(self)
|
|
|
|
def add_file(
|
|
self,
|
|
session_id: str,
|
|
widget_id: str,
|
|
file: UploadedFileRec,
|
|
) -> UploadedFileRec:
|
|
"""Add a file to the FileManager, and return a new UploadedFileRec
|
|
with its ID assigned.
|
|
|
|
The "on_files_updated" Signal will be emitted.
|
|
|
|
Parameters
|
|
----------
|
|
session_id
|
|
The ID of the session that owns the file.
|
|
widget_id
|
|
The widget ID of the FileUploader that created the file.
|
|
file
|
|
The file to add.
|
|
|
|
Returns
|
|
-------
|
|
UploadedFileRec
|
|
The added file, which has its unique ID assigned.
|
|
"""
|
|
files_by_widget = session_id, widget_id
|
|
|
|
# Assign the file a unique ID
|
|
file_id = self._get_next_file_id()
|
|
file = UploadedFileRec(
|
|
id=file_id, name=file.name, type=file.type, data=file.data
|
|
)
|
|
|
|
with self._files_lock:
|
|
file_list = self._files_by_id.get(files_by_widget, None)
|
|
if file_list is not None:
|
|
file_list.append(file)
|
|
else:
|
|
self._files_by_id[files_by_widget] = [file]
|
|
|
|
self.on_files_updated.send(session_id)
|
|
return file
|
|
|
|
def get_all_files(self, session_id: str, widget_id: str) -> List[UploadedFileRec]:
|
|
"""Return all the files stored for the given widget.
|
|
|
|
Parameters
|
|
----------
|
|
session_id
|
|
The ID of the session that owns the files.
|
|
widget_id
|
|
The widget ID of the FileUploader that created the files.
|
|
"""
|
|
file_list_id = (session_id, widget_id)
|
|
with self._files_lock:
|
|
return self._files_by_id.get(file_list_id, []).copy()
|
|
|
|
def get_files(
|
|
self, session_id: str, widget_id: str, file_ids: List[int]
|
|
) -> List[UploadedFileRec]:
|
|
"""Return the files with the given widget_id and file_ids.
|
|
|
|
Parameters
|
|
----------
|
|
session_id
|
|
The ID of the session that owns the files.
|
|
widget_id
|
|
The widget ID of the FileUploader that created the files.
|
|
file_ids
|
|
List of file IDs. Only files whose IDs are in this list will be
|
|
returned.
|
|
"""
|
|
return [
|
|
f for f in self.get_all_files(session_id, widget_id) if f.id in file_ids
|
|
]
|
|
|
|
def remove_orphaned_files(
|
|
self,
|
|
session_id: str,
|
|
widget_id: str,
|
|
newest_file_id: int,
|
|
active_file_ids: List[int],
|
|
) -> None:
|
|
"""Remove 'orphaned' files: files that have been uploaded and
|
|
subsequently deleted, but haven't yet been removed from memory.
|
|
|
|
Because FileUploader can live inside forms, file deletion is made a
|
|
bit tricky: a file deletion should only happen after the form is
|
|
submitted.
|
|
|
|
FileUploader's widget value is an array of numbers that has two parts:
|
|
- The first number is always 'this.state.newestServerFileId'.
|
|
- The remaining 0 or more numbers are the file IDs of all the
|
|
uploader's uploaded files.
|
|
|
|
When the server receives the widget value, it deletes "orphaned"
|
|
uploaded files. An orphaned file is any file associated with a given
|
|
FileUploader whose file ID is not in the active_file_ids, and whose
|
|
ID is <= `newestServerFileId`.
|
|
|
|
This logic ensures that a FileUploader within a form doesn't have any
|
|
of its "unsubmitted" uploads prematurely deleted when the script is
|
|
re-run.
|
|
"""
|
|
file_list_id = (session_id, widget_id)
|
|
with self._files_lock:
|
|
file_list = self._files_by_id.get(file_list_id)
|
|
if file_list is None:
|
|
return
|
|
|
|
# Remove orphaned files from the list:
|
|
# - `f.id in active_file_ids`:
|
|
# File is currently tracked by the widget. DON'T remove.
|
|
# - `f.id > newest_file_id`:
|
|
# file was uploaded *after* the widget was most recently
|
|
# updated. (It's probably in a form.) DON'T remove.
|
|
# - `f.id < newest_file_id and f.id not in active_file_ids`:
|
|
# File is not currently tracked by the widget, and was uploaded
|
|
# *before* this most recent update. This means it's been deleted
|
|
# by the user on the frontend, and is now "orphaned". Remove!
|
|
new_list = [
|
|
f for f in file_list if f.id > newest_file_id or f.id in active_file_ids
|
|
]
|
|
self._files_by_id[file_list_id] = new_list
|
|
num_removed = len(file_list) - len(new_list)
|
|
|
|
if num_removed > 0:
|
|
LOGGER.debug("Removed %s orphaned files" % num_removed)
|
|
|
|
def remove_file(self, session_id: str, widget_id: str, file_id: int) -> bool:
|
|
"""Remove the file list with the given ID, if it exists.
|
|
|
|
The "on_files_updated" Signal will be emitted.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the file was removed, or False if no such file exists.
|
|
"""
|
|
file_list_id = (session_id, widget_id)
|
|
with self._files_lock:
|
|
file_list = self._files_by_id.get(file_list_id, None)
|
|
if file_list is None:
|
|
return False
|
|
|
|
# Remove the file from its list.
|
|
new_file_list = [file for file in file_list if file.id != file_id]
|
|
self._files_by_id[file_list_id] = new_file_list
|
|
|
|
self.on_files_updated.send(session_id)
|
|
return True
|
|
|
|
def _remove_files(self, session_id: str, widget_id: str) -> None:
|
|
"""Remove the file list for the provided widget in the
|
|
provided session, if it exists.
|
|
|
|
Does not emit any signals.
|
|
"""
|
|
files_by_widget = session_id, widget_id
|
|
with self._files_lock:
|
|
self._files_by_id.pop(files_by_widget, None)
|
|
|
|
def remove_files(self, session_id: str, widget_id: str) -> None:
|
|
"""Remove the file list for the provided widget in the
|
|
provided session, if it exists.
|
|
|
|
The "on_files_updated" Signal will be emitted.
|
|
|
|
Parameters
|
|
----------
|
|
session_id : str
|
|
The ID of the session that owns the files.
|
|
widget_id : str
|
|
The widget ID of the FileUploader that created the files.
|
|
"""
|
|
self._remove_files(session_id, widget_id)
|
|
self.on_files_updated.send(session_id)
|
|
|
|
def remove_session_files(self, session_id: str) -> None:
|
|
"""Remove all files that belong to the given session.
|
|
|
|
Parameters
|
|
----------
|
|
session_id : str
|
|
The ID of the session whose files we're removing.
|
|
|
|
"""
|
|
# Copy the keys into a list, because we'll be mutating the dictionary.
|
|
with self._files_lock:
|
|
all_ids = list(self._files_by_id.keys())
|
|
|
|
for files_id in all_ids:
|
|
if files_id[0] == session_id:
|
|
self.remove_files(*files_id)
|
|
|
|
def _get_next_file_id(self) -> int:
|
|
"""Return the next file ID and increment our ID counter."""
|
|
with self._file_id_lock:
|
|
file_id = self._file_id_counter
|
|
self._file_id_counter += 1
|
|
return file_id
|
|
|
|
def get_stats(self) -> List[CacheStat]:
|
|
with self._files_lock:
|
|
# Flatten all files into a single list
|
|
all_files: List[UploadedFileRec] = []
|
|
for file_list in self._files_by_id.values():
|
|
all_files.extend(file_list)
|
|
|
|
return [
|
|
CacheStat(
|
|
category_name="UploadedFileManager",
|
|
cache_name="",
|
|
byte_length=len(file.data),
|
|
)
|
|
for file in all_files
|
|
]
|