# Copyright 2018-2022 Streamlit Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A library of caching utilities.""" import contextlib import functools import hashlib import inspect import math import os import pickle import shutil import threading import time import types from collections import namedtuple from typing import Dict, Optional, List, Iterator, Any, Callable import attr from cachetools import TTLCache from pympler.asizeof import asizeof from streamlit import config from streamlit import file_util from streamlit import util from streamlit.error_util import handle_uncaught_app_exception from streamlit.errors import StreamlitAPIWarning from streamlit.legacy_caching.hashing import update_hash, HashFuncsDict from streamlit.legacy_caching.hashing import HashReason from streamlit.logger import get_logger import streamlit as st from streamlit.stats import CacheStat, CacheStatsProvider _LOGGER = get_logger(__name__) # The timer function we use with TTLCache. This is the default timer func, but # is exposed here as a constant so that it can be patched in unit tests. _TTLCACHE_TIMER = time.monotonic _CacheEntry = namedtuple("_CacheEntry", ["value", "hash"]) _DiskCacheEntry = namedtuple("_DiskCacheEntry", ["value"]) @attr.s(auto_attribs=True, slots=True) class MemCache: cache: TTLCache display_name: str class _MemCaches(CacheStatsProvider): """Manages all in-memory st.cache caches""" def __init__(self): # Contains a cache object for each st.cache'd function self._lock = threading.RLock() self._function_caches: Dict[str, MemCache] = {} def __repr__(self) -> str: return util.repr_(self) def get_cache( self, key: str, max_entries: Optional[float], ttl: Optional[float], display_name: str = "", ) -> MemCache: """Return the mem cache for the given key. If it doesn't exist, create a new one with the given params. """ if max_entries is None: max_entries = math.inf if ttl is None: ttl = math.inf if not isinstance(max_entries, (int, float)): raise RuntimeError("max_entries must be an int") if not isinstance(ttl, (int, float)): raise RuntimeError("ttl must be a float") # Get the existing cache, if it exists, and validate that its params # haven't changed. with self._lock: mem_cache = self._function_caches.get(key) if ( mem_cache is not None and mem_cache.cache.ttl == ttl and mem_cache.cache.maxsize == max_entries ): return mem_cache # Create a new cache object and put it in our dict _LOGGER.debug( "Creating new mem_cache (key=%s, max_entries=%s, ttl=%s)", key, max_entries, ttl, ) ttl_cache = TTLCache(maxsize=max_entries, ttl=ttl, timer=_TTLCACHE_TIMER) mem_cache = MemCache(ttl_cache, display_name) self._function_caches[key] = mem_cache return mem_cache def clear(self) -> None: """Clear all caches""" with self._lock: self._function_caches = {} def get_stats(self) -> List[CacheStat]: with self._lock: # Shallow-clone our caches. We don't want to hold the global # lock during stats-gathering. function_caches = self._function_caches.copy() stats = [ CacheStat("st_cache", cache.display_name, asizeof(c)) for cache in function_caches.values() for c in cache.cache ] return stats # Our singleton _MemCaches instance _mem_caches = _MemCaches() # A thread-local counter that's incremented when we enter @st.cache # and decremented when we exit. class ThreadLocalCacheInfo(threading.local): def __init__(self): self.cached_func_stack: List[types.FunctionType] = [] self.suppress_st_function_warning = 0 def __repr__(self) -> str: return util.repr_(self) _cache_info = ThreadLocalCacheInfo() @contextlib.contextmanager def _calling_cached_function(func: types.FunctionType) -> Iterator[None]: _cache_info.cached_func_stack.append(func) try: yield finally: _cache_info.cached_func_stack.pop() @contextlib.contextmanager def suppress_cached_st_function_warning() -> Iterator[None]: _cache_info.suppress_st_function_warning += 1 try: yield finally: _cache_info.suppress_st_function_warning -= 1 assert _cache_info.suppress_st_function_warning >= 0 def _show_cached_st_function_warning( dg: "st.delta_generator.DeltaGenerator", st_func_name: str, cached_func: types.FunctionType, ) -> None: # Avoid infinite recursion by suppressing additional cached # function warnings from within the cached function warning. with suppress_cached_st_function_warning(): e = CachedStFunctionWarning(st_func_name, cached_func) dg.exception(e) def maybe_show_cached_st_function_warning( dg: "st.delta_generator.DeltaGenerator", st_func_name: str ) -> None: """If appropriate, warn about calling st.foo inside @cache. DeltaGenerator's @_with_element and @_widget wrappers use this to warn the user when they're calling st.foo() from within a function that is wrapped in @st.cache. Parameters ---------- dg : DeltaGenerator The DeltaGenerator to publish the warning to. st_func_name : str The name of the Streamlit function that was called. """ if ( len(_cache_info.cached_func_stack) > 0 and _cache_info.suppress_st_function_warning <= 0 ): cached_func = _cache_info.cached_func_stack[-1] _show_cached_st_function_warning(dg, st_func_name, cached_func) def _read_from_mem_cache( mem_cache: MemCache, key: str, allow_output_mutation: bool, func_or_code: Callable[..., Any], hash_funcs: Optional[HashFuncsDict], ) -> Any: cache = mem_cache.cache if key in cache: entry = cache[key] if not allow_output_mutation: computed_output_hash = _get_output_hash( entry.value, func_or_code, hash_funcs ) stored_output_hash = entry.hash if computed_output_hash != stored_output_hash: _LOGGER.debug("Cached object was mutated: %s", key) raise CachedObjectMutationError(entry.value, func_or_code) _LOGGER.debug("Memory cache HIT: %s", type(entry.value)) return entry.value else: _LOGGER.debug("Memory cache MISS: %s", key) raise CacheKeyNotFoundError("Key not found in mem cache") def _write_to_mem_cache( mem_cache: MemCache, key: str, value: Any, allow_output_mutation: bool, func_or_code: Callable[..., Any], hash_funcs: Optional[HashFuncsDict], ) -> None: if allow_output_mutation: hash = None else: hash = _get_output_hash(value, func_or_code, hash_funcs) mem_cache.display_name = f"{func_or_code.__module__}.{func_or_code.__qualname__}" mem_cache.cache[key] = _CacheEntry(value=value, hash=hash) def _get_output_hash( value: Any, func_or_code: Callable[..., Any], hash_funcs: Optional[HashFuncsDict] ) -> bytes: hasher = hashlib.new("md5") update_hash( value, hasher=hasher, hash_funcs=hash_funcs, hash_reason=HashReason.CACHING_FUNC_OUTPUT, hash_source=func_or_code, ) return hasher.digest() def _read_from_disk_cache(key: str) -> Any: path = file_util.get_streamlit_file_path("cache", "%s.pickle" % key) try: with file_util.streamlit_read(path, binary=True) as input: entry = pickle.load(input) value = entry.value _LOGGER.debug("Disk cache HIT: %s", type(value)) except util.Error as e: _LOGGER.error(e) raise CacheError("Unable to read from cache: %s" % e) except FileNotFoundError: raise CacheKeyNotFoundError("Key not found in disk cache") return value def _write_to_disk_cache(key: str, value: Any) -> None: path = file_util.get_streamlit_file_path("cache", "%s.pickle" % key) try: with file_util.streamlit_write(path, binary=True) as output: entry = _DiskCacheEntry(value=value) pickle.dump(entry, output, pickle.HIGHEST_PROTOCOL) except util.Error as e: _LOGGER.debug(e) # Clean up file so we don't leave zero byte files. try: os.remove(path) except (FileNotFoundError, IOError, OSError): pass raise CacheError("Unable to write to cache: %s" % e) def _read_from_cache( mem_cache: MemCache, key: str, persist: bool, allow_output_mutation: bool, func_or_code: Callable[..., Any], hash_funcs: Optional[HashFuncsDict] = None, ) -> Any: """Read a value from the cache. Our goal is to read from memory if possible. If the data was mutated (hash changed), we show a warning. If reading from memory fails, we either read from disk or rerun the code. """ try: return _read_from_mem_cache( mem_cache, key, allow_output_mutation, func_or_code, hash_funcs ) except CachedObjectMutationError as e: handle_uncaught_app_exception(CachedObjectMutationWarning(e)) return e.cached_value except CacheKeyNotFoundError as e: if persist: value = _read_from_disk_cache(key) _write_to_mem_cache( mem_cache, key, value, allow_output_mutation, func_or_code, hash_funcs ) return value raise e def _write_to_cache( mem_cache: MemCache, key: str, value: Any, persist: bool, allow_output_mutation: bool, func_or_code: Callable[..., Any], hash_funcs: Optional[HashFuncsDict] = None, ): _write_to_mem_cache( mem_cache, key, value, allow_output_mutation, func_or_code, hash_funcs ) if persist: _write_to_disk_cache(key, value) def cache( func=None, persist=False, allow_output_mutation=False, show_spinner=True, suppress_st_warning=False, hash_funcs=None, max_entries=None, ttl=None, ): """Function decorator to memoize function executions. Parameters ---------- func : callable The function to cache. Streamlit hashes the function and dependent code. persist : boolean Whether to persist the cache on disk. allow_output_mutation : boolean Streamlit shows a warning when return values are mutated, as that can have unintended consequences. This is done by hashing the return value internally. If you know what you're doing and would like to override this warning, set this to True. show_spinner : boolean Enable the spinner. Default is True to show a spinner when there is a cache miss. suppress_st_warning : boolean Suppress warnings about calling Streamlit functions from within the cached function. hash_funcs : dict or None Mapping of types or fully qualified names to hash functions. This is used to override the behavior of the hasher inside Streamlit's caching mechanism: when the hasher encounters an object, it will first check to see if its type matches a key in this dict and, if so, will use the provided function to generate a hash for it. See below for an example of how this can be used. max_entries : int or None The maximum number of entries to keep in the cache, or None for an unbounded cache. (When a new entry is added to a full cache, the oldest cached entry will be removed.) The default is None. ttl : float or None The maximum number of seconds to keep an entry in the cache, or None if cache entries should not expire. The default is None. Example ------- >>> @st.cache ... def fetch_and_clean_data(url): ... # Fetch data from URL here, and then clean it up. ... return data ... >>> d1 = fetch_and_clean_data(DATA_URL_1) >>> # Actually executes the function, since this is the first time it was >>> # encountered. >>> >>> d2 = fetch_and_clean_data(DATA_URL_1) >>> # Does not execute the function. Instead, returns its previously computed >>> # value. This means that now the data in d1 is the same as in d2. >>> >>> d3 = fetch_and_clean_data(DATA_URL_2) >>> # This is a different URL, so the function executes. To set the ``persist`` parameter, use this command as follows: >>> @st.cache(persist=True) ... def fetch_and_clean_data(url): ... # Fetch data from URL here, and then clean it up. ... return data To disable hashing return values, set the ``allow_output_mutation`` parameter to ``True``: >>> @st.cache(allow_output_mutation=True) ... def fetch_and_clean_data(url): ... # Fetch data from URL here, and then clean it up. ... return data To override the default hashing behavior, pass a custom hash function. You can do that by mapping a type (e.g. ``MongoClient``) to a hash function (``id``) like this: >>> @st.cache(hash_funcs={MongoClient: id}) ... def connect_to_database(url): ... return MongoClient(url) Alternatively, you can map the type's fully-qualified name (e.g. ``"pymongo.mongo_client.MongoClient"``) to the hash function instead: >>> @st.cache(hash_funcs={"pymongo.mongo_client.MongoClient": id}) ... def connect_to_database(url): ... return MongoClient(url) """ _LOGGER.debug("Entering st.cache: %s", func) # Support passing the params via function decorator, e.g. # @st.cache(persist=True, allow_output_mutation=True) if func is None: return lambda f: cache( func=f, persist=persist, allow_output_mutation=allow_output_mutation, show_spinner=show_spinner, suppress_st_warning=suppress_st_warning, hash_funcs=hash_funcs, max_entries=max_entries, ttl=ttl, ) cache_key = None @functools.wraps(func) def wrapped_func(*args, **kwargs): """This function wrapper will only call the underlying function in the case of a cache miss. Cached objects are stored in the cache/ directory.""" if not config.get_option("client.caching"): _LOGGER.debug("Purposefully skipping cache") return func(*args, **kwargs) name = func.__qualname__ if len(args) == 0 and len(kwargs) == 0: message = "Running `%s()`." % name else: message = "Running `%s(...)`." % name def get_or_create_cached_value(): nonlocal cache_key if cache_key is None: # Delay generating the cache key until the first call. # This way we can see values of globals, including functions # defined after this one. # If we generated the key earlier we would only hash those # globals by name, and miss changes in their code or value. cache_key = _hash_func(func, hash_funcs) # First, get the cache that's attached to this function. # This cache's key is generated (above) from the function's code. mem_cache = _mem_caches.get_cache(cache_key, max_entries, ttl) # Next, calculate the key for the value we'll be searching for # within that cache. This key is generated from both the function's # code and the arguments that are passed into it. (Even though this # key is used to index into a per-function cache, it must be # globally unique, because it is *also* used for a global on-disk # cache that is *not* per-function.) value_hasher = hashlib.new("md5") if args: update_hash( args, hasher=value_hasher, hash_funcs=hash_funcs, hash_reason=HashReason.CACHING_FUNC_ARGS, hash_source=func, ) if kwargs: update_hash( kwargs, hasher=value_hasher, hash_funcs=hash_funcs, hash_reason=HashReason.CACHING_FUNC_ARGS, hash_source=func, ) value_key = value_hasher.hexdigest() # Avoid recomputing the body's hash by just appending the # previously-computed hash to the arg hash. value_key = "%s-%s" % (value_key, cache_key) _LOGGER.debug("Cache key: %s", value_key) try: return_value = _read_from_cache( mem_cache=mem_cache, key=value_key, persist=persist, allow_output_mutation=allow_output_mutation, func_or_code=func, hash_funcs=hash_funcs, ) _LOGGER.debug("Cache hit: %s", func) except CacheKeyNotFoundError: _LOGGER.debug("Cache miss: %s", func) with _calling_cached_function(func): if suppress_st_warning: with suppress_cached_st_function_warning(): return_value = func(*args, **kwargs) else: return_value = func(*args, **kwargs) _write_to_cache( mem_cache=mem_cache, key=value_key, value=return_value, persist=persist, allow_output_mutation=allow_output_mutation, func_or_code=func, hash_funcs=hash_funcs, ) return return_value if show_spinner: with st.spinner(message): return get_or_create_cached_value() else: return get_or_create_cached_value() # Make this a well-behaved decorator by preserving important function # attributes. try: wrapped_func.__dict__.update(func.__dict__) except AttributeError: pass return wrapped_func def _hash_func(func: types.FunctionType, hash_funcs: HashFuncsDict) -> str: # Create the unique key for a function's cache. The cache will be retrieved # from inside the wrapped function. # # A naive implementation would involve simply creating the cache object # right in the wrapper, which in a normal Python script would be executed # only once. But in Streamlit, we reload all modules related to a user's # app when the app is re-run, which means that - among other things - all # function decorators in the app will be re-run, and so any decorator-local # objects will be recreated. # # Furthermore, our caches can be destroyed and recreated (in response to # cache clearing, for example), which means that retrieving the function's # cache in the decorator (so that the wrapped function can save a lookup) # is incorrect: the cache itself may be recreated between # decorator-evaluation time and decorated-function-execution time. So we # must retrieve the cache object *and* perform the cached-value lookup # inside the decorated function. func_hasher = hashlib.new("md5") # Include the function's __module__ and __qualname__ strings in the hash. # This means that two identical functions in different modules # will not share a hash; it also means that two identical *nested* # functions in the same module will not share a hash. # We do not pass `hash_funcs` here, because we don't want our function's # name to get an unexpected hash. update_hash( (func.__module__, func.__qualname__), hasher=func_hasher, hash_funcs=None, hash_reason=HashReason.CACHING_FUNC_BODY, hash_source=func, ) # Include the function's body in the hash. We *do* pass hash_funcs here, # because this step will be hashing any objects referenced in the function # body. update_hash( func, hasher=func_hasher, hash_funcs=hash_funcs, hash_reason=HashReason.CACHING_FUNC_BODY, hash_source=func, ) cache_key = func_hasher.hexdigest() _LOGGER.debug( "mem_cache key for %s.%s: %s", func.__module__, func.__qualname__, cache_key ) return cache_key def clear_cache() -> bool: """Clear the memoization cache. Returns ------- boolean True if the disk cache was cleared. False otherwise (e.g. cache file doesn't exist on disk). """ _clear_mem_cache() return _clear_disk_cache() def get_cache_path() -> str: return file_util.get_streamlit_file_path("cache") def _clear_disk_cache() -> bool: # TODO: Only delete disk cache for functions related to the user's current # script. cache_path = get_cache_path() if os.path.isdir(cache_path): shutil.rmtree(cache_path) return True return False def _clear_mem_cache() -> None: _mem_caches.clear() class CacheError(Exception): pass class CacheKeyNotFoundError(Exception): pass class CachedObjectMutationError(ValueError): """This is used internally, but never shown to the user. Users see CachedObjectMutationWarning instead. """ def __init__(self, cached_value, func_or_code): self.cached_value = cached_value if inspect.iscode(func_or_code): self.cached_func_name = "a code block" else: self.cached_func_name = _get_cached_func_name_md(func_or_code) def __repr__(self) -> str: return util.repr_(self) class CachedStFunctionWarning(StreamlitAPIWarning): def __init__(self, st_func_name, cached_func): msg = self._get_message(st_func_name, cached_func) super(CachedStFunctionWarning, self).__init__(msg) def _get_message(self, st_func_name, cached_func): args = { "st_func_name": "`st.%s()` or `st.write()`" % st_func_name, "func_name": _get_cached_func_name_md(cached_func), } return ( """ Your script uses %(st_func_name)s to write to your Streamlit app from within some cached code at %(func_name)s. This code will only be called when we detect a cache "miss", which can lead to unexpected results. How to fix this: * Move the %(st_func_name)s call outside %(func_name)s. * Or, if you know what you're doing, use `@st.cache(suppress_st_warning=True)` to suppress the warning. """ % args ).strip("\n") class CachedObjectMutationWarning(StreamlitAPIWarning): def __init__(self, orig_exc): msg = self._get_message(orig_exc) super(CachedObjectMutationWarning, self).__init__(msg) def _get_message(self, orig_exc): return ( """ Return value of %(func_name)s was mutated between runs. By default, Streamlit's cache should be treated as immutable, or it may behave in unexpected ways. You received this warning because Streamlit detected that an object returned by %(func_name)s was mutated outside of %(func_name)s. How to fix this: * If you did not mean to mutate that return value: - If possible, inspect your code to find and remove that mutation. - Otherwise, you could also clone the returned value so you can freely mutate it. * If you actually meant to mutate the return value and know the consequences of doing so, annotate the function with `@st.cache(allow_output_mutation=True)`. For more information and detailed solutions check out [our documentation.] (https://docs.streamlit.io/library/advanced-features/caching) """ % {"func_name": orig_exc.cached_func_name} ).strip("\n") def _get_cached_func_name_md(func: types.FunctionType) -> str: """Get markdown representation of the function name.""" if hasattr(func, "__name__"): return "`%s()`" % func.__name__ else: return "a cached function"