first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Iterator
from .memo_decorator import MEMO_CALL_STACK, _memo_caches, MemoAPI
from .singleton_decorator import SINGLETON_CALL_STACK, _singleton_caches, SingletonAPI
def maybe_show_cached_st_function_warning(dg, st_func_name: str) -> None:
MEMO_CALL_STACK.maybe_show_cached_st_function_warning(dg, st_func_name)
SINGLETON_CALL_STACK.maybe_show_cached_st_function_warning(dg, st_func_name)
@contextlib.contextmanager
def suppress_cached_st_function_warning() -> Iterator[None]:
with MEMO_CALL_STACK.suppress_cached_st_function_warning(), SINGLETON_CALL_STACK.suppress_cached_st_function_warning():
yield
# Explicitly export public symobls
from .memo_decorator import (
get_memo_stats_provider as get_memo_stats_provider,
)
from .singleton_decorator import (
get_singleton_stats_provider as get_singleton_stats_provider,
)
# Create and export public API singletons.
memo = MemoAPI()
singleton = SingletonAPI()

View File

@@ -0,0 +1,119 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import types
from typing import Any, Optional
from streamlit import type_util
from streamlit.errors import (
StreamlitAPIWarning,
StreamlitAPIException,
)
class CacheType(enum.Enum):
MEMO = "experimental_memo"
SINGLETON = "experimental_singleton"
class UnhashableTypeError(Exception):
pass
class UnhashableParamError(StreamlitAPIException):
def __init__(
self,
cache_type: CacheType,
func: types.FunctionType,
arg_name: Optional[str],
arg_value: Any,
orig_exc: BaseException,
):
msg = self._create_message(cache_type, func, arg_name, arg_value)
super().__init__(msg)
self.with_traceback(orig_exc.__traceback__)
@staticmethod
def _create_message(
cache_type: CacheType,
func: types.FunctionType,
arg_name: Optional[str],
arg_value: Any,
) -> str:
arg_name_str = arg_name if arg_name is not None else "(unnamed)"
arg_type = type_util.get_fqn_type(arg_value)
func_name = func.__name__
arg_replacement_name = f"_{arg_name}" if arg_name is not None else "_arg"
return (
f"""
Cannot hash argument '{arg_name_str}' (of type `{arg_type}`) in '{func_name}'.
To address this, you can tell Streamlit not to hash this argument by adding a
leading underscore to the argument's name in the function signature:
```
@st.{cache_type.value}
def {func_name}({arg_replacement_name}, ...):
...
```
"""
).strip("\n")
class CacheKeyNotFoundError(Exception):
pass
class CacheError(Exception):
pass
class CachedStFunctionWarning(StreamlitAPIWarning):
def __init__(
self,
cache_type: CacheType,
st_func_name: str,
cached_func: types.FunctionType,
):
args = {
"st_func_name": f"`st.{st_func_name}()` or `st.write()`",
"func_name": self._get_cached_func_name_md(cached_func),
"decorator_name": cache_type.value,
}
msg = (
"""
Your script uses %(st_func_name)s to write to your Streamlit app from within
some cached code at %(func_name)s. This code will only be called when we detect
a cache "miss", which can lead to unexpected results.
How to fix this:
* Move the %(st_func_name)s call outside %(func_name)s.
* Or, if you know what you're doing, use `@st.%(decorator_name)s(suppress_st_warning=True)`
to suppress the warning.
"""
% args
).strip("\n")
super().__init__(msg)
@staticmethod
def _get_cached_func_name_md(func: types.FunctionType) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return "`%s()`" % func.__name__
else:
return "a cached function"

View File

