2022-05-23 00:16:32 +04:00

311 lines
11 KiB
Python

# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import threading
from typing import Dict, NamedTuple, List, Tuple
from blinker import Signal
from streamlit import util
from streamlit.logger import get_logger
from streamlit.stats import CacheStatsProvider, CacheStat
LOGGER = get_logger(__name__)
class UploadedFileRec(NamedTuple):
"""Metadata and raw bytes for an uploaded file. Immutable."""
id: int
name: str
type: str
data: bytes
class UploadedFile(io.BytesIO):
"""A mutable uploaded file.
This class extends BytesIO, which has copy-on-write semantics when
initialized with `bytes`.
"""
def __init__(self, record: UploadedFileRec):
# BytesIO's copy-on-write semantics doesn't seem to be mentioned in
# the Python docs - possibly because it's a CPython-only optimization
# and not guaranteed to be in other Python runtimes. But it's detailed
# here: https://hg.python.org/cpython/rev/79a5fbe2c78f
super(UploadedFile, self).__init__(record.data)
self.id = record.id
self.name = record.name
self.type = record.type
self.size = len(record.data)
def __eq__(self, other: object) -> bool:
if not isinstance(other, UploadedFile):
return NotImplemented
return self.id == other.id
def __repr__(self) -> str:
return util.repr_(self)
class UploadedFileManager(CacheStatsProvider):
"""Holds files uploaded by users of the running Streamlit app,
and emits an event signal when a file is added.
"""
def __init__(self):
# List of files for a given widget in a given session.
self._files_by_id: Dict[Tuple[str, str], List[UploadedFileRec]] = {}
# A counter that generates unique file IDs. Each file ID is greater
# than the previous ID, which means we can use IDs to compare files
# by age.
self._file_id_counter = 1
self._file_id_lock = threading.Lock()
# Prevents concurrent access to the _files_by_id dict.
# In remove_session_files(), we iterate over the dict's keys. It's
# an error to mutate a dict while iterating; this lock prevents that.
self._files_lock = threading.Lock()
self.on_files_updated = Signal(
doc="""Emitted when a file list is added to the manager or updated.
Parameters
----------
session_id : str
The session_id for the session whose files were updated.
"""
)
def __repr__(self) -> str:
return util.repr_(self)
def add_file(
self,
session_id: str,
widget_id: str,
file: UploadedFileRec,
) -> UploadedFileRec:
"""Add a file to the FileManager, and return a new UploadedFileRec
with its ID assigned.
The "on_files_updated" Signal will be emitted.
Parameters
----------
session_id
The ID of the session that owns the file.
widget_id
The widget ID of the FileUploader that created the file.
file
The file to add.
Returns
-------
UploadedFileRec
The added file, which has its unique ID assigned.
"""
files_by_widget = session_id, widget_id
# Assign the file a unique ID
file_id = self._get_next_file_id()
file = UploadedFileRec(
id=file_id, name=file.name, type=file.type, data=file.data
)
with self._files_lock:
file_list = self._files_by_id.get(files_by_widget, None)
if file_list is not None:
file_list.append(file)
else:
self._files_by_id[files_by_widget] = [file]
self.on_files_updated.send(session_id)
return file
def get_all_files(self, session_id: str, widget_id: str) -> List[UploadedFileRec]:
"""Return all the files stored for the given widget.
Parameters
----------
session_id
The ID of the session that owns the files.
widget_id
The widget ID of the FileUploader that created the files.
"""
file_list_id = (session_id, widget_id)
with self._files_lock:
return self._files_by_id.get(file_list_id, []).copy()
def get_files(
self, session_id: str, widget_id: str, file_ids: List[int]
) -> List[UploadedFileRec]:
"""Return the files with the given widget_id and file_ids.
Parameters
----------
session_id
The ID of the session that owns the files.
widget_id
The widget ID of the FileUploader that created the files.
file_ids
List of file IDs. Only files whose IDs are in this list will be
returned.
"""
return [
f for f in self.get_all_files(session_id, widget_id) if f.id in file_ids
]
def remove_orphaned_files(
self,
session_id: str,
widget_id: str,
newest_file_id: int,
active_file_ids: List[int],
) -> None:
"""Remove 'orphaned' files: files that have been uploaded and
subsequently deleted, but haven't yet been removed from memory.
Because FileUploader can live inside forms, file deletion is made a
bit tricky: a file deletion should only happen after the form is
submitted.
FileUploader's widget value is an array of numbers that has two parts:
- The first number is always 'this.state.newestServerFileId'.
- The remaining 0 or more numbers are the file IDs of all the
uploader's uploaded files.
When the server receives the widget value, it deletes "orphaned"
uploaded files. An orphaned file is any file associated with a given
FileUploader whose file ID is not in the active_file_ids, and whose
ID is <= `newestServerFileId`.
This logic ensures that a FileUploader within a form doesn't have any
of its "unsubmitted" uploads prematurely deleted when the script is
re-run.
"""
file_list_id = (session_id, widget_id)
with self._files_lock:
file_list = self._files_by_id.get(file_list_id)
if file_list is None:
return
# Remove orphaned files from the list:
# - `f.id in active_file_ids`:
# File is currently tracked by the widget. DON'T remove.
# - `f.id > newest_file_id`:
# file was uploaded *after* the widget was most recently
# updated. (It's probably in a form.) DON'T remove.
# - `f.id < newest_file_id and f.id not in active_file_ids`:
# File is not currently tracked by the widget, and was uploaded
# *before* this most recent update. This means it's been deleted
# by the user on the frontend, and is now "orphaned". Remove!
new_list = [
f for f in file_list if f.id > newest_file_id or f.id in active_file_ids
]
self._files_by_id[file_list_id] = new_list
num_removed = len(file_list) - len(new_list)
if num_removed > 0:
LOGGER.debug("Removed %s orphaned files" % num_removed)
def remove_file(self, session_id: str, widget_id: str, file_id: int) -> bool:
"""Remove the file list with the given ID, if it exists.
The "on_files_updated" Signal will be emitted.
Returns
-------
bool
True if the file was removed, or False if no such file exists.
"""
file_list_id = (session_id, widget_id)
with self._files_lock:
file_list = self._files_by_id.get(file_list_id, None)
if file_list is None:
return False
# Remove the file from its list.
new_file_list = [file for file in file_list if file.id != file_id]
self._files_by_id[file_list_id] = new_file_list
self.on_files_updated.send(session_id)
return True
def _remove_files(self, session_id: str, widget_id: str) -> None:
"""Remove the file list for the provided widget in the
provided session, if it exists.
Does not emit any signals.
"""
files_by_widget = session_id, widget_id
with self._files_lock:
self._files_by_id.pop(files_by_widget, None)
def remove_files(self, session_id: str, widget_id: str) -> None:
"""Remove the file list for the provided widget in the
provided session, if it exists.
The "on_files_updated" Signal will be emitted.
Parameters
----------
session_id : str
The ID of the session that owns the files.
widget_id : str
The widget ID of the FileUploader that created the files.
"""
self._remove_files(session_id, widget_id)
self.on_files_updated.send(session_id)
def remove_session_files(self, session_id: str) -> None:
"""Remove all files that belong to the given session.
Parameters
----------
session_id : str
The ID of the session whose files we're removing.
"""
# Copy the keys into a list, because we'll be mutating the dictionary.
with self._files_lock:
all_ids = list(self._files_by_id.keys())
for files_id in all_ids:
if files_id[0] == session_id:
self.remove_files(*files_id)
def _get_next_file_id(self) -> int:
"""Return the next file ID and increment our ID counter."""
with self._file_id_lock:
file_id = self._file_id_counter
self._file_id_counter += 1
return file_id
def get_stats(self) -> List[CacheStat]:
with self._files_lock:
# Flatten all files into a single list
all_files: List[UploadedFileRec] = []
for file_list in self._files_by_id.values():
all_files.extend(file_list)
return [
CacheStat(
category_name="UploadedFileManager",
cache_name="",
byte_length=len(file.data),
)
for file in all_files
]