@@ -0,0 +1,343 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common cache logic shared by st.memo and st.singleton."""
import contextlib
import functools
import hashlib
import inspect
import threading
import types
from abc import abstractmethod
from typing import Callable, List, Iterator, Tuple, Optional, Any, Union
import streamlit as st
from streamlit import util
from streamlit.caching.cache_errors import CacheKeyNotFoundError
from streamlit.logger import get_logger
from .cache_errors import (
CacheType,
CachedStFunctionWarning,
UnhashableParamError,
UnhashableTypeError,
)
from .hashing import update_hash
_LOGGER = get_logger(__name__)
class Cache:
"""Function cache interface. Caches persist across script runs."""
@abstractmethod
def read_value(self, value_key: str) -> Any:
"""Read a value from the cache.
Raises
------
CacheKeyNotFoundError
Raised if value_key is not in the cache.
"""
raise NotImplementedError
@abstractmethod
def write_value(self, value_key: str, value: Any) -> None:
"""Write a value to the cache, overwriting any existing value that
uses the value_key.
"""
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
"""Clear all values from this function cache."""
raise NotImplementedError
class CachedFunction:
"""Encapsulates data for a cached function instance.
CachedFunction instances are scoped to a single script run - they're not
persistent.
"""
def __init__(
self, func: types.FunctionType, show_spinner: bool, suppress_st_warning: bool
):
self.func = func
self.show_spinner = show_spinner
self.suppress_st_warning = suppress_st_warning
@property
def cache_type(self) -> CacheType:
raise NotImplementedError
@property
def call_stack(self) -> "CachedFunctionCallStack":
raise NotImplementedError
def get_function_cache(self, function_key: str) -> Cache:
"""Get or create the function cache for the given key."""
raise NotImplementedError
def create_cache_wrapper(cached_func: CachedFunction) -> Callable[..., Any]:
"""Create a wrapper for a CachedFunction. This implements the common
plumbing for both st.memo and st.singleton.
"""
func = cached_func.func
function_key = _make_function_key(cached_func.cache_type, func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""This function wrapper will only call the underlying function in
the case of a cache miss.
"""
# Retrieve the function's cache object. We must do this inside the
# wrapped function, because caches can be invalidated at any time.
cache = cached_func.get_function_cache(function_key)
name = func.__qualname__
if len(args) == 0 and len(kwargs) == 0:
message = f"Running `{name}()`."
else:
message = f"Running `{name}(...)`."
def get_or_create_cached_value():
# Generate the key for the cached value. This is based on the
# arguments passed to the function.
value_key = _make_value_key(cached_func.cache_type, func, *args, **kwargs)
try:
return_value = cache.read_value(value_key)
_LOGGER.debug("Cache hit: %s", func)
except CacheKeyNotFoundError:
_LOGGER.debug("Cache miss: %s", func)
with cached_func.call_stack.calling_cached_function(func):
if cached_func.suppress_st_warning:
with cached_func.call_stack.suppress_cached_st_function_warning():
return_value = func(*args, **kwargs)
else:
return_value = func(*args, **kwargs)
cache.write_value(value_key, return_value)
return return_value
if cached_func.show_spinner:
with st.spinner(message):
return get_or_create_cached_value()
else:
return get_or_create_cached_value()
def clear():
"""Clear the wrapped function's associated cache."""
cache = cached_func.get_function_cache(function_key)
cache.clear()
# Mypy doesn't support declaring attributes of function objects,
# so we have to suppress a warning here. We can remove this suppression
# when this issue is resolved: https://github.com/python/mypy/issues/2087
wrapper.clear = clear # type: ignore
return wrapper
class CachedFunctionCallStack(threading.local):
"""A utility for warning users when they call `st` commands inside
a cached function. Internally, this is just a counter that's incremented
when we enter a cache function, and decremented when we exit.
Data is stored in a thread-local object, so it's safe to use an instance
of this class across multiple threads.
"""
def __init__(self, cache_type: CacheType):
self._cached_func_stack: List[types.FunctionType] = []
self._suppress_st_function_warning = 0
self._cache_type = cache_type
def __repr__(self) -> str:
return util.repr_(self)
@contextlib.contextmanager
def calling_cached_function(self, func: types.FunctionType) -> Iterator[None]:
self._cached_func_stack.append(func)
try:
yield
finally:
self._cached_func_stack.pop()
@contextlib.contextmanager
def suppress_cached_st_function_warning(self) -> Iterator[None]:
self._suppress_st_function_warning += 1
try:
yield
finally:
self._suppress_st_function_warning -= 1
assert self._suppress_st_function_warning >= 0
def maybe_show_cached_st_function_warning(
self, dg: "st.delta_generator.DeltaGenerator", st_func_name: str
) -> None:
"""If appropriate, warn about calling st.foo inside @memo.
DeltaGenerator's @_with_element and @_widget wrappers use this to warn
the user when they're calling st.foo() from within a function that is
wrapped in @st.cache.
Parameters
----------
dg : DeltaGenerator
The DeltaGenerator to publish the warning to.
st_func_name : str
The name of the Streamlit function that was called.
"""
if len(self._cached_func_stack) > 0 and self._suppress_st_function_warning <= 0:
cached_func = self._cached_func_stack[-1]
self._show_cached_st_function_warning(dg, st_func_name, cached_func)
def _show_cached_st_function_warning(
self,
dg: "st.delta_generator.DeltaGenerator",
st_func_name: str,
cached_func: types.FunctionType,
) -> None:
# Avoid infinite recursion by suppressing additional cached
# function warnings from within the cached function warning.
with self.suppress_cached_st_function_warning():
e = CachedStFunctionWarning(self._cache_type, st_func_name, cached_func)
dg.exception(e)
def _make_value_key(
cache_type: CacheType, func: types.FunctionType, *args, **kwargs
) -> str:
"""Create the key for a value within a cache.
This key is generated from the function's arguments. All arguments
will be hashed, except for those named with a leading "_".
Raises
------
StreamlitAPIException
Raised (with a nicely-formatted explanation message) if we encounter
an un-hashable arg.
"""
# Create a (name, value) list of all *args and **kwargs passed to the
# function.
arg_pairs: List[Tuple[Optional[str], Any]] = []
for arg_idx in range(len(args)):
arg_name = _get_positional_arg_name(func, arg_idx)
arg_pairs.append((arg_name, args[arg_idx]))
for kw_name, kw_val in kwargs.items():
# **kwargs ordering is preserved, per PEP 468
# https://www.python.org/dev/peps/pep-0468/, so this iteration is
# deterministic.
arg_pairs.append((kw_name, kw_val))
# Create the hash from each arg value, except for those args whose name
# starts with "_". (Underscore-prefixed args are deliberately excluded from
# hashing.)
args_hasher = hashlib.new("md5")
for arg_name, arg_value in arg_pairs:
if arg_name is not None and arg_name.startswith("_"):
_LOGGER.debug("Not hashing %s because it starts with _", arg_name)
continue
try:
update_hash(
(arg_name, arg_value),
hasher=args_hasher,
cache_type=cache_type,
)
except UnhashableTypeError as exc:
raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc)
value_key = args_hasher.hexdigest()
_LOGGER.debug("Cache key: %s", value_key)
return value_key
def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str:
"""Create the unique key for a function's cache.
A function's key is stable across reruns of the app, and changes when
the function's source code changes.
"""
func_hasher = hashlib.new("md5")
# Include the function's __module__ and __qualname__ strings in the hash.
# This means that two identical functions in different modules
# will not share a hash; it also means that two identical *nested*
# functions in the same module will not share a hash.
update_hash(
(func.__module__, func.__qualname__),
hasher=func_hasher,
cache_type=cache_type,
)
# Include the function's source code in its hash. If the source code can't
# be retrieved, fall back to the function's bytecode instead.
source_code: Union[str, bytes]
try:
source_code = inspect.getsource(func)
except OSError as e:
_LOGGER.debug(
"Failed to retrieve function's source code when building its key; falling back to bytecode. err={0}",
e,
)
source_code = func.__code__.co_code
update_hash(
source_code,
hasher=func_hasher,
cache_type=cache_type,
)
cache_key = func_hasher.hexdigest()
return cache_key
def _get_positional_arg_name(func: types.FunctionType, arg_index: int) -> Optional[str]:
"""Return the name of a function's positional argument.
If arg_index is out of range, or refers to a parameter that is not a
named positional argument (e.g. an *args, **kwargs, or keyword-only param),
return None instead.
"""
if arg_index < 0:
return None
params: List[inspect.Parameter] = list(inspect.signature(func).parameters.values())
if arg_index >= len(params):
return None
if params[arg_index].kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
):
return params[arg_index].name
return None

View File

@@ -0,0 +1,389 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hashing for st.memo and st.singleton."""
import collections
import functools
import hashlib
import inspect
import io
import os
import pickle
import sys
import tempfile
import threading
import unittest.mock
import weakref
from typing import Any, Pattern, Optional, Dict, List
from streamlit import type_util
from streamlit import util
from streamlit.logger import get_logger
from streamlit.uploaded_file_manager import UploadedFile
from .cache_errors import (
CacheType,
UnhashableTypeError,
)
_LOGGER = get_logger(__name__)
# If a dataframe has more than this many rows, we consider it large and hash a sample.
_PANDAS_ROWS_LARGE = 100000
_PANDAS_SAMPLE_SIZE = 10000
# Similar to dataframes, we also sample large numpy arrays.
_NP_SIZE_LARGE = 1000000
_NP_SAMPLE_SIZE = 100000
# Arbitrary item to denote where we found a cycle in a hashed object.
# This allows us to hash self-referencing lists, dictionaries, etc.
_CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE"
def update_hash(val: Any, hasher, cache_type: CacheType) -> None:
"""Updates a hashlib hasher with the hash of val.
This is the main entrypoint to hashing.py.
"""
ch = _CacheFuncHasher(cache_type)
ch.update(hasher, val)
class _HashStack:
"""Stack of what has been hashed, for debug and circular reference detection.
This internally keeps 1 stack per thread.
Internally, this stores the ID of pushed objects rather than the objects
themselves because otherwise the "in" operator inside __contains__ would
fail for objects that don't return a boolean for "==" operator. For
example, arr == 10 where arr is a NumPy array returns another NumPy array.
This causes the "in" to crash since it expects a boolean.
"""
def __init__(self):
self._stack: collections.OrderedDict[int, List[Any]] = collections.OrderedDict()
def __repr__(self) -> str:
return util.repr_(self)
def push(self, val: Any):
self._stack[id(val)] = val
def pop(self):
self._stack.popitem()
def __contains__(self, val: Any):
return id(val) in self._stack
class _HashStacks:
"""Stacks of what has been hashed, with at most 1 stack per thread."""
def __init__(self):
self._stacks: weakref.WeakKeyDictionary[
threading.Thread, _HashStack
] = weakref.WeakKeyDictionary()
def __repr__(self) -> str:
return util.repr_(self)
@property
def current(self) -> _HashStack:
current_thread = threading.current_thread()
stack = self._stacks.get(current_thread, None)
if stack is None:
stack = _HashStack()
self._stacks[current_thread] = stack
return stack
hash_stacks = _HashStacks()
def _int_to_bytes(i: int) -> bytes:
num_bytes = (i.bit_length() + 8) // 8
return i.to_bytes(num_bytes, "little", signed=True)
def _key(obj: Optional[Any]) -> Any:
"""Return key for memoization."""
if obj is None:
return None
def is_simple(obj):
return (
isinstance(obj, bytes)
or isinstance(obj, bytearray)
or isinstance(obj, str)
or isinstance(obj, float)
or isinstance(obj, int)
or isinstance(obj, bool)
or obj is None
)
if is_simple(obj):
return obj
if isinstance(obj, tuple):
if all(map(is_simple, obj)):
return obj
if isinstance(obj, list):
if all(map(is_simple, obj)):
return ("__l", tuple(obj))
if (
type_util.is_type(obj, "pandas.core.frame.DataFrame")
or type_util.is_type(obj, "numpy.ndarray")
or inspect.isbuiltin(obj)
or inspect.isroutine(obj)
or inspect.iscode(obj)
):
return id(obj)
return NoResult
class _CacheFuncHasher:
"""A hasher that can hash objects with cycles."""
def __init__(self, cache_type: CacheType):
self._hashes: Dict[Any, bytes] = {}
# The number of the bytes in the hash.
self.size = 0
self.cache_type = cache_type
def __repr__(self) -> str:
return util.repr_(self)
def to_bytes(self, obj: Any) -> bytes:
"""Add memoization to _to_bytes and protect against cycles in data structures."""
tname = type(obj).__qualname__.encode()
key = (tname, _key(obj))
# Memoize if possible.
if key[1] is not NoResult:
if key in self._hashes:
return self._hashes[key]
# Break recursive cycles.
if obj in hash_stacks.current:
return _CYCLE_PLACEHOLDER
hash_stacks.current.push(obj)
try:
# Hash the input
b = b"%s:%s" % (tname, self._to_bytes(obj))
# Hmmm... It's possible that the size calculation is wrong. When we
# call to_bytes inside _to_bytes things get double-counted.
self.size += sys.getsizeof(b)
if key[1] is not NoResult:
self._hashes[key] = b
finally:
# In case an UnhashableTypeError (or other) error is thrown, clean up the
# stack so we don't get false positives in future hashing calls
hash_stacks.current.pop()
return b
def update(self, hasher, obj: Any) -> None:
"""Update the provided hasher with the hash of an object."""
b = self.to_bytes(obj)
hasher.update(b)
def _to_bytes(self, obj: Any) -> bytes:
"""Hash objects to bytes, including code with dependencies.
Python's built in `hash` does not produce consistent results across
runs.
"""
if isinstance(obj, unittest.mock.Mock):
# Mock objects can appear to be infinitely
# deep, so we don't try to hash them at all.
return self.to_bytes(id(obj))
elif isinstance(obj, bytes) or isinstance(obj, bytearray):
return obj
elif isinstance(obj, str):
return obj.encode()
elif isinstance(obj, float):
return self.to_bytes(hash(obj))
elif isinstance(obj, int):
return _int_to_bytes(obj)
elif isinstance(obj, (list, tuple)):
h = hashlib.new("md5")
for item in obj:
self.update(h, item)
return h.digest()
elif isinstance(obj, dict):
h = hashlib.new("md5")
for item in obj.items():
self.update(h, item)
return h.digest()
elif obj is None:
return b"0"
elif obj is True:
return b"1"
elif obj is False:
return b"0"
elif type_util.is_type(obj, "pandas.core.frame.DataFrame") or type_util.is_type(
obj, "pandas.core.series.Series"
):
import pandas as pd
if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
try:
return b"%s" % pd.util.hash_pandas_object(obj).sum()
except TypeError:
# Use pickle if pandas cannot hash the object for example if
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
elif type_util.is_type(obj, "numpy.ndarray"):
h = hashlib.new("md5")
self.update(h, obj.shape)
if obj.size >= _NP_SIZE_LARGE:
import numpy as np
state = np.random.RandomState(0)
obj = state.choice(obj.flat, size=_NP_SAMPLE_SIZE)
self.update(h, obj.tobytes())
return h.digest()
elif inspect.isbuiltin(obj):
return bytes(obj.__name__.encode())
elif type_util.is_type(obj, "builtins.mappingproxy") or type_util.is_type(
obj, "builtins.dict_items"
):
return self.to_bytes(dict(obj))
elif type_util.is_type(obj, "builtins.getset_descriptor"):
return bytes(obj.__qualname__.encode())
elif isinstance(obj, UploadedFile):
# UploadedFile is a BytesIO (thus IOBase) but has a name.
# It does not have a timestamp so this must come before
# temproary files
h = hashlib.new("md5")
self.update(h, obj.name)
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif hasattr(obj, "name") and (
isinstance(obj, io.IOBase)
# Handle temporary files used during testing
or isinstance(obj, tempfile._TemporaryFileWrapper)
):
# Hash files as name + last modification date + offset.
# NB: we're using hasattr("name") to differentiate between
# on-disk and in-memory StringIO/BytesIO file representations.
# That means that this condition must come *before* the next
# condition, which just checks for StringIO/BytesIO.
h = hashlib.new("md5")
obj_name = getattr(obj, "name", "wonthappen") # Just to appease MyPy.
self.update(h, obj_name)
self.update(h, os.path.getmtime(obj_name))
self.update(h, obj.tell())
return h.digest()
elif isinstance(obj, Pattern):
return self.to_bytes([obj.pattern, obj.flags])
elif isinstance(obj, io.StringIO) or isinstance(obj, io.BytesIO):
# Hash in-memory StringIO/BytesIO by their full contents
# and seek position.
h = hashlib.new("md5")
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif type_util.is_type(obj, "numpy.ufunc"):
# For numpy.remainder, this returns remainder.
return bytes(obj.__name__.encode())
elif inspect.ismodule(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# so the current warning is quite annoying...
# st.warning(('Streamlit does not support hashing modules. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name for internal modules.
return self.to_bytes(obj.__name__)
elif inspect.isclass(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# (e.g. in every "except" statement) so the current warning is
# quite annoying...
# st.warning(('Streamlit does not support hashing classes. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name of classes.
return self.to_bytes(obj.__name__)
elif isinstance(obj, functools.partial):
# The return value of functools.partial is not a plain function:
# it's a callable object that remembers the original function plus
# the values you pickled into it. So here we need to special-case it.
h = hashlib.new("md5")
self.update(h, obj.args)
self.update(h, obj.func)
self.update(h, obj.keywords)
return h.digest()
else:
# As a last resort, hash the output of the object's __reduce__ method
h = hashlib.new("md5")
try:
reduce_data = obj.__reduce__()
except BaseException as e:
raise UnhashableTypeError() from e
for item in reduce_data:
self.update(h, item)
return h.digest()
class NoResult:
"""Placeholder class for return values when None is meaningful."""
pass

View File

@@ -0,0 +1,495 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.memo: pickle-based caching"""
import os
import pickle
import shutil
import threading
import time
import types
from typing import Optional, Any, Dict, cast, List, Callable, TypeVar, overload
from typing import Union
import math
from cachetools import TTLCache
from streamlit import util
from streamlit.errors import StreamlitAPIException
from streamlit.file_util import (
streamlit_read,
streamlit_write,
get_streamlit_file_path,
)
from streamlit.logger import get_logger
from streamlit.stats import CacheStatsProvider, CacheStat
from .cache_errors import (
CacheError,
CacheKeyNotFoundError,
CacheType,
)
from .cache_utils import (
Cache,
create_cache_wrapper,
CachedFunctionCallStack,
CachedFunction,
)
_LOGGER = get_logger(__name__)
# The timer function we use with TTLCache. This is the default timer func, but
# is exposed here as a constant so that it can be patched in unit tests.
_TTLCACHE_TIMER = time.monotonic
# Streamlit directory where persisted memoized items live.
# (This is the same directory that @st.cache persisted items live. But memoized
# items have a different extension, so they don't overlap.)
_CACHE_DIR_NAME = "cache"
MEMO_CALL_STACK = CachedFunctionCallStack(CacheType.MEMO)
class MemoCaches(CacheStatsProvider):
"""Manages all MemoCache instances"""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: Dict[str, "MemoCache"] = {}
def get_cache(
self,
key: str,
persist: Optional[str],
max_entries: Optional[Union[int, float]],
ttl: Optional[Union[int, float]],
display_name: str,
) -> "MemoCache":
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
if max_entries is None:
max_entries = math.inf
if ttl is None:
ttl = math.inf
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if (
cache is not None
and cache.ttl == ttl
and cache.max_entries == max_entries
and cache.persist == persist
):
return cache
# Create a new cache object and put it in our dict
_LOGGER.debug(
"Creating new MemoCache (key=%s, persist=%s, max_entries=%s, ttl=%s)",
key,
persist,
max_entries,
ttl,
)
cache = MemoCache(
key=key,
persist=persist,
max_entries=max_entries,
ttl=ttl,
display_name=display_name,
)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all in-memory and on-disk caches."""
with self._caches_lock:
self._function_caches = {}
# TODO: Only delete disk cache for functions related to the user's
# current script.
cache_path = get_cache_path()
if os.path.isdir(cache_path):
shutil.rmtree(cache_path)
def get_stats(self) -> List[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: List[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return stats
# Singleton MemoCaches instance
_memo_caches = MemoCaches()
def get_memo_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all memoized functions."""
return _memo_caches
class MemoizedFunction(CachedFunction):
"""Implements the CachedFunction protocol for @st.memo"""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool,
suppress_st_warning: bool,
persist: Optional[str],
max_entries: Optional[int],
ttl: Optional[float],
):
super().__init__(func, show_spinner, suppress_st_warning)
self.persist = persist
self.max_entries = max_entries
self.ttl = ttl
@property
def cache_type(self) -> CacheType:
return CacheType.MEMO
@property
def call_stack(self) -> CachedFunctionCallStack:
return MEMO_CALL_STACK
@property
def display_name(self) -> str:
"""A human-readable name for the cached function"""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _memo_caches.get_cache(
key=function_key,
persist=self.persist,
max_entries=self.max_entries,
ttl=self.ttl,
display_name=self.display_name,
)
class MemoAPI:
"""Implements the public st.memo API: the @st.memo decorator, and
st.memo.clear().
"""
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
@staticmethod
def __call__(func: F) -> F:
...
# Decorator with arguments
@overload
@staticmethod
def __call__(
*,
persist: Optional[str] = None,
show_spinner: bool = True,
suppress_st_warning: bool = False,
max_entries: Optional[int] = None,
ttl: Optional[float] = None,
) -> Callable[[F], F]:
...
@staticmethod
def __call__(
func: Optional[F] = None,
*,
persist: Optional[str] = None,
show_spinner: bool = True,
suppress_st_warning: bool = False,
max_entries: Optional[int] = None,
ttl: Optional[float] = None,
):
"""Function decorator to memoize function executions.
Memoized data is stored in "pickled" form, which means that the return
value of a memoized function must be pickleable.
Each caller of a memoized function gets its own copy of the cached data.
You can clear a memoized function's cache with f.clear().
Parameters
----------
func : callable
The function to memoize. Streamlit hashes the function's source code.
persist : str or None
Optional location to persist cached data to. Currently, the only
valid value is "disk", which will persist to the local disk.
show_spinner : boolean
Enable the spinner. Default is True to show a spinner when there is
a cache miss.
suppress_st_warning : boolean
Suppress warnings about calling Streamlit functions from within
the cached function.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. (When a new entry is added to a full cache,
the oldest cached entry will be removed.) The default is None.
ttl : float or None
The maximum number of seconds to keep an entry in the cache, or
None if cache entries should not expire. The default is None.
Example
-------
>>> @st.experimental_memo
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
...
>>> d1 = fetch_and_clean_data(DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> d2 = fetch_and_clean_data(DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the data in d1 is the same as in d2.
>>>
>>> d3 = fetch_and_clean_data(DATA_URL_2)
>>> # This is a different URL, so the function executes.
To set the ``persist`` parameter, use this command as follows:
>>> @st.experimental_memo(persist="disk")
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
By default, all parameters to a memoized function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> @st.experimental_memo
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
...
>>> connection = make_database_connection()
>>> d1 = fetch_and_clean_data(connection, num_rows=10)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> another_connection = make_database_connection()
>>> d2 = fetch_and_clean_data(another_connection, num_rows=10)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _database_connection parameter was different
>>> # in both calls.
A memoized function's cache can be procedurally cleared:
>>> @st.experimental_memo
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
...
>>> fetch_and_clean_data.clear()
>>> # Clear all cached entries for this function.
"""
if persist not in (None, "disk"):
# We'll eventually have more persist options.
raise StreamlitAPIException(
f"Unsupported persist option '{persist}'. Valid values are 'disk' or None."
)
# Support passing the params via function decorator, e.g.
# @st.memo(persist=True, show_spinner=False)
if func is None:
return lambda f: create_cache_wrapper(
MemoizedFunction(
func=f,
persist=persist,
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
max_entries=max_entries,
ttl=ttl,
)
)
return create_cache_wrapper(
MemoizedFunction(
func=cast(types.FunctionType, func),
persist=persist,
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
max_entries=max_entries,
ttl=ttl,
)
)
@staticmethod
def clear() -> None:
"""Clear all in-memory and on-disk memo caches."""
_memo_caches.clear_all()
class MemoCache(Cache):
"""Manages cached values for a single st.memo-ized function."""
def __init__(
self,
key: str,
persist: Optional[str],
max_entries: float,
ttl: float,
display_name: str,
):
self.key = key
self.display_name = display_name
self.persist = persist
self._mem_cache = TTLCache(maxsize=max_entries, ttl=ttl, timer=_TTLCACHE_TIMER)
self._mem_cache_lock = threading.Lock()
@property
def max_entries(self) -> float:
return cast(float, self._mem_cache.maxsize)
@property
def ttl(self) -> float:
return cast(float, self._mem_cache.ttl)
def get_stats(self) -> List[CacheStat]:
stats: List[CacheStat] = []
with self._mem_cache_lock:
for item_key, item_value in self._mem_cache.items():
stats.append(
CacheStat(
category_name="st_memo",
cache_name=self.display_name,
byte_length=len(item_value),
)
)
return stats
def read_value(self, key: str) -> Any:
"""Read a value from the cache. Raise `CacheKeyNotFoundError` if the
value doesn't exist, and `CacheError` if the value exists but can't
be unpickled.
"""
try:
pickled_value = self._read_from_mem_cache(key)
except CacheKeyNotFoundError as e:
if self.persist == "disk":
pickled_value = self._read_from_disk_cache(key)
self._write_to_mem_cache(key, pickled_value)
else:
raise e
try:
return pickle.loads(pickled_value)
except pickle.UnpicklingError as exc:
raise CacheError(f"Failed to unpickle {key}") from exc
def write_value(self, key: str, value: Any) -> None:
"""Write a value to the cache. It must be pickleable."""
try:
pickled_value = pickle.dumps(value)
except pickle.PicklingError as exc:
raise CacheError(f"Failed to pickle {key}") from exc
self._write_to_mem_cache(key, pickled_value)
if self.persist == "disk":
self._write_to_disk_cache(key, pickled_value)
def clear(self) -> None:
with self._mem_cache_lock:
# We keep a lock for the entirety of the clear operation to avoid
# disk cache race conditions.
for key in self._mem_cache.keys():
self._remove_from_disk_cache(key)
self._mem_cache.clear()
def _read_from_mem_cache(self, key: str) -> bytes:
with self._mem_cache_lock:
if key in self._mem_cache:
entry = bytes(self._mem_cache[key])
_LOGGER.debug("Memory cache HIT: %s", key)
return entry
else:
_LOGGER.debug("Memory cache MISS: %s", key)
raise CacheKeyNotFoundError("Key not found in mem cache")
def _read_from_disk_cache(self, key: str) -> bytes:
path = self._get_file_path(key)
try:
with streamlit_read(path, binary=True) as input:
value = input.read()
_LOGGER.debug("Disk cache HIT: %s", key)
return bytes(value)
except FileNotFoundError:
raise CacheKeyNotFoundError("Key not found in disk cache")
except BaseException as e:
_LOGGER.error(e)
raise CacheError("Unable to read from cache") from e
def _write_to_mem_cache(self, key: str, pickled_value: bytes) -> None:
with self._mem_cache_lock:
self._mem_cache[key] = pickled_value
def _write_to_disk_cache(self, key: str, pickled_value: bytes) -> None:
path = self._get_file_path(key)
try:
with streamlit_write(path, binary=True) as output:
output.write(pickled_value)
except util.Error as e:
_LOGGER.debug(e)
# Clean up file so we don't leave zero byte files.
try:
os.remove(path)
except (FileNotFoundError, IOError, OSError):
pass
raise CacheError("Unable to write to cache") from e
def _remove_from_disk_cache(self, key: str) -> None:
"""Delete a cache file from disk. If the file does not exist on disk,
return silently. If another exception occurs, log it. Does not throw.
"""
path = self._get_file_path(key)
try:
os.remove(path)
except FileNotFoundError:
pass
except BaseException as e:
_LOGGER.exception("Unable to remove a file from the disk cache", e)
def _get_file_path(self, value_key: str) -> str:
"""Return the path of the disk cache file for the given value."""
return get_streamlit_file_path(_CACHE_DIR_NAME, f"{self.key}-{value_key}.memo")
def get_cache_path() -> str:
return get_streamlit_file_path(_CACHE_DIR_NAME)

View File

@@ -0,0 +1,289 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.singleton implementation"""
import threading
import types
from typing import Optional, Any, Dict, List, TypeVar, Callable, overload, cast
from pympler import asizeof
from streamlit.logger import get_logger
from streamlit.stats import CacheStatsProvider, CacheStat
from .cache_errors import CacheKeyNotFoundError, CacheType
from .cache_utils import (
Cache,
create_cache_wrapper,
CachedFunctionCallStack,
CachedFunction,
)
_LOGGER = get_logger(__name__)
SINGLETON_CALL_STACK = CachedFunctionCallStack(CacheType.SINGLETON)
class SingletonCaches(CacheStatsProvider):
"""Manages all SingletonCache instances"""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: Dict[str, "SingletonCache"] = {}
def get_cache(self, key: str, display_name: str) -> "SingletonCache":
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if cache is not None:
return cache
# Create a new cache object and put it in our dict
_LOGGER.debug("Creating new SingletonCache (key=%s)", key)
cache = SingletonCache(key=key, display_name=display_name)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all singleton caches."""
with self._caches_lock:
self._function_caches = {}
def get_stats(self) -> List[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: List[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return stats
# Singleton SingletonCaches instance
_singleton_caches = SingletonCaches()
def get_singleton_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all singleton functions."""
return _singleton_caches
class SingletonFunction(CachedFunction):
"""Implements the CachedFunction protocol for @st.singleton"""
@property
def cache_type(self) -> CacheType:
return CacheType.SINGLETON
@property
def call_stack(self) -> CachedFunctionCallStack:
return SINGLETON_CALL_STACK
@property
def display_name(self) -> str:
"""A human-readable name for the cached function"""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _singleton_caches.get_cache(
key=function_key, display_name=self.display_name
)
class SingletonAPI:
"""Implements the public st.singleton API: the @st.singleton decorator,
and st.singleton.clear().
"""
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
@staticmethod
def __call__(func: F) -> F:
...
# Decorator with arguments
@overload
@staticmethod
def __call__(
*,
show_spinner: bool = True,
suppress_st_warning=False,
) -> Callable[[F], F]:
...
@staticmethod
def __call__(
func: Optional[F] = None,
*,
show_spinner: bool = True,
suppress_st_warning=False,
):
"""Function decorator to store singleton objects.
Each singleton object is shared across all users connected to the app.
Singleton objects *must* be thread-safe, because they can be accessed from
multiple threads concurrently.
(If thread-safety is an issue, consider using ``st.session_state`` to
store per-session singleton objects instead.)
You can clear a memoized function's cache with f.clear().
Parameters
----------
func : callable
The function that creates the singleton. Streamlit hashes the
function's source code.
show_spinner : boolean
Enable the spinner. Default is True to show a spinner when there is
a "cache miss" and the singleton is being created.
suppress_st_warning : boolean
Suppress warnings about calling Streamlit functions from within
the singleton function.
Example
-------
>>> @st.experimental_singleton
... def get_database_session(url):
... # Create a database session object that points to the URL.
... return session
...
>>> s1 = get_database_session(SESSION_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(SESSION_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the connection object in s1 is the same as in s2.
>>>
>>> s3 = get_database_session(SESSION_URL_2)
>>> # This is a different URL, so the function executes.
By default, all parameters to a singleton function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> @st.experimental_singleton
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
...
>>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _sessionmaker parameter was different
>>> # in both calls.
A singleton function's cache can be procedurally cleared:
>>> @st.experimental_singleton
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
...
>>> get_database_session.clear()
>>> # Clear all cached entries for this function.
"""
# Support passing the params via function decorator, e.g.
# @st.singleton(show_spinner=False)
if func is None:
return lambda f: create_cache_wrapper(
SingletonFunction(
func=f,
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
)
)
return create_cache_wrapper(
SingletonFunction(
func=cast(types.FunctionType, func),
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
)
)
@staticmethod
def clear() -> None:
"""Clear all singleton caches."""
_singleton_caches.clear_all()
class SingletonCache(Cache):
"""Manages cached values for a single st.singleton function."""
def __init__(self, key: str, display_name: str):
self.key = key
self.display_name = display_name
self._mem_cache: Dict[str, Any] = {}
self._mem_cache_lock = threading.Lock()
def read_value(self, key: str) -> Any:
"""Read a value from the cache. Raise `CacheKeyNotFoundError` if the
value doesn't exist.
"""
with self._mem_cache_lock:
if key in self._mem_cache:
entry = self._mem_cache[key]
return entry
else:
raise CacheKeyNotFoundError()
def write_value(self, key: str, value: Any) -> None:
"""Write a value to the cache."""
with self._mem_cache_lock:
self._mem_cache[key] = value
def clear(self) -> None:
with self._mem_cache_lock:
self._mem_cache.clear()
def get_stats(self) -> List[CacheStat]:
# Shallow clone our cache. Computing item sizes is potentially
# expensive, and we want to minimize the time we spend holding
# the lock.
with self._mem_cache_lock:
mem_cache = self._mem_cache.copy()
stats: List[CacheStat] = []
for item_key, item_value in mem_cache.items():
stats.append(
CacheStat(
category_name="st_singleton",
cache_name=self.display_name,
byte_length=asizeof.asizeof(item_value),
)
)
return stats