first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions

View File

@ -0,0 +1,508 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streamlit.
How to use Streamlit in 3 seconds:
1. Write an app
>>> import streamlit as st
>>> st.write(anything_you_want)
2. Run your app
$ streamlit run my_script.py
3. Use your app
A new tab will open on your browser. That's your Streamlit app!
4. Modify your code, save it, and watch changes live on your browser.
Take a look at the other commands in this module to find out what else
Streamlit can do:
>>> dir(streamlit)
Or try running our "Hello World":
$ streamlit hello
For more detailed info, see https://docs.streamlit.io.
"""
# IMPORTANT: Prefix with an underscore anything that the user shouldn't see.
# Must be at the top, to avoid circular dependency.
from streamlit import logger as _logger
from streamlit import config as _config
from streamlit.proto.RootContainer_pb2 import RootContainer
from streamlit.secrets import Secrets, SECRETS_FILE_LOC
_LOGGER = _logger.get_logger("root")
# Give the package a version.
from importlib_metadata import version as _version
__version__ = _version("streamlit")
from typing import NoReturn
import contextlib as _contextlib
import sys as _sys
import threading as _threading
import urllib.parse as _parse
import click as _click
from streamlit import code_util as _code_util
from streamlit import env_util as _env_util
from streamlit import source_util as _source_util
from streamlit import string_util as _string_util
from streamlit.delta_generator import DeltaGenerator as _DeltaGenerator
from streamlit.scriptrunner import (
add_script_run_ctx as _add_script_run_ctx,
get_script_run_ctx as _get_script_run_ctx,
StopException,
RerunException as _RerunException,
RerunData as _RerunData,
)
from streamlit.errors import StreamlitAPIException
from streamlit.proto import ForwardMsg_pb2 as _ForwardMsg_pb2
# Modules that the user should have access to. These are imported with "as"
# syntax pass mypy checking with implicit_reexport disabled.
from streamlit.echo import echo as echo
from streamlit.legacy_caching import cache as cache
from streamlit.caching import singleton as experimental_singleton
from streamlit.caching import memo as experimental_memo
# This is set to True inside cli._main_run(), and is False otherwise.
# If False, we should assume that DeltaGenerator functions are effectively
# no-ops, and adapt gracefully.
_is_running_with_streamlit = False
def _update_logger():
_logger.set_log_level(_config.get_option("logger.level").upper())
_logger.update_formatter()
_logger.init_tornado_logs()
# Make this file only depend on config option in an asynchronous manner. This
# avoids a race condition when another file (such as a test file) tries to pass
# in an alternative config.
_config.on_config_parsed(_update_logger, True)
_main = _DeltaGenerator(root_container=RootContainer.MAIN)
sidebar = _DeltaGenerator(root_container=RootContainer.SIDEBAR, parent=_main)
secrets = Secrets(SECRETS_FILE_LOC)
# DeltaGenerator methods:
altair_chart = _main.altair_chart
area_chart = _main.area_chart
audio = _main.audio
balloons = _main.balloons
bar_chart = _main.bar_chart
bokeh_chart = _main.bokeh_chart
button = _main.button
caption = _main.caption
camera_input = _main.camera_input
checkbox = _main.checkbox
code = _main.code
columns = _main.columns
container = _main.container
dataframe = _main.dataframe
date_input = _main.date_input
download_button = _main.download_button
expander = _main.expander
pydeck_chart = _main.pydeck_chart
empty = _main.empty
error = _main.error
exception = _main.exception
file_uploader = _main.file_uploader
form = _main.form
form_submit_button = _main.form_submit_button
graphviz_chart = _main.graphviz_chart
header = _main.header
help = _main.help
image = _main.image
info = _main.info
json = _main.json
latex = _main.latex
line_chart = _main.line_chart
map = _main.map
markdown = _main.markdown
metric = _main.metric
multiselect = _main.multiselect
number_input = _main.number_input
plotly_chart = _main.plotly_chart
progress = _main.progress
pyplot = _main.pyplot
radio = _main.radio
selectbox = _main.selectbox
select_slider = _main.select_slider
slider = _main.slider
snow = _main.snow
subheader = _main.subheader
success = _main.success
table = _main.table
text = _main.text
text_area = _main.text_area
text_input = _main.text_input
time_input = _main.time_input
title = _main.title
vega_lite_chart = _main.vega_lite_chart
video = _main.video
warning = _main.warning
write = _main.write
color_picker = _main.color_picker
# Legacy
_legacy_dataframe = _main._legacy_dataframe
_legacy_table = _main._legacy_table
_legacy_altair_chart = _main._legacy_altair_chart
_legacy_area_chart = _main._legacy_area_chart
_legacy_bar_chart = _main._legacy_bar_chart
_legacy_line_chart = _main._legacy_line_chart
_legacy_vega_lite_chart = _main._legacy_vega_lite_chart
# Apache Arrow
_arrow_dataframe = _main._arrow_dataframe
_arrow_table = _main._arrow_table
_arrow_altair_chart = _main._arrow_altair_chart
_arrow_area_chart = _main._arrow_area_chart
_arrow_bar_chart = _main._arrow_bar_chart
_arrow_line_chart = _main._arrow_line_chart
_arrow_vega_lite_chart = _main._arrow_vega_lite_chart
# Config
get_option = _config.get_option
from streamlit.commands.page_config import set_page_config
# Session State
from streamlit.state import SessionStateProxy
session_state = SessionStateProxy()
# Beta APIs
beta_container = _main.beta_container
beta_expander = _main.beta_expander
beta_columns = _main.beta_columns
def set_option(key, value):
"""Set config option.
Currently, only the following config options can be set within the script itself:
* client.caching
* client.displayEnabled
* deprecation.*
Calling with any other options will raise StreamlitAPIException.
Run `streamlit config show` in the terminal to see all available options.
Parameters
----------
key : str
The config option key of the form "section.optionName". To see all
available options, run `streamlit config show` on a terminal.
value
The new value to assign to this config option.
"""
opt = _config._config_options_template[key]
if opt.scriptable:
_config.set_option(key, value)
return
raise StreamlitAPIException(
"{key} cannot be set on the fly. Set as command line option, e.g. streamlit run script.py --{key}, or in config.toml instead.".format(
key=key
)
)
def experimental_show(*args):
"""Write arguments and *argument names* to your app for debugging purposes.
Show() has similar properties to write():
1. You can pass in multiple arguments, all of which will be debugged.
2. It returns None, so it's "slot" in the app cannot be reused.
Note: This is an experimental feature. See
https://docs.streamlit.io/library/advanced-features/prerelease#experimental for more information.
Parameters
----------
*args : any
One or many objects to debug in the App.
Example
-------
>>> dataframe = pd.DataFrame({
... 'first column': [1, 2, 3, 4],
... 'second column': [10, 20, 30, 40],
... })
>>> st.experimental_show(dataframe)
Notes
-----
This is an experimental feature with usage limitations:
- The method must be called with the name `show`.
- Must be called in one line of code, and only once per line.
- When passing multiple arguments the inclusion of `,` or `)` in a string
argument may cause an error.
"""
if not args:
return
try:
import inspect
# Get the calling line of code
current_frame = inspect.currentframe()
if current_frame is None:
warning("`show` not enabled in the shell")
return
if current_frame.f_back is not None:
lines = inspect.getframeinfo(current_frame.f_back)[3]
else:
lines = None
if not lines:
warning("`show` not enabled in the shell")
return
# Parse arguments from the line
line = lines[0].split("show", 1)[1]
inputs = _code_util.get_method_args_from_code(args, line)
# Escape markdown and add deltas
for idx, input in enumerate(inputs):
escaped = _string_util.escape_markdown(input)
markdown("**%s**" % escaped)
write(args[idx])
except Exception:
_, exc, exc_tb = _sys.exc_info()
exception(exc)
def experimental_get_query_params():
"""Return the query parameters that is currently showing in the browser's URL bar.
Returns
-------
dict
The current query parameters as a dict. "Query parameters" are the part of the URL that comes
after the first "?".
Example
-------
Let's say the user's web browser is at
`http://localhost:8501/?show_map=True&selected=asia&selected=america`.
Then, you can get the query parameters using the following:
>>> st.experimental_get_query_params()
{"show_map": ["True"], "selected": ["asia", "america"]}
Note that the values in the returned dict are *always* lists. This is
because we internally use Python's urllib.parse.parse_qs(), which behaves
this way. And this behavior makes sense when you consider that every item
in a query string is potentially a 1-element array.
"""
ctx = _get_script_run_ctx()
if ctx is None:
return ""
return _parse.parse_qs(ctx.query_string)
def experimental_set_query_params(**query_params):
"""Set the query parameters that are shown in the browser's URL bar.
Parameters
----------
**query_params : dict
The query parameters to set, as key-value pairs.
Example
-------
To point the user's web browser to something like
"http://localhost:8501/?show_map=True&selected=asia&selected=america",
you would do the following:
>>> st.experimental_set_query_params(
... show_map=True,
... selected=["asia", "america"],
... )
"""
ctx = _get_script_run_ctx()
if ctx is None:
return
ctx.query_string = _parse.urlencode(query_params, doseq=True)
msg = _ForwardMsg_pb2.ForwardMsg()
msg.page_info_changed.query_string = ctx.query_string
ctx.enqueue(msg)
@_contextlib.contextmanager
def spinner(text="In progress..."):
"""Temporarily displays a message while executing a block of code.
Parameters
----------
text : str
A message to display while executing that block
Example
-------
>>> with st.spinner('Wait for it...'):
>>> time.sleep(5)
>>> st.success('Done!')
"""
import streamlit.legacy_caching.caching as legacy_caching
import streamlit.caching as caching
from streamlit.elements.utils import clean_text
from streamlit.proto.Spinner_pb2 import Spinner as SpinnerProto
# @st.cache optionally uses spinner for long-running computations.
# Normally, streamlit warns the user when they call st functions
# from within an @st.cache'd function. But we do *not* want to show
# these warnings for spinner's message, so we create and mutate this
# message delta within the "suppress_cached_st_function_warning"
# context.
with legacy_caching.suppress_cached_st_function_warning():
with caching.suppress_cached_st_function_warning():
message = empty()
try:
# Set the message 0.1 seconds in the future to avoid annoying
# flickering if this spinner runs too quickly.
DELAY_SECS = 0.1
display_message = True
display_message_lock = _threading.Lock()
def set_message():
with display_message_lock:
if display_message:
with legacy_caching.suppress_cached_st_function_warning():
with caching.suppress_cached_st_function_warning():
spinner_proto = SpinnerProto()
spinner_proto.text = clean_text(text)
message._enqueue("spinner", spinner_proto)
_add_script_run_ctx(_threading.Timer(DELAY_SECS, set_message)).start()
# Yield control back to the context.
yield
finally:
if display_message_lock:
with display_message_lock:
display_message = False
with legacy_caching.suppress_cached_st_function_warning():
with caching.suppress_cached_st_function_warning():
message.empty()
def _transparent_write(*args):
"""This is just st.write, but returns the arguments you passed to it."""
write(*args)
if len(args) == 1:
return args[0]
return args
# We want to show a warning when the user runs a Streamlit script without
# 'streamlit run', but we need to make sure the warning appears only once no
# matter how many times __init__ gets loaded.
_use_warning_has_been_displayed = False
def _maybe_print_use_warning():
"""Print a warning if Streamlit is imported but not being run with `streamlit run`.
The warning is printed only once.
"""
global _use_warning_has_been_displayed
if not _use_warning_has_been_displayed:
_use_warning_has_been_displayed = True
warning = _click.style("Warning:", bold=True, fg="yellow")
if _env_util.is_repl():
_LOGGER.warning(
f"\n {warning} to view a Streamlit app on a browser, use Streamlit in a file and\n run it with the following command:\n\n streamlit run [FILE_NAME] [ARGUMENTS]"
)
elif not _is_running_with_streamlit and _config.get_option(
"global.showWarningOnDirectExecution"
):
script_name = _sys.argv[0]
_LOGGER.warning(
f"\n {warning} to view this Streamlit app on a browser, run it with the following\n command:\n\n streamlit run {script_name} [ARGUMENTS]"
)
def stop() -> NoReturn:
"""Stops execution immediately.
Streamlit will not run any statements after `st.stop()`.
We recommend rendering a message to explain why the script has stopped.
When run outside of Streamlit, this will raise an Exception.
Example
-------
>>> name = st.text_input('Name')
>>> if not name:
>>> st.warning('Please input a name.')
>>> st.stop()
>>> st.success('Thank you for inputting a name.')
"""
raise StopException()
def experimental_rerun():
"""Rerun the script immediately.
When `st.experimental_rerun()` is called, the script is halted - no
more statements will be run, and the script will be queued to re-run
from the top.
If this function is called outside of Streamlit, it will raise an
Exception.
"""
ctx = _get_script_run_ctx()
query_string = "" if ctx is None else ctx.query_string
raise _RerunException(_RerunData(query_string=query_string))

View File

@ -0,0 +1,21 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from streamlit.cli import main
if __name__ == "__main__":
# Set prog_name so that the Streamlit server sees the same command line
# string whether streamlit is called directly or via `python -m streamlit`.
main(prog_name="streamlit")

View File

@ -0,0 +1,642 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import threading
import uuid
from enum import Enum
from typing import TYPE_CHECKING, Callable, Optional, List
from streamlit.uploaded_file_manager import UploadedFileManager
import tornado.ioloop
import streamlit.elements.exception as exception_utils
from streamlit import __version__, caching, config, legacy_caching, secrets
from streamlit.case_converters import to_snake_case
from streamlit.credentials import Credentials
from streamlit.in_memory_file_manager import in_memory_file_manager
from streamlit.logger import get_logger
from streamlit.metrics_util import Installation
from streamlit.proto.ClientState_pb2 import ClientState
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.GitInfo_pb2 import GitInfo
from streamlit.proto.NewSession_pb2 import Config, CustomThemeConfig, UserInfo
from streamlit.session_data import SessionData
from streamlit.scriptrunner import (
RerunData,
ScriptRunner,
ScriptRunnerEvent,
)
from streamlit.watcher import LocalSourcesWatcher
LOGGER = get_logger(__name__)
if TYPE_CHECKING:
from streamlit.state import SessionState
class AppSessionState(Enum):
APP_NOT_RUNNING = "APP_NOT_RUNNING"
APP_IS_RUNNING = "APP_IS_RUNNING"
SHUTDOWN_REQUESTED = "SHUTDOWN_REQUESTED"
def _generate_scriptrun_id() -> str:
"""Randomly generate a unique ID for a script execution."""
return str(uuid.uuid4())
class AppSession:
"""
Contains session data for a single "user" of an active app
(that is, a connected browser tab).
Each AppSession has its own SessionData, root DeltaGenerator, ScriptRunner,
and widget state.
An AppSession is attached to each thread involved in running its script.
"""
def __init__(
self,
ioloop: tornado.ioloop.IOLoop,
session_data: SessionData,
uploaded_file_manager: UploadedFileManager,
message_enqueued_callback: Optional[Callable[[], None]],
local_sources_watcher: LocalSourcesWatcher,
):
"""Initialize the AppSession.
Parameters
----------
ioloop : tornado.ioloop.IOLoop
The Tornado IOLoop that we're running within.
session_data : SessionData
Object storing parameters related to running a script
uploaded_file_manager : UploadedFileManager
The server's UploadedFileManager.
message_enqueued_callback : Callable[[], None]
After enqueuing a message, this callable notification will be invoked.
local_sources_watcher: LocalSourcesWatcher
The file watcher that lets the session know local files have changed.
"""
# Each AppSession has a unique string ID.
self.id = str(uuid.uuid4())
self._ioloop = ioloop
self._session_data = session_data
self._uploaded_file_mgr = uploaded_file_manager
self._message_enqueued_callback = message_enqueued_callback
self._state = AppSessionState.APP_NOT_RUNNING
# Need to remember the client state here because when a script reruns
# due to the source code changing we need to pass in the previous client state.
self._client_state = ClientState()
self._local_sources_watcher = local_sources_watcher
self._local_sources_watcher.register_file_change_callback(
self._on_source_file_changed
)
self._stop_config_listener = config.on_config_parsed(
self._on_source_file_changed, force_connect=True
)
# The script should rerun when the `secrets.toml` file has been changed.
secrets._file_change_listener.connect(self._on_secrets_file_changed)
self._run_on_save = config.get_option("server.runOnSave")
self._scriptrunner: Optional[ScriptRunner] = None
# This needs to be lazily imported to avoid a dependency cycle.
from streamlit.state import SessionState
self._session_state = SessionState()
LOGGER.debug("AppSession initialized (id=%s)", self.id)
def flush_browser_queue(self) -> List[ForwardMsg]:
"""Clear the forward message queue and return the messages it contained.
The Server calls this periodically to deliver new messages
to the browser connected to this app.
Returns
-------
list[ForwardMsg]
The messages that were removed from the queue and should
be delivered to the browser.
"""
return self._session_data.flush_browser_queue()
def shutdown(self) -> None:
"""Shut down the AppSession.
It's an error to use a AppSession after it's been shut down.
"""
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
LOGGER.debug("Shutting down (id=%s)", self.id)
# Clear any unused session files in upload file manager and media
# file manager
self._uploaded_file_mgr.remove_session_files(self.id)
in_memory_file_manager.clear_session_files(self.id)
in_memory_file_manager.del_expired_files()
# Shut down the ScriptRunner, if one is active.
# self._state must not be set to SHUTDOWN_REQUESTED until
# after this is called.
if self._scriptrunner is not None:
self._scriptrunner.request_stop()
self._state = AppSessionState.SHUTDOWN_REQUESTED
self._local_sources_watcher.close()
if self._stop_config_listener is not None:
self._stop_config_listener()
secrets._file_change_listener.disconnect(self._on_secrets_file_changed)
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
"""Enqueue a new ForwardMsg to our browser queue.
This can be called on both the main thread and a ScriptRunner
run thread.
Parameters
----------
msg : ForwardMsg
The message to enqueue
"""
if not config.get_option("client.displayEnabled"):
return
self._session_data.enqueue(msg)
if self._message_enqueued_callback:
self._message_enqueued_callback()
def handle_backmsg_exception(self, e: BaseException) -> None:
"""Handle an Exception raised while processing a BackMsg from the browser."""
# This does a few things:
# 1) Clears the current app in the browser.
# 2) Marks the current app as "stopped" in the browser.
# 3) HACK: Resets any script params that may have been broken (e.g. the
# command-line when rerunning with wrong argv[0])
self._on_scriptrunner_event(
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
)
self._on_scriptrunner_event(
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STARTED
)
self._on_scriptrunner_event(
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
)
# Send an Exception message to the frontend.
# Because _on_scriptrunner_event does its work in an ioloop callback,
# this exception ForwardMsg *must* also be enqueued in a callback,
# so that it will be enqueued *after* the various ForwardMsgs that
# _on_scriptrunner_event sends.
self._ioloop.spawn_callback(
lambda: self._enqueue_forward_msg(self._create_exception_message(e))
)
def request_rerun(self, client_state: Optional[ClientState]) -> None:
"""Signal that we're interested in running the script.
If the script is not already running, it will be started immediately.
Otherwise, a rerun will be requested.
Parameters
----------
client_state : streamlit.proto.ClientState_pb2.ClientState | None
The ClientState protobuf to run the script with, or None
to use previous client state.
"""
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
LOGGER.warning("Discarding rerun request after shutdown")
return
if client_state:
rerun_data = RerunData(
client_state.query_string, client_state.widget_states
)
else:
rerun_data = RerunData()
if self._scriptrunner is not None:
if bool(config.get_option("runner.fastReruns")):
# If fastReruns is enabled, we don't send rerun requests to our
# existing ScriptRunner. Instead, we tell it to shut down. We'll
# then spin up a new ScriptRunner, below, to handle the rerun
# immediately.
self._scriptrunner.request_stop()
self._scriptrunner = None
else:
# fastReruns is not enabled. Send our ScriptRunner a rerun
# request. If the request is accepted, we're done.
success = self._scriptrunner.request_rerun(rerun_data)
if success:
return
# If we are here, then either we have no ScriptRunner, or our
# current ScriptRunner is shutting down and cannot handle a rerun
# request - so we'll create and start a new ScriptRunner.
self._create_scriptrunner(rerun_data)
def _create_scriptrunner(self, initial_rerun_data: RerunData) -> None:
"""Create and run a new ScriptRunner with the given RerunData."""
self._scriptrunner = ScriptRunner(
session_id=self.id,
session_data=self._session_data,
client_state=self._client_state,
session_state=self._session_state,
uploaded_file_mgr=self._uploaded_file_mgr,
initial_rerun_data=initial_rerun_data,
)
self._scriptrunner.on_event.connect(self._on_scriptrunner_event)
self._scriptrunner.start()
@property
def session_state(self) -> "SessionState":
return self._session_state
def _on_source_file_changed(self) -> None:
"""One of our source files changed. Schedule a rerun if appropriate."""
if self._run_on_save:
self.request_rerun(self._client_state)
else:
self._enqueue_forward_msg(self._create_file_change_message())
def _on_secrets_file_changed(self, _) -> None:
"""Called when `secrets._file_change_listener` emits a Signal."""
# NOTE: At the time of writing, this function only calls `_on_source_file_changed`.
# The reason behind creating this function instead of just passing `_on_source_file_changed`
# to `connect` / `disconnect` directly is that every function that is passed to `connect` / `disconnect`
# must have at least one argument for `sender` (in this case we don't really care about it, thus `_`),
# and introducing an unnecessary argument to `_on_source_file_changed` just for this purpose sounded finicky.
self._on_source_file_changed()
def _clear_queue(self) -> None:
self._session_data.clear_browser_queue()
def _on_scriptrunner_event(
self,
sender: Optional[ScriptRunner],
event: ScriptRunnerEvent,
forward_msg: Optional[ForwardMsg] = None,
exception: Optional[BaseException] = None,
client_state: Optional[ClientState] = None,
) -> None:
"""Called when our ScriptRunner emits an event.
This is generally called from the sender ScriptRunner's script thread.
We forward the event on to _handle_scriptrunner_event_on_main_thread,
which will be called on the main thread.
"""
self._ioloop.spawn_callback(
lambda: self._handle_scriptrunner_event_on_main_thread(
sender, event, forward_msg, exception, client_state
)
)
def _handle_scriptrunner_event_on_main_thread(
self,
sender: Optional[ScriptRunner],
event: ScriptRunnerEvent,
forward_msg: Optional[ForwardMsg] = None,
exception: Optional[BaseException] = None,
client_state: Optional[ClientState] = None,
) -> None:
"""Handle a ScriptRunner event.
This function must only be called on the main thread.
Parameters
----------
sender : ScriptRunner | None
The ScriptRunner that emitted the event. (This may be set to
None when called from `handle_backmsg_exception`, if no
ScriptRunner was active when the backmsg exception was raised.)
event : ScriptRunnerEvent
The event type.
forward_msg : ForwardMsg | None
The ForwardMsg to send to the frontend. Set only for the
ENQUEUE_FORWARD_MSG event.
exception : BaseException | None
An exception thrown during compilation. Set only for the
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
client_state : streamlit.proto.ClientState_pb2.ClientState | None
The ScriptRunner's final ClientState. Set only for the
SHUTDOWN event.
"""
assert (
threading.main_thread() == threading.current_thread()
), "This function must only be called on the main thread"
if sender is not self._scriptrunner:
# This event was sent by a non-current ScriptRunner; ignore it.
# This can happen after sppinng up a new ScriptRunner (to handle a
# rerun request, for example) while another ScriptRunner is still
# shutting down. The shutting-down ScriptRunner may still
# emit events.
LOGGER.debug("Ignoring event from non-current ScriptRunner: %s", event)
return
prev_state = self._state
if event == ScriptRunnerEvent.SCRIPT_STARTED:
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
self._state = AppSessionState.APP_IS_RUNNING
self._clear_queue()
self._enqueue_forward_msg(self._create_new_session_message())
elif (
event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
or event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR
):
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
self._state = AppSessionState.APP_NOT_RUNNING
script_succeeded = event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
script_finished_msg = self._create_script_finished_message(
ForwardMsg.FINISHED_SUCCESSFULLY
if script_succeeded
else ForwardMsg.FINISHED_WITH_COMPILE_ERROR
)
self._enqueue_forward_msg(script_finished_msg)
if script_succeeded:
# The script completed successfully: update our
# LocalSourcesWatcher to account for any source code changes
# that change which modules should be watched.
self._local_sources_watcher.update_watched_modules()
else:
# The script didn't complete successfully: send the exception
# to the frontend.
assert (
exception is not None
), "exception must be set for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event"
msg = ForwardMsg()
exception_utils.marshall(
msg.session_event.script_compilation_exception, exception
)
self._enqueue_forward_msg(msg)
elif event == ScriptRunnerEvent.SHUTDOWN:
assert (
client_state is not None
), "client_state must be set for the SHUTDOWN event"
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
# Only clear media files if the script is done running AND the
# session is actually shutting down.
in_memory_file_manager.clear_session_files(self.id)
self._client_state = client_state
self._scriptrunner = None
elif event == ScriptRunnerEvent.ENQUEUE_FORWARD_MSG:
assert (
forward_msg is not None
), "null forward_msg in ENQUEUE_FORWARD_MSG event"
self._enqueue_forward_msg(forward_msg)
# Send a message if our run state changed
app_was_running = prev_state == AppSessionState.APP_IS_RUNNING
app_is_running = self._state == AppSessionState.APP_IS_RUNNING
if app_is_running != app_was_running:
self._enqueue_forward_msg(self._create_session_state_changed_message())
def _create_session_state_changed_message(self) -> ForwardMsg:
"""Create and return a session_state_changed ForwardMsg."""
msg = ForwardMsg()
msg.session_state_changed.run_on_save = self._run_on_save
msg.session_state_changed.script_is_running = (
self._state == AppSessionState.APP_IS_RUNNING
)
return msg
def _create_file_change_message(self) -> ForwardMsg:
"""Create and return a 'script_changed_on_disk' ForwardMsg."""
msg = ForwardMsg()
msg.session_event.script_changed_on_disk = True
return msg
def _create_new_session_message(self) -> ForwardMsg:
"""Create and return a new_session ForwardMsg."""
msg = ForwardMsg()
msg.new_session.script_run_id = _generate_scriptrun_id()
msg.new_session.name = self._session_data.name
msg.new_session.main_script_path = self._session_data.main_script_path
_populate_config_msg(msg.new_session.config)
_populate_theme_msg(msg.new_session.custom_theme)
# Immutable session data. We send this every time a new session is
# started, to avoid having to track whether the client has already
# received it. It does not change from run to run; it's up to the
# to perform one-time initialization only once.
imsg = msg.new_session.initialize
_populate_user_info_msg(imsg.user_info)
imsg.environment_info.streamlit_version = __version__
imsg.environment_info.python_version = ".".join(map(str, sys.version_info))
imsg.session_state.run_on_save = self._run_on_save
imsg.session_state.script_is_running = (
self._state == AppSessionState.APP_IS_RUNNING
)
imsg.command_line = self._session_data.command_line
imsg.session_id = self.id
return msg
def _create_script_finished_message(
self, status: "ForwardMsg.ScriptFinishedStatus.ValueType"
) -> ForwardMsg:
"""Create and return a script_finished ForwardMsg."""
msg = ForwardMsg()
msg.script_finished = status
return msg
def _create_exception_message(self, e: BaseException) -> ForwardMsg:
"""Create and return an Exception ForwardMsg."""
msg = ForwardMsg()
exception_utils.marshall(msg.delta.new_element.exception, e)
return msg
def handle_git_information_request(self) -> None:
msg = ForwardMsg()
try:
from streamlit.git_util import GitRepo
repo = GitRepo(self._session_data.main_script_path)
repo_info = repo.get_repo_info()
if repo_info is None:
return
repository_name, branch, module = repo_info
msg.git_info_changed.repository = repository_name
msg.git_info_changed.branch = branch
msg.git_info_changed.module = module
msg.git_info_changed.untracked_files[:] = repo.untracked_files
msg.git_info_changed.uncommitted_files[:] = repo.uncommitted_files
if repo.is_head_detached:
msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED
elif len(repo.ahead_commits) > 0:
msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE
else:
msg.git_info_changed.state = GitInfo.GitStates.DEFAULT
self._enqueue_forward_msg(msg)
except Exception as e:
# Users may never even install Git in the first place, so this
# error requires no action. It can be useful for debugging.
LOGGER.debug("Obtaining Git information produced an error", exc_info=e)
def handle_rerun_script_request(
self, client_state: Optional[ClientState] = None
) -> None:
"""Tell the ScriptRunner to re-run its script.
Parameters
----------
client_state : streamlit.proto.ClientState_pb2.ClientState | None
The ClientState protobuf to run the script with, or None
to use previous client state.
"""
self.request_rerun(client_state)
def handle_stop_script_request(self) -> None:
"""Tell the ScriptRunner to stop running its script."""
if self._scriptrunner is not None:
self._scriptrunner.request_stop()
def handle_clear_cache_request(self) -> None:
"""Clear this app's cache.
Because this cache is global, it will be cleared for all users.
"""
legacy_caching.clear_cache()
caching.memo.clear()
caching.singleton.clear()
self._session_state.clear()
def handle_set_run_on_save_request(self, new_value: bool) -> None:
"""Change our run_on_save flag to the given value.
The browser will be notified of the change.
Parameters
----------
new_value : bool
New run_on_save value
"""
self._run_on_save = new_value
self._enqueue_forward_msg(self._create_session_state_changed_message())
def _populate_config_msg(msg: Config) -> None:
msg.gather_usage_stats = config.get_option("browser.gatherUsageStats")
msg.max_cached_message_age = config.get_option("global.maxCachedMessageAge")
msg.mapbox_token = config.get_option("mapbox.token")
msg.allow_run_on_save = config.get_option("server.allowRunOnSave")
msg.hide_top_bar = config.get_option("ui.hideTopBar")
def _populate_theme_msg(msg: CustomThemeConfig) -> None:
enum_encoded_options = {"base", "font"}
theme_opts = config.get_options_for_section("theme")
if not any(theme_opts.values()):
return
for option_name, option_val in theme_opts.items():
if option_name not in enum_encoded_options and option_val is not None:
setattr(msg, to_snake_case(option_name), option_val)
# NOTE: If unset, base and font will default to the protobuf enum zero
# values, which are BaseTheme.LIGHT and FontFamily.SANS_SERIF,
# respectively. This is why we both don't handle the cases explicitly and
# also only log a warning when receiving invalid base/font options.
base_map = {
"light": msg.BaseTheme.LIGHT,
"dark": msg.BaseTheme.DARK,
}
base = theme_opts["base"]
if base is not None:
if base not in base_map:
LOGGER.warning(
f'"{base}" is an invalid value for theme.base.'
f" Allowed values include {list(base_map.keys())}."
' Setting theme.base to "light".'
)
else:
msg.base = base_map[base]
font_map = {
"sans serif": msg.FontFamily.SANS_SERIF,
"serif": msg.FontFamily.SERIF,
"monospace": msg.FontFamily.MONOSPACE,
}
font = theme_opts["font"]
if font is not None:
if font not in font_map:
LOGGER.warning(
f'"{font}" is an invalid value for theme.font.'
f" Allowed values include {list(font_map.keys())}."
' Setting theme.font to "sans serif".'
)
else:
msg.font = font_map[font]
def _populate_user_info_msg(msg: UserInfo) -> None:
msg.installation_id = Installation.instance().installation_id
msg.installation_id_v3 = Installation.instance().installation_id_v3
if Credentials.get_current().activation:
msg.email = Credentials.get_current().activation.email
else:
msg.email = ""

View File

@ -0,0 +1,125 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import streamlit
def _show_beta_warning(name: str, date: str) -> None:
streamlit.warning(
f"Please replace `st.beta_{name}` with `st.{name}`.\n\n"
f"`st.beta_{name}` will be removed after {date}."
)
def function_beta_warning(func, date):
"""Wrapper for functions that are no longer in beta.
Wrapped functions will run as normal, but then proceed to show an st.warning
saying that the beta_ version will be removed in ~3 months.
Parameters
----------
func: callable
The `st.` function that used to be in beta.
date: str
A date like "2020-01-01", indicating the last day we'll guarantee
support for the beta_ prefix.
"""
def wrapped_func(*args, **kwargs):
# Note: Since we use a wrapper, beta_ functions will not autocomplete
# correctly on VSCode.
result = func(*args, **kwargs)
_show_beta_warning(func.__name__, date)
return result
# Update the wrapped func's name & docstring so st.help does the right thing
wrapped_func.__name__ = "beta_" + func.__name__
wrapped_func.__doc__ = func.__doc__
return wrapped_func
def object_beta_warning(obj, obj_name, date):
"""Wrapper for objects that are no longer in beta.
Wrapped objects will run as normal, but then proceed to show an st.warning
saying that the beta_ version will be removed in ~3 months.
Parameters
----------
obj: Any
The `st.` object that used to be in beta.
obj_name: str
The name of the object within __init__.py
date: str
A date like "2020-01-01", indicating the last day we'll guarantee
support for the beta_ prefix.
"""
has_shown_beta_warning = False
def show_wrapped_obj_warning():
nonlocal has_shown_beta_warning
if not has_shown_beta_warning:
has_shown_beta_warning = True
_show_beta_warning(obj_name, date)
class Wrapper:
def __init__(self, obj):
self._obj = obj
# Override all the Wrapped object's magic functions
for name in Wrapper._get_magic_functions(obj.__class__):
setattr(
self.__class__,
name,
property(self._make_magic_function_proxy(name)),
)
def __getattr__(self, attr):
# We handle __getattr__ separately from our other magic
# functions. The wrapped class may not actually implement it,
# but we still need to implement it to call all its normal
# functions.
if attr in self.__dict__:
return getattr(self, attr)
show_wrapped_obj_warning()
return getattr(self._obj, attr)
@staticmethod
def _get_magic_functions(cls) -> List[str]:
# ignore the handful of magic functions we cannot override without
# breaking the Wrapper.
ignore = ("__class__", "__dict__", "__getattribute__", "__getattr__")
return [
name
for name in dir(cls)
if name not in ignore and name.startswith("__")
]
@staticmethod
def _make_magic_function_proxy(name):
def proxy(self, *args):
show_wrapped_obj_warning()
return getattr(self._obj, name)
return proxy
return Wrapper(obj)

View File

@ -0,0 +1,358 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import signal
import sys
from typing import Any, Dict, List, Optional
import click
import tornado.ioloop
from streamlit import session_data
from streamlit.git_util import GitRepo, MIN_GIT_VERSION
from streamlit import version
from streamlit import config
from streamlit import net_util
from streamlit import url_util
from streamlit import env_util
from streamlit import secrets
from streamlit import util
from streamlit.config import CONFIG_FILENAMES
from streamlit.logger import get_logger
from streamlit.secrets import SECRETS_FILE_LOC
from streamlit.server.server import Server, server_address_is_unix_socket
from streamlit.watcher import report_watchdog_availability, watch_file
LOGGER = get_logger(__name__)
# Wait for 1 second before opening a browser. This gives old tabs a chance to
# reconnect.
# This must be >= 2 * WebSocketConnection.ts#RECONNECT_WAIT_TIME_MS.
BROWSER_WAIT_TIMEOUT_SEC = 1
NEW_VERSION_TEXT = """
%(new_version)s
See what's new at https://discuss.streamlit.io/c/announcements
Enter the following command to upgrade:
%(prompt)s %(command)s
""" % {
"new_version": click.style(
"A new version of Streamlit is available.", fg="blue", bold=True
),
"prompt": click.style("$", fg="blue"),
"command": click.style("pip install streamlit --upgrade", bold=True),
}
def _set_up_signal_handler() -> None:
LOGGER.debug("Setting up signal handler")
def signal_handler(signal_number, stack_frame):
# The server will shut down its threads and stop the ioloop
Server.get_current().stop(from_signal=True)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if sys.platform == "win32":
signal.signal(signal.SIGBREAK, signal_handler)
else:
signal.signal(signal.SIGQUIT, signal_handler)
def _fix_sys_path(main_script_path: str) -> None:
"""Add the script's folder to the sys path.
Python normally does this automatically, but since we exec the script
ourselves we need to do it instead.
"""
sys.path.insert(0, os.path.dirname(main_script_path))
def _fix_matplotlib_crash() -> None:
"""Set Matplotlib backend to avoid a crash.
The default Matplotlib backend crashes Python on OSX when run on a thread
that's not the main thread, so here we set a safer backend as a fix.
Users can always disable this behavior by setting the config
runner.fixMatplotlib = false.
This fix is OS-independent. We didn't see a good reason to make this
Mac-only. Consistency within Streamlit seemed more important.
"""
if config.get_option("runner.fixMatplotlib"):
try:
# TODO: a better option may be to set
# os.environ["MPLBACKEND"] = "Agg". We'd need to do this towards
# the top of __init__.py, before importing anything that imports
# pandas (which imports matplotlib). Alternately, we could set
# this environment variable in a new entrypoint defined in
# setup.py. Both of these introduce additional trickiness: they
# need to run without consulting streamlit.config.get_option,
# because this would import streamlit, and therefore matplotlib.
import matplotlib
matplotlib.use("Agg")
except ImportError:
pass
def _fix_tornado_crash() -> None:
"""Set default asyncio policy to be compatible with Tornado 6.
Tornado 6 (at least) is not compatible with the default
asyncio implementation on Windows. So here we
pick the older SelectorEventLoopPolicy when the OS is Windows
if the known-incompatible default policy is in use.
This has to happen as early as possible to make it a low priority and
overrideable
See: https://github.com/tornadoweb/tornado/issues/2608
FIXME: if/when tornado supports the defaults in asyncio,
remove and bump tornado requirement for py38
"""
if env_util.IS_WINDOWS and sys.version_info >= (3, 8):
import asyncio
try:
from asyncio import ( # type: ignore[attr-defined]
WindowsProactorEventLoopPolicy,
WindowsSelectorEventLoopPolicy,
)
except ImportError:
pass
# Not affected
else:
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
# WindowsProactorEventLoopPolicy is not compatible with
# Tornado 6 fallback to the pre-3.8 default of Selector
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
def _fix_sys_argv(main_script_path: str, args: List[str]) -> None:
"""sys.argv needs to exclude streamlit arguments and parameters
and be set to what a user's script may expect.
"""
import sys
sys.argv = [main_script_path] + list(args)
def _on_server_start(server: Server) -> None:
_maybe_print_old_git_warning(server.main_script_path)
_print_url(server.is_running_hello)
report_watchdog_availability()
_print_new_version_message()
# Load secrets.toml if it exists. If the file doesn't exist, this
# function will return without raising an exception. We catch any parse
# errors and display them here.
try:
secrets.load_if_toml_exists()
except BaseException as e:
LOGGER.error(f"Failed to load {SECRETS_FILE_LOC}", exc_info=e)
def maybe_open_browser():
if config.get_option("server.headless"):
# Don't open browser when in headless mode.
return
if server.browser_is_connected:
# Don't auto-open browser if there's already a browser connected.
# This can happen if there's an old tab repeatedly trying to
# connect, and it happens to success before we launch the browser.
return
if config.is_manually_set("browser.serverAddress"):
addr = config.get_option("browser.serverAddress")
elif config.is_manually_set("server.address"):
if server_address_is_unix_socket():
# Don't open browser when server address is an unix socket
return
addr = config.get_option("server.address")
else:
addr = "localhost"
util.open_browser(session_data.get_url(addr))
# Schedule the browser to open using the IO Loop on the main thread, but
# only if no other browser connects within 1s.
ioloop = tornado.ioloop.IOLoop.current()
ioloop.call_later(BROWSER_WAIT_TIMEOUT_SEC, maybe_open_browser)
def _fix_pydeck_mapbox_api_warning() -> None:
"""Sets MAPBOX_API_KEY environment variable needed for PyDeck otherwise it will throw an exception"""
os.environ["MAPBOX_API_KEY"] = config.get_option("mapbox.token")
def _print_new_version_message() -> None:
if version.should_show_new_version_notice():
click.secho(NEW_VERSION_TEXT)
def _print_url(is_running_hello: bool) -> None:
if is_running_hello:
title_message = "Welcome to Streamlit. Check out our demo in your browser."
else:
title_message = "You can now view your Streamlit app in your browser."
named_urls = []
if config.is_manually_set("browser.serverAddress"):
named_urls = [
("URL", session_data.get_url(config.get_option("browser.serverAddress")))
]
elif (
config.is_manually_set("server.address") and not server_address_is_unix_socket()
):
named_urls = [
("URL", session_data.get_url(config.get_option("server.address"))),
]
elif config.get_option("server.headless"):
internal_ip = net_util.get_internal_ip()
if internal_ip:
named_urls.append(("Network URL", session_data.get_url(internal_ip)))
external_ip = net_util.get_external_ip()
if external_ip:
named_urls.append(("External URL", session_data.get_url(external_ip)))
else:
named_urls = [
("Local URL", session_data.get_url("localhost")),
]
internal_ip = net_util.get_internal_ip()
if internal_ip:
named_urls.append(("Network URL", session_data.get_url(internal_ip)))
click.secho("")
click.secho(" %s" % title_message, fg="blue", bold=True)
click.secho("")
for url_name, url in named_urls:
url_util.print_url(url_name, url)
click.secho("")
if is_running_hello:
click.secho(" Ready to create your own Python apps super quickly?")
click.secho(" Head over to ", nl=False)
click.secho("https://docs.streamlit.io", bold=True)
click.secho("")
click.secho(" May you create awesome apps!")
click.secho("")
click.secho("")
def _maybe_print_old_git_warning(main_script_path: str) -> None:
"""If our script is running in a Git repo, and we're running a very old
Git version, print a warning that Git integration will be unavailable.
"""
repo = GitRepo(main_script_path)
if (
not repo.is_valid()
and repo.git_version is not None
and repo.git_version < MIN_GIT_VERSION
):
git_version_string = ".".join(str(val) for val in repo.git_version)
min_version_string = ".".join(str(val) for val in MIN_GIT_VERSION)
click.secho("")
click.secho(" Git integration is disabled.", fg="yellow", bold=True)
click.secho("")
click.secho(
f" Streamlit requires Git {min_version_string} or later, "
f"but you have {git_version_string}.",
fg="yellow",
)
click.secho(
" Git is used by Streamlit Cloud (https://streamlit.io/cloud).",
fg="yellow",
)
click.secho(" To enable this feature, please update Git.", fg="yellow")
def load_config_options(flag_options: Dict[str, Any]) -> None:
"""Load config options from config.toml files, then overlay the ones set by
flag_options.
The "streamlit run" command supports passing Streamlit's config options
as flags. This function reads through the config options set via flag,
massages them, and passes them to get_config_options() so that they
overwrite config option defaults and those loaded from config.toml files.
Parameters
----------
flag_options : Dict[str, Any]
A dict of config options where the keys are the CLI flag version of the
config option names.
"""
options_from_flags = {
name.replace("_", "."): val
for name, val in flag_options.items()
if val is not None
}
# Force a reparse of config files (if they exist). The result is cached
# for future calls.
config.get_config_options(force_reparse=True, options_from_flags=options_from_flags)
def _install_config_watchers(flag_options: Dict[str, Any]) -> None:
def on_config_changed(_path):
load_config_options(flag_options)
for filename in CONFIG_FILENAMES:
if os.path.exists(filename):
watch_file(filename, on_config_changed)
def run(
main_script_path: str,
command_line: Optional[str],
args: List[str],
flag_options: Dict[str, Any],
) -> None:
"""Run a script in a separate thread and start a server for the app.
This starts a blocking ioloop.
"""
_fix_sys_path(main_script_path)
_fix_matplotlib_crash()
_fix_tornado_crash()
_fix_sys_argv(main_script_path, args)
_fix_pydeck_mapbox_api_warning()
_install_config_watchers(flag_options)
# Install a signal handler that will shut down the ioloop
# and close all our threads
_set_up_signal_handler()
ioloop = tornado.ioloop.IOLoop.current()
# Create and start the server.
server = Server(ioloop, main_script_path, command_line)
server.start(_on_server_start)
# Start the ioloop. This function will not return until the
# server is shut down.
ioloop.start()

View File

@ -0,0 +1,42 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Iterator
from .memo_decorator import MEMO_CALL_STACK, _memo_caches, MemoAPI
from .singleton_decorator import SINGLETON_CALL_STACK, _singleton_caches, SingletonAPI
def maybe_show_cached_st_function_warning(dg, st_func_name: str) -> None:
MEMO_CALL_STACK.maybe_show_cached_st_function_warning(dg, st_func_name)
SINGLETON_CALL_STACK.maybe_show_cached_st_function_warning(dg, st_func_name)
@contextlib.contextmanager
def suppress_cached_st_function_warning() -> Iterator[None]:
with MEMO_CALL_STACK.suppress_cached_st_function_warning(), SINGLETON_CALL_STACK.suppress_cached_st_function_warning():
yield
# Explicitly export public symobls
from .memo_decorator import (
get_memo_stats_provider as get_memo_stats_provider,
)
from .singleton_decorator import (
get_singleton_stats_provider as get_singleton_stats_provider,
)
# Create and export public API singletons.
memo = MemoAPI()
singleton = SingletonAPI()

View File

@ -0,0 +1,119 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import types
from typing import Any, Optional
from streamlit import type_util
from streamlit.errors import (
StreamlitAPIWarning,
StreamlitAPIException,
)
class CacheType(enum.Enum):
MEMO = "experimental_memo"
SINGLETON = "experimental_singleton"
class UnhashableTypeError(Exception):
pass
class UnhashableParamError(StreamlitAPIException):
def __init__(
self,
cache_type: CacheType,
func: types.FunctionType,
arg_name: Optional[str],
arg_value: Any,
orig_exc: BaseException,
):
msg = self._create_message(cache_type, func, arg_name, arg_value)
super().__init__(msg)
self.with_traceback(orig_exc.__traceback__)
@staticmethod
def _create_message(
cache_type: CacheType,
func: types.FunctionType,
arg_name: Optional[str],
arg_value: Any,
) -> str:
arg_name_str = arg_name if arg_name is not None else "(unnamed)"
arg_type = type_util.get_fqn_type(arg_value)
func_name = func.__name__
arg_replacement_name = f"_{arg_name}" if arg_name is not None else "_arg"
return (
f"""
Cannot hash argument '{arg_name_str}' (of type `{arg_type}`) in '{func_name}'.
To address this, you can tell Streamlit not to hash this argument by adding a
leading underscore to the argument's name in the function signature:
```
@st.{cache_type.value}
def {func_name}({arg_replacement_name}, ...):
...
```
"""
).strip("\n")
class CacheKeyNotFoundError(Exception):
pass
class CacheError(Exception):
pass
class CachedStFunctionWarning(StreamlitAPIWarning):
def __init__(
self,
cache_type: CacheType,
st_func_name: str,
cached_func: types.FunctionType,
):
args = {
"st_func_name": f"`st.{st_func_name}()` or `st.write()`",
"func_name": self._get_cached_func_name_md(cached_func),
"decorator_name": cache_type.value,
}
msg = (
"""
Your script uses %(st_func_name)s to write to your Streamlit app from within
some cached code at %(func_name)s. This code will only be called when we detect
a cache "miss", which can lead to unexpected results.
How to fix this:
* Move the %(st_func_name)s call outside %(func_name)s.
* Or, if you know what you're doing, use `@st.%(decorator_name)s(suppress_st_warning=True)`
to suppress the warning.
"""
% args
).strip("\n")
super().__init__(msg)
@staticmethod
def _get_cached_func_name_md(func: types.FunctionType) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return "`%s()`" % func.__name__
else:
return "a cached function"

View File

@ -0,0 +1,343 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common cache logic shared by st.memo and st.singleton."""
import contextlib
import functools
import hashlib
import inspect
import threading
import types
from abc import abstractmethod
from typing import Callable, List, Iterator, Tuple, Optional, Any, Union
import streamlit as st
from streamlit import util
from streamlit.caching.cache_errors import CacheKeyNotFoundError
from streamlit.logger import get_logger
from .cache_errors import (
CacheType,
CachedStFunctionWarning,
UnhashableParamError,
UnhashableTypeError,
)
from .hashing import update_hash
_LOGGER = get_logger(__name__)
class Cache:
"""Function cache interface. Caches persist across script runs."""
@abstractmethod
def read_value(self, value_key: str) -> Any:
"""Read a value from the cache.
Raises
------
CacheKeyNotFoundError
Raised if value_key is not in the cache.
"""
raise NotImplementedError
@abstractmethod
def write_value(self, value_key: str, value: Any) -> None:
"""Write a value to the cache, overwriting any existing value that
uses the value_key.
"""
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
"""Clear all values from this function cache."""
raise NotImplementedError
class CachedFunction:
"""Encapsulates data for a cached function instance.
CachedFunction instances are scoped to a single script run - they're not
persistent.
"""
def __init__(
self, func: types.FunctionType, show_spinner: bool, suppress_st_warning: bool
):
self.func = func
self.show_spinner = show_spinner
self.suppress_st_warning = suppress_st_warning
@property
def cache_type(self) -> CacheType:
raise NotImplementedError
@property
def call_stack(self) -> "CachedFunctionCallStack":
raise NotImplementedError
def get_function_cache(self, function_key: str) -> Cache:
"""Get or create the function cache for the given key."""
raise NotImplementedError
def create_cache_wrapper(cached_func: CachedFunction) -> Callable[..., Any]:
"""Create a wrapper for a CachedFunction. This implements the common
plumbing for both st.memo and st.singleton.
"""
func = cached_func.func
function_key = _make_function_key(cached_func.cache_type, func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""This function wrapper will only call the underlying function in
the case of a cache miss.
"""
# Retrieve the function's cache object. We must do this inside the
# wrapped function, because caches can be invalidated at any time.
cache = cached_func.get_function_cache(function_key)
name = func.__qualname__
if len(args) == 0 and len(kwargs) == 0:
message = f"Running `{name}()`."
else:
message = f"Running `{name}(...)`."
def get_or_create_cached_value():
# Generate the key for the cached value. This is based on the
# arguments passed to the function.
value_key = _make_value_key(cached_func.cache_type, func, *args, **kwargs)
try:
return_value = cache.read_value(value_key)
_LOGGER.debug("Cache hit: %s", func)
except CacheKeyNotFoundError:
_LOGGER.debug("Cache miss: %s", func)
with cached_func.call_stack.calling_cached_function(func):
if cached_func.suppress_st_warning:
with cached_func.call_stack.suppress_cached_st_function_warning():
return_value = func(*args, **kwargs)
else:
return_value = func(*args, **kwargs)
cache.write_value(value_key, return_value)
return return_value
if cached_func.show_spinner:
with st.spinner(message):
return get_or_create_cached_value()
else:
return get_or_create_cached_value()
def clear():
"""Clear the wrapped function's associated cache."""
cache = cached_func.get_function_cache(function_key)
cache.clear()
# Mypy doesn't support declaring attributes of function objects,
# so we have to suppress a warning here. We can remove this suppression
# when this issue is resolved: https://github.com/python/mypy/issues/2087
wrapper.clear = clear # type: ignore
return wrapper
class CachedFunctionCallStack(threading.local):
"""A utility for warning users when they call `st` commands inside
a cached function. Internally, this is just a counter that's incremented
when we enter a cache function, and decremented when we exit.
Data is stored in a thread-local object, so it's safe to use an instance
of this class across multiple threads.
"""
def __init__(self, cache_type: CacheType):
self._cached_func_stack: List[types.FunctionType] = []
self._suppress_st_function_warning = 0
self._cache_type = cache_type
def __repr__(self) -> str:
return util.repr_(self)
@contextlib.contextmanager
def calling_cached_function(self, func: types.FunctionType) -> Iterator[None]:
self._cached_func_stack.append(func)
try:
yield
finally:
self._cached_func_stack.pop()
@contextlib.contextmanager
def suppress_cached_st_function_warning(self) -> Iterator[None]:
self._suppress_st_function_warning += 1
try:
yield
finally:
self._suppress_st_function_warning -= 1
assert self._suppress_st_function_warning >= 0
def maybe_show_cached_st_function_warning(
self, dg: "st.delta_generator.DeltaGenerator", st_func_name: str
) -> None:
"""If appropriate, warn about calling st.foo inside @memo.
DeltaGenerator's @_with_element and @_widget wrappers use this to warn
the user when they're calling st.foo() from within a function that is
wrapped in @st.cache.
Parameters
----------
dg : DeltaGenerator
The DeltaGenerator to publish the warning to.
st_func_name : str
The name of the Streamlit function that was called.
"""
if len(self._cached_func_stack) > 0 and self._suppress_st_function_warning <= 0:
cached_func = self._cached_func_stack[-1]
self._show_cached_st_function_warning(dg, st_func_name, cached_func)
def _show_cached_st_function_warning(
self,
dg: "st.delta_generator.DeltaGenerator",
st_func_name: str,
cached_func: types.FunctionType,
) -> None:
# Avoid infinite recursion by suppressing additional cached
# function warnings from within the cached function warning.
with self.suppress_cached_st_function_warning():
e = CachedStFunctionWarning(self._cache_type, st_func_name, cached_func)
dg.exception(e)
def _make_value_key(
cache_type: CacheType, func: types.FunctionType, *args, **kwargs
) -> str:
"""Create the key for a value within a cache.
This key is generated from the function's arguments. All arguments
will be hashed, except for those named with a leading "_".
Raises
------
StreamlitAPIException
Raised (with a nicely-formatted explanation message) if we encounter
an un-hashable arg.
"""
# Create a (name, value) list of all *args and **kwargs passed to the
# function.
arg_pairs: List[Tuple[Optional[str], Any]] = []
for arg_idx in range(len(args)):
arg_name = _get_positional_arg_name(func, arg_idx)
arg_pairs.append((arg_name, args[arg_idx]))
for kw_name, kw_val in kwargs.items():
# **kwargs ordering is preserved, per PEP 468
# https://www.python.org/dev/peps/pep-0468/, so this iteration is
# deterministic.
arg_pairs.append((kw_name, kw_val))
# Create the hash from each arg value, except for those args whose name
# starts with "_". (Underscore-prefixed args are deliberately excluded from
# hashing.)
args_hasher = hashlib.new("md5")
for arg_name, arg_value in arg_pairs:
if arg_name is not None and arg_name.startswith("_"):
_LOGGER.debug("Not hashing %s because it starts with _", arg_name)
continue
try:
update_hash(
(arg_name, arg_value),
hasher=args_hasher,
cache_type=cache_type,
)
except UnhashableTypeError as exc:
raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc)
value_key = args_hasher.hexdigest()
_LOGGER.debug("Cache key: %s", value_key)
return value_key
def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str:
"""Create the unique key for a function's cache.
A function's key is stable across reruns of the app, and changes when
the function's source code changes.
"""
func_hasher = hashlib.new("md5")
# Include the function's __module__ and __qualname__ strings in the hash.
# This means that two identical functions in different modules
# will not share a hash; it also means that two identical *nested*
# functions in the same module will not share a hash.
update_hash(
(func.__module__, func.__qualname__),
hasher=func_hasher,
cache_type=cache_type,
)
# Include the function's source code in its hash. If the source code can't
# be retrieved, fall back to the function's bytecode instead.
source_code: Union[str, bytes]
try:
source_code = inspect.getsource(func)
except OSError as e:
_LOGGER.debug(
"Failed to retrieve function's source code when building its key; falling back to bytecode. err={0}",
e,
)
source_code = func.__code__.co_code
update_hash(
source_code,
hasher=func_hasher,
cache_type=cache_type,
)
cache_key = func_hasher.hexdigest()
return cache_key
def _get_positional_arg_name(func: types.FunctionType, arg_index: int) -> Optional[str]:
"""Return the name of a function's positional argument.
If arg_index is out of range, or refers to a parameter that is not a
named positional argument (e.g. an *args, **kwargs, or keyword-only param),
return None instead.
"""
if arg_index < 0:
return None
params: List[inspect.Parameter] = list(inspect.signature(func).parameters.values())
if arg_index >= len(params):
return None
if params[arg_index].kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
):
return params[arg_index].name
return None

View File

@ -0,0 +1,389 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hashing for st.memo and st.singleton."""
import collections
import functools
import hashlib
import inspect
import io
import os
import pickle
import sys
import tempfile
import threading
import unittest.mock
import weakref
from typing import Any, Pattern, Optional, Dict, List
from streamlit import type_util
from streamlit import util
from streamlit.logger import get_logger
from streamlit.uploaded_file_manager import UploadedFile
from .cache_errors import (
CacheType,
UnhashableTypeError,
)
_LOGGER = get_logger(__name__)
# If a dataframe has more than this many rows, we consider it large and hash a sample.
_PANDAS_ROWS_LARGE = 100000
_PANDAS_SAMPLE_SIZE = 10000
# Similar to dataframes, we also sample large numpy arrays.
_NP_SIZE_LARGE = 1000000
_NP_SAMPLE_SIZE = 100000
# Arbitrary item to denote where we found a cycle in a hashed object.
# This allows us to hash self-referencing lists, dictionaries, etc.
_CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE"
def update_hash(val: Any, hasher, cache_type: CacheType) -> None:
"""Updates a hashlib hasher with the hash of val.
This is the main entrypoint to hashing.py.
"""
ch = _CacheFuncHasher(cache_type)
ch.update(hasher, val)
class _HashStack:
"""Stack of what has been hashed, for debug and circular reference detection.
This internally keeps 1 stack per thread.
Internally, this stores the ID of pushed objects rather than the objects
themselves because otherwise the "in" operator inside __contains__ would
fail for objects that don't return a boolean for "==" operator. For
example, arr == 10 where arr is a NumPy array returns another NumPy array.
This causes the "in" to crash since it expects a boolean.
"""
def __init__(self):
self._stack: collections.OrderedDict[int, List[Any]] = collections.OrderedDict()
def __repr__(self) -> str:
return util.repr_(self)
def push(self, val: Any):
self._stack[id(val)] = val
def pop(self):
self._stack.popitem()
def __contains__(self, val: Any):
return id(val) in self._stack
class _HashStacks:
"""Stacks of what has been hashed, with at most 1 stack per thread."""
def __init__(self):
self._stacks: weakref.WeakKeyDictionary[
threading.Thread, _HashStack
] = weakref.WeakKeyDictionary()
def __repr__(self) -> str:
return util.repr_(self)
@property
def current(self) -> _HashStack:
current_thread = threading.current_thread()
stack = self._stacks.get(current_thread, None)
if stack is None:
stack = _HashStack()
self._stacks[current_thread] = stack
return stack
hash_stacks = _HashStacks()
def _int_to_bytes(i: int) -> bytes:
num_bytes = (i.bit_length() + 8) // 8
return i.to_bytes(num_bytes, "little", signed=True)
def _key(obj: Optional[Any]) -> Any:
"""Return key for memoization."""
if obj is None:
return None
def is_simple(obj):
return (
isinstance(obj, bytes)
or isinstance(obj, bytearray)
or isinstance(obj, str)
or isinstance(obj, float)
or isinstance(obj, int)
or isinstance(obj, bool)
or obj is None
)
if is_simple(obj):
return obj
if isinstance(obj, tuple):
if all(map(is_simple, obj)):
return obj
if isinstance(obj, list):
if all(map(is_simple, obj)):
return ("__l", tuple(obj))
if (
type_util.is_type(obj, "pandas.core.frame.DataFrame")
or type_util.is_type(obj, "numpy.ndarray")
or inspect.isbuiltin(obj)
or inspect.isroutine(obj)
or inspect.iscode(obj)
):
return id(obj)
return NoResult
class _CacheFuncHasher:
"""A hasher that can hash objects with cycles."""
def __init__(self, cache_type: CacheType):
self._hashes: Dict[Any, bytes] = {}
# The number of the bytes in the hash.
self.size = 0
self.cache_type = cache_type
def __repr__(self) -> str:
return util.repr_(self)
def to_bytes(self, obj: Any) -> bytes:
"""Add memoization to _to_bytes and protect against cycles in data structures."""
tname = type(obj).__qualname__.encode()
key = (tname, _key(obj))
# Memoize if possible.
if key[1] is not NoResult:
if key in self._hashes:
return self._hashes[key]
# Break recursive cycles.
if obj in hash_stacks.current:
return _CYCLE_PLACEHOLDER
hash_stacks.current.push(obj)
try:
# Hash the input
b = b"%s:%s" % (tname, self._to_bytes(obj))
# Hmmm... It's possible that the size calculation is wrong. When we
# call to_bytes inside _to_bytes things get double-counted.
self.size += sys.getsizeof(b)
if key[1] is not NoResult:
self._hashes[key] = b
finally:
# In case an UnhashableTypeError (or other) error is thrown, clean up the
# stack so we don't get false positives in future hashing calls
hash_stacks.current.pop()
return b
def update(self, hasher, obj: Any) -> None:
"""Update the provided hasher with the hash of an object."""
b = self.to_bytes(obj)
hasher.update(b)
def _to_bytes(self, obj: Any) -> bytes:
"""Hash objects to bytes, including code with dependencies.
Python's built in `hash` does not produce consistent results across
runs.
"""
if isinstance(obj, unittest.mock.Mock):
# Mock objects can appear to be infinitely
# deep, so we don't try to hash them at all.
return self.to_bytes(id(obj))
elif isinstance(obj, bytes) or isinstance(obj, bytearray):
return obj
elif isinstance(obj, str):
return obj.encode()
elif isinstance(obj, float):
return self.to_bytes(hash(obj))
elif isinstance(obj, int):
return _int_to_bytes(obj)
elif isinstance(obj, (list, tuple)):
h = hashlib.new("md5")
for item in obj:
self.update(h, item)
return h.digest()
elif isinstance(obj, dict):
h = hashlib.new("md5")
for item in obj.items():
self.update(h, item)
return h.digest()
elif obj is None:
return b"0"
elif obj is True:
return b"1"
elif obj is False:
return b"0"
elif type_util.is_type(obj, "pandas.core.frame.DataFrame") or type_util.is_type(
obj, "pandas.core.series.Series"
):
import pandas as pd
if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
try:
return b"%s" % pd.util.hash_pandas_object(obj).sum()
except TypeError:
# Use pickle if pandas cannot hash the object for example if
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
elif type_util.is_type(obj, "numpy.ndarray"):
h = hashlib.new("md5")
self.update(h, obj.shape)
if obj.size >= _NP_SIZE_LARGE:
import numpy as np
state = np.random.RandomState(0)
obj = state.choice(obj.flat, size=_NP_SAMPLE_SIZE)
self.update(h, obj.tobytes())
return h.digest()
elif inspect.isbuiltin(obj):
return bytes(obj.__name__.encode())
elif type_util.is_type(obj, "builtins.mappingproxy") or type_util.is_type(
obj, "builtins.dict_items"
):
return self.to_bytes(dict(obj))
elif type_util.is_type(obj, "builtins.getset_descriptor"):
return bytes(obj.__qualname__.encode())
elif isinstance(obj, UploadedFile):
# UploadedFile is a BytesIO (thus IOBase) but has a name.
# It does not have a timestamp so this must come before
# temproary files
h = hashlib.new("md5")
self.update(h, obj.name)
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif hasattr(obj, "name") and (
isinstance(obj, io.IOBase)
# Handle temporary files used during testing
or isinstance(obj, tempfile._TemporaryFileWrapper)
):
# Hash files as name + last modification date + offset.
# NB: we're using hasattr("name") to differentiate between
# on-disk and in-memory StringIO/BytesIO file representations.
# That means that this condition must come *before* the next
# condition, which just checks for StringIO/BytesIO.
h = hashlib.new("md5")
obj_name = getattr(obj, "name", "wonthappen") # Just to appease MyPy.
self.update(h, obj_name)
self.update(h, os.path.getmtime(obj_name))
self.update(h, obj.tell())
return h.digest()
elif isinstance(obj, Pattern):
return self.to_bytes([obj.pattern, obj.flags])
elif isinstance(obj, io.StringIO) or isinstance(obj, io.BytesIO):
# Hash in-memory StringIO/BytesIO by their full contents
# and seek position.
h = hashlib.new("md5")
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif type_util.is_type(obj, "numpy.ufunc"):
# For numpy.remainder, this returns remainder.
return bytes(obj.__name__.encode())
elif inspect.ismodule(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# so the current warning is quite annoying...
# st.warning(('Streamlit does not support hashing modules. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name for internal modules.
return self.to_bytes(obj.__name__)
elif inspect.isclass(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# (e.g. in every "except" statement) so the current warning is
# quite annoying...
# st.warning(('Streamlit does not support hashing classes. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name of classes.
return self.to_bytes(obj.__name__)
elif isinstance(obj, functools.partial):
# The return value of functools.partial is not a plain function:
# it's a callable object that remembers the original function plus
# the values you pickled into it. So here we need to special-case it.
h = hashlib.new("md5")
self.update(h, obj.args)
self.update(h, obj.func)
self.update(h, obj.keywords)
return h.digest()
else:
# As a last resort, hash the output of the object's __reduce__ method
h = hashlib.new("md5")
try:
reduce_data = obj.__reduce__()
except BaseException as e:
raise UnhashableTypeError() from e
for item in reduce_data:
self.update(h, item)
return h.digest()
class NoResult:
"""Placeholder class for return values when None is meaningful."""
pass

View File

@ -0,0 +1,495 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.memo: pickle-based caching"""
import os
import pickle
import shutil
import threading
import time
import types
from typing import Optional, Any, Dict, cast, List, Callable, TypeVar, overload
from typing import Union
import math
from cachetools import TTLCache
from streamlit import util
from streamlit.errors import StreamlitAPIException
from streamlit.file_util import (
streamlit_read,
streamlit_write,
get_streamlit_file_path,
)
from streamlit.logger import get_logger
from streamlit.stats import CacheStatsProvider, CacheStat
from .cache_errors import (
CacheError,
CacheKeyNotFoundError,
CacheType,
)
from .cache_utils import (
Cache,
create_cache_wrapper,
CachedFunctionCallStack,
CachedFunction,
)
_LOGGER = get_logger(__name__)
# The timer function we use with TTLCache. This is the default timer func, but
# is exposed here as a constant so that it can be patched in unit tests.
_TTLCACHE_TIMER = time.monotonic
# Streamlit directory where persisted memoized items live.
# (This is the same directory that @st.cache persisted items live. But memoized
# items have a different extension, so they don't overlap.)
_CACHE_DIR_NAME = "cache"
MEMO_CALL_STACK = CachedFunctionCallStack(CacheType.MEMO)
class MemoCaches(CacheStatsProvider):
"""Manages all MemoCache instances"""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: Dict[str, "MemoCache"] = {}
def get_cache(
self,
key: str,
persist: Optional[str],
max_entries: Optional[Union[int, float]],
ttl: Optional[Union[int, float]],
display_name: str,
) -> "MemoCache":
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
if max_entries is None:
max_entries = math.inf
if ttl is None:
ttl = math.inf
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if (
cache is not None
and cache.ttl == ttl
and cache.max_entries == max_entries
and cache.persist == persist
):
return cache
# Create a new cache object and put it in our dict
_LOGGER.debug(
"Creating new MemoCache (key=%s, persist=%s, max_entries=%s, ttl=%s)",
key,
persist,
max_entries,
ttl,
)
cache = MemoCache(
key=key,
persist=persist,
max_entries=max_entries,
ttl=ttl,
display_name=display_name,
)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all in-memory and on-disk caches."""
with self._caches_lock:
self._function_caches = {}
# TODO: Only delete disk cache for functions related to the user's
# current script.
cache_path = get_cache_path()
if os.path.isdir(cache_path):
shutil.rmtree(cache_path)
def get_stats(self) -> List[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: List[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return stats
# Singleton MemoCaches instance
_memo_caches = MemoCaches()
def get_memo_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all memoized functions."""
return _memo_caches
class MemoizedFunction(CachedFunction):
"""Implements the CachedFunction protocol for @st.memo"""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool,
suppress_st_warning: bool,
persist: Optional[str],
max_entries: Optional[int],
ttl: Optional[float],
):
super().__init__(func, show_spinner, suppress_st_warning)
self.persist = persist
self.max_entries = max_entries
self.ttl = ttl
@property
def cache_type(self) -> CacheType:
return CacheType.MEMO
@property
def call_stack(self) -> CachedFunctionCallStack:
return MEMO_CALL_STACK
@property
def display_name(self) -> str:
"""A human-readable name for the cached function"""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _memo_caches.get_cache(
key=function_key,
persist=self.persist,
max_entries=self.max_entries,
ttl=self.ttl,
display_name=self.display_name,
)
class MemoAPI:
"""Implements the public st.memo API: the @st.memo decorator, and
st.memo.clear().
"""
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
@staticmethod
def __call__(func: F) -> F:
...
# Decorator with arguments
@overload
@staticmethod
def __call__(
*,
persist: Optional[str] = None,
show_spinner: bool = True,
suppress_st_warning: bool = False,
max_entries: Optional[int] = None,
ttl: Optional[float] = None,
) -> Callable[[F], F]:
...
@staticmethod
def __call__(
func: Optional[F] = None,
*,
persist: Optional[str] = None,
show_spinner: bool = True,
suppress_st_warning: bool = False,
max_entries: Optional[int] = None,
ttl: Optional[float] = None,
):
"""Function decorator to memoize function executions.
Memoized data is stored in "pickled" form, which means that the return
value of a memoized function must be pickleable.
Each caller of a memoized function gets its own copy of the cached data.
You can clear a memoized function's cache with f.clear().
Parameters
----------
func : callable
The function to memoize. Streamlit hashes the function's source code.
persist : str or None
Optional location to persist cached data to. Currently, the only
valid value is "disk", which will persist to the local disk.
show_spinner : boolean
Enable the spinner. Default is True to show a spinner when there is
a cache miss.
suppress_st_warning : boolean
Suppress warnings about calling Streamlit functions from within
the cached function.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. (When a new entry is added to a full cache,
the oldest cached entry will be removed.) The default is None.
ttl : float or None
The maximum number of seconds to keep an entry in the cache, or
None if cache entries should not expire. The default is None.
Example
-------
>>> @st.experimental_memo
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
...
>>> d1 = fetch_and_clean_data(DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> d2 = fetch_and_clean_data(DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the data in d1 is the same as in d2.
>>>
>>> d3 = fetch_and_clean_data(DATA_URL_2)
>>> # This is a different URL, so the function executes.
To set the ``persist`` parameter, use this command as follows:
>>> @st.experimental_memo(persist="disk")
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
By default, all parameters to a memoized function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> @st.experimental_memo
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
...
>>> connection = make_database_connection()
>>> d1 = fetch_and_clean_data(connection, num_rows=10)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> another_connection = make_database_connection()
>>> d2 = fetch_and_clean_data(another_connection, num_rows=10)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _database_connection parameter was different
>>> # in both calls.
A memoized function's cache can be procedurally cleared:
>>> @st.experimental_memo
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
...
>>> fetch_and_clean_data.clear()
>>> # Clear all cached entries for this function.
"""
if persist not in (None, "disk"):
# We'll eventually have more persist options.
raise StreamlitAPIException(
f"Unsupported persist option '{persist}'. Valid values are 'disk' or None."
)
# Support passing the params via function decorator, e.g.
# @st.memo(persist=True, show_spinner=False)
if func is None:
return lambda f: create_cache_wrapper(
MemoizedFunction(
func=f,
persist=persist,
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
max_entries=max_entries,
ttl=ttl,
)
)
return create_cache_wrapper(
MemoizedFunction(
func=cast(types.FunctionType, func),
persist=persist,
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
max_entries=max_entries,
ttl=ttl,
)
)
@staticmethod
def clear() -> None:
"""Clear all in-memory and on-disk memo caches."""
_memo_caches.clear_all()
class MemoCache(Cache):
"""Manages cached values for a single st.memo-ized function."""
def __init__(
self,
key: str,
persist: Optional[str],
max_entries: float,
ttl: float,
display_name: str,
):
self.key = key
self.display_name = display_name
self.persist = persist
self._mem_cache = TTLCache(maxsize=max_entries, ttl=ttl, timer=_TTLCACHE_TIMER)
self._mem_cache_lock = threading.Lock()
@property
def max_entries(self) -> float:
return cast(float, self._mem_cache.maxsize)
@property
def ttl(self) -> float:
return cast(float, self._mem_cache.ttl)
def get_stats(self) -> List[CacheStat]:
stats: List[CacheStat] = []
with self._mem_cache_lock:
for item_key, item_value in self._mem_cache.items():
stats.append(
CacheStat(
category_name="st_memo",
cache_name=self.display_name,
byte_length=len(item_value),
)
)
return stats
def read_value(self, key: str) -> Any:
"""Read a value from the cache. Raise `CacheKeyNotFoundError` if the
value doesn't exist, and `CacheError` if the value exists but can't
be unpickled.
"""
try:
pickled_value = self._read_from_mem_cache(key)
except CacheKeyNotFoundError as e:
if self.persist == "disk":
pickled_value = self._read_from_disk_cache(key)
self._write_to_mem_cache(key, pickled_value)
else:
raise e
try:
return pickle.loads(pickled_value)
except pickle.UnpicklingError as exc:
raise CacheError(f"Failed to unpickle {key}") from exc
def write_value(self, key: str, value: Any) -> None:
"""Write a value to the cache. It must be pickleable."""
try:
pickled_value = pickle.dumps(value)
except pickle.PicklingError as exc:
raise CacheError(f"Failed to pickle {key}") from exc
self._write_to_mem_cache(key, pickled_value)
if self.persist == "disk":
self._write_to_disk_cache(key, pickled_value)
def clear(self) -> None:
with self._mem_cache_lock:
# We keep a lock for the entirety of the clear operation to avoid
# disk cache race conditions.
for key in self._mem_cache.keys():
self._remove_from_disk_cache(key)
self._mem_cache.clear()
def _read_from_mem_cache(self, key: str) -> bytes:
with self._mem_cache_lock:
if key in self._mem_cache:
entry = bytes(self._mem_cache[key])
_LOGGER.debug("Memory cache HIT: %s", key)
return entry
else:
_LOGGER.debug("Memory cache MISS: %s", key)
raise CacheKeyNotFoundError("Key not found in mem cache")
def _read_from_disk_cache(self, key: str) -> bytes:
path = self._get_file_path(key)
try:
with streamlit_read(path, binary=True) as input:
value = input.read()
_LOGGER.debug("Disk cache HIT: %s", key)
return bytes(value)
except FileNotFoundError:
raise CacheKeyNotFoundError("Key not found in disk cache")
except BaseException as e:
_LOGGER.error(e)
raise CacheError("Unable to read from cache") from e
def _write_to_mem_cache(self, key: str, pickled_value: bytes) -> None:
with self._mem_cache_lock:
self._mem_cache[key] = pickled_value
def _write_to_disk_cache(self, key: str, pickled_value: bytes) -> None:
path = self._get_file_path(key)
try:
with streamlit_write(path, binary=True) as output:
output.write(pickled_value)
except util.Error as e:
_LOGGER.debug(e)
# Clean up file so we don't leave zero byte files.
try:
os.remove(path)
except (FileNotFoundError, IOError, OSError):
pass
raise CacheError("Unable to write to cache") from e
def _remove_from_disk_cache(self, key: str) -> None:
"""Delete a cache file from disk. If the file does not exist on disk,
return silently. If another exception occurs, log it. Does not throw.
"""
path = self._get_file_path(key)
try:
os.remove(path)
except FileNotFoundError:
pass
except BaseException as e:
_LOGGER.exception("Unable to remove a file from the disk cache", e)
def _get_file_path(self, value_key: str) -> str:
"""Return the path of the disk cache file for the given value."""
return get_streamlit_file_path(_CACHE_DIR_NAME, f"{self.key}-{value_key}.memo")
def get_cache_path() -> str:
return get_streamlit_file_path(_CACHE_DIR_NAME)

View File

@ -0,0 +1,289 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.singleton implementation"""
import threading
import types
from typing import Optional, Any, Dict, List, TypeVar, Callable, overload, cast
from pympler import asizeof
from streamlit.logger import get_logger
from streamlit.stats import CacheStatsProvider, CacheStat
from .cache_errors import CacheKeyNotFoundError, CacheType
from .cache_utils import (
Cache,
create_cache_wrapper,
CachedFunctionCallStack,
CachedFunction,
)
_LOGGER = get_logger(__name__)
SINGLETON_CALL_STACK = CachedFunctionCallStack(CacheType.SINGLETON)
class SingletonCaches(CacheStatsProvider):
"""Manages all SingletonCache instances"""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: Dict[str, "SingletonCache"] = {}
def get_cache(self, key: str, display_name: str) -> "SingletonCache":
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if cache is not None:
return cache
# Create a new cache object and put it in our dict
_LOGGER.debug("Creating new SingletonCache (key=%s)", key)
cache = SingletonCache(key=key, display_name=display_name)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all singleton caches."""
with self._caches_lock:
self._function_caches = {}
def get_stats(self) -> List[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: List[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return stats
# Singleton SingletonCaches instance
_singleton_caches = SingletonCaches()
def get_singleton_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all singleton functions."""
return _singleton_caches
class SingletonFunction(CachedFunction):
"""Implements the CachedFunction protocol for @st.singleton"""
@property
def cache_type(self) -> CacheType:
return CacheType.SINGLETON
@property
def call_stack(self) -> CachedFunctionCallStack:
return SINGLETON_CALL_STACK
@property
def display_name(self) -> str:
"""A human-readable name for the cached function"""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _singleton_caches.get_cache(
key=function_key, display_name=self.display_name
)
class SingletonAPI:
"""Implements the public st.singleton API: the @st.singleton decorator,
and st.singleton.clear().
"""
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
@staticmethod
def __call__(func: F) -> F:
...
# Decorator with arguments
@overload
@staticmethod
def __call__(
*,
show_spinner: bool = True,
suppress_st_warning=False,
) -> Callable[[F], F]:
...
@staticmethod
def __call__(
func: Optional[F] = None,
*,
show_spinner: bool = True,
suppress_st_warning=False,
):
"""Function decorator to store singleton objects.
Each singleton object is shared across all users connected to the app.
Singleton objects *must* be thread-safe, because they can be accessed from
multiple threads concurrently.
(If thread-safety is an issue, consider using ``st.session_state`` to
store per-session singleton objects instead.)
You can clear a memoized function's cache with f.clear().
Parameters
----------
func : callable
The function that creates the singleton. Streamlit hashes the
function's source code.
show_spinner : boolean
Enable the spinner. Default is True to show a spinner when there is
a "cache miss" and the singleton is being created.
suppress_st_warning : boolean
Suppress warnings about calling Streamlit functions from within
the singleton function.
Example
-------
>>> @st.experimental_singleton
... def get_database_session(url):
... # Create a database session object that points to the URL.
... return session
...
>>> s1 = get_database_session(SESSION_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(SESSION_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the connection object in s1 is the same as in s2.
>>>
>>> s3 = get_database_session(SESSION_URL_2)
>>> # This is a different URL, so the function executes.
By default, all parameters to a singleton function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> @st.experimental_singleton
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
...
>>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _sessionmaker parameter was different
>>> # in both calls.
A singleton function's cache can be procedurally cleared:
>>> @st.experimental_singleton
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
...
>>> get_database_session.clear()
>>> # Clear all cached entries for this function.
"""
# Support passing the params via function decorator, e.g.
# @st.singleton(show_spinner=False)
if func is None:
return lambda f: create_cache_wrapper(
SingletonFunction(
func=f,
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
)
)
return create_cache_wrapper(
SingletonFunction(
func=cast(types.FunctionType, func),
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
)
)
@staticmethod
def clear() -> None:
"""Clear all singleton caches."""
_singleton_caches.clear_all()
class SingletonCache(Cache):
"""Manages cached values for a single st.singleton function."""
def __init__(self, key: str, display_name: str):
self.key = key
self.display_name = display_name
self._mem_cache: Dict[str, Any] = {}
self._mem_cache_lock = threading.Lock()
def read_value(self, key: str) -> Any:
"""Read a value from the cache. Raise `CacheKeyNotFoundError` if the
value doesn't exist.
"""
with self._mem_cache_lock:
if key in self._mem_cache:
entry = self._mem_cache[key]
return entry
else:
raise CacheKeyNotFoundError()
def write_value(self, key: str, value: Any) -> None:
"""Write a value to the cache."""
with self._mem_cache_lock:
self._mem_cache[key] = value
def clear(self) -> None:
with self._mem_cache_lock:
self._mem_cache.clear()
def get_stats(self) -> List[CacheStat]:
# Shallow clone our cache. Computing item sizes is potentially
# expensive, and we want to minimize the time we spend holding
# the lock.
with self._mem_cache_lock:
mem_cache = self._mem_cache.copy()
stats: List[CacheStat] = []
for item_key, item_value in mem_cache.items():
stats.append(
CacheStat(
category_name="st_singleton",
cache_name=self.display_name,
byte_length=asizeof.asizeof(item_value),
)
)
return stats

View File

@ -0,0 +1,79 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def to_upper_camel_case(snake_case_str):
"""Converts snake_case to UpperCamelCase.
Example:
foo_bar -> FooBar
"""
return "".join(map(str.title, snake_case_str.split("_")))
def to_lower_camel_case(snake_case_str):
"""Converts snake_case to lowerCamelCase.
Example:
foo_bar -> fooBar
fooBar -> foobar
"""
words = snake_case_str.split("_")
if len(words) > 1:
capitalized = [w.title() for w in words]
capitalized[0] = words[0]
return "".join(capitalized)
else:
return snake_case_str
def to_snake_case(camel_case_str):
"""Converts UpperCamelCase and lowerCamelCase to snake_case.
Examples:
fooBar -> foo_bar
BazBang -> baz_bang
"""
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_str)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def convert_dict_keys(func, in_dict):
"""Apply a conversion function to all keys in a dict.
Parameters
----------
func : callable
The function to apply. Takes a str and returns a str.
in_dict : dict
The dictionary to convert. If some value in this dict is itself a dict,
it also gets recursively converted.
Returns
-------
dict
A new dict with all the contents of `in_dict`, but with the keys
converted by `func`.
"""
out_dict = dict()
for k, v in in_dict.items():
converted_key = func(k)
if type(v) is dict:
out_dict[converted_key] = convert_dict_keys(func, v)
else:
out_dict[converted_key] = v
return out_dict

View File

@ -0,0 +1,335 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is a script which is run when the Streamlit package is executed."""
from streamlit import config as _config
import os
from typing import Optional
import click
import streamlit
from streamlit.credentials import Credentials, check_credentials
import streamlit.bootstrap as bootstrap
from streamlit.case_converters import to_snake_case
ACCEPTED_FILE_EXTENSIONS = ("py", "py3")
LOG_LEVELS = ("error", "warning", "info", "debug")
def _convert_config_option_to_click_option(config_option):
"""Composes given config option options as options for click lib."""
option = "--{}".format(config_option.key)
param = config_option.key.replace(".", "_")
description = config_option.description
if config_option.deprecated:
description += "\n {} - {}".format(
config_option.deprecation_text, config_option.expiration_date
)
envvar = "STREAMLIT_{}".format(to_snake_case(param).upper())
return {
"param": param,
"description": description,
"type": config_option.type,
"option": option,
"envvar": envvar,
}
def configurator_options(func):
"""Decorator that adds config param keys to click dynamically."""
for _, value in reversed(_config._config_options_template.items()):
parsed_parameter = _convert_config_option_to_click_option(value)
config_option = click.option(
parsed_parameter["option"],
parsed_parameter["param"],
help=parsed_parameter["description"],
type=parsed_parameter["type"],
show_envvar=True,
envvar=parsed_parameter["envvar"],
)
func = config_option(func)
return func
# Fetch remote file at url_path to main_script_path
def _download_remote(main_script_path, url_path):
import requests
with open(main_script_path, "wb") as fp:
try:
resp = requests.get(url_path)
resp.raise_for_status()
fp.write(resp.content)
except requests.exceptions.RequestException as e:
raise click.BadParameter(("Unable to fetch {}.\n{}".format(url_path, e)))
@click.group(context_settings={"auto_envvar_prefix": "STREAMLIT"})
@click.option("--log_level", show_default=True, type=click.Choice(LOG_LEVELS))
@click.version_option(prog_name="Streamlit")
@click.pass_context
def main(ctx, log_level="info"):
"""Try out a demo with:
$ streamlit hello
Or use the line below to run your own script:
$ streamlit run your_script.py
"""
if log_level:
from streamlit.logger import get_logger
LOGGER = get_logger(__name__)
LOGGER.warning(
"Setting the log level using the --log_level flag is unsupported."
"\nUse the --logger.level flag (after your streamlit command) instead."
)
@main.command("help")
@click.pass_context
def help(ctx):
"""Print this help message."""
# Pretend user typed 'streamlit --help' instead of 'streamlit help'.
import sys
# We use _get_command_line_as_string to run some error checks but don't do
# anything with its return value.
_get_command_line_as_string()
assert len(sys.argv) == 2 # This is always true, but let's assert anyway.
sys.argv[1] = "--help"
main(prog_name="streamlit")
@main.command("version")
@click.pass_context
def main_version(ctx):
"""Print Streamlit's version number."""
# Pretend user typed 'streamlit --version' instead of 'streamlit version'
import sys
# We use _get_command_line_as_string to run some error checks but don't do
# anything with its return value.
_get_command_line_as_string()
assert len(sys.argv) == 2 # This is always true, but let's assert anyway.
sys.argv[1] = "--version"
main()
@main.command("docs")
def main_docs():
"""Show help in browser."""
print("Showing help page in browser...")
from streamlit import util
util.open_browser("https://docs.streamlit.io")
@main.command("hello")
@configurator_options
def main_hello(**kwargs):
"""Runs the Hello World script."""
from streamlit.hello import hello
bootstrap.load_config_options(flag_options=kwargs)
filename = hello.__file__
_main_run(filename, flag_options=kwargs)
@main.command("run")
@configurator_options
@click.argument("target", required=True, envvar="STREAMLIT_RUN_TARGET")
@click.argument("args", nargs=-1)
def main_run(target, args=None, **kwargs):
"""Run a Python script, piping stderr to Streamlit.
The script can be local or it can be an url. In the latter case, Streamlit
will download the script to a temporary file and runs this file.
"""
from validators import url
bootstrap.load_config_options(flag_options=kwargs)
_, extension = os.path.splitext(target)
if extension[1:] not in ACCEPTED_FILE_EXTENSIONS:
if extension[1:] == "":
raise click.BadArgumentUsage(
"Streamlit requires raw Python (.py) files, but the provided file has no extension.\nFor more information, please see https://docs.streamlit.io"
)
else:
raise click.BadArgumentUsage(
"Streamlit requires raw Python (.py) files, not %s.\nFor more information, please see https://docs.streamlit.io"
% extension
)
if url(target):
from streamlit.temporary_directory import TemporaryDirectory
with TemporaryDirectory() as temp_dir:
from urllib.parse import urlparse
from streamlit import url_util
path = urlparse(target).path
main_script_path = os.path.join(
temp_dir, path.strip("/").rsplit("/", 1)[-1]
)
# if this is a GitHub/Gist blob url, convert to a raw URL first.
target = url_util.process_gitblob_url(target)
_download_remote(main_script_path, target)
_main_run(main_script_path, args, flag_options=kwargs)
else:
if not os.path.exists(target):
raise click.BadParameter("File does not exist: {}".format(target))
_main_run(target, args, flag_options=kwargs)
def _get_command_line_as_string() -> Optional[str]:
import subprocess
parent = click.get_current_context().parent
if parent is None:
return None
if "streamlit.cli" in parent.command_path:
raise RuntimeError(
"Running streamlit via `python -m streamlit.cli <command>` is"
" unsupported. Please use `python -m streamlit <command>` instead."
)
cmd_line_as_list = [parent.command_path]
cmd_line_as_list.extend(click.get_os_args())
return subprocess.list2cmdline(cmd_line_as_list)
def _main_run(file, args=None, flag_options=None):
if args is None:
args = []
if flag_options is None:
flag_options = {}
command_line = _get_command_line_as_string()
# Set a global flag indicating that we're "within" streamlit.
streamlit._is_running_with_streamlit = True
check_credentials()
bootstrap.run(file, command_line, args, flag_options)
# SUBCOMMAND: cache
@main.group("cache")
def cache():
"""Manage the Streamlit cache."""
pass
@cache.command("clear")
def cache_clear():
"""Clear st.cache, st.memo, and st.singleton caches."""
import streamlit.legacy_caching
import streamlit.caching
result = streamlit.legacy_caching.clear_cache()
cache_path = streamlit.legacy_caching.get_cache_path()
if result:
print("Cleared directory %s." % cache_path)
else:
print("Nothing to clear at %s." % cache_path)
streamlit.caching.memo.clear()
streamlit.caching.singleton.clear()
# SUBCOMMAND: config
@main.group("config")
def config():
"""Manage Streamlit's config settings."""
pass
@config.command("show")
@configurator_options
def config_show(**kwargs):
"""Show all of Streamlit's config settings."""
bootstrap.load_config_options(flag_options=kwargs)
_config.show_config()
# SUBCOMMAND: activate
@main.group("activate", invoke_without_command=True)
@click.pass_context
def activate(ctx):
"""Activate Streamlit by entering your email."""
if not ctx.invoked_subcommand:
Credentials.get_current().activate()
@activate.command("reset")
def activate_reset():
"""Reset Activation Credentials."""
Credentials.get_current().reset()
# SUBCOMMAND: test
@main.group("test", hidden=True)
def test():
"""Internal-only commands used for testing.
These commands are not included in the output of `streamlit help`.
"""
pass
@test.command("prog_name")
def test_prog_name():
"""Assert that the program name is set to `streamlit test`.
This is used by our cli-smoke-tests to verify that the program name is set
to `streamlit ...` whether the streamlit binary is invoked directly or via
`python -m streamlit ...`.
"""
# We use _get_command_line_as_string to run some error checks but don't do
# anything with its return value.
_get_command_line_as_string()
parent = click.get_current_context().parent
assert parent is not None
assert parent.command_path == "streamlit test"
if __name__ == "__main__":
main()

View File

@ -0,0 +1,87 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A bunch of useful code utilities."""
import re
def extract_args(line):
"""Parse argument strings from all outer parentheses in a line of code.
Parameters
----------
line : str
A line of code
Returns
-------
list of strings
Contents of the outer parentheses
Example
-------
>>> line = 'foo(bar, baz), "a", my(func)'
>>> extract_args(line)
['bar, baz', 'func']
"""
stack = 0
startIndex = None
results = []
for i, c in enumerate(line):
if c == "(":
if stack == 0:
startIndex = i + 1
stack += 1
elif c == ")":
stack -= 1
if stack == 0:
results.append(line[startIndex:i])
return results
def get_method_args_from_code(args, line):
"""Parse arguments from a stringified arguments list inside parentheses
Parameters
----------
args : list
A list where it's size matches the expected number of parsed arguments
line : str
Stringified line of code with method arguments inside parentheses
Returns
-------
list of strings
Parsed arguments
Example
-------
>>> line = 'foo(bar, baz, my(func, tion))'
>>>
>>> get_method_args_from_code(range(0, 3), line)
['bar', 'baz', 'my(func, tion)']
"""
line_args = extract_args(line)[0]
# Split arguments, https://stackoverflow.com/a/26634150
if len(args) > 1:
inputs = re.split(r",\s*(?![^(){}[\]]*\))", line_args)
assert len(inputs) == len(args), "Could not split arguments"
else:
inputs = [line_args]
return inputs

View File

@ -0,0 +1,13 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,227 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from urllib.parse import urlparse
from textwrap import dedent
from streamlit.scriptrunner import get_script_run_ctx
from streamlit.proto import ForwardMsg_pb2
from streamlit.proto import PageConfig_pb2
from streamlit.elements import image
from streamlit.errors import StreamlitAPIException
from streamlit.util import lower_clean_dict_keys
GET_HELP_KEY = "get help"
REPORT_A_BUG_KEY = "report a bug"
ABOUT_KEY = "about"
def set_page_config(
page_title=None,
page_icon=None,
layout="centered",
initial_sidebar_state="auto",
menu_items=None,
):
"""
Configures the default settings of the page.
.. note::
This must be the first Streamlit command used in your app, and must only
be set once.
Parameters
----------
page_title: str or None
The page title, shown in the browser tab. If None, defaults to the
filename of the script ("app.py" would show "app • Streamlit").
page_icon : Anything supported by st.image or str or None
The page favicon.
Besides the types supported by `st.image` (like URLs or numpy arrays),
you can pass in an emoji as a string ("🦈") or a shortcode (":shark:").
If you're feeling lucky, try "random" for a random emoji!
Emoji icons are courtesy of Twemoji and loaded from MaxCDN.
layout: "centered" or "wide"
How the page content should be laid out. Defaults to "centered",
which constrains the elements into a centered column of fixed width;
"wide" uses the entire screen.
initial_sidebar_state: "auto" or "expanded" or "collapsed"
How the sidebar should start out. Defaults to "auto",
which hides the sidebar on mobile-sized devices, and shows it otherwise.
"expanded" shows the sidebar initially; "collapsed" hides it.
menu_items: dict
Configure the menu that appears on the top-right side of this app.
The keys in this dict denote the menu item you'd like to configure:
- "Get help": str or None
The URL this menu item should point to.
If None, hides this menu item.
- "Report a Bug": str or None
The URL this menu item should point to.
If None, hides this menu item.
- "About": str or None
A markdown string to show in the About dialog.
If None, only shows Streamlit's default About text.
Example
-------
>>> st.set_page_config(
... page_title="Ex-stream-ly Cool App",
... page_icon="🧊",
... layout="wide",
... initial_sidebar_state="expanded",
... menu_items={
... 'Get Help': 'https://www.extremelycoolapp.com/help',
... 'Report a bug': "https://www.extremelycoolapp.com/bug",
... 'About': "# This is a header. This is an *extremely* cool app!"
... }
... )
"""
msg = ForwardMsg_pb2.ForwardMsg()
if page_title:
msg.page_config_changed.title = page_title
if page_icon:
if page_icon == "random":
page_icon = get_random_emoji()
msg.page_config_changed.favicon = image.image_to_url(
page_icon,
width=-1, # Always use full width for favicons
clamp=False,
channels="RGB",
output_format="auto",
image_id="favicon",
allow_emoji=True,
)
if layout == "centered":
layout = PageConfig_pb2.PageConfig.CENTERED
elif layout == "wide":
layout = PageConfig_pb2.PageConfig.WIDE
else:
raise StreamlitAPIException(
f'`layout` must be "centered" or "wide" (got "{layout}")'
)
msg.page_config_changed.layout = layout
if initial_sidebar_state == "auto":
initial_sidebar_state = PageConfig_pb2.PageConfig.AUTO
elif initial_sidebar_state == "expanded":
initial_sidebar_state = PageConfig_pb2.PageConfig.EXPANDED
elif initial_sidebar_state == "collapsed":
initial_sidebar_state = PageConfig_pb2.PageConfig.COLLAPSED
else:
raise StreamlitAPIException(
'`initial_sidebar_state` must be "auto" or "expanded" or "collapsed" '
+ f'(got "{initial_sidebar_state}")'
)
msg.page_config_changed.initial_sidebar_state = initial_sidebar_state
if menu_items is not None:
lowercase_menu_items = lower_clean_dict_keys(menu_items)
validate_menu_items(lowercase_menu_items)
menu_items_proto = msg.page_config_changed.menu_items
set_menu_items_proto(lowercase_menu_items, menu_items_proto)
ctx = get_script_run_ctx()
if ctx is None:
return
ctx.enqueue(msg)
def get_random_emoji():
import random
# Emojis recommended by https://share.streamlit.io/rensdimmendaal/emoji-recommender/main/app/streamlit.py
# for the term "streamlit". Watch out for zero-width joiners,
# as they won't parse correctly in the list() call!
RANDOM_EMOJIS = list(
"🔥™🎉🚀🌌💣✨🌙🎆🎇💥🤩🤙🌛🤘⬆💡🤪🥂⚡💨🌠🎊🍿😛🔮🤟🌃🍃🍾💫▪🌴🎈🎬🌀🎄😝☔⛽🍂💃😎🍸🎨🥳☀😍🅱🌞😻🌟😜💦💅🦄😋😉👻🍁🤤👯🌻‼🌈👌🎃💛😚🔫🙌👽🍬🌅☁🍷👭☕🌚💁👅🥰🍜😌🎥🕺❕🧡☄💕🍻✅🌸🚬🤓🍹®☺💪😙☘🤠✊🤗🍵🤞😂💯😏📻🎂💗💜🌊❣🌝😘💆🤑🌿🦋😈⛄🚿😊🌹🥴😽💋😭🖤🙆👐⚪💟☃🙈🍭💻🥀🚗🤧🍝💎💓🤝💄💖🔞⁉⏰🕊🎧☠♥🌳🏾🙉⭐💊🍳🌎🙊💸❤🔪😆🌾✈📚💀🏠✌🏃🌵🚨💂🤫🤭😗😄🍒👏🙃🖖💞😅🎅🍄🆓👉💩🔊🤷⌚👸😇🚮💏👳🏽💘💿💉👠🎼🎶🎤👗❄🔐🎵🤒🍰👓🏄🌲🎮🙂📈🚙📍😵🗣❗🌺🙄👄🚘🥺🌍🏡♦💍🌱👑👙☑👾🍩🥶📣🏼🤣☯👵🍫➡🎀😃✋🍞🙇😹🙏👼🐝⚫🎁🍪🔨🌼👆👀😳🌏📖👃🎸👧💇🔒💙😞⛅🏻🍴😼🗿🍗♠🦁✔🤖☮🐢🐎💤😀🍺😁😴📺☹😲👍🎭💚🍆🍋🔵🏁🔴🔔🧐👰☎🏆🤡🐠📲🙋📌🐬✍🔑📱💰🐱💧🎓🍕👟🐣👫🍑😸🍦👁🆗🎯📢🚶🦅🐧💢🏀🚫💑🐟🌽🏊🍟💝💲🐍🍥🐸☝♣👊⚓❌🐯🏈📰🌧👿🐳💷🐺📞🆒🍀🤐🚲🍔👹🙍🌷🙎🐥💵🔝📸⚠❓🎩✂🍼😑⬇⚾🍎💔🐔⚽💭🏌🐷🍍✖🍇📝🍊🐙👋🤔🥊🗽🐑🐘🐰💐🐴♀🐦🍓✏👂🏴👇🆘😡🏉👩💌😺✝🐼🐒🐶👺🖕👬🍉🐻🐾⬅⏬▶👮🍌♂🔸👶🐮👪⛳🐐🎾🐕👴🐨🐊🔹©🎣👦👣👨👈💬⭕📹📷"
)
# Also pick out some vanity emojis.
ENG_EMOJIS = [
"🎈", # st.balloons 🎈🎈
"🤓", # Abhi
"🏈", # Amey
"🚲", # Thiago
"🐧", # Matteo
"🦒", # Ken
"🐳", # Karrie
"🕹️", # Jonathan
"🇦🇲", # Henrikh
"🎸", # Guido
"🦈", # Austin
"💎", # Emiliano
"👩‍🎤", # Naomi
"🧙‍♂️", # Jon
"🐻", # Brandon
"🎎", # James
# TODO: Solicit emojis from the rest of Streamlit
]
# Weigh our emojis 10x, cuz we're awesome!
# TODO: fix the random seed with a hash of the user's app code, for stability?
return random.choice(RANDOM_EMOJIS + 10 * ENG_EMOJIS)
def set_menu_items_proto(lowercase_menu_items, menu_items_proto):
if GET_HELP_KEY in lowercase_menu_items:
if lowercase_menu_items[GET_HELP_KEY] is not None:
menu_items_proto.get_help_url = lowercase_menu_items[GET_HELP_KEY]
else:
menu_items_proto.hide_get_help = True
if REPORT_A_BUG_KEY in lowercase_menu_items:
if lowercase_menu_items[REPORT_A_BUG_KEY] is not None:
menu_items_proto.report_a_bug_url = lowercase_menu_items[REPORT_A_BUG_KEY]
else:
menu_items_proto.hide_report_a_bug = True
if ABOUT_KEY in lowercase_menu_items:
if lowercase_menu_items[ABOUT_KEY] is not None:
menu_items_proto.about_section_md = dedent(lowercase_menu_items[ABOUT_KEY])
def validate_menu_items(dict):
for k, v in dict.items():
if not valid_menu_item_key(k):
raise StreamlitAPIException(
"We only accept the keys: "
f'"Get help", "Report a bug", and "About" ("{k}" is not a valid key.)'
)
if v is not None:
if not valid_url(v) and k != ABOUT_KEY:
raise StreamlitAPIException(f'"{v}" is a not a valid URL!')
def valid_menu_item_key(key):
return key in [GET_HELP_KEY, REPORT_A_BUG_KEY, ABOUT_KEY]
def valid_url(url):
"""
This code is copied and pasted from:
https://stackoverflow.com/questions/7160737/how-to-validate-a-url-in-python-malformed-or-not
"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False

View File

@ -0,0 +1,13 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,24 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modules that the user should have access to. These are imported with "as"
# syntax pass mypy checking with implicit_reexport disabled.
from .components import declare_component as declare_component
# `html` and `iframe` are part of Custom Components, so they appear in this
# `streamlit.components.v1` namespace.
import streamlit
html = streamlit._main._html
iframe = streamlit._main._iframe

View File

@ -0,0 +1,380 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data marshalling utilities for ArrowTable protobufs, which are used by
CustomComponent for dataframe serialization.
"""
import pandas as pd
import pyarrow as pa
from streamlit import type_util, util
def marshall(proto, data, default_uuid=None):
"""Marshall data into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict, or None
Something that is or can be converted to a dataframe.
"""
if type_util.is_pandas_styler(data):
_marshall_styler(proto, data, default_uuid)
df = type_util.convert_anything_to_df(data)
_marshall_index(proto, df.index)
_marshall_columns(proto, df.columns)
_marshall_data(proto, df.to_numpy())
def _marshall_styler(proto, styler, default_uuid):
"""Marshall pandas.Styler styling data into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
styler : pandas.Styler
Styler holding styling data for the dataframe.
default_uuid : str
If Styler custom uuid is not provided, this value will be used.
"""
# NB: UUID should be set before _compute is called.
_marshall_uuid(proto, styler, default_uuid)
# NB: We're using protected members of Styler to get styles,
# which is non-ideal and could break if Styler's interface changes.
styler._compute()
# In Pandas 1.3.0, styler._translate() signature was changed.
# 2 arguments were added: sparse_index and sparse_columns.
# The functionality that they provide is not yet supported.
if type_util.is_pandas_version_less_than("1.3.0"):
pandas_styles = styler._translate()
else:
pandas_styles = styler._translate(False, False)
_marshall_caption(proto, styler)
_marshall_styles(proto, styler, pandas_styles)
_marshall_display_values(proto, styler.data, pandas_styles)
def _marshall_uuid(proto, styler, default_uuid):
"""Marshall pandas.Styler UUID into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
styler : pandas.Styler
Styler holding styling data for the dataframe.
default_uuid : str
If Styler custom uuid is not provided, this value will be used.
"""
if styler.uuid is None:
styler.set_uuid(default_uuid)
proto.styler.uuid = str(styler.uuid)
def _marshall_caption(proto, styler):
"""Marshall pandas.Styler caption into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
styler : pandas.Styler
Styler holding styling data for the dataframe.
"""
if styler.caption is not None:
proto.styler.caption = styler.caption
def _marshall_styles(proto, styler, styles):
"""Marshall pandas.Styler styles into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
styler : pandas.Styler
Styler holding styling data for the dataframe.
styles : dict
pandas.Styler translated styles.
"""
css_rules = []
if "table_styles" in styles:
table_styles = styles["table_styles"]
table_styles = _trim_pandas_styles(table_styles)
for style in table_styles:
# NB: styles in "table_styles" have a space
# between the UUID and the selector.
rule = _pandas_style_to_css(
"table_styles", style, styler.uuid, separator=" "
)
css_rules.append(rule)
if "cellstyle" in styles:
cellstyle = styles["cellstyle"]
cellstyle = _trim_pandas_styles(cellstyle)
for style in cellstyle:
rule = _pandas_style_to_css("cell_style", style, styler.uuid)
css_rules.append(rule)
if len(css_rules) > 0:
proto.styler.styles = "\n".join(css_rules)
def _trim_pandas_styles(styles):
"""Trim pandas styles dict.
Parameters
----------
styles : dict
pandas.Styler translated styles.
"""
# Filter out empty styles, as every cell will have a class
# but the list of props may just be [['', '']].
return [x for x in styles if any(any(y) for y in x["props"])]
def _pandas_style_to_css(style_type, style, uuid, separator=""):
"""Convert pandas.Styler translated styles entry to CSS.
Parameters
----------
style : dict
pandas.Styler translated styles entry.
uuid: str
pandas.Styler UUID.
separator: str
A string separator used between table and cell selectors.
"""
declarations = []
for css_property, css_value in style["props"]:
declaration = css_property.strip() + ": " + css_value.strip()
declarations.append(declaration)
table_selector = "#T_" + str(uuid)
# In pandas < 1.1.0
# translated_style["cellstyle"] has the following shape:
# [
# {
# "props": [["color", " black"], ["background-color", "orange"], ["", ""]],
# "selector": "row0_col0"
# }
# ...
# ]
#
# In pandas >= 1.1.0
# translated_style["cellstyle"] has the following shape:
# [
# {
# "props": [("color", " black"), ("background-color", "orange"), ("", "")],
# "selectors": ["row0_col0"]
# }
# ...
# ]
if style_type == "table_styles" or (
style_type == "cell_style" and type_util.is_pandas_version_less_than("1.1.0")
):
cell_selectors = [style["selector"]]
else:
cell_selectors = style["selectors"]
selectors = []
for cell_selector in cell_selectors:
selectors.append(table_selector + separator + cell_selector)
selector = ", ".join(selectors)
declaration_block = "; ".join(declarations)
rule_set = selector + " { " + declaration_block + " }"
return rule_set
def _marshall_display_values(proto, df, styles):
"""Marshall pandas.Styler display values into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
df : pandas.DataFrame
A dataframe with original values.
styles : dict
pandas.Styler translated styles.
"""
new_df = _use_display_values(df, styles)
proto.styler.display_values = _dataframe_to_pybytes(new_df)
def _use_display_values(df, styles):
"""Create a new pandas.DataFrame where display values are used instead of original ones.
Parameters
----------
df : pandas.DataFrame
A dataframe with original values.
styles : dict
pandas.Styler translated styles.
"""
# (HK) TODO: Rewrite this method without using regex.
import re
# If values in a column are not of the same type, Arrow Table
# serialization would fail. Thus, we need to cast all values
# of the dataframe to strings before assigning them display values.
new_df = df.astype(str)
cell_selector_regex = re.compile(r"row(\d+)_col(\d+)")
if "body" in styles:
rows = styles["body"]
for row in rows:
for cell in row:
cell_id = cell["id"]
match = cell_selector_regex.match(cell_id)
if match:
r, c = map(int, match.groups())
new_df.iat[r, c] = str(cell["display_value"])
return new_df
def _dataframe_to_pybytes(df):
"""Convert pandas.DataFrame to pybytes.
Parameters
----------
df : pandas.DataFrame
A dataframe to convert.
"""
table = pa.Table.from_pandas(df)
sink = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(sink, table.schema)
writer.write_table(table)
writer.close()
return sink.getvalue().to_pybytes()
def _marshall_index(proto, index):
"""Marshall pandas.DataFrame index into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
index : Index or array-like
Index to use for resulting frame.
Will default to RangeIndex (0, 1, 2, ..., n) if no index is provided.
"""
index = map(util._maybe_tuple_to_list, index.values)
index_df = pd.DataFrame(index)
proto.index = _dataframe_to_pybytes(index_df)
def _marshall_columns(proto, columns):
"""Marshall pandas.DataFrame columns into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
columns : Index or array-like
Column labels to use for resulting frame.
Will default to RangeIndex (0, 1, 2, ..., n) if no column labels are provided.
"""
columns = map(util._maybe_tuple_to_list, columns.values)
columns_df = pd.DataFrame(columns)
proto.columns = _dataframe_to_pybytes(columns_df)
def _marshall_data(proto, data):
"""Marshall pandas.DataFrame data into an ArrowTable proto.
Parameters
----------
proto : proto.ArrowTable
Output. The protobuf for a Streamlit ArrowTable proto.
df : pandas.DataFrame
A dataframe to marshall.
"""
df = pd.DataFrame(data)
proto.data = _dataframe_to_pybytes(df)
def arrow_proto_to_dataframe(proto):
"""Convert ArrowTable proto to pandas.DataFrame.
Parameters
----------
proto : proto.ArrowTable
Output. pandas.DataFrame
"""
data = _pybytes_to_dataframe(proto.data)
index = _pybytes_to_dataframe(proto.index)
columns = _pybytes_to_dataframe(proto.columns)
return pd.DataFrame(
data.values, index=index.values.T.tolist(), columns=columns.values.T.tolist()
)
def _pybytes_to_dataframe(source):
"""Convert pybytes to pandas.DataFrame.
Parameters
----------
source : pybytes
Will default to RangeIndex (0, 1, 2, ..., n) if no `index` or `columns` are provided.
"""
reader = pa.RecordBatchStreamReader(source)
return reader.read_pandas()

View File

@ -0,0 +1,442 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import mimetypes
import os
import threading
from typing import Any, Dict, Optional, Type, Union
import tornado.web
from streamlit.scriptrunner import get_script_run_ctx
import streamlit.server.routes
from streamlit import type_util
from streamlit.elements.form import current_form_id
from streamlit import util
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.proto.Components_pb2 import SpecialArg, ArrowTable as ArrowTableProto
from streamlit.proto.Element_pb2 import Element
from streamlit.state import NoValue, register_widget
from streamlit.type_util import to_bytes
LOGGER = get_logger(__name__)
class MarshallComponentException(StreamlitAPIException):
"""Class for exceptions generated during custom component marshalling."""
pass
class CustomComponent:
"""A Custom Component declaration."""
def __init__(
self,
name: str,
path: Optional[str] = None,
url: Optional[str] = None,
):
if (path is None and url is None) or (path is not None and url is not None):
raise StreamlitAPIException(
"Either 'path' or 'url' must be set, but not both."
)
self.name = name
self.path = path
self.url = url
def __repr__(self) -> str:
return util.repr_(self)
@property
def abspath(self) -> Optional[str]:
"""The absolute path that the component is served from."""
if self.path is None:
return None
return os.path.abspath(self.path)
def __call__(
self,
*args,
default: Any = None,
key: Optional[str] = None,
**kwargs,
) -> Any:
"""An alias for create_instance."""
return self.create_instance(*args, default=default, key=key, **kwargs)
def create_instance(
self,
*args,
default: Any = None,
key: Optional[str] = None,
**kwargs,
) -> Any:
"""Create a new instance of the component.
Parameters
----------
*args
Must be empty; all args must be named. (This parameter exists to
enforce correct use of the function.)
default: any or None
The default return value for the component. This is returned when
the component's frontend hasn't yet specified a value with
`setComponentValue`.
key: str or None
If not None, this is the user key we use to generate the
component's "widget ID".
**kwargs
Keyword args to pass to the component.
Returns
-------
any or None
The component's widget value.
"""
if len(args) > 0:
raise MarshallComponentException(f"Argument '{args[0]}' needs a label")
try:
import pyarrow
from streamlit.components.v1 import component_arrow
except ImportError:
raise StreamlitAPIException(
"""To use Custom Components in Streamlit, you need to install
PyArrow. To do so locally:
`pip install pyarrow`
And if you're using Streamlit Cloud, add "pyarrow" to your requirements.txt."""
)
# In addition to the custom kwargs passed to the component, we also
# send the special 'default' and 'key' params to the component
# frontend.
all_args = dict(kwargs, **{"default": default, "key": key})
json_args = {}
special_args = []
for arg_name, arg_val in all_args.items():
if type_util.is_bytes_like(arg_val):
bytes_arg = SpecialArg()
bytes_arg.key = arg_name
bytes_arg.bytes = to_bytes(arg_val)
special_args.append(bytes_arg)
elif type_util.is_dataframe_like(arg_val):
dataframe_arg = SpecialArg()
dataframe_arg.key = arg_name
component_arrow.marshall(dataframe_arg.arrow_dataframe.data, arg_val)
special_args.append(dataframe_arg)
else:
json_args[arg_name] = arg_val
try:
serialized_json_args = json.dumps(json_args)
except BaseException as e:
raise MarshallComponentException(
"Could not convert component args to JSON", e
)
def marshall_component(dg, element: Element) -> Union[Any, Type[NoValue]]:
element.component_instance.component_name = self.name
element.component_instance.form_id = current_form_id(dg)
if self.url is not None:
element.component_instance.url = self.url
# Normally, a widget's element_hash (which determines
# its identity across multiple runs of an app) is computed
# by hashing the entirety of its protobuf. This means that,
# if any of the arguments to the widget are changed, Streamlit
# considers it a new widget instance and it loses its previous
# state.
#
# However! If a *component* has a `key` argument, then the
# component's hash identity is determined by entirely by
# `component_name + url + key`. This means that, when `key`
# exists, the component will maintain its identity even when its
# other arguments change, and the component's iframe won't be
# remounted on the frontend.
#
# So: if `key` is None, we marshall the element's arguments
# *before* computing its widget_ui_value (which creates its hash).
# If `key` is not None, we marshall the arguments *after*.
def marshall_element_args():
element.component_instance.json_args = serialized_json_args
element.component_instance.special_args.extend(special_args)
if key is None:
marshall_element_args()
def deserialize_component(ui_value, widget_id=""):
# ui_value is an object from json, an ArrowTable proto, or a bytearray
return ui_value
ctx = get_script_run_ctx()
widget_value, _ = register_widget(
element_type="component_instance",
element_proto=element.component_instance,
user_key=key,
widget_func_name=self.name,
deserializer=deserialize_component,
serializer=lambda x: x,
ctx=ctx,
)
if key is not None:
marshall_element_args()
if widget_value is None:
widget_value = default
elif isinstance(widget_value, ArrowTableProto):
widget_value = component_arrow.arrow_proto_to_dataframe(widget_value)
# widget_value will be either None or whatever the component's most
# recent setWidgetValue value is. We coerce None -> NoValue,
# because that's what DeltaGenerator._enqueue expects.
return widget_value if widget_value is not None else NoValue
# We currently only support writing to st._main, but this will change
# when we settle on an improved API in a post-layout world.
dg = streamlit._main
element = Element()
return_value = marshall_component(dg, element)
result = dg._enqueue(
"component_instance", element.component_instance, return_value
)
return result
def __eq__(self, other) -> bool:
"""Equality operator."""
return (
isinstance(other, CustomComponent)
and self.name == other.name
and self.path == other.path
and self.url == other.url
)
def __ne__(self, other) -> bool:
"""Inequality operator."""
return not self == other
def __str__(self) -> str:
return f"'{self.name}': {self.path if self.path is not None else self.url}"
def declare_component(
name: str,
path: Optional[str] = None,
url: Optional[str] = None,
) -> CustomComponent:
"""Create and register a custom component.
Parameters
----------
name: str
A short, descriptive name for the component. Like, "slider".
path: str or None
The path to serve the component's frontend files from. Either
`path` or `url` must be specified, but not both.
url: str or None
The URL that the component is served from. Either `path` or `url`
must be specified, but not both.
Returns
-------
CustomComponent
A CustomComponent that can be called like a function.
Calling the component will create a new instance of the component
in the Streamlit app.
"""
# Get our stack frame.
current_frame = inspect.currentframe()
assert current_frame is not None
# Get the stack frame of our calling function.
caller_frame = current_frame.f_back
assert caller_frame is not None
# Get the caller's module name. `__name__` gives us the module's
# fully-qualified name, which includes its package.
module = inspect.getmodule(caller_frame)
assert module is not None
module_name = module.__name__
# If the caller was the main module that was executed (that is, if the
# user executed `python my_component.py`), then this name will be
# "__main__" instead of the actual package name. In this case, we use
# the main module's filename, sans `.py` extension, as the component name.
if module_name == "__main__":
file_path = inspect.getfile(caller_frame)
filename = os.path.basename(file_path)
module_name, _ = os.path.splitext(filename)
# Build the component name.
component_name = f"{module_name}.{name}"
# Create our component object, and register it.
component = CustomComponent(name=component_name, path=path, url=url)
ComponentRegistry.instance().register_component(component)
return component
class ComponentRequestHandler(tornado.web.RequestHandler):
def initialize(self, registry: "ComponentRegistry"):
self._registry = registry
def get(self, path: str) -> None:
parts = path.split("/")
component_name = parts[0]
component_root = self._registry.get_component_path(component_name)
if component_root is None:
self.write("not found")
self.set_status(404)
return
filename = "/".join(parts[1:])
abspath = os.path.join(component_root, filename)
LOGGER.debug("ComponentRequestHandler: GET: %s -> %s", path, abspath)
try:
with open(abspath, "rb") as file:
contents = file.read()
except (OSError) as e:
LOGGER.error(f"ComponentRequestHandler: GET {path} read error", exc_info=e)
self.write("read error")
self.set_status(404)
return
self.write(contents)
self.set_header("Content-Type", self.get_content_type(abspath))
self.set_extra_headers(path)
def set_extra_headers(self, path) -> None:
"""Disable cache for HTML files.
Other assets like JS and CSS are suffixed with their hash, so they can
be cached indefinitely.
"""
is_index_url = len(path) == 0
if is_index_url or path.endswith(".html"):
self.set_header("Cache-Control", "no-cache")
else:
self.set_header("Cache-Control", "public")
def set_default_headers(self) -> None:
if streamlit.server.routes.allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self) -> None:
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()
@staticmethod
def get_content_type(abspath) -> str:
"""Returns the ``Content-Type`` header to be used for this request.
From tornado.web.StaticFileHandler.
"""
mime_type, encoding = mimetypes.guess_type(abspath)
# per RFC 6713, use the appropriate type for a gzip compressed file
if encoding == "gzip":
return "application/gzip"
# As of 2015-07-21 there is no bzip2 encoding defined at
# http://www.iana.org/assignments/media-types/media-types.xhtml
# So for that (and any other encoding), use octet-stream.
elif encoding is not None:
return "application/octet-stream"
elif mime_type is not None:
return mime_type
# if mime_type not detected, use application/octet-stream
else:
return "application/octet-stream"
@staticmethod
def get_url(file_id: str) -> str:
"""Return the URL for a component file with the given ID."""
return "components/{}".format(file_id)
class ComponentRegistry:
_instance_lock = threading.Lock()
_instance = None # type: Optional[ComponentRegistry]
@classmethod
def instance(cls) -> "ComponentRegistry":
"""Returns the singleton ComponentRegistry"""
# We use a double-checked locking optimization to avoid the overhead
# of acquiring the lock in the common case:
# https://en.wikipedia.org/wiki/Double-checked_locking
if cls._instance is None:
with cls._instance_lock:
if cls._instance is None:
cls._instance = ComponentRegistry()
return cls._instance
def __init__(self):
self._components = {} # type: Dict[str, CustomComponent]
self._lock = threading.Lock()
def __repr__(self) -> str:
return util.repr_(self)
def register_component(self, component: CustomComponent) -> None:
"""Register a CustomComponent.
Parameters
----------
component : CustomComponent
The component to register.
"""
# Validate the component's path
abspath = component.abspath
if abspath is not None and not os.path.isdir(abspath):
raise StreamlitAPIException(f"No such component directory: '{abspath}'")
with self._lock:
existing = self._components.get(component.name)
self._components[component.name] = component
if existing is not None and component != existing:
LOGGER.warning(
"%s overriding previously-registered %s",
component,
existing,
)
LOGGER.debug("Registered component %s", component)
def get_component_path(self, name: str) -> Optional[str]:
"""Return the filesystem path for the component with the given name.
If no such component is registered, or if the component exists but is
being served from a URL, return None instead.
"""
component = self._components.get(name, None)
return component.abspath if component is not None else None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,291 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This class stores a key-value pair for the config system."""
import datetime
import re
import textwrap
from typing import Any, Callable, Optional
from streamlit.errors import DeprecationError
from streamlit import util
class ConfigOption:
'''Stores a Streamlit configuration option.
A configuration option, like 'browser.serverPort', which indicates which port
to use when connecting to the proxy. There are two ways to create a
ConfigOption:
Simple ConfigOptions are created as follows:
ConfigOption('browser.serverPort',
description = 'Connect to the proxy at this port.',
default_val = 8501)
More complex config options resolve thier values at runtime as follows:
@ConfigOption('browser.serverPort')
def _proxy_port():
"""Connect to the proxy at this port.
Defaults to 8501.
"""
return 8501
NOTE: For complex config options, the function is called each time the
option.value is evaluated!
Attributes
----------
key : str
The fully qualified section.name
value : any
The value for this option. If this is a a complex config option then
the callback is called EACH TIME value is evaluated.
section : str
The section of this option. Example: 'global'.
name : str
See __init__.
description : str
See __init__.
where_defined : str
Indicates which file set this config option.
ConfigOption.DEFAULT_DEFINITION means this file.
visibility : {"visible", "hidden"}
See __init__.
scriptable : bool
See __init__.
deprecated: bool
See __init__.
deprecation_text : str or None
See __init__.
expiration_date : str or None
See __init__.
replaced_by : str or None
See __init__.
'''
# This is a special value for ConfigOption.where_defined which indicates
# that the option default was not overridden.
DEFAULT_DEFINITION = "<default>"
# This is a special value for ConfigOption.where_defined which indicates
# that the options was defined by Streamlit's own code.
STREAMLIT_DEFINITION = "<streamlit>"
def __init__(
self,
key: str,
description: Optional[str] = None,
default_val: Optional[Any] = None,
visibility: str = "visible",
scriptable: bool = False,
deprecated: bool = False,
deprecation_text: Optional[str] = None,
expiration_date: Optional[str] = None,
replaced_by: Optional[str] = None,
type_: type = str,
):
"""Create a ConfigOption with the given name.
Parameters
----------
key : str
Should be of the form "section.optionName"
Examples: server.name, deprecation.v1_0_featureName
description : str
Like a comment for the config option.
default_val : any
The value for this config option.
visibility : {"visible", "hidden"}
Whether this option should be shown to users.
scriptable : bool
Whether this config option can be set within a user script.
deprecated: bool
Whether this config option is deprecated.
deprecation_text : str or None
Required if deprecated == True. Set this to a string explaining
what to use instead.
expiration_date : str or None
Required if deprecated == True. set this to the date at which it
will no longer be accepted. Format: 'YYYY-MM-DD'.
replaced_by : str or None
If this is option has been deprecated in favor or another option,
set this to the path to the new option. Example:
'server.runOnSave'. If this is set, the 'deprecated' option
will automatically be set to True, and deprecation_text will have a
meaningful default (unless you override it).
type_ : one of str, int, float or bool
Useful to cast the config params sent by cmd option parameter.
"""
# Parse out the section and name.
self.key = key
key_format = (
# Capture a group called "section"
r"(?P<section>"
# Matching text comprised of letters and numbers that begins
# with a lowercase letter with an optional "_" preceeding it.
# Examples: "_section", "section1"
r"\_?[a-z][a-zA-Z0-9]*"
r")"
# Separator between groups
r"\."
# Capture a group called "name"
r"(?P<name>"
# Match text comprised of letters and numbers beginning with a
# lowercase letter.
# Examples: "name", "nameOfConfig", "config1"
r"[a-z][a-zA-Z0-9]*"
r")$"
)
match = re.match(key_format, self.key)
assert match, f'Key "{self.key}" has invalid format.'
self.section, self.name = match.group("section"), match.group("name")
self.description = description
self.visibility = visibility
self.scriptable = scriptable
self.default_val = default_val
self.deprecated = deprecated
self.replaced_by = replaced_by
self._get_val_func: Optional[Callable[[], Any]] = None
self.where_defined = ConfigOption.DEFAULT_DEFINITION
self.type = type_
if self.replaced_by:
self.deprecated = True
if deprecation_text is None:
deprecation_text = "Replaced by %s." % self.replaced_by
if self.deprecated:
assert expiration_date, "expiration_date is required for deprecated items"
assert deprecation_text, "deprecation_text is required for deprecated items"
self.expiration_date = expiration_date
self.deprecation_text = textwrap.dedent(deprecation_text)
self.set_value(default_val)
def __repr__(self) -> str:
return util.repr_(self)
def __call__(self, get_val_func: Callable[[], Any]) -> "ConfigOption":
"""Assign a function to compute the value for this option.
This method is called when ConfigOption is used as a decorator.
Parameters
----------
get_val_func : function
A function which will be called to get the value of this parameter.
We will use its docString as the description.
Returns
-------
ConfigOption
Returns self, which makes testing easier. See config_test.py.
"""
assert (
get_val_func.__doc__
), "Complex config options require doc strings for their description."
self.description = get_val_func.__doc__
self._get_val_func = get_val_func
return self
@property
def value(self) -> Any:
"""Get the value of this config option."""
if self._get_val_func is None:
return None
return self._get_val_func()
def set_value(self, value: Any, where_defined: Optional[str] = None) -> None:
"""Set the value of this option.
Parameters
----------
value
The new value for this parameter.
where_defined : str
New value to remember where this parameter was set.
"""
self._get_val_func = lambda: value
if where_defined is None:
self.where_defined = ConfigOption.DEFAULT_DEFINITION
else:
self.where_defined = where_defined
if self.deprecated and self.where_defined != ConfigOption.DEFAULT_DEFINITION:
details = {
"key": self.key,
"file": self.where_defined,
"explanation": self.deprecation_text,
"date": self.expiration_date,
}
if self.is_expired():
raise DeprecationError(
textwrap.dedent(
"""
════════════════════════════════════════════════
%(key)s IS NO LONGER SUPPORTED.
%(explanation)s
Please update %(file)s.
════════════════════════════════════════════════
"""
)
% details
)
else:
from streamlit.logger import get_logger
LOGGER = get_logger(__name__)
LOGGER.warning(
textwrap.dedent(
"""
════════════════════════════════════════════════
%(key)s IS DEPRECATED.
%(explanation)s
This option will be removed on or after %(date)s.
Please update %(file)s.
════════════════════════════════════════════════
"""
)
% details
)
def is_expired(self) -> bool:
"""Returns true if expiration_date is in the past."""
if not self.deprecated:
return False
expiration_date = _parse_yyyymmdd_str(self.expiration_date)
now = datetime.datetime.now()
return now > expiration_date
def _parse_yyyymmdd_str(date_str: str) -> datetime.datetime:
year, month, day = [int(token) for token in date_str.split("-", 2)]
return datetime.datetime(year, month, day)

View File

@ -0,0 +1,134 @@
import toml
from typing import Dict
import click
from streamlit.config_option import ConfigOption
def server_option_changed(
old_options: Dict[str, ConfigOption], new_options: Dict[str, ConfigOption]
) -> bool:
"""Return True if and only if an option in the server section differs
between old_options and new_options."""
for opt_name in old_options.keys():
if not opt_name.startswith("server"):
continue
old_val = old_options[opt_name].value
new_val = new_options[opt_name].value
if old_val != new_val:
return True
return False
def show_config(
section_descriptions: Dict[str, str],
config_options: Dict[str, ConfigOption],
) -> None:
"""Print the given config sections/options to the terminal."""
SKIP_SECTIONS = {"_test", "ui"}
out = []
out.append(
_clean(
"""
# Below are all the sections and options you can have in
~/.streamlit/config.toml.
"""
)
)
def append_desc(text):
out.append(click.style(text, bold=True))
def append_comment(text):
out.append(click.style(text))
def append_section(text):
out.append(click.style(text, bold=True, fg="green"))
def append_setting(text):
out.append(click.style(text, fg="green"))
def append_newline():
out.append("")
for section, section_description in section_descriptions.items():
if section in SKIP_SECTIONS:
continue
append_newline()
append_section("[%s]" % section)
append_newline()
for key, option in config_options.items():
if option.section != section:
continue
if option.visibility == "hidden":
continue
if option.is_expired():
continue
key = option.key.split(".")[1]
description_paragraphs = _clean_paragraphs(option.description)
for i, txt in enumerate(description_paragraphs):
if i == 0:
append_desc("# %s" % txt)
else:
append_comment("# %s" % txt)
toml_default = toml.dumps({"default": option.default_val})
toml_default = toml_default[10:].strip()
if len(toml_default) > 0:
append_comment("# Default: %s" % toml_default)
else:
# Don't say "Default: (unset)" here because this branch applies
# to complex config settings too.
pass
if option.deprecated:
append_comment("#")
append_comment("# " + click.style("DEPRECATED.", fg="yellow"))
append_comment(
"# %s" % "\n".join(_clean_paragraphs(option.deprecation_text))
)
append_comment(
"# This option will be removed on or after %s."
% option.expiration_date
)
append_comment("#")
option_is_manually_set = (
option.where_defined != ConfigOption.DEFAULT_DEFINITION
)
if option_is_manually_set:
append_comment("# The value below was set in %s" % option.where_defined)
toml_setting = toml.dumps({key: option.value})
if len(toml_setting) == 0:
toml_setting = f"# {key} =\n"
elif not option_is_manually_set:
toml_setting = f"# {toml_setting}"
append_setting(toml_setting)
click.echo("\n".join(out))
def _clean(txt):
"""Replace all whitespace with a single space."""
return " ".join(txt.split()).strip()
def _clean_paragraphs(txt):
paragraphs = txt.split("\n\n")
cleaned_paragraphs = [_clean(x) for x in paragraphs]
return cleaned_paragraphs

View File

@ -0,0 +1,286 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Manage the user's Streamlit credentials."""
import os
import sys
import textwrap
from collections import namedtuple
from typing import Optional
import click
import toml
from streamlit import util
from streamlit import env_util
from streamlit import file_util
from streamlit.logger import get_logger
LOGGER = get_logger(__name__)
# WT_SESSION is a Windows Terminal specific environment variable. If it exists,
# we are on the latest Windows Terminal that supports emojis
_SHOW_EMOJIS = not env_util.IS_WINDOWS or os.environ.get("WT_SESSION")
if env_util.IS_WINDOWS:
_CONFIG_FILE_PATH = r"%userprofile%/.streamlit/config.toml"
else:
_CONFIG_FILE_PATH = "~/.streamlit/config.toml"
_Activation = namedtuple(
"_Activation",
[
"email", # str : the user's email.
"is_valid", # boolean : whether the email is valid.
],
)
# IMPORTANT: Break the text below at 80 chars.
_EMAIL_PROMPT = """
{0}%(welcome)s
If you're one of our development partners or you're interested in getting
personal technical support or Streamlit updates, please enter your email
address below. Otherwise, you may leave the field blank.
%(email)s""".format(
"👋 " if _SHOW_EMOJIS else ""
) % {
"welcome": click.style("Welcome to Streamlit!", bold=True),
"email": click.style("Email: ", fg="blue"),
}
# IMPORTANT: Break the text below at 80 chars.
_TELEMETRY_TEXT = """
%(privacy)s
As an open source project, we collect usage statistics. We cannot see and do
not store information contained in Streamlit apps. You can find out more by
reading our privacy policy at: %(link)s
If you'd like to opt out of usage statistics, add the following to
%(config)s, creating that file if necessary:
[browser]
gatherUsageStats = false
""" % {
"privacy": click.style("Privacy Policy:", bold=True),
"link": click.style("https://streamlit.io/privacy-policy", underline=True),
"config": click.style(_CONFIG_FILE_PATH),
}
# IMPORTANT: Break the text below at 80 chars.
_INSTRUCTIONS_TEXT = """
%(start)s
%(prompt)s %(hello)s
""" % {
"start": click.style("Get started by typing:", fg="blue", bold=True),
"prompt": click.style("$", fg="blue"),
"hello": click.style("streamlit hello", bold=True),
}
class Credentials(object):
"""Credentials class."""
_singleton = None # type: Optional[Credentials]
@classmethod
def get_current(cls):
"""Return the singleton instance."""
if cls._singleton is None:
Credentials()
return Credentials._singleton
def __init__(self):
"""Initialize class."""
if Credentials._singleton is not None:
raise RuntimeError(
"Credentials already initialized. Use .get_current() instead"
)
self.activation = None
self._conf_file = _get_credential_file_path()
Credentials._singleton = self
def __repr__(self) -> str:
return util.repr_(self)
def load(self, auto_resolve=False) -> None:
"""Load from toml file."""
if self.activation is not None:
LOGGER.error("Credentials already loaded. Not rereading file.")
return
try:
with open(self._conf_file, "r") as f:
data = toml.load(f).get("general")
if data is None:
raise Exception
self.activation = _verify_email(data.get("email"))
except FileNotFoundError:
if auto_resolve:
return self.activate(show_instructions=not auto_resolve)
raise RuntimeError(
'Credentials not found. Please run "streamlit activate".'
)
except Exception as e:
if auto_resolve:
self.reset()
return self.activate(show_instructions=not auto_resolve)
raise Exception(
textwrap.dedent(
"""
Unable to load credentials from %s.
Run "streamlit reset" and try again.
"""
)
% (self._conf_file)
)
def _check_activated(self, auto_resolve=True):
"""Check if streamlit is activated.
Used by `streamlit run script.py`
"""
try:
self.load(auto_resolve)
except (Exception, RuntimeError) as e:
_exit(str(e))
if self.activation is None or not self.activation.is_valid:
_exit("Activation email not valid.")
@classmethod
def reset(cls):
"""Reset credentials by removing file.
This is used by `streamlit activate reset` in case a user wants
to start over.
"""
c = Credentials.get_current()
c.activation = None
try:
os.remove(c._conf_file)
except OSError as e:
LOGGER.error("Error removing credentials file: %s" % e)
def save(self):
"""Save to toml file."""
if self.activation is None:
return
# Create intermediate directories if necessary
os.makedirs(os.path.dirname(self._conf_file), exist_ok=True)
# Write the file
data = {"email": self.activation.email}
with open(self._conf_file, "w") as f:
toml.dump({"general": data}, f)
def activate(self, show_instructions: bool = True) -> None:
"""Activate Streamlit.
Used by `streamlit activate`.
"""
try:
self.load()
except RuntimeError:
pass
if self.activation:
if self.activation.is_valid:
_exit("Already activated")
else:
_exit(
"Activation not valid. Please run "
"`streamlit activate reset` then `streamlit activate`"
)
else:
activated = False
while not activated:
email = click.prompt(
text=_EMAIL_PROMPT, prompt_suffix="", default="", show_default=False
)
self.activation = _verify_email(email)
if self.activation.is_valid:
self.save()
click.secho(_TELEMETRY_TEXT)
if show_instructions:
click.secho(_INSTRUCTIONS_TEXT)
activated = True
else: # pragma: nocover
LOGGER.error("Please try again.")
def _verify_email(email: str) -> _Activation:
"""Verify the user's email address.
The email can either be an empty string (if the user chooses not to enter
it), or a string with a single '@' somewhere in it.
Parameters
----------
email : str
Returns
-------
_Activation
An _Activation object. Its 'is_valid' property will be True only if
the email was validated.
"""
email = email.strip()
# We deliberately use simple email validation here
# since we do not use email address anywhere to send emails.
if len(email) > 0 and email.count("@") != 1:
LOGGER.error("That doesn't look like an email :(")
return _Activation(None, False)
return _Activation(email, True)
def _exit(message): # pragma: nocover
"""Exit program with error."""
LOGGER.error(message)
sys.exit(-1)
def _get_credential_file_path():
return file_util.get_streamlit_file_path("credentials.toml")
def _check_credential_file_exists():
return os.path.exists(_get_credential_file_path())
def check_credentials():
"""Check credentials and potentially activate.
Note
----
If there is no credential file and we are in headless mode, we should not
check, since credential would be automatically set to an empty string.
"""
from streamlit import config
if not _check_credential_file_exists() and config.get_option("server.headless"):
return
Credentials.get_current()._check_activated()

View File

@ -0,0 +1,207 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Any, List
from streamlit import util
from streamlit.scriptrunner import get_script_run_ctx
def make_delta_path(
root_container: int, parent_path: Tuple[int, ...], index: int
) -> List[int]:
delta_path = [root_container]
delta_path.extend(parent_path)
delta_path.append(index)
return delta_path
def get_container_cursor(
root_container: Optional[int],
) -> Optional["RunningCursor"]:
"""Return the top-level RunningCursor for the given container.
This is the cursor that is used when user code calls something like
`st.foo` (which uses the main container) or `st.sidebar.foo` (which uses
the sidebar container).
"""
if root_container is None:
return None
ctx = get_script_run_ctx()
if ctx is None:
return None
if root_container in ctx.cursors:
return ctx.cursors[root_container]
cursor = RunningCursor(root_container=root_container)
ctx.cursors[root_container] = cursor
return cursor
class Cursor:
"""A pointer to a delta location in the app.
When adding an element to the app, you should always call
get_locked_cursor() on that element's respective Cursor.
"""
def __repr__(self) -> str:
return util.repr_(self)
@property
def root_container(self) -> int:
"""The top-level container this cursor lives within - either
RootContainer.MAIN or RootContainer.SIDEBAR."""
raise NotImplementedError()
@property
def parent_path(self) -> Tuple[int, ...]:
"""The cursor's parent's path within its container."""
raise NotImplementedError()
@property
def index(self) -> int:
"""The index of the Delta within its parent block."""
raise NotImplementedError()
@property
def delta_path(self) -> List[int]:
"""The complete path of the delta pointed to by this cursor - its
container, parent path, and index.
"""
return make_delta_path(self.root_container, self.parent_path, self.index)
@property
def is_locked(self) -> bool:
raise NotImplementedError()
def get_locked_cursor(self, **props) -> "LockedCursor":
raise NotImplementedError()
@property
def props(self) -> Any:
"""Other data in this cursor. This is a temporary measure that will go
away when we implement improved return values for elements.
This is only implemented in LockedCursor.
"""
raise NotImplementedError()
class RunningCursor(Cursor):
def __init__(self, root_container: int, parent_path: Tuple[int, ...] = ()):
"""A moving pointer to a delta location in the app.
RunningCursors auto-increment to the next available location when you
call get_locked_cursor() on them.
Parameters
----------
root_container: int
The root container this cursor lives in.
parent_path: tuple of ints
The full path of this cursor, consisting of the IDs of all ancestors.
The 0th item is the topmost ancestor.
"""
self._root_container = root_container
self._parent_path = parent_path
self._index = 0
@property
def root_container(self) -> int:
return self._root_container
@property
def parent_path(self) -> Tuple[int, ...]:
return self._parent_path
@property
def index(self) -> int:
return self._index
@property
def is_locked(self) -> bool:
return False
def get_locked_cursor(self, **props) -> "LockedCursor":
locked_cursor = LockedCursor(
root_container=self._root_container,
parent_path=self._parent_path,
index=self._index,
**props,
)
self._index += 1
return locked_cursor
class LockedCursor(Cursor):
def __init__(
self,
root_container: int,
parent_path: Tuple[int, ...] = (),
index: int = 0,
**props,
):
"""A locked pointer to a location in the app.
LockedCursors always point to the same location, even when you call
get_locked_cursor() on them.
Parameters
----------
root_container: int
The root container this cursor lives in.
parent_path: tuple of ints
The full path of this cursor, consisting of the IDs of all ancestors. The
0th item is the topmost ancestor.
index: int
**props: any
Anything else you want to store in this cursor. This is a temporary
measure that will go away when we implement improved return values
for elements.
"""
self._root_container = root_container
self._index = index
self._parent_path = parent_path
self._props = props
@property
def root_container(self) -> int:
return self._root_container
@property
def parent_path(self) -> Tuple[int, ...]:
return self._parent_path
@property
def index(self) -> int:
return self._index
@property
def is_locked(self) -> bool:
return True
def get_locked_cursor(self, **props) -> "LockedCursor":
self._props = props
return self
@property
def props(self) -> Any:
return self._props

View File

@ -0,0 +1,764 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Allows us to create and absorb changes (aka Deltas) to elements."""
from typing import Optional, Iterable
import streamlit as st
from streamlit import cursor, caching
from streamlit import legacy_caching
from streamlit import type_util
from streamlit import util
from streamlit.cursor import Cursor
from streamlit.scriptrunner import get_script_run_ctx
from streamlit.errors import StreamlitAPIException
from streamlit.errors import NoSessionContext
from streamlit.proto import Block_pb2
from streamlit.proto import ForwardMsg_pb2
from streamlit.proto.RootContainer_pb2 import RootContainer
from streamlit.logger import get_logger
from streamlit.elements.balloons import BalloonsMixin
from streamlit.elements.button import ButtonMixin
from streamlit.elements.markdown import MarkdownMixin
from streamlit.elements.text import TextMixin
from streamlit.elements.alert import AlertMixin
from streamlit.elements.json import JsonMixin
from streamlit.elements.doc_string import HelpMixin
from streamlit.elements.exception import ExceptionMixin
from streamlit.elements.bokeh_chart import BokehMixin
from streamlit.elements.graphviz_chart import GraphvizMixin
from streamlit.elements.plotly_chart import PlotlyMixin
from streamlit.elements.deck_gl_json_chart import PydeckMixin
from streamlit.elements.map import MapMixin
from streamlit.elements.iframe import IframeMixin
from streamlit.elements.media import MediaMixin
from streamlit.elements.checkbox import CheckboxMixin
from streamlit.elements.multiselect import MultiSelectMixin
from streamlit.elements.metric import MetricMixin
from streamlit.elements.radio import RadioMixin
from streamlit.elements.selectbox import SelectboxMixin
from streamlit.elements.text_widgets import TextWidgetsMixin
from streamlit.elements.time_widgets import TimeWidgetsMixin
from streamlit.elements.progress import ProgressMixin
from streamlit.elements.empty import EmptyMixin
from streamlit.elements.number_input import NumberInputMixin
from streamlit.elements.camera_input import CameraInputMixin
from streamlit.elements.color_picker import ColorPickerMixin
from streamlit.elements.file_uploader import FileUploaderMixin
from streamlit.elements.select_slider import SelectSliderMixin
from streamlit.elements.slider import SliderMixin
from streamlit.elements.snow import SnowMixin
from streamlit.elements.image import ImageMixin
from streamlit.elements.pyplot import PyplotMixin
from streamlit.elements.write import WriteMixin
from streamlit.elements.layouts import LayoutsMixin
from streamlit.elements.form import FormMixin, FormData, current_form_id
from streamlit.state import NoValue
# DataFrame elements come in two flavors: "Legacy" and "Arrow".
# We select between them with the DataFrameElementSelectorMixin.
from streamlit.elements.arrow import ArrowMixin
from streamlit.elements.arrow_altair import ArrowAltairMixin
from streamlit.elements.arrow_vega_lite import ArrowVegaLiteMixin
from streamlit.elements.legacy_data_frame import LegacyDataFrameMixin
from streamlit.elements.legacy_altair import LegacyAltairMixin
from streamlit.elements.legacy_vega_lite import LegacyVegaLiteMixin
from streamlit.elements.dataframe_selector import DataFrameSelectorMixin
LOGGER = get_logger(__name__)
# Save the type built-in for when we override the name "type".
_type = type
MAX_DELTA_BYTES = 14 * 1024 * 1024 # 14MB
# List of Streamlit commands that perform a Pandas "melt" operation on
# input dataframes.
DELTA_TYPES_THAT_MELT_DATAFRAMES = ("line_chart", "area_chart", "bar_chart")
ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES = (
"arrow_line_chart",
"arrow_area_chart",
"arrow_bar_chart",
)
class DeltaGenerator(
AlertMixin,
BalloonsMixin,
BokehMixin,
ButtonMixin,
CameraInputMixin,
CheckboxMixin,
ColorPickerMixin,
EmptyMixin,
ExceptionMixin,
FileUploaderMixin,
FormMixin,
GraphvizMixin,
HelpMixin,
IframeMixin,
ImageMixin,
LayoutsMixin,
MarkdownMixin,
MapMixin,
MediaMixin,
MetricMixin,
MultiSelectMixin,
NumberInputMixin,
PlotlyMixin,
ProgressMixin,
PydeckMixin,
PyplotMixin,
RadioMixin,
SelectboxMixin,
SelectSliderMixin,
SliderMixin,
SnowMixin,
JsonMixin,
TextMixin,
TextWidgetsMixin,
TimeWidgetsMixin,
WriteMixin,
ArrowMixin,
ArrowAltairMixin,
ArrowVegaLiteMixin,
LegacyDataFrameMixin,
LegacyAltairMixin,
LegacyVegaLiteMixin,
DataFrameSelectorMixin,
):
"""Creator of Delta protobuf messages.
Parameters
----------
root_container: BlockPath_pb2.BlockPath.ContainerValue or None
The root container for this DeltaGenerator. If None, this is a null
DeltaGenerator which doesn't print to the app at all (useful for
testing).
cursor: cursor.Cursor or None
This is either:
- None: if this is the running DeltaGenerator for a top-level
container (MAIN or SIDEBAR)
- RunningCursor: if this is the running DeltaGenerator for a
non-top-level container (created with dg.container())
- LockedCursor: if this is a locked DeltaGenerator returned by some
other DeltaGenerator method. E.g. the dg returned in dg =
st.text("foo").
parent: DeltaGenerator
To support the `with dg` notation, DGs are arranged as a tree. Each DG
remembers its own parent, and the root of the tree is the main DG.
block_type: None or "vertical" or "horizontal" or "column" or "expandable"
If this is a block DG, we track its type to prevent nested columns/expanders
"""
# The pydoc below is for user consumption, so it doesn't talk about
# DeltaGenerator constructor parameters (which users should never use). For
# those, see above.
def __init__(
self,
root_container: Optional[int] = RootContainer.MAIN,
cursor: Optional[Cursor] = None,
parent: Optional["DeltaGenerator"] = None,
block_type: Optional[str] = None,
):
"""Inserts or updates elements in Streamlit apps.
As a user, you should never initialize this object by hand. Instead,
DeltaGenerator objects are initialized for you in two places:
1) When you call `dg = st.foo()` for some method "foo", sometimes `dg`
is a DeltaGenerator object. You can call methods on the `dg` object to
update the element `foo` that appears in the Streamlit app.
2) This is an internal detail, but `st.sidebar` itself is a
DeltaGenerator. That's why you can call `st.sidebar.foo()` to place
an element `foo` inside the sidebar.
"""
# Sanity check our Container + Cursor, to ensure that our Cursor
# is using the same Container that we are.
if (
root_container is not None
and cursor is not None
and root_container != cursor.root_container
):
raise RuntimeError(
"DeltaGenerator root_container and cursor.root_container must be the same"
)
# Whether this DeltaGenerator is nested in the main area or sidebar.
# No relation to `st.container()`.
self._root_container = root_container
# NOTE: You should never use this directly! Instead, use self._cursor,
# which is a computed property that fetches the right cursor.
self._provided_cursor = cursor
self._parent = parent
self._block_type = block_type
# If this an `st.form` block, this will get filled in.
self._form_data: Optional[FormData] = None
# Change the module of all mixin'ed functions to be st.delta_generator,
# instead of the original module (e.g. st.elements.markdown)
for mixin in self.__class__.__bases__:
for (name, func) in mixin.__dict__.items():
if callable(func):
func.__module__ = self.__module__
def __repr__(self) -> str:
return util.repr_(self)
def __enter__(self):
# with block started
ctx = get_script_run_ctx()
if ctx:
ctx.dg_stack.append(self)
def __exit__(self, type, value, traceback):
# with block ended
ctx = get_script_run_ctx()
if ctx is not None:
ctx.dg_stack.pop()
# Re-raise any exceptions
return False
@property
def _active_dg(self) -> "DeltaGenerator":
"""Return the DeltaGenerator that's currently 'active'.
If we are the main DeltaGenerator, and are inside a `with` block that
creates a container, our active_dg is that container. Otherwise,
our active_dg is self.
"""
if self == self._main_dg:
# We're being invoked via an `st.foo` pattern - use the current
# `with` dg (aka the top of the stack).
ctx = get_script_run_ctx()
if ctx and len(ctx.dg_stack) > 0:
return ctx.dg_stack[-1]
# We're being invoked via an `st.sidebar.foo` pattern - ignore the
# current `with` dg.
return self
@property
def _main_dg(self) -> "DeltaGenerator":
"""Return this DeltaGenerator's root - that is, the top-level ancestor
DeltaGenerator that we belong to (this generally means the st._main
DeltaGenerator).
"""
return self._parent._main_dg if self._parent else self
def __getattr__(self, name):
import streamlit as st
streamlit_methods = [
method_name for method_name in dir(st) if callable(getattr(st, method_name))
]
def wrapper(*args, **kwargs):
if name in streamlit_methods:
if self._root_container == RootContainer.SIDEBAR:
message = (
"Method `%(name)s()` does not exist for "
"`st.sidebar`. Did you mean `st.%(name)s()`?" % {"name": name}
)
else:
message = (
"Method `%(name)s()` does not exist for "
"`DeltaGenerator` objects. Did you mean "
"`st.%(name)s()`?" % {"name": name}
)
else:
message = "`%(name)s()` is not a valid Streamlit command." % {
"name": name
}
raise StreamlitAPIException(message)
return wrapper
@property
def _parent_block_types(self) -> Iterable[str]:
"""Iterate all the block types used by this DeltaGenerator and all
its ancestor DeltaGenerators.
"""
current_dg: Optional[DeltaGenerator] = self
while current_dg is not None:
if current_dg._block_type is not None:
yield current_dg._block_type
current_dg = current_dg._parent
@property
def _cursor(self) -> Optional[Cursor]:
"""Return our Cursor. This will be None if we're not running in a
ScriptThread - e.g., if we're running a "bare" script outside of
Streamlit.
"""
if self._provided_cursor is None:
return cursor.get_container_cursor(self._root_container)
else:
return self._provided_cursor
@property
def _is_top_level(self) -> bool:
return self._provided_cursor is None
def _get_delta_path_str(self) -> str:
"""Returns the element's delta path as a string like "[0, 2, 3, 1]".
This uniquely identifies the element's position in the front-end,
which allows (among other potential uses) the InMemoryFileManager to maintain
session-specific maps of InMemoryFile objects placed with their "coordinates".
This way, users can (say) use st.image with a stream of different images,
and Streamlit will expire the older images and replace them in place.
"""
# Operate on the active DeltaGenerator, in case we're in a `with` block.
dg = self._active_dg
return str(dg._cursor.delta_path) if dg._cursor is not None else "[]"
def _enqueue(
self,
delta_type,
element_proto,
return_value=None,
last_index=None,
element_width=None,
element_height=None,
):
"""Create NewElement delta, fill it, and enqueue it.
Parameters
----------
delta_type: string
The name of the streamlit method being called
element_proto: proto
The actual proto in the NewElement type e.g. Alert/Button/Slider
return_value: any or None
The value to return to the calling script (for widgets)
element_width : int or None
Desired width for the element
element_height : int or None
Desired height for the element
Returns
-------
DeltaGenerator or any
If this element is NOT an interactive widget, return a
DeltaGenerator that can be used to modify the newly-created
element. Otherwise, if the element IS a widget, return the
`return_value` parameter.
"""
# Operate on the active DeltaGenerator, in case we're in a `with` block.
dg = self._active_dg
# Warn if we're called from within a legacy @st.cache function
legacy_caching.maybe_show_cached_st_function_warning(dg, delta_type)
# Warn if we're called from within @st.memo or @st.singleton
caching.maybe_show_cached_st_function_warning(dg, delta_type)
# Warn if an element is being changed but the user isn't running the streamlit server.
st._maybe_print_use_warning()
# Some elements have a method.__name__ != delta_type in proto.
# This really matters for line_chart, bar_chart & area_chart,
# since add_rows() relies on method.__name__ == delta_type
# TODO: Fix for all elements (or the cache warning above will be wrong)
proto_type = delta_type
if proto_type in DELTA_TYPES_THAT_MELT_DATAFRAMES:
proto_type = "vega_lite_chart"
# Mirror the logic for arrow_ elements.
if proto_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES:
proto_type = "arrow_vega_lite_chart"
# Copy the marshalled proto into the overall msg proto
msg = ForwardMsg_pb2.ForwardMsg()
msg_el_proto = getattr(msg.delta.new_element, proto_type)
msg_el_proto.CopyFrom(element_proto)
# Only enqueue message and fill in metadata if there's a container.
msg_was_enqueued = False
if dg._root_container is not None and dg._cursor is not None:
msg.metadata.delta_path[:] = dg._cursor.delta_path
if element_width is not None:
msg.metadata.element_dimension_spec.width = element_width
if element_height is not None:
msg.metadata.element_dimension_spec.height = element_height
_enqueue_message(msg)
msg_was_enqueued = True
if msg_was_enqueued:
# Get a DeltaGenerator that is locked to the current element
# position.
new_cursor = (
dg._cursor.get_locked_cursor(
delta_type=delta_type, last_index=last_index
)
if dg._cursor is not None
else None
)
output_dg = DeltaGenerator(
root_container=dg._root_container,
cursor=new_cursor,
parent=dg,
)
else:
# If the message was not enqueued, just return self since it's a
# no-op from the point of view of the app.
output_dg = dg
return _value_or_dg(return_value, output_dg)
def _block(self, block_proto=Block_pb2.Block()) -> "DeltaGenerator":
# Operate on the active DeltaGenerator, in case we're in a `with` block.
dg = self._active_dg
# Prevent nested columns & expanders by checking all parents.
block_type = block_proto.WhichOneof("type")
# Convert the generator to a list, so we can use it multiple times.
parent_block_types = frozenset(dg._parent_block_types)
if block_type == "column" and block_type in parent_block_types:
raise StreamlitAPIException(
"Columns may not be nested inside other columns."
)
if block_type == "expandable" and block_type in parent_block_types:
raise StreamlitAPIException(
"Expanders may not be nested inside other expanders."
)
if dg._root_container is None or dg._cursor is None:
return dg
msg = ForwardMsg_pb2.ForwardMsg()
msg.metadata.delta_path[:] = dg._cursor.delta_path
msg.delta.add_block.CopyFrom(block_proto)
# Normally we'd return a new DeltaGenerator that uses the locked cursor
# below. But in this case we want to return a DeltaGenerator that uses
# a brand new cursor for this new block we're creating.
block_cursor = cursor.RunningCursor(
root_container=dg._root_container,
parent_path=dg._cursor.parent_path + (dg._cursor.index,),
)
block_dg = DeltaGenerator(
root_container=dg._root_container,
cursor=block_cursor,
parent=dg,
block_type=block_type,
)
# Blocks inherit their parent form ids.
# NOTE: Container form ids aren't set in proto.
block_dg._form_data = FormData(current_form_id(dg))
# Must be called to increment this cursor's index.
dg._cursor.get_locked_cursor(last_index=None)
_enqueue_message(msg)
return block_dg
def _legacy_add_rows(self, data=None, **kwargs):
"""Concatenate a dataframe to the bottom of the current one.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict,
or None
Table to concat. Optional.
**kwargs : pandas.DataFrame, numpy.ndarray, Iterable, dict, or None
The named dataset to concat. Optional. You can only pass in 1
dataset (including the one in the data parameter).
Example
-------
>>> df1 = pd.DataFrame(
... np.random.randn(50, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> my_table = st._legacy_table(df1)
>>>
>>> df2 = pd.DataFrame(
... np.random.randn(50, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> my_table._legacy_add_rows(df2)
>>> # Now the table shown in the Streamlit app contains the data for
>>> # df1 followed by the data for df2.
You can do the same thing with plots. For example, if you want to add
more data to a line chart:
>>> # Assuming df1 and df2 from the example above still exist...
>>> my_chart = st._legacy_line_chart(df1)
>>> my_chart._legacy_add_rows(df2)
>>> # Now the chart shown in the Streamlit app contains the data for
>>> # df1 followed by the data for df2.
And for plots whose datasets are named, you can pass the data with a
keyword argument where the key is the name:
>>> my_chart = st._legacy_vega_lite_chart({
... 'mark': 'line',
... 'encoding': {'x': 'a', 'y': 'b'},
... 'datasets': {
... 'some_fancy_name': df1, # <-- named dataset
... },
... 'data': {'name': 'some_fancy_name'},
... }),
>>> my_chart._legacy_add_rows(some_fancy_name=df2) # <-- name used as keyword
"""
if self._root_container is None or self._cursor is None:
return self
if not self._cursor.is_locked:
raise StreamlitAPIException("Only existing elements can `add_rows`.")
# Accept syntax st._legacy_add_rows(df).
if data is not None and len(kwargs) == 0:
name = ""
# Accept syntax st._legacy_add_rows(foo=df).
elif len(kwargs) == 1:
name, data = kwargs.popitem()
# Raise error otherwise.
else:
raise StreamlitAPIException(
"Wrong number of arguments to add_rows()."
"Command requires exactly one dataset"
)
# When doing _legacy_add_rows on an element that does not already have data
# (for example, st._legacy_ine_chart() without any args), call the original
# st._legacy_foo() element with new data instead of doing an _legacy_add_rows().
if (
self._cursor.props["delta_type"] in DELTA_TYPES_THAT_MELT_DATAFRAMES
and self._cursor.props["last_index"] is None
):
# IMPORTANT: This assumes delta types and st method names always
# match!
# delta_type doesn't have any prefix, but st_method_name starts with "_legacy_".
st_method_name = "_legacy_" + self._cursor.props["delta_type"]
st_method = getattr(self, st_method_name)
st_method(data, **kwargs)
return
data, self._cursor.props["last_index"] = _maybe_melt_data_for_add_rows(
data, self._cursor.props["delta_type"], self._cursor.props["last_index"]
)
msg = ForwardMsg_pb2.ForwardMsg()
msg.metadata.delta_path[:] = self._cursor.delta_path
import streamlit.elements.legacy_data_frame as data_frame
data_frame.marshall_data_frame(data, msg.delta.add_rows.data)
if name:
msg.delta.add_rows.name = name
msg.delta.add_rows.has_name = True
_enqueue_message(msg)
return self
def _arrow_add_rows(self, data=None, **kwargs):
"""Concatenate a dataframe to the bottom of the current one.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict, or None
Table to concat. Optional.
**kwargs : pandas.DataFrame, numpy.ndarray, Iterable, dict, or None
The named dataset to concat. Optional. You can only pass in 1
dataset (including the one in the data parameter).
Example
-------
>>> df1 = pd.DataFrame(
... np.random.randn(50, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> my_table = st._arrow_table(df1)
>>>
>>> df2 = pd.DataFrame(
... np.random.randn(50, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> my_table._arrow_add_rows(df2)
>>> # Now the table shown in the Streamlit app contains the data for
>>> # df1 followed by the data for df2.
You can do the same thing with plots. For example, if you want to add
more data to a line chart:
>>> # Assuming df1 and df2 from the example above still exist...
>>> my_chart = st._arrow_line_chart(df1)
>>> my_chart._arrow_add_rows(df2)
>>> # Now the chart shown in the Streamlit app contains the data for
>>> # df1 followed by the data for df2.
And for plots whose datasets are named, you can pass the data with a
keyword argument where the key is the name:
>>> my_chart = st._arrow_vega_lite_chart({
... 'mark': 'line',
... 'encoding': {'x': 'a', 'y': 'b'},
... 'datasets': {
... 'some_fancy_name': df1, # <-- named dataset
... },
... 'data': {'name': 'some_fancy_name'},
... }),
>>> my_chart._arrow_add_rows(some_fancy_name=df2) # <-- name used as keyword
"""
if self._root_container is None or self._cursor is None:
return self
if not self._cursor.is_locked:
raise StreamlitAPIException("Only existing elements can `add_rows`.")
# Accept syntax st._arrow_add_rows(df).
if data is not None and len(kwargs) == 0:
name = ""
# Accept syntax st._arrow_add_rows(foo=df).
elif len(kwargs) == 1:
name, data = kwargs.popitem()
# Raise error otherwise.
else:
raise StreamlitAPIException(
"Wrong number of arguments to add_rows()."
"Command requires exactly one dataset"
)
# When doing _arrow_add_rows on an element that does not already have data
# (for example, st._arrow_line_chart() without any args), call the original
# st._arrow_foo() element with new data instead of doing a _arrow_add_rows().
if (
self._cursor.props["delta_type"] in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
and self._cursor.props["last_index"] is None
):
# IMPORTANT: This assumes delta types and st method names always
# match!
# delta_type starts with "arrow_", but st_method_name starts with "_arrow_".
st_method_name = "_" + self._cursor.props["delta_type"]
st_method = getattr(self, st_method_name)
st_method(data, **kwargs)
return
data, self._cursor.props["last_index"] = _maybe_melt_data_for_add_rows(
data, self._cursor.props["delta_type"], self._cursor.props["last_index"]
)
msg = ForwardMsg_pb2.ForwardMsg()
msg.metadata.delta_path[:] = self._cursor.delta_path
import streamlit.elements.arrow as arrow_proto
default_uuid = str(hash(self._get_delta_path_str()))
arrow_proto.marshall(msg.delta.arrow_add_rows.data, data, default_uuid)
if name:
msg.delta.arrow_add_rows.name = name
msg.delta.arrow_add_rows.has_name = True
_enqueue_message(msg)
return self
def _maybe_melt_data_for_add_rows(data, delta_type, last_index):
import pandas as pd
# For some delta types we have to reshape the data structure
# otherwise the input data and the actual data used
# by vega_lite will be different and it will throw an error.
if (
delta_type in DELTA_TYPES_THAT_MELT_DATAFRAMES
or delta_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
):
if not isinstance(data, pd.DataFrame):
data = type_util.convert_anything_to_df(data)
if type(data.index) is pd.RangeIndex:
old_step = _get_pandas_index_attr(data, "step")
# We have to drop the predefined index
data = data.reset_index(drop=True)
old_stop = _get_pandas_index_attr(data, "stop")
if old_step is None or old_stop is None:
raise StreamlitAPIException(
"'RangeIndex' object has no attribute 'step'"
)
start = last_index + old_step
stop = last_index + old_step + old_stop
data.index = pd.RangeIndex(start=start, stop=stop, step=old_step)
last_index = stop - 1
index_name = data.index.name
if index_name is None:
index_name = "index"
data = pd.melt(data.reset_index(), id_vars=[index_name])
return data, last_index
def _get_pandas_index_attr(data, attr):
return getattr(data.index, attr, None)
def _value_or_dg(value, dg):
"""Return either value, or None, or dg.
This is needed because Widgets have meaningful return values. This is
unlike other elements, which always return None. Then we internally replace
that None with a DeltaGenerator instance.
However, sometimes a widget may want to return None, and in this case it
should not be replaced by a DeltaGenerator. So we have a special NoValue
object that gets replaced by None.
"""
if value is NoValue:
return None
if value is None:
return dg
return value
def _enqueue_message(msg):
"""Enqueues a ForwardMsg proto to send to the app."""
ctx = get_script_run_ctx()
if ctx is None:
raise NoSessionContext()
ctx.enqueue(msg)

View File

@ -0,0 +1,21 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Variables for dev purposes.
The main purpose of this module (right now at least) is to avoid a dependency
cycle between streamlit.config and streamlit.logger.
"""
is_development_mode = False

View File

@ -0,0 +1,116 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import re
import textwrap
import traceback
from typing import List, Iterable, Optional
_SPACES_RE = re.compile("\\s*")
_EMPTY_LINE_RE = re.compile("\\s*\n")
@contextlib.contextmanager
def echo(code_location="above"):
"""Use in a `with` block to draw some code on the app, then execute it.
Parameters
----------
code_location : "above" or "below"
Whether to show the echoed code before or after the results of the
executed code block.
Example
-------
>>> with st.echo():
>>> st.write('This code will be printed')
"""
from streamlit import code, warning, empty, source_util
if code_location == "below":
show_code = code
show_warning = warning
else:
placeholder = empty()
show_code = placeholder.code
show_warning = placeholder.warning
try:
# Get stack frame *before* running the echoed code. The frame's
# line number will point to the `st.echo` statement we're running.
frame = traceback.extract_stack()[-3]
filename, start_line = frame.filename, frame.lineno
# Read the file containing the source code of the echoed statement.
with source_util.open_python_file(filename) as source_file:
source_lines = source_file.readlines()
# Get the indent of the first line in the echo block, skipping over any
# empty lines.
initial_indent = _get_initial_indent(source_lines[start_line:])
# Iterate over the remaining lines in the source file
# until we find one that's indented less than the rest of the
# block. That's our end line.
#
# Note that this is *not* a perfect strategy, because
# de-denting is not guaranteed to signal "end of block". (A
# triple-quoted string might be dedented but still in the
# echo block, for example.)
# TODO: rewrite this to parse the AST to get the *actual* end of the block.
lines_to_display: List[str] = []
for line in source_lines[start_line:]:
indent = _get_indent(line)
if indent is not None and indent < initial_indent:
break
lines_to_display.append(line)
code_string = textwrap.dedent("".join(lines_to_display))
# Run the echoed code...
yield
# And draw the code string to the app!
show_code(code_string, "python")
except FileNotFoundError as err:
show_warning("Unable to display code. %s" % err)
def _get_initial_indent(lines: Iterable[str]) -> int:
"""Return the indent of the first non-empty line in the list.
If all lines are empty, return 0.
"""
for line in lines:
indent = _get_indent(line)
if indent is not None:
return indent
return 0
def _get_indent(line: str) -> Optional[int]:
"""Get the number of whitespaces at the beginning of the given line.
If the line is empty, or if it contains just whitespace and a newline,
return None.
"""
if _EMPTY_LINE_RE.match(line) is not None:
return None
match = _SPACES_RE.match(line)
return match.end() if match is not None else 0

View File

@ -0,0 +1,13 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,98 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import streamlit
from streamlit.proto.Alert_pb2 import Alert as AlertProto
from .utils import clean_text
class AlertMixin:
def error(self, body):
"""Display error message.
Parameters
----------
body : str
The error text to display.
Example
-------
>>> st.error('This is an error')
"""
alert_proto = AlertProto()
alert_proto.body = clean_text(body)
alert_proto.format = AlertProto.ERROR
return self.dg._enqueue("alert", alert_proto)
def warning(self, body):
"""Display warning message.
Parameters
----------
body : str
The warning text to display.
Example
-------
>>> st.warning('This is a warning')
"""
alert_proto = AlertProto()
alert_proto.body = clean_text(body)
alert_proto.format = AlertProto.WARNING
return self.dg._enqueue("alert", alert_proto)
def info(self, body):
"""Display an informational message.
Parameters
----------
body : str
The info text to display.
Example
-------
>>> st.info('This is a purely informational message')
"""
alert_proto = AlertProto()
alert_proto.body = clean_text(body)
alert_proto.format = AlertProto.INFO
return self.dg._enqueue("alert", alert_proto)
def success(self, body):
"""Display a success message.
Parameters
----------
body : str
The success text to display.
Example
-------
>>> st.success('This is a success message!')
"""
alert_proto = AlertProto()
alert_proto.body = clean_text(body)
alert_proto.format = AlertProto.SUCCESS
return self.dg._enqueue("alert", alert_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,406 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Union, cast
from numpy import ndarray
from pandas import DataFrame
from pandas.io.formats.style import Styler
import pyarrow as pa
import streamlit
from streamlit import type_util
from streamlit.proto.Arrow_pb2 import Arrow as ArrowProto
Data = Optional[
Union[DataFrame, Styler, pa.Table, ndarray, Iterable, Dict[str, List[Any]]]
]
class ArrowMixin:
def _arrow_dataframe(
self,
data: Data = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> "streamlit.delta_generator.DeltaGenerator":
"""Display a dataframe as an interactive table.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict, or None
The data to display.
If 'data' is a pandas.Styler, it will be used to style its
underyling DataFrame.
width : int or None
Desired width of the UI element expressed in pixels. If None, a
default width based on the page width is used.
height : int or None
Desired height of the UI element expressed in pixels. If None, a
default height is used.
Examples
--------
>>> df = pd.DataFrame(
... np.random.randn(50, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> st._arrow_dataframe(df)
>>> st._arrow_dataframe(df, 200, 100)
You can also pass a Pandas Styler object to change the style of
the rendered DataFrame:
>>> df = pd.DataFrame(
... np.random.randn(10, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> st._arrow_dataframe(df.style.highlight_max(axis=0))
"""
# If pandas.Styler uuid is not provided, a hash of the position
# of the element will be used. This will cause a rerender of the table
# when the position of the element is changed.
delta_path = self.dg._get_delta_path_str()
default_uuid = str(hash(delta_path))
proto = ArrowProto()
marshall(proto, data, default_uuid)
return cast(
"streamlit.delta_generator.DeltaGenerator",
self.dg._enqueue(
"arrow_data_frame", proto, element_width=width, element_height=height
),
)
def _arrow_table(
self, data: Data = None
) -> "streamlit.delta_generator.DeltaGenerator":
"""Display a static table.
This differs from `st._arrow_dataframe` in that the table in this case is
static: its entire contents are laid out directly on the page.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict, or None
The table data.
Example
-------
>>> df = pd.DataFrame(
... np.random.randn(10, 5),
... columns=("col %d" % i for i in range(5)))
...
>>> st._arrow_table(df)
"""
# If pandas.Styler uuid is not provided, a hash of the position
# of the element will be used. This will cause a rerender of the table
# when the position of the element is changed.
delta_path = self.dg._get_delta_path_str()
default_uuid = str(hash(delta_path))
proto = ArrowProto()
marshall(proto, data, default_uuid)
return cast(
"streamlit.delta_generator.DeltaGenerator",
self.dg._enqueue("arrow_table", proto),
)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall(proto: ArrowProto, data: Data, default_uuid: Optional[str] = None) -> None:
"""Marshall pandas.DataFrame into an Arrow proto.
Parameters
----------
proto : proto.Arrow
Output. The protobuf for Streamlit Arrow proto.
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict, or None
Something that is or can be converted to a dataframe.
default_uuid : Optional[str]
If pandas.Styler UUID is not provided, this value will be used.
This attribute is optional and only used for pandas.Styler, other elements
(e.g. charts) can ignore it.
"""
if type_util.is_pandas_styler(data):
# default_uuid is a string only if the data is a `Styler`,
# and `None` otherwise.
assert isinstance(
default_uuid, str
), "Default UUID must be a string for Styler data."
_marshall_styler(proto, data, default_uuid)
if isinstance(data, pa.Table):
proto.data = type_util.pyarrow_table_to_bytes(data)
else:
df = type_util.convert_anything_to_df(data)
proto.data = type_util.data_frame_to_bytes(df)
def _marshall_styler(proto: ArrowProto, styler: Styler, default_uuid: str) -> None:
"""Marshall pandas.Styler into an Arrow proto.
Parameters
----------
proto : proto.Arrow
Output. The protobuf for Streamlit Arrow proto.
styler : pandas.Styler
Helps style a DataFrame or Series according to the data with HTML and CSS.
default_uuid : str
If pandas.Styler uuid is not provided, this value will be used.
"""
# pandas.Styler uuid should be set before _compute is called.
_marshall_uuid(proto, styler, default_uuid)
# We're using protected members of pandas.Styler to get styles,
# which is not ideal and could break if the interface changes.
styler._compute()
# In Pandas 1.3.0, styler._translate() signature was changed.
# 2 arguments were added: sparse_index and sparse_columns.
# The functionality that they provide is not yet supported.
if type_util.is_pandas_version_less_than("1.3.0"):
pandas_styles = styler._translate()
else:
pandas_styles = styler._translate(False, False)
_marshall_caption(proto, styler)
_marshall_styles(proto, styler, pandas_styles)
_marshall_display_values(proto, styler.data, pandas_styles)
def _marshall_uuid(proto: ArrowProto, styler: Styler, default_uuid: str) -> None:
"""Marshall pandas.Styler uuid into an Arrow proto.
Parameters
----------
proto : proto.Arrow
Output. The protobuf for Streamlit Arrow proto.
styler : pandas.Styler
Helps style a DataFrame or Series according to the data with HTML and CSS.
default_uuid : str
If pandas.Styler uuid is not provided, this value will be used.
"""
if styler.uuid is None:
styler.set_uuid(default_uuid)
proto.styler.uuid = str(styler.uuid)
def _marshall_caption(proto: ArrowProto, styler: Styler) -> None:
"""Marshall pandas.Styler caption into an Arrow proto.
Parameters
----------
proto : proto.Arrow
Output. The protobuf for Streamlit Arrow proto.
styler : pandas.Styler
Helps style a DataFrame or Series according to the data with HTML and CSS.
"""
if styler.caption is not None:
proto.styler.caption = styler.caption
def _marshall_styles(proto: ArrowProto, styler: Styler, styles: Dict[str, Any]) -> None:
"""Marshall pandas.Styler styles into an Arrow proto.
Parameters
----------
proto : proto.Arrow
Output. The protobuf for Streamlit Arrow proto.
styler : pandas.Styler
Helps style a DataFrame or Series according to the data with HTML and CSS.
styles : dict
pandas.Styler translated styles.
"""
css_rules = []
if "table_styles" in styles:
table_styles = styles["table_styles"]
table_styles = _trim_pandas_styles(table_styles)
for style in table_styles:
# styles in "table_styles" have a space
# between the uuid and selector.
rule = _pandas_style_to_css(
"table_styles", style, styler.uuid, separator=" "
)
css_rules.append(rule)
if "cellstyle" in styles:
cellstyle = styles["cellstyle"]
cellstyle = _trim_pandas_styles(cellstyle)
for style in cellstyle:
rule = _pandas_style_to_css("cell_style", style, styler.uuid)
css_rules.append(rule)
if len(css_rules) > 0:
proto.styler.styles = "\n".join(css_rules)
def _trim_pandas_styles(styles: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filter out empty styles.
Every cell will have a class, but the list of props
may just be [['', '']].
Parameters
----------
styles : list
pandas.Styler translated styles.
"""
return [x for x in styles if any(any(y) for y in x["props"])]
def _pandas_style_to_css(
style_type: str,
style: Dict[str, Any],
uuid: str,
separator: str = "",
) -> str:
"""Convert pandas.Styler translated style to CSS.
Parameters
----------
style_type : str
Either "table_styles" or "cell_style".
style : dict
pandas.Styler translated style.
uuid : str
pandas.Styler uuid.
separator : str
A string separator used between table and cell selectors.
"""
declarations = []
for css_property, css_value in style["props"]:
declaration = css_property.strip() + ": " + css_value.strip()
declarations.append(declaration)
table_selector = f"#T_{uuid}"
# In pandas < 1.1.0
# translated_style["cellstyle"] has the following shape:
# [
# {
# "props": [["color", " black"], ["background-color", "orange"], ["", ""]],
# "selector": "row0_col0"
# }
# ...
# ]
#
# In pandas >= 1.1.0
# translated_style["cellstyle"] has the following shape:
# [
# {
# "props": [("color", " black"), ("background-color", "orange"), ("", "")],
# "selectors": ["row0_col0"]
# }
# ...
# ]
if style_type == "table_styles" or (
style_type == "cell_style" and type_util.is_pandas_version_less_than("1.1.0")
):
cell_selectors = [style["selector"]]
else:
cell_selectors = style["selectors"]
selectors = []
for cell_selector in cell_selectors:
selectors.append(table_selector + separator + cell_selector)
selector = ", ".join(selectors)
declaration_block = "; ".join(declarations)
rule_set = selector + " { " + declaration_block + " }"
return rule_set
def _marshall_display_values(
proto: ArrowProto, df: DataFrame, styles: Dict[str, Any]
) -> None:
"""Marshall pandas.Styler display values into an Arrow proto.
Parameters
----------
proto : proto.Arrow
Output. The protobuf for Streamlit Arrow proto.
df : pandas.DataFrame
A dataframe with original values.
styles : dict
pandas.Styler translated styles.
"""
new_df = _use_display_values(df, styles)
proto.styler.display_values = type_util.data_frame_to_bytes(new_df)
def _use_display_values(df: DataFrame, styles: Dict[str, Any]) -> DataFrame:
"""Create a new pandas.DataFrame where display values are used instead of original ones.
Parameters
----------
df : pandas.DataFrame
A dataframe with original values.
styles : dict
pandas.Styler translated styles.
"""
import re
# If values in a column are not of the same type, Arrow
# serialization would fail. Thus, we need to cast all values
# of the dataframe to strings before assigning them display values.
new_df = df.astype(str)
cell_selector_regex = re.compile(r"row(\d+)_col(\d+)")
if "body" in styles:
rows = styles["body"]
for row in rows:
for cell in row:
match = cell_selector_regex.match(cell["id"])
if match:
r, c = map(int, match.groups())
new_df.iat[r, c] = str(cell["display_value"])
return new_df

View File

@ -0,0 +1,380 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper around Altair.
Altair is a Python visualization library based on Vega-Lite,
a nice JSON schema for expressing graphs and charts."""
from datetime import date
from enum import Enum
from typing import cast
import altair as alt
import pandas as pd
from altair.vegalite.v4.api import Chart
import streamlit
import streamlit.elements.arrow_vega_lite as arrow_vega_lite
from streamlit import type_util
from streamlit.proto.ArrowVegaLiteChart_pb2 import (
ArrowVegaLiteChart as ArrowVegaLiteChartProto,
)
from .arrow import Data
from .utils import last_index_for_melted_dataframes
class ChartType(Enum):
AREA = "area"
BAR = "bar"
LINE = "line"
class ArrowAltairMixin:
def _arrow_line_chart(
self,
data: Data = None,
width: int = 0,
height: int = 0,
use_container_width: bool = True,
) -> "streamlit.delta_generator.DeltaGenerator":
"""Display a line chart.
This is syntax-sugar around st._arrow_altair_chart. The main difference
is this command uses the data's own column and indices to figure out
the chart's spec. As a result this is easier to use for many "just plot
this" scenarios, while being less customizable.
If st._arrow_line_chart does not guess the data specification
correctly, try specifying your desired chart using st._arrow_altair_chart.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict or None
Data to be plotted.
width : int
The chart width in pixels. If 0, selects the width automatically.
height : int
The chart height in pixels. If 0, selects the height automatically.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the width argument.
Example
-------
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
... columns=['a', 'b', 'c'])
...
>>> st._arrow_line_chart(chart_data)
.. output::
https://static.streamlit.io/0.50.0-td2L/index.html?id=BdxXG3MmrVBfJyqS2R2ki8
height: 220px
"""
proto = ArrowVegaLiteChartProto()
chart = _generate_chart(ChartType.LINE, data, width, height)
marshall(proto, chart, use_container_width)
last_index = last_index_for_melted_dataframes(data)
return cast(
"streamlit.delta_generator.DeltaGenerator",
self.dg._enqueue("arrow_line_chart", proto, last_index=last_index),
)
def _arrow_area_chart(
self,
data: Data = None,
width: int = 0,
height: int = 0,
use_container_width: bool = True,
) -> "streamlit.delta_generator.DeltaGenerator":
"""Display an area chart.
This is just syntax-sugar around st._arrow_altair_chart. The main difference
is this command uses the data's own column and indices to figure out
the chart's spec. As a result this is easier to use for many "just plot
this" scenarios, while being less customizable.
If st._arrow_area_chart does not guess the data specification
correctly, try specifying your desired chart using st._arrow_altair_chart.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, or dict
Data to be plotted.
width : int
The chart width in pixels. If 0, selects the width automatically.
height : int
The chart height in pixels. If 0, selects the height automatically.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the width argument.
Example
-------
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
... columns=['a', 'b', 'c'])
...
>>> st._arrow_area_chart(chart_data)
.. output::
https://static.streamlit.io/0.50.0-td2L/index.html?id=Pp65STuFj65cJRDfhGh4Jt
height: 220px
"""
proto = ArrowVegaLiteChartProto()
chart = _generate_chart(ChartType.AREA, data, width, height)
marshall(proto, chart, use_container_width)
last_index = last_index_for_melted_dataframes(data)
return cast(
"streamlit.delta_generator.DeltaGenerator",
self.dg._enqueue("arrow_area_chart", proto, last_index=last_index),
)
def _arrow_bar_chart(
self,
data: Data = None,
width: int = 0,
height: int = 0,
use_container_width: bool = True,
) -> "streamlit.delta_generator.DeltaGenerator":
"""Display a bar chart.
This is just syntax-sugar around st._arrow_altair_chart. The main difference
is this command uses the data's own column and indices to figure out
the chart's spec. As a result this is easier to use for many "just plot
this" scenarios, while being less customizable.
If st._arrow_bar_chart does not guess the data specification
correctly, try specifying your desired chart using st._arrow_altair_chart.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, or dict
Data to be plotted.
width : int
The chart width in pixels. If 0, selects the width automatically.
height : int
The chart height in pixels. If 0, selects the height automatically.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the width argument.
Example
-------
>>> chart_data = pd.DataFrame(
... np.random.randn(50, 3),
... columns=["a", "b", "c"])
...
>>> st._arrow_bar_chart(chart_data)
.. output::
https://static.streamlit.io/0.66.0-2BLtg/index.html?id=GaYDn6vxskvBUkBwsGVEaL
height: 220px
"""
proto = ArrowVegaLiteChartProto()
chart = _generate_chart(ChartType.BAR, data, width, height)
marshall(proto, chart, use_container_width)
last_index = last_index_for_melted_dataframes(data)
return cast(
"streamlit.delta_generator.DeltaGenerator",
self.dg._enqueue("arrow_bar_chart", proto, last_index=last_index),
)
def _arrow_altair_chart(
self, altair_chart: Chart, use_container_width: bool = False
) -> "streamlit.delta_generator.DeltaGenerator":
"""Display a chart using the Altair library.
Parameters
----------
altair_chart : altair.vegalite.v2.api.Chart
The Altair chart object to display.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over Altair's native `width` value.
Example
-------
>>> import pandas as pd
>>> import numpy as np
>>> import altair as alt
>>>
>>> df = pd.DataFrame(
... np.random.randn(200, 3),
... columns=['a', 'b', 'c'])
...
>>> c = alt.Chart(df).mark_circle().encode(
... x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c'])
>>>
>>> st._arrow_altair_chart(c, use_container_width=True)
.. output::
https://static.streamlit.io/0.25.0-2JkNY/index.html?id=8jmmXR8iKoZGV4kXaKGYV5
height: 200px
Examples of Altair charts can be found at
https://altair-viz.github.io/gallery/.
"""
proto = ArrowVegaLiteChartProto()
marshall(
proto,
altair_chart,
use_container_width=use_container_width,
)
return cast(
"streamlit.delta_generator.DeltaGenerator",
self.dg._enqueue("arrow_vega_lite_chart", proto),
)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def _is_date_column(df: pd.DataFrame, name: str) -> bool:
"""True if the column with the given name stores datetime.date values.
This function just checks the first value in the given column, so
it's meaningful only for columns whose values all share the same type.
Parameters
----------
df : pd.DataFrame
name : str
The column name
Returns
-------
bool
"""
column = df[name]
if column.size == 0:
return False
return isinstance(column[0], date)
def _generate_chart(
chart_type: ChartType, data: Data, width: int = 0, height: int = 0
) -> Chart:
"""This function uses the chart's type, data columns and indices to figure out the chart's spec."""
if data is None:
# Use an empty-ish dict because if we use None the x axis labels rotate
# 90 degrees. No idea why. Need to debug.
data = {"": []}
if not isinstance(data, pd.DataFrame):
data = type_util.convert_anything_to_df(data)
index_name = data.index.name
if index_name is None:
index_name = "index"
data = pd.melt(data.reset_index(), id_vars=[index_name])
if chart_type == ChartType.AREA:
opacity = {"value": 0.7}
else:
opacity = {"value": 1.0}
# Set the X and Y axes' scale to "utc" if they contain date values.
# This causes time data to be displayed in UTC, rather the user's local
# time zone. (By default, vega-lite displays time data in the browser's
# local time zone, regardless of which time zone the data specifies:
# https://vega.github.io/vega-lite/docs/timeunit.html#output).
x_scale = (
alt.Scale(type="utc") if _is_date_column(data, index_name) else alt.Undefined
)
y_scale = alt.Scale(type="utc") if _is_date_column(data, "value") else alt.Undefined
x_type = alt.Undefined
# Bar charts should have a discrete (ordinal) x-axis, UNLESS type is date/time
# https://github.com/streamlit/streamlit/pull/2097#issuecomment-714802475
if chart_type == ChartType.BAR and not _is_date_column(data, index_name):
x_type = "ordinal"
chart = (
getattr(
alt.Chart(data, width=width, height=height), "mark_" + chart_type.value
)()
.encode(
alt.X(index_name, title="", scale=x_scale, type=x_type),
alt.Y("value", title="", scale=y_scale),
alt.Color("variable", title="", type="nominal"),
alt.Tooltip([index_name, "value", "variable"]),
opacity=opacity,
)
.interactive()
)
return chart
def marshall(
vega_lite_chart: ArrowVegaLiteChartProto,
altair_chart: Chart,
use_container_width: bool = False,
**kwargs,
):
"""Marshall chart's data into proto."""
import altair as alt
# Normally altair_chart.to_dict() would transform the dataframe used by the
# chart into an array of dictionaries. To avoid that, we install a
# transformer that replaces datasets with a reference by the object id of
# the dataframe. We then fill in the dataset manually later on.
datasets = {}
def id_transform(data):
"""Altair data transformer that returns a fake named dataset with the
object id."""
datasets[id(data)] = data
return {"name": str(id(data))}
alt.data_transformers.register("id", id_transform)
with alt.data_transformers.enable("id"):
chart_dict = altair_chart.to_dict()
# Put datasets back into the chart dict but note how they weren't
# transformed.
chart_dict["datasets"] = datasets
arrow_vega_lite.marshall(
vega_lite_chart,
chart_dict,
use_container_width=use_container_width,
**kwargs,
)

View File

@ -0,0 +1,207 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper around Vega-Lite."""
import json
from typing import Any, Dict, Optional, cast
import streamlit
import streamlit.elements.lib.dicttools as dicttools
from streamlit.logger import get_logger
from streamlit.proto.ArrowVegaLiteChart_pb2 import (
ArrowVegaLiteChart as ArrowVegaLiteChartProto,
)
from . import arrow
from .arrow import Data
LOGGER = get_logger(__name__)
class ArrowVegaLiteMixin:
def _arrow_vega_lite_chart(
self,
data: Data = None,
spec: Optional[Dict[str, Any]] = None,
use_container_width: bool = False,
**kwargs,
) -> "streamlit.delta_generator.DeltaGenerator":
"""Display a chart using the Vega-Lite library.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict, or None
Either the data to be plotted or a Vega-Lite spec containing the
data (which more closely follows the Vega-Lite API).
spec : dict or None
The Vega-Lite spec for the chart. If the spec was already passed in
the previous argument, this must be set to None. See
https://vega.github.io/vega-lite/docs/ for more info.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over Vega-Lite's native `width` value.
**kwargs : any
Same as spec, but as keywords.
Example
-------
>>> import pandas as pd
>>> import numpy as np
>>>
>>> df = pd.DataFrame(
... np.random.randn(200, 3),
... columns=['a', 'b', 'c'])
>>>
>>> st._arrow_vega_lite_chart(df, {
... 'mark': {'type': 'circle', 'tooltip': True},
... 'encoding': {
... 'x': {'field': 'a', 'type': 'quantitative'},
... 'y': {'field': 'b', 'type': 'quantitative'},
... 'size': {'field': 'c', 'type': 'quantitative'},
... 'color': {'field': 'c', 'type': 'quantitative'},
... },
... })
Examples of Vega-Lite usage without Streamlit can be found at
https://vega.github.io/vega-lite/examples/. Most of those can be easily
translated to the syntax shown above.
"""
proto = ArrowVegaLiteChartProto()
marshall(
proto,
data,
spec,
use_container_width=use_container_width,
**kwargs,
)
return cast(
"streamlit.delta_generator.DeltaGenerator",
self.dg._enqueue("arrow_vega_lite_chart", proto),
)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall(
proto: ArrowVegaLiteChartProto,
data: Data = None,
spec: Optional[Dict[str, Any]] = None,
use_container_width: bool = False,
**kwargs,
):
"""Construct a Vega-Lite chart object.
See DeltaGenerator.vega_lite_chart for docs.
"""
# Support passing data inside spec['datasets'] and spec['data'].
# (The data gets pulled out of the spec dict later on.)
if isinstance(data, dict) and spec is None:
spec = data
data = None
# Support passing no spec arg, but filling it with kwargs.
# Example:
# marshall(proto, baz='boz')
if spec is None:
spec = dict()
else:
# Clone the spec dict, since we may be mutating it.
spec = dict(spec)
# Support passing in kwargs. Example:
# marshall(proto, {foo: 'bar'}, baz='boz')
if len(kwargs):
# Merge spec with unflattened kwargs, where kwargs take precedence.
# This only works for string keys, but kwarg keys are strings anyways.
spec = dict(spec, **dicttools.unflatten(kwargs, _CHANNELS))
if len(spec) == 0:
raise ValueError("Vega-Lite charts require a non-empty spec dict.")
if "autosize" not in spec:
spec["autosize"] = {"type": "fit", "contains": "padding"}
# Pull data out of spec dict when it's in a 'datasets' key:
# marshall(proto, {datasets: {foo: df1, bar: df2}, ...})
if "datasets" in spec:
for k, v in spec["datasets"].items():
dataset = proto.datasets.add()
dataset.name = str(k)
dataset.has_name = True
arrow.marshall(dataset.data, v)
del spec["datasets"]
# Pull data out of spec dict when it's in a top-level 'data' key:
# marshall(proto, {data: df})
# marshall(proto, {data: {values: df, ...}})
# marshall(proto, {data: {url: 'url'}})
# marshall(proto, {data: {name: 'foo'}})
if "data" in spec:
data_spec = spec["data"]
if isinstance(data_spec, dict):
if "values" in data_spec:
data = data_spec["values"]
del spec["data"]
else:
data = data_spec
del spec["data"]
proto.spec = json.dumps(spec)
proto.use_container_width = use_container_width
if data is not None:
arrow.marshall(proto.data, data)
# See https://vega.github.io/vega-lite/docs/encoding.html
_CHANNELS = set(
[
"x",
"y",
"x2",
"y2",
"xError",
"yError2",
"xError",
"yError2",
"longitude",
"latitude",
"color",
"opacity",
"fillOpacity",
"strokeOpacity",
"strokeWidth",
"size",
"shape",
"text",
"tooltip",
"href",
"key",
"order",
"detail",
"facet",
"row",
"column",
]
)

View File

@ -0,0 +1,39 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import streamlit
from streamlit.proto.Balloons_pb2 import Balloons as BalloonsProto
class BalloonsMixin:
def balloons(self):
"""Draw celebratory balloons.
Example
-------
>>> st.balloons()
...then watch your app and get ready for a celebration!
"""
balloons_proto = BalloonsProto()
balloons_proto.show = True
return self.dg._enqueue("balloons", balloons_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,104 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper around Bokeh."""
import hashlib
import json
from typing import cast
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.BokehChart_pb2 import BokehChart as BokehChartProto
ST_BOKEH_VERSION = "2.4.1"
class BokehMixin:
def bokeh_chart(self, figure, use_container_width=False):
"""Display an interactive Bokeh chart.
Bokeh is a charting library for Python. The arguments to this function
closely follow the ones for Bokeh's `show` function. You can find
more about Bokeh at https://bokeh.pydata.org.
To show Bokeh charts in Streamlit, call `st.bokeh_chart`
wherever you would call Bokeh's `show`.
Parameters
----------
figure : bokeh.plotting.figure.Figure
A Bokeh figure to plot.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over Bokeh's native `width` value.
Example
-------
>>> import streamlit as st
>>> from bokeh.plotting import figure
>>>
>>> x = [1, 2, 3, 4, 5]
>>> y = [6, 7, 2, 4, 5]
>>>
>>> p = figure(
... title='simple line example',
... x_axis_label='x',
... y_axis_label='y')
...
>>> p.line(x, y, legend_label='Trend', line_width=2)
>>>
>>> st.bokeh_chart(p, use_container_width=True)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.bokeh_chart.py
height: 700px
"""
import bokeh
if bokeh.__version__ != ST_BOKEH_VERSION:
raise StreamlitAPIException(
f"Streamlit only supports Bokeh version {ST_BOKEH_VERSION}, "
f"but you have version {bokeh.__version__} installed. Please "
f"run `pip install --force-reinstall --no-deps bokeh=="
f"{ST_BOKEH_VERSION}` to install the correct version."
)
# Generate element ID from delta path
delta_path = self.dg._get_delta_path_str()
element_id = hashlib.md5(delta_path.encode()).hexdigest()
bokeh_chart_proto = BokehChartProto()
marshall(bokeh_chart_proto, figure, use_container_width, element_id)
return self.dg._enqueue("bokeh_chart", bokeh_chart_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall(proto, figure, use_container_width, element_id):
"""Construct a Bokeh chart object.
See DeltaGenerator.bokeh_chart for docs.
"""
from bokeh.embed import json_item
data = json_item(figure)
proto.figure = json.dumps(data)
proto.use_container_width = use_container_width
proto.element_id = element_id

View File

@ -0,0 +1,386 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.type_util import Key, to_key
from typing import cast, Optional, Union, BinaryIO, TextIO
from textwrap import dedent
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Button_pb2 import Button as ButtonProto
from streamlit.in_memory_file_manager import in_memory_file_manager
from streamlit.proto.DownloadButton_pb2 import DownloadButton as DownloadButtonProto
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from .form import current_form_id, is_in_form
from .utils import check_callback_rules, check_session_state_rules
FORM_DOCS_INFO = """
For more information, refer to the
[documentation for forms](https://docs.streamlit.io/library/api-reference/control-flow/st.form).
"""
DownloadButtonDataType = Union[str, bytes, TextIO, BinaryIO]
class ButtonMixin:
def button(
self,
label: str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_click: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> bool:
"""Display a button widget.
Parameters
----------
label : str
A short label explaining to the user what this button is for.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed when the button is
hovered over.
on_click : callable
An optional callback invoked when this button is clicked.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the button if set to True. The
default is False. This argument can only be supplied by keyword.
Returns
-------
bool
True if the button was clicked on the last run of the app,
False otherwise.
Example
-------
>>> if st.button('Say hello'):
... st.write('Why hello there')
... else:
... st.write('Goodbye')
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.button.py
height: 220px
"""
key = to_key(key)
ctx = get_script_run_ctx()
return self.dg._button(
label,
key,
help,
is_form_submitter=False,
on_click=on_click,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def download_button(
self,
label: str,
data: DownloadButtonDataType,
file_name: Optional[str] = None,
mime: Optional[str] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_click: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> bool:
"""Display a download button widget.
This is useful when you would like to provide a way for your users
to download a file directly from your app.
Note that the data to be downloaded is stored in-memory while the
user is connected, so it's a good idea to keep file sizes under a
couple hundred megabytes to conserve memory.
Parameters
----------
label : str
A short label explaining to the user what this button is for.
data : str or bytes or file
The contents of the file to be downloaded. See example below for
caching techniques to avoid recomputing this data unnecessarily.
file_name: str
An optional string to use as the name of the file to be downloaded,
such as 'my_file.csv'. If not specified, the name will be
automatically generated.
mime : str or None
The MIME type of the data. If None, defaults to "text/plain"
(if data is of type *str* or is a textual *file*) or
"application/octet-stream" (if data is of type *bytes* or is a
binary *file*).
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed when the button is
hovered over.
on_click : callable
An optional callback invoked when this button is clicked.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the download button if set to
True. The default is False. This argument can only be supplied by
keyword.
Returns
-------
bool
True if the button was clicked on the last run of the app,
False otherwise.
Examples
--------
Download a large DataFrame as a CSV:
>>> @st.cache
... def convert_df(df):
... # IMPORTANT: Cache the conversion to prevent computation on every rerun
... return df.to_csv().encode('utf-8')
>>>
>>> csv = convert_df(my_large_df)
>>>
>>> st.download_button(
... label="Download data as CSV",
... data=csv,
... file_name='large_df.csv',
... mime='text/csv',
... )
Download a string as a file:
>>> text_contents = '''This is some text'''
>>> st.download_button('Download some text', text_contents)
Download a binary file:
>>> binary_contents = b'example content'
>>> # Defaults to 'application/octet-stream'
>>> st.download_button('Download binary file', binary_contents)
Download an image:
>>> with open("flower.png", "rb") as file:
... btn = st.download_button(
... label="Download image",
... data=file,
... file_name="flower.png",
... mime="image/png"
... )
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.download_button.py
height: 335px
"""
ctx = get_script_run_ctx()
return self._download_button(
label=label,
data=data,
file_name=file_name,
mime=mime,
key=key,
help=help,
on_click=on_click,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _download_button(
self,
label: str,
data: DownloadButtonDataType,
file_name: Optional[str] = None,
mime: Optional[str] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_click: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> bool:
key = to_key(key)
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
if is_in_form(self.dg):
raise StreamlitAPIException(
f"`st.download_button()` can't be used in an `st.form()`.{FORM_DOCS_INFO}"
)
download_button_proto = DownloadButtonProto()
download_button_proto.label = label
download_button_proto.default = False
marshall_file(
self.dg._get_delta_path_str(), data, download_button_proto, mime, file_name
)
if help is not None:
download_button_proto.help = dedent(help)
def deserialize_button(ui_value, widget_id=""):
return ui_value or False
current_value, _ = register_widget(
"download_button",
download_button_proto,
user_key=key,
on_change_handler=on_click,
args=args,
kwargs=kwargs,
deserializer=deserialize_button,
serializer=bool,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
download_button_proto.disabled = disabled
self.dg._enqueue("download_button", download_button_proto)
return cast(bool, current_value)
def _button(
self,
label: str,
key: Optional[str],
help: Optional[str],
is_form_submitter: bool,
on_click: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> bool:
if not is_form_submitter:
check_callback_rules(self.dg, on_click)
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
# It doesn't make sense to create a button inside a form (except
# for the "Form Submitter" button that's automatically created in
# every form). We throw an error to warn the user about this.
# We omit this check for scripts running outside streamlit, because
# they will have no script_run_ctx.
if streamlit._is_running_with_streamlit:
if is_in_form(self.dg) and not is_form_submitter:
raise StreamlitAPIException(
f"`st.button()` can't be used in an `st.form()`.{FORM_DOCS_INFO}"
)
elif not is_in_form(self.dg) and is_form_submitter:
raise StreamlitAPIException(
f"`st.form_submit_button()` must be used inside an `st.form()`.{FORM_DOCS_INFO}"
)
button_proto = ButtonProto()
button_proto.label = label
button_proto.default = False
button_proto.is_form_submitter = is_form_submitter
button_proto.form_id = current_form_id(self.dg)
if help is not None:
button_proto.help = dedent(help)
def deserialize_button(ui_value: bool, widget_id: str = "") -> bool:
return ui_value or False
current_value, _ = register_widget(
"button",
button_proto,
user_key=key,
on_change_handler=on_click,
args=args,
kwargs=kwargs,
deserializer=deserialize_button,
serializer=bool,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
button_proto.disabled = disabled
self.dg._enqueue("button", button_proto)
return cast(bool, current_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall_file(coordinates, data, proto_download_button, mimetype, file_name=None):
if isinstance(data, str):
data = data.encode()
mimetype = mimetype or "text/plain"
elif isinstance(data, io.TextIOWrapper):
string_data = data.read()
data = string_data.encode()
mimetype = mimetype or "text/plain"
# Assume bytes; try methods until we run out.
elif isinstance(data, bytes):
mimetype = mimetype or "application/octet-stream"
elif isinstance(data, io.BytesIO):
data.seek(0)
data = data.getvalue()
mimetype = mimetype or "application/octet-stream"
elif isinstance(data, io.RawIOBase) or isinstance(data, io.BufferedReader):
data.seek(0)
data = data.read()
mimetype = mimetype or "application/octet-stream"
else:
raise RuntimeError("Invalid binary data format: %s" % type(data))
this_file = in_memory_file_manager.add(
data, mimetype, coordinates, file_name=file_name, is_for_static_download=True
)
proto_download_button.url = this_file.url

View File

@ -0,0 +1,236 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.type_util import Key, to_key
from textwrap import dedent
from typing import Optional, cast, List
import streamlit
from streamlit.proto.CameraInput_pb2 import (
CameraInput as CameraInputProto,
)
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from ..proto.Common_pb2 import (
FileUploaderState as FileUploaderStateProto,
UploadedFileInfo as UploadedFileInfoProto,
)
from ..uploaded_file_manager import UploadedFile, UploadedFileRec
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
SomeUploadedSnapshotFile = Optional[UploadedFile]
class CameraInputMixin:
def camera_input(
self,
label: str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> SomeUploadedSnapshotFile:
"""Display a widget that returns pictures from the user's webcam.
Parameters
----------
label : str
A short label explaining to the user what this widget is used for.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
A tooltip that gets displayed next to the camera input.
on_change : callable
An optional callback invoked when this camera_input's value
changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the camera input if set to
True. The default is False. This argument can only be supplied by
keyword.
Returns
-------
None or UploadedFile
The UploadedFile class is a subclass of BytesIO, and therefore
it is "file-like". This means you can pass them anywhere where
a file is expected.
Examples
--------
>>> import streamlit as st
>>>
>>> picture = st.camera_input("Take a picture")
>>>
>>> if picture:
... st.image(picture)
"""
ctx = get_script_run_ctx()
return self._camera_input(
label=label,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _camera_input(
self,
label: str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> SomeUploadedSnapshotFile:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
camera_input_proto = CameraInputProto()
camera_input_proto.label = label
camera_input_proto.form_id = current_form_id(self.dg)
if help is not None:
camera_input_proto.help = dedent(help)
def serialize_camera_image_input(
snapshot: SomeUploadedSnapshotFile,
) -> FileUploaderStateProto:
state_proto = FileUploaderStateProto()
ctx = get_script_run_ctx()
if ctx is None:
return state_proto
# ctx.uploaded_file_mgr._file_id_counter stores the id to use for
# the *next* uploaded file, so the current highest file id is the
# counter minus 1.
state_proto.max_file_id = ctx.uploaded_file_mgr._file_id_counter - 1
if not snapshot:
return state_proto
file_info: UploadedFileInfoProto = state_proto.uploaded_file_info.add()
file_info.id = snapshot.id
file_info.name = snapshot.name
file_info.size = snapshot.size
return state_proto
def deserialize_camera_image_input(
ui_value: Optional[FileUploaderStateProto], widget_id: str
) -> SomeUploadedSnapshotFile:
file_recs = self._get_file_recs_for_camera_input_widget(widget_id, ui_value)
if len(file_recs) == 0:
return_value = None
else:
return_value = UploadedFile(file_recs[0])
return return_value
widget_value, _ = register_widget(
"camera_input",
camera_input_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_camera_image_input,
serializer=serialize_camera_image_input,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
camera_input_proto.disabled = disabled
ctx = get_script_run_ctx()
camera_image_input_state = serialize_camera_image_input(widget_value)
uploaded_shapshot_info = camera_image_input_state.uploaded_file_info
if ctx is not None and len(uploaded_shapshot_info) != 0:
newest_file_id = camera_image_input_state.max_file_id
active_file_ids = [f.id for f in uploaded_shapshot_info]
ctx.uploaded_file_mgr.remove_orphaned_files(
session_id=ctx.session_id,
widget_id=camera_input_proto.id,
newest_file_id=newest_file_id,
active_file_ids=active_file_ids,
)
self.dg._enqueue("camera_input", camera_input_proto)
return cast(SomeUploadedSnapshotFile, widget_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
@staticmethod
def _get_file_recs_for_camera_input_widget(
widget_id: str, widget_value: Optional[FileUploaderStateProto]
) -> List[UploadedFileRec]:
if widget_value is None:
return []
ctx = get_script_run_ctx()
if ctx is None:
return []
uploaded_file_info = widget_value.uploaded_file_info
if len(uploaded_file_info) == 0:
return []
active_file_ids = [f.id for f in uploaded_file_info]
# Grab the files that correspond to our active file IDs.
return ctx.uploaded_file_mgr.get_files(
session_id=ctx.session_id,
widget_id=widget_id,
file_ids=active_file_ids,
)

View File

@ -0,0 +1,155 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.type_util import Key, to_key
from textwrap import dedent
from typing import cast, Optional
import streamlit
from streamlit.proto.Checkbox_pb2 import Checkbox as CheckboxProto
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
class CheckboxMixin:
def checkbox(
self,
label: str,
value: bool = False,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> bool:
"""Display a checkbox widget.
Parameters
----------
label : str
A short label explaining to the user what this checkbox is for.
value : bool
Preselect the checkbox when it first renders. This will be
cast to bool internally.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the checkbox.
on_change : callable
An optional callback invoked when this checkbox's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the checkbox if set to True.
The default is False. This argument can only be supplied by keyword.
Returns
-------
bool
Whether or not the checkbox is checked.
Example
-------
>>> agree = st.checkbox('I agree')
>>>
>>> if agree:
... st.write('Great!')
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.checkbox.py
height: 220px
"""
ctx = get_script_run_ctx()
return self._checkbox(
label=label,
value=value,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _checkbox(
self,
label: str,
value: bool = False,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> bool:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(
default_value=None if value is False else value, key=key
)
checkbox_proto = CheckboxProto()
checkbox_proto.label = label
checkbox_proto.default = bool(value)
checkbox_proto.form_id = current_form_id(self.dg)
if help is not None:
checkbox_proto.help = dedent(help)
def deserialize_checkbox(ui_value: Optional[bool], widget_id: str = "") -> bool:
return bool(ui_value if ui_value is not None else value)
current_value, set_frontend_value = register_widget(
"checkbox",
checkbox_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_checkbox,
serializer=bool,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
checkbox_proto.disabled = disabled
if set_frontend_value:
checkbox_proto.value = current_value
checkbox_proto.set_value = True
self.dg._enqueue("checkbox", checkbox_proto)
return cast(bool, current_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,183 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.type_util import Key, to_key
from textwrap import dedent
from typing import Optional, cast
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ColorPicker_pb2 import ColorPicker as ColorPickerProto
from streamlit.state import register_widget
from streamlit.state import (
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
class ColorPickerMixin:
def color_picker(
self,
label: str,
value: Optional[str] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> str:
"""Display a color picker widget.
Parameters
----------
label : str
A short label explaining to the user what this input is for.
value : str
The hex value of this widget when it first renders. If None,
defaults to black.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the color picker.
on_change : callable
An optional callback invoked when this color_picker's value
changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the color picker if set to
True. The default is False. This argument can only be supplied by
keyword.
Returns
-------
str
The selected color as a hex string.
Example
-------
>>> color = st.color_picker('Pick A Color', '#00f900')
>>> st.write('The current color is', color)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.color_picker.py
height: 335px
"""
ctx = get_script_run_ctx()
return self._color_picker(
label=label,
value=value,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _color_picker(
self,
label: str,
value: Optional[str] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> str:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=value, key=key)
# set value default
if value is None:
value = "#000000"
# make sure the value is a string
if not isinstance(value, str):
raise StreamlitAPIException(
"""
Color Picker Value has invalid type: %s. Expects a hex string
like '#00FFAA' or '#000'.
"""
% type(value).__name__
)
# validate the value and expects a hex string
match = re.match(r"^#(?:[0-9a-fA-F]{3}){1,2}$", value)
if not match:
raise StreamlitAPIException(
"""
'%s' is not a valid hex code for colors. Valid ones are like
'#00FFAA' or '#000'.
"""
% value
)
color_picker_proto = ColorPickerProto()
color_picker_proto.label = label
color_picker_proto.default = str(value)
color_picker_proto.form_id = current_form_id(self.dg)
if help is not None:
color_picker_proto.help = dedent(help)
def deserialize_color_picker(
ui_value: Optional[str], widget_id: str = ""
) -> str:
return str(ui_value if ui_value is not None else value)
current_value, set_frontend_value = register_widget(
"color_picker",
color_picker_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_color_picker,
serializer=str,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
color_picker_proto.disabled = disabled
if set_frontend_value:
color_picker_proto.value = current_value
color_picker_proto.set_value = True
self.dg._enqueue("color_picker", color_picker_proto)
return cast(str, current_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,441 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Selects between our two DataFrame serialization methods ("legacy" and
"arrow") based on a config option"""
from typing import cast
import streamlit
from streamlit import config
def _use_arrow() -> bool:
"""True if we're using Apache Arrow for DataFrame serialization."""
# Explicitly coerce to bool here because mypy is (incorrectly) complaining
# that we're trying to return 'Any'.
return bool(config.get_option("global.dataFrameSerialization") == "arrow")
class DataFrameSelectorMixin:
def dataframe(self, data=None, width=None, height=None):
"""Display a dataframe as an interactive table.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict, or None
The data to display.
If 'data' is a pandas.Styler, it will be used to style its
underyling DataFrame. Streamlit supports custom cell
values and colors. (It does not support some of the more exotic
pandas styling features, like bar charts, hovering, and captions.)
Styler support is experimental!
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
To use pyarrow tables, please enable pyarrow by changing the config setting,
`config.dataFrameSerialization = "arrow"`.
width : int or None
Desired width of the UI element expressed in pixels. If None, a
default width based on the page width is used.
height : int or None
Desired height of the UI element expressed in pixels. If None, a
default height is used.
Examples
--------
>>> df = pd.DataFrame(
... np.random.randn(50, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> st.dataframe(df) # Same as st.write(df)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/data.dataframe.py
height: 410px
>>> st.dataframe(df, 200, 100)
You can also pass a Pandas Styler object to change the style of
the rendered DataFrame:
>>> df = pd.DataFrame(
... np.random.randn(10, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> st.dataframe(df.style.highlight_max(axis=0))
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/data.dataframe1.py
height: 410px
"""
if _use_arrow():
return self.dg._arrow_dataframe(data, width, height)
else:
return self.dg._legacy_dataframe(data, width, height)
def table(self, data=None):
"""Display a static table.
This differs from `st.dataframe` in that the table in this case is
static: its entire contents are laid out directly on the page.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict, or None
The table data.
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
To use pyarrow tables, please enable pyarrow by changing the config setting,
`config.dataFrameSerialization = "arrow"`.
Example
-------
>>> df = pd.DataFrame(
... np.random.randn(10, 5),
... columns=('col %d' % i for i in range(5)))
...
>>> st.table(df)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/data.table.py
height: 480px
"""
if _use_arrow():
return self.dg._arrow_table(data)
else:
return self.dg._legacy_table(data)
def line_chart(self, data=None, width=0, height=0, use_container_width=True):
"""Display a line chart.
This is syntax-sugar around st.altair_chart. The main difference
is this command uses the data's own column and indices to figure out
the chart's spec. As a result this is easier to use for many "just plot
this" scenarios, while being less customizable.
If st.line_chart does not guess the data specification
correctly, try specifying your desired chart using st.altair_chart.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict or None
Data to be plotted.
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
To use pyarrow tables, please enable pyarrow by changing the config setting,
`config.dataFrameSerialization = "arrow"`.
width : int
The chart width in pixels. If 0, selects the width automatically.
height : int
The chart height in pixels. If 0, selects the height automatically.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the width argument.
Example
-------
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
... columns=['a', 'b', 'c'])
...
>>> st.line_chart(chart_data)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.line_chart.py
height: 400px
"""
if _use_arrow():
return self.dg._arrow_line_chart(data, width, height, use_container_width)
else:
return self.dg._legacy_line_chart(data, width, height, use_container_width)
def area_chart(self, data=None, width=0, height=0, use_container_width=True):
"""Display an area chart.
This is just syntax-sugar around st.altair_chart. The main difference
is this command uses the data's own column and indices to figure out
the chart's spec. As a result this is easier to use for many "just plot
this" scenarios, while being less customizable.
If st.area_chart does not guess the data specification
correctly, try specifying your desired chart using st.altair_chart.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, or dict
Data to be plotted.
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
To use pyarrow tables, please enable pyarrow by changing the config setting,
`config.dataFrameSerialization = "arrow"`.
width : int
The chart width in pixels. If 0, selects the width automatically.
height : int
The chart height in pixels. If 0, selects the height automatically.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the width argument.
Example
-------
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
... columns=['a', 'b', 'c'])
...
>>> st.area_chart(chart_data)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.area_chart.py
height: 400px
"""
if _use_arrow():
return self.dg._arrow_area_chart(data, width, height, use_container_width)
else:
return self.dg._legacy_area_chart(data, width, height, use_container_width)
def bar_chart(self, data=None, width=0, height=0, use_container_width=True):
"""Display a bar chart.
This is just syntax-sugar around st.altair_chart. The main difference
is this command uses the data's own column and indices to figure out
the chart's spec. As a result this is easier to use for many "just plot
this" scenarios, while being less customizable.
If st.bar_chart does not guess the data specification
correctly, try specifying your desired chart using st.altair_chart.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, or dict
Data to be plotted.
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
To use pyarrow tables, please enable pyarrow by changing the config setting,
`config.dataFrameSerialization = "arrow"`.
width : int
The chart width in pixels. If 0, selects the width automatically.
height : int
The chart height in pixels. If 0, selects the height automatically.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the width argument.
Example
-------
>>> chart_data = pd.DataFrame(
... np.random.randn(50, 3),
... columns=["a", "b", "c"])
...
>>> st.bar_chart(chart_data)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.bar_chart.py
height: 400px
"""
if _use_arrow():
return self.dg._arrow_bar_chart(data, width, height, use_container_width)
else:
return self.dg._legacy_bar_chart(data, width, height, use_container_width)
def altair_chart(self, altair_chart, use_container_width=False):
"""Display a chart using the Altair library.
Parameters
----------
altair_chart : altair.vegalite.v2.api.Chart
The Altair chart object to display.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over Altair's native `width` value.
Example
-------
>>> import pandas as pd
>>> import numpy as np
>>> import altair as alt
>>>
>>> df = pd.DataFrame(
... np.random.randn(200, 3),
... columns=['a', 'b', 'c'])
...
>>> c = alt.Chart(df).mark_circle().encode(
... x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c'])
>>>
>>> st.altair_chart(c, use_container_width=True)
Examples of Altair charts can be found at
https://altair-viz.github.io/gallery/.
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.vega_lite_chart.py
height: 300px
"""
if _use_arrow():
return self.dg._arrow_altair_chart(altair_chart, use_container_width)
else:
return self.dg._legacy_altair_chart(altair_chart, use_container_width)
def vega_lite_chart(
self,
data=None,
spec=None,
use_container_width=False,
**kwargs,
):
"""Display a chart using the Vega-Lite library.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict, or None
Either the data to be plotted or a Vega-Lite spec containing the
data (which more closely follows the Vega-Lite API).
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
To use pyarrow tables, please enable pyarrow by changing the config setting,
`config.dataFrameSerialization = "arrow"`.
spec : dict or None
The Vega-Lite spec for the chart. If the spec was already passed in
the previous argument, this must be set to None. See
https://vega.github.io/vega-lite/docs/ for more info.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over Vega-Lite's native `width` value.
**kwargs : any
Same as spec, but as keywords.
Example
-------
>>> import pandas as pd
>>> import numpy as np
>>>
>>> df = pd.DataFrame(
... np.random.randn(200, 3),
... columns=['a', 'b', 'c'])
>>>
>>> st.vega_lite_chart(df, {
... 'mark': {'type': 'circle', 'tooltip': True},
... 'encoding': {
... 'x': {'field': 'a', 'type': 'quantitative'},
... 'y': {'field': 'b', 'type': 'quantitative'},
... 'size': {'field': 'c', 'type': 'quantitative'},
... 'color': {'field': 'c', 'type': 'quantitative'},
... },
... })
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.vega_lite_chart.py
height: 300px
Examples of Vega-Lite usage without Streamlit can be found at
https://vega.github.io/vega-lite/examples/. Most of those can be easily
translated to the syntax shown above.
"""
if _use_arrow():
return self.dg._arrow_vega_lite_chart(
data, spec, use_container_width, **kwargs
)
else:
return self.dg._legacy_vega_lite_chart(
data, spec, use_container_width, **kwargs
)
def add_rows(self, data=None, **kwargs):
"""Concatenate a dataframe to the bottom of the current one.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, Iterable, dict, or None
Table to concat. Optional.
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
To use pyarrow tables, please enable pyarrow by changing the config setting,
`config.dataFrameSerialization = "arrow"`.
**kwargs : pandas.DataFrame, numpy.ndarray, Iterable, dict, or None
The named dataset to concat. Optional. You can only pass in 1
dataset (including the one in the data parameter).
Example
-------
>>> df1 = pd.DataFrame(
... np.random.randn(50, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> my_table = st.table(df1)
>>>
>>> df2 = pd.DataFrame(
... np.random.randn(50, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> my_table.add_rows(df2)
>>> # Now the table shown in the Streamlit app contains the data for
>>> # df1 followed by the data for df2.
You can do the same thing with plots. For example, if you want to add
more data to a line chart:
>>> # Assuming df1 and df2 from the example above still exist...
>>> my_chart = st.line_chart(df1)
>>> my_chart.add_rows(df2)
>>> # Now the chart shown in the Streamlit app contains the data for
>>> # df1 followed by the data for df2.
And for plots whose datasets are named, you can pass the data with a
keyword argument where the key is the name:
>>> my_chart = st.vega_lite_chart({
... 'mark': 'line',
... 'encoding': {'x': 'a', 'y': 'b'},
... 'datasets': {
... 'some_fancy_name': df1, # <-- named dataset
... },
... 'data': {'name': 'some_fancy_name'},
... }),
>>> my_chart.add_rows(some_fancy_name=df2) # <-- name used as keyword
"""
if _use_arrow():
return self.dg._arrow_add_rows(data, **kwargs)
else:
return self.dg._legacy_add_rows(data, **kwargs)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,117 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast, Any, Dict
import streamlit
import json
from streamlit.proto.DeckGlJsonChart_pb2 import DeckGlJsonChart as PydeckProto
class PydeckMixin:
def pydeck_chart(self, pydeck_obj=None, use_container_width=False):
"""Draw a chart using the PyDeck library.
This supports 3D maps, point clouds, and more! More info about PyDeck
at https://deckgl.readthedocs.io/en/latest/.
These docs are also quite useful:
- DeckGL docs: https://github.com/uber/deck.gl/tree/master/docs
- DeckGL JSON docs: https://github.com/uber/deck.gl/tree/master/modules/json
When using this command, we advise all users to use a personal Mapbox
token. This ensures the map tiles used in this chart are more
robust. You can do this with the mapbox.token config option.
To get a token for yourself, create an account at
https://mapbox.com. It's free! (for moderate usage levels). For more info
on how to set config options, see
https://docs.streamlit.io/library/advanced-features/configuration#set-configuration-options
Parameters
----------
spec: pydeck.Deck or None
Object specifying the PyDeck chart to draw.
Example
-------
Here's a chart using a HexagonLayer and a ScatterplotLayer on top of
the light map style:
>>> df = pd.DataFrame(
... np.random.randn(1000, 2) / [50, 50] + [37.76, -122.4],
... columns=['lat', 'lon'])
>>>
>>> st.pydeck_chart(pdk.Deck(
... map_style='mapbox://styles/mapbox/light-v9',
... initial_view_state=pdk.ViewState(
... latitude=37.76,
... longitude=-122.4,
... zoom=11,
... pitch=50,
... ),
... layers=[
... pdk.Layer(
... 'HexagonLayer',
... data=df,
... get_position='[lon, lat]',
... radius=200,
... elevation_scale=4,
... elevation_range=[0, 1000],
... pickable=True,
... extruded=True,
... ),
... pdk.Layer(
... 'ScatterplotLayer',
... data=df,
... get_position='[lon, lat]',
... get_color='[200, 30, 0, 160]',
... get_radius=200,
... ),
... ],
... ))
.. output::
https://static.streamlit.io/0.25.0-2JkNY/index.html?id=ASTdExBpJ1WxbGceneKN1i
height: 530px
"""
pydeck_proto = PydeckProto()
marshall(pydeck_proto, pydeck_obj, use_container_width)
return self.dg._enqueue("deck_gl_json_chart", pydeck_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
# Map used when no data is passed.
EMPTY_MAP: Dict[str, Any] = {
"initialViewState": {"latitude": 0, "longitude": 0, "pitch": 0, "zoom": 1}
}
def marshall(pydeck_proto, pydeck_obj, use_container_width):
if pydeck_obj is None:
spec = json.dumps(EMPTY_MAP)
else:
spec = pydeck_obj.to_json()
pydeck_proto.json = spec
pydeck_proto.use_container_width = use_container_width
if pydeck_obj is not None and isinstance(pydeck_obj.deck_widget.tooltip, dict):
pydeck_proto.tooltip = json.dumps(pydeck_obj.deck_widget.tooltip)

View File

@ -0,0 +1,166 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Allows us to create and absorb changes (aka Deltas) to elements."""
import inspect
from typing import cast
import streamlit
from streamlit.proto.DocString_pb2 import DocString as DocStringProto
from streamlit.logger import get_logger
LOGGER = get_logger(__name__)
CONFUSING_STREAMLIT_MODULES = (
"streamlit.echo",
"streamlit.delta_generator",
"streamlit.legacy_caching.caching",
)
CONFUSING_STREAMLIT_SIG_PREFIXES = ("(element, ",)
class HelpMixin:
def help(self, obj):
"""Display object's doc string, nicely formatted.
Displays the doc string for this object.
Parameters
----------
obj : Object
The object whose docstring should be displayed.
Example
-------
Don't remember how to initialize a dataframe? Try this:
>>> st.help(pandas.DataFrame)
Want to quickly check what datatype is output by a certain function?
Try:
>>> x = my_poorly_documented_function()
>>> st.help(x)
"""
doc_string_proto = DocStringProto()
_marshall(doc_string_proto, obj)
return self.dg._enqueue("doc_string", doc_string_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def _marshall(doc_string_proto, obj):
"""Construct a DocString object.
See DeltaGenerator.help for docs.
"""
try:
doc_string_proto.name = obj.__name__
except AttributeError:
pass
module_name = getattr(obj, "__module__", None)
if module_name in CONFUSING_STREAMLIT_MODULES:
doc_string_proto.module = "streamlit"
elif module_name is not None:
doc_string_proto.module = module_name
else:
# Leave doc_string_proto.module as an empty string (default value).
pass
obj_type = type(obj)
doc_string_proto.type = str(obj_type)
if callable(obj):
doc_string_proto.signature = _get_signature(obj)
doc_string = inspect.getdoc(obj)
# Sometimes an object has no docstring, but the object's type does.
# If that's the case here, use the type's docstring.
# For objects where type is type we do not print the docs.
# We also do not print the docs for functions and methods if
# the docstring is empty.
if (
doc_string is None
and obj_type is not type
and not inspect.isfunction(obj)
and not inspect.ismethod(obj)
):
doc_string = inspect.getdoc(obj_type)
if doc_string is None:
doc_string = "No docs available."
doc_string_proto.doc_string = doc_string
def _get_signature(f):
is_delta_gen = False
try:
is_delta_gen = f.__module__ == "streamlit.delta_generator"
if is_delta_gen:
# DeltaGenerator functions are doubly wrapped, and their function
# signatures are useless unless we unwrap them.
f = _unwrap_decorated_func(f)
# Functions such as numpy.minimum don't have a __module__ attribute,
# since we're only using it to check if its a DeltaGenerator, its ok
# to continue
except AttributeError:
pass
sig = ""
try:
sig = str(inspect.signature(f))
except ValueError:
# f is a builtin.
pass
if is_delta_gen:
for prefix in CONFUSING_STREAMLIT_SIG_PREFIXES:
if sig.startswith(prefix):
sig = sig.replace(prefix, "(")
break
return sig
def _unwrap_decorated_func(f):
if hasattr(f, "__wrapped__"):
try:
while getattr(f, "__wrapped__"):
contents = f.__wrapped__
if not callable(contents):
break
f = contents
return f
except AttributeError:
pass
# Fall back to original function, though it's unlikely we'll reach
# this part of the code.
return f

View File

@ -0,0 +1,71 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import streamlit
from streamlit.proto.Empty_pb2 import Empty as EmptyProto
class EmptyMixin:
def empty(self):
"""Insert a single-element container.
Inserts a container into your app that can be used to hold a single element.
This allows you to, for example, remove elements at any point, or replace
several elements at once (using a child multi-element container).
To insert/replace/clear an element on the returned container, you can
use "with" notation or just call methods directly on the returned object.
See examples below.
Examples
--------
Overwriting elements in-place using "with" notation:
>>> import time
>>>
>>> with st.empty():
... for seconds in range(60):
... st.write(f"{seconds} seconds have passed")
... time.sleep(1)
... st.write("✔️ 1 minute over!")
Replacing several elements, then clearing them:
>>> placeholder = st.empty()
>>>
>>> # Replace the placeholder with some text:
>>> placeholder.text("Hello")
>>>
>>> # Replace the text with a chart:
>>> placeholder.line_chart({"data": [1, 5, 2, 6]})
>>>
>>> # Replace the chart with several elements:
>>> with placeholder.container():
... st.write("This is one element")
... st.write("This is another")
...
>>> # Clear all those elements:
>>> placeholder.empty()
"""
empty_proto = EmptyProto()
return self.dg._enqueue("empty", empty_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,250 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import traceback
import typing
from typing import Optional, cast, List
import streamlit
from streamlit.proto.Exception_pb2 import Exception as ExceptionProto
from streamlit.errors import MarkdownFormattedException
from streamlit.errors import StreamlitAPIException
from streamlit.errors import StreamlitAPIWarning
from streamlit.errors import StreamlitDeprecationWarning
from streamlit.errors import UncaughtAppException
from streamlit.logger import get_logger
LOGGER = get_logger(__name__)
# When client.showErrorDetails is False, we show a generic warning in the
# frontend when we encounter an uncaught app exception.
_GENERIC_UNCAUGHT_EXCEPTION_TEXT = "This app has encountered an error. The original error message is redacted to prevent data leaks. Full error details have been recorded in the logs (if you're on Streamlit Cloud, click on 'Manage app' in the lower right of your app)."
# Extract the streamlit package path
_STREAMLIT_DIR = os.path.dirname(streamlit.__file__)
# Make it absolute, resolve aliases, and ensure there's a trailing path
# separator
_STREAMLIT_DIR = os.path.join(os.path.realpath(_STREAMLIT_DIR), "")
class ExceptionMixin:
def exception(self, exception):
"""Display an exception.
Parameters
----------
exception : Exception
The exception to display.
Example
-------
>>> e = RuntimeError('This is an exception of type RuntimeError')
>>> st.exception(e)
"""
exception_proto = ExceptionProto()
marshall(exception_proto, exception)
return self.dg._enqueue("exception", exception_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall(exception_proto: ExceptionProto, exception: BaseException) -> None:
"""Marshalls an Exception.proto message.
Parameters
----------
exception_proto : Exception.proto
The Exception protobuf to fill out
exception : BaseException
The exception whose data we're extracting
"""
# If this is a StreamlitAPIException, we prune all Streamlit entries
# from the exception's stack trace.
is_api_exception = isinstance(exception, StreamlitAPIException)
is_deprecation_exception = isinstance(exception, StreamlitDeprecationWarning)
is_markdown_exception = isinstance(exception, MarkdownFormattedException)
is_uncaught_app_exception = isinstance(exception, UncaughtAppException)
stack_trace = (
[]
if is_deprecation_exception
else _get_stack_trace_str_list(
exception, strip_streamlit_stack_entries=is_api_exception
)
)
# Some exceptions (like UserHashError) have an alternate_name attribute so
# we can pretend to the user that the exception is called something else.
if getattr(exception, "alternate_name", None) is not None:
exception_proto.type = getattr(exception, "alternate_name")
else:
exception_proto.type = type(exception).__name__
exception_proto.stack_trace.extend(stack_trace)
exception_proto.is_warning = isinstance(exception, Warning)
try:
if isinstance(exception, SyntaxError):
# SyntaxErrors have additional fields (filename, text, lineno,
# offset) that we can use for a nicely-formatted message telling
# the user what to fix.
exception_proto.message = _format_syntax_error_message(exception)
else:
exception_proto.message = str(exception).strip()
exception_proto.message_is_markdown = is_markdown_exception
except Exception as str_exception:
# Sometimes the exception's __str__/__unicode__ method itself
# raises an error.
exception_proto.message = ""
LOGGER.warning(
"""
Streamlit was unable to parse the data from an exception in the user's script.
This is usually due to a bug in the Exception object itself. Here is some info
about that Exception object, so you can report a bug to the original author:
Exception type:
%(etype)s
Problem:
%(str_exception)s
Traceback:
%(str_exception_tb)s
"""
% {
"etype": type(exception).__name__,
"str_exception": str_exception,
"str_exception_tb": "\n".join(_get_stack_trace_str_list(str_exception)),
}
)
if is_uncaught_app_exception:
uae = typing.cast(UncaughtAppException, exception)
exception_proto.message = _GENERIC_UNCAUGHT_EXCEPTION_TEXT
type_str = str(type(uae.exc))
exception_proto.type = type_str.replace("<class '", "").replace("'>", "")
def _format_syntax_error_message(exception: SyntaxError) -> str:
"""Returns a nicely formatted SyntaxError message that emulates
what the Python interpreter outputs, e.g.:
> File "raven.py", line 3
> st.write('Hello world!!'))
> ^
> SyntaxError: invalid syntax
"""
if exception.text:
if exception.offset is not None:
caret_indent = " " * max(exception.offset - 1, 0)
else:
caret_indent = ""
return (
'File "%(filename)s", line %(lineno)s\n'
" %(text)s\n"
" %(caret_indent)s^\n"
"%(errname)s: %(msg)s"
% {
"filename": exception.filename,
"lineno": exception.lineno,
"text": exception.text.rstrip(),
"caret_indent": caret_indent,
"errname": type(exception).__name__,
"msg": exception.msg,
}
)
# If a few edge cases, SyntaxErrors don't have all these nice fields. So we
# have a fall back here.
# Example edge case error message: encoding declaration in Unicode string
return str(exception)
def _get_stack_trace_str_list(
exception: BaseException, strip_streamlit_stack_entries: bool = False
) -> List[str]:
"""Get the stack trace for the given exception.
Parameters
----------
exception : BaseException
The exception to extract the traceback from
strip_streamlit_stack_entries : bool
If True, all traceback entries that are in the Streamlit package
will be removed from the list. We do this for exceptions that result
from incorrect usage of Streamlit APIs, so that the user doesn't see
a bunch of noise about ScriptRunner, DeltaGenerator, etc.
Returns
-------
list
The exception traceback as a list of strings
"""
extracted_traceback: Optional[traceback.StackSummary] = None
if isinstance(exception, StreamlitAPIWarning):
extracted_traceback = exception.tacked_on_stack
elif hasattr(exception, "__traceback__"):
extracted_traceback = traceback.extract_tb(exception.__traceback__)
if isinstance(exception, UncaughtAppException):
extracted_traceback = traceback.extract_tb(exception.exc.__traceback__)
# Format the extracted traceback and add it to the protobuf element.
if extracted_traceback is None:
stack_trace_str_list = [
"Cannot extract the stack trace for this exception. "
"Try calling exception() within the `catch` block."
]
else:
if strip_streamlit_stack_entries:
extracted_frames = _get_nonstreamlit_traceback(extracted_traceback)
stack_trace_str_list = traceback.format_list(extracted_frames)
else:
stack_trace_str_list = traceback.format_list(extracted_traceback)
stack_trace_str_list = [item.strip() for item in stack_trace_str_list]
return stack_trace_str_list
def _is_in_streamlit_package(file: str) -> bool:
"""True if the given file is part of the streamlit package."""
try:
common_prefix = os.path.commonprefix([os.path.realpath(file), _STREAMLIT_DIR])
except ValueError:
# Raised if paths are on different drives.
return False
return common_prefix == _STREAMLIT_DIR
def _get_nonstreamlit_traceback(
extracted_tb: traceback.StackSummary,
) -> List[traceback.FrameSummary]:
return [
entry for entry in extracted_tb if not _is_in_streamlit_package(entry.filename)
]

View File

@ -0,0 +1,397 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.type_util import Key, to_key
from typing import cast, overload, List, Optional, Union
from textwrap import dedent
import sys
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
import streamlit
from streamlit import config
from streamlit.logger import get_logger
from streamlit.proto.FileUploader_pb2 import FileUploader as FileUploaderProto
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from .form import current_form_id
from ..proto.Common_pb2 import (
FileUploaderState as FileUploaderStateProto,
UploadedFileInfo as UploadedFileInfoProto,
)
from ..uploaded_file_manager import UploadedFile, UploadedFileRec
from .utils import check_callback_rules, check_session_state_rules
LOGGER = get_logger(__name__)
SomeUploadedFiles = Optional[Union[UploadedFile, List[UploadedFile]]]
class FileUploaderMixin:
# Multiple overloads are defined on `file_uploader()` below to represent
# the different return types of `file_uploader()`.
# These return types differ according to the value of the `accept_multiple_files` argument.
# There are 2 associated variables, each with 2 options.
# 1. The `accept_multiple_files` argument is set as `True`,
# or it is set as `False` or omitted, in which case the default value `False`.
# 2. The `type` argument may or may not be provided as a keyword-only argument.
# There must be 2x2=4 overloads to cover all the possible arguments,
# as these overloads must be mutually exclusive for mypy.
# 1. type is given as not a keyword-only argument
# 2. accept_multiple_files = True
@overload
def file_uploader(
self,
label: str,
type: Optional[Union[str, List[str]]],
accept_multiple_files: Literal[True],
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*,
disabled: bool = False,
) -> Optional[List[UploadedFile]]:
...
# 1. type is given as not a keyword-only argument
# 2. accept_multiple_files = False or omitted
@overload
def file_uploader(
self,
label: str,
type: Optional[Union[str, List[str]]],
accept_multiple_files: Literal[False] = False,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*,
disabled: bool = False,
) -> Optional[UploadedFile]:
...
# The following 2 overloads represent the cases where
# the `type` argument is a keyword-only argument.
# See https://github.com/python/mypy/issues/4020#issuecomment-737600893
# for the related discussions and examples.
# 1. type is skipped or a keyword argument
# 2. accept_multiple_files = True
@overload
def file_uploader(
self,
label: str,
*,
accept_multiple_files: Literal[True],
type: Optional[Union[str, List[str]]] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
disabled: bool = False,
) -> Optional[List[UploadedFile]]:
...
# 1. type is skipped or a keyword argument
# 2. accept_multiple_files = False or omitted
@overload
def file_uploader(
self,
label: str,
*,
accept_multiple_files: Literal[False] = False,
type: Optional[Union[str, List[str]]] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
disabled: bool = False,
) -> Optional[UploadedFile]:
...
def file_uploader(
self,
label: str,
type: Optional[Union[str, List[str]]] = None,
accept_multiple_files: bool = False,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
):
"""Display a file uploader widget.
By default, uploaded files are limited to 200MB. You can configure
this using the `server.maxUploadSize` config option. For more info
on how to set config options, see
https://docs.streamlit.io/library/advanced-features/configuration#set-configuration-options
Parameters
----------
label : str
A short label explaining to the user what this file uploader is for.
type : str or list of str or None
Array of allowed extensions. ['png', 'jpg']
The default is None, which means all extensions are allowed.
accept_multiple_files : bool
If True, allows the user to upload multiple files at the same time,
in which case the return value will be a list of files.
Default: False
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
A tooltip that gets displayed next to the file uploader.
on_change : callable
An optional callback invoked when this file_uploader's value
changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the file uploader if set to
True. The default is False. This argument can only be supplied by
keyword.
Returns
-------
None or UploadedFile or list of UploadedFile
- If accept_multiple_files is False, returns either None or
an UploadedFile object.
- If accept_multiple_files is True, returns a list with the
uploaded files as UploadedFile objects. If no files were
uploaded, returns an empty list.
The UploadedFile class is a subclass of BytesIO, and therefore
it is "file-like". This means you can pass them anywhere where
a file is expected.
Examples
--------
Insert a file uploader that accepts a single file at a time:
>>> uploaded_file = st.file_uploader("Choose a file")
>>> if uploaded_file is not None:
... # To read file as bytes:
... bytes_data = uploaded_file.getvalue()
... st.write(bytes_data)
>>>
... # To convert to a string based IO:
... stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
... st.write(stringio)
>>>
... # To read file as string:
... string_data = stringio.read()
... st.write(string_data)
>>>
... # Can be used wherever a "file-like" object is accepted:
... dataframe = pd.read_csv(uploaded_file)
... st.write(dataframe)
Insert a file uploader that accepts multiple files at a time:
>>> uploaded_files = st.file_uploader("Choose a CSV file", accept_multiple_files=True)
>>> for uploaded_file in uploaded_files:
... bytes_data = uploaded_file.read()
... st.write("filename:", uploaded_file.name)
... st.write(bytes_data)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.file_uploader.py
height: 375px
"""
ctx = get_script_run_ctx()
return self._file_uploader(
label=label,
type=type,
accept_multiple_files=accept_multiple_files,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _file_uploader(
self,
label: str,
type: Optional[Union[str, List[str]]] = None,
accept_multiple_files: bool = False,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
):
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
if type:
if isinstance(type, str):
type = [type]
# May need a regex or a library to validate file types are valid
# extensions.
type = [
file_type if file_type[0] == "." else f".{file_type}"
for file_type in type
]
file_uploader_proto = FileUploaderProto()
file_uploader_proto.label = label
file_uploader_proto.type[:] = type if type is not None else []
file_uploader_proto.max_upload_size_mb = config.get_option(
"server.maxUploadSize"
)
file_uploader_proto.multiple_files = accept_multiple_files
file_uploader_proto.form_id = current_form_id(self.dg)
if help is not None:
file_uploader_proto.help = dedent(help)
def deserialize_file_uploader(
ui_value: Optional[FileUploaderStateProto], widget_id: str
) -> SomeUploadedFiles:
file_recs = self._get_file_recs(widget_id, ui_value)
if len(file_recs) == 0:
return_value: Optional[Union[List[UploadedFile], UploadedFile]] = (
[] if accept_multiple_files else None
)
else:
files = [UploadedFile(rec) for rec in file_recs]
return_value = files if accept_multiple_files else files[0]
return return_value
def serialize_file_uploader(files: SomeUploadedFiles) -> FileUploaderStateProto:
state_proto = FileUploaderStateProto()
ctx = get_script_run_ctx()
if ctx is None:
return state_proto
# ctx.uploaded_file_mgr._file_id_counter stores the id to use for
# the *next* uploaded file, so the current highest file id is the
# counter minus 1.
state_proto.max_file_id = ctx.uploaded_file_mgr._file_id_counter - 1
if not files:
return state_proto
elif not isinstance(files, list):
files = [files]
for f in files:
file_info: UploadedFileInfoProto = state_proto.uploaded_file_info.add()
file_info.id = f.id
file_info.name = f.name
file_info.size = f.size
return state_proto
# FileUploader's widget value is a list of file IDs
# representing the current set of files that this uploader should
# know about.
widget_value, _ = register_widget(
"file_uploader",
file_uploader_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_file_uploader,
serializer=serialize_file_uploader,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
file_uploader_proto.disabled = disabled
file_uploader_state = serialize_file_uploader(widget_value)
uploaded_file_info = file_uploader_state.uploaded_file_info
if ctx is not None and len(uploaded_file_info) != 0:
newest_file_id = file_uploader_state.max_file_id
active_file_ids = [f.id for f in uploaded_file_info]
ctx.uploaded_file_mgr.remove_orphaned_files(
session_id=ctx.session_id,
widget_id=file_uploader_proto.id,
newest_file_id=newest_file_id,
active_file_ids=active_file_ids,
)
self.dg._enqueue("file_uploader", file_uploader_proto)
return cast(SomeUploadedFiles, widget_value)
@staticmethod
def _get_file_recs(
widget_id: str, widget_value: Optional[FileUploaderStateProto]
) -> List[UploadedFileRec]:
if widget_value is None:
return []
ctx = get_script_run_ctx()
if ctx is None:
return []
uploaded_file_info = widget_value.uploaded_file_info
if len(uploaded_file_info) == 0:
return []
active_file_ids = [f.id for f in uploaded_file_info]
# Grab the files that correspond to our active file IDs.
return ctx.uploaded_file_mgr.get_files(
session_id=ctx.session_id,
widget_id=widget_id,
file_ids=active_file_ids,
)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,274 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap
from typing import cast, Optional, NamedTuple
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto import Block_pb2
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
class FormData(NamedTuple):
"""Form data stored on a DeltaGenerator."""
# The form's unique ID.
form_id: str
def _current_form(
this_dg: "streamlit.delta_generator.DeltaGenerator",
) -> Optional[FormData]:
"""Find the FormData for the given DeltaGenerator.
Forms are blocks, and can have other blocks nested inside them.
To find the current form, we walk up the dg_stack until we find
a DeltaGenerator that has FormData.
"""
if not streamlit._is_running_with_streamlit:
return None
if this_dg._form_data is not None:
return this_dg._form_data
if this_dg == this_dg._main_dg:
# We were created via an `st.foo` call.
# Walk up the dg_stack to see if we're nested inside a `with st.form` statement.
ctx = get_script_run_ctx()
if ctx is None or len(ctx.dg_stack) == 0:
return None
for dg in reversed(ctx.dg_stack):
if dg._form_data is not None:
return dg._form_data
else:
# We were created via an `dg.foo` call.
# Take a look at our parent's form data to see if we're nested inside a form.
parent = this_dg._parent
if parent is not None and parent._form_data is not None:
return parent._form_data
return None
def current_form_id(dg: "streamlit.delta_generator.DeltaGenerator") -> str:
"""Return the form_id for the current form, or the empty string if we're
not inside an `st.form` block.
(We return the empty string, instead of None, because this value is
assigned to protobuf message fields, and None is not valid.)
"""
form_data = _current_form(dg)
if form_data is None:
return ""
return form_data.form_id
def is_in_form(dg: "streamlit.delta_generator.DeltaGenerator") -> bool:
"""True if the DeltaGenerator is inside an st.form block."""
return current_form_id(dg) != ""
def _build_duplicate_form_message(user_key: Optional[str] = None) -> str:
if user_key is not None:
message = textwrap.dedent(
f"""
There are multiple identical forms with `key='{user_key}'`.
To fix this, please make sure that the `key` argument is unique for
each `st.form` you create.
"""
)
else:
message = textwrap.dedent(
"""
There are multiple identical forms with the same generated key.
When a form is created, it's assigned an internal key based on
its structure. Multiple forms with an identical structure will
result in the same internal key, which causes this error.
To fix this error, please pass a unique `key` argument to
`st.form`.
"""
)
return message.strip("\n")
class FormMixin:
def form(self, key: str, clear_on_submit: bool = False):
"""Create a form that batches elements together with a "Submit" button.
A form is a container that visually groups other elements and
widgets together, and contains a Submit button. When the form's
Submit button is pressed, all widget values inside the form will be
sent to Streamlit in a batch.
To add elements to a form object, you can use "with" notation
(preferred) or just call methods directly on the form. See
examples below.
Forms have a few constraints:
* Every form must contain a ``st.form_submit_button``.
* ``st.button`` and ``st.download_button`` cannot be added to a form.
* Forms can appear anywhere in your app (sidebar, columns, etc),
but they cannot be embedded inside other forms.
For more information about forms, check out our
`blog post <https://blog.streamlit.io/introducing-submit-button-and-forms/>`_.
Parameters
----------
key : str
A string that identifies the form. Each form must have its own
key. (This key is not displayed to the user in the interface.)
clear_on_submit : bool
If True, all widgets inside the form will be reset to their default
values after the user presses the Submit button. Defaults to False.
(Note that Custom Components are unaffected by this flag, and
will not be reset to their defaults on form submission.)
Examples
--------
Inserting elements using "with" notation:
>>> with st.form("my_form"):
... st.write("Inside the form")
... slider_val = st.slider("Form slider")
... checkbox_val = st.checkbox("Form checkbox")
...
... # Every form must have a submit button.
... submitted = st.form_submit_button("Submit")
... if submitted:
... st.write("slider", slider_val, "checkbox", checkbox_val)
...
>>> st.write("Outside the form")
Inserting elements out of order:
>>> form = st.form("my_form")
>>> form.slider("Inside the form")
>>> st.slider("Outside the form")
>>>
>>> # Now add a submit button to the form:
>>> form.form_submit_button("Submit")
"""
from .utils import check_session_state_rules
if is_in_form(self.dg):
raise StreamlitAPIException("Forms cannot be nested in other forms.")
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
# A form is uniquely identified by its key.
form_id = key
ctx = get_script_run_ctx()
if ctx is not None:
new_form_id = form_id not in ctx.form_ids_this_run
if new_form_id:
ctx.form_ids_this_run.add(form_id)
else:
raise StreamlitAPIException(_build_duplicate_form_message(key))
block_proto = Block_pb2.Block()
block_proto.form.form_id = form_id
block_proto.form.clear_on_submit = clear_on_submit
block_dg = self.dg._block(block_proto)
# Attach the form's button info to the newly-created block's
# DeltaGenerator.
block_dg._form_data = FormData(form_id)
return block_dg
def form_submit_button(
self,
label: str = "Submit",
help: Optional[str] = None,
on_click=None,
args=None,
kwargs=None,
) -> bool:
"""Display a form submit button.
When this button is clicked, all widget values inside the form will be
sent to Streamlit in a batch.
Every form must have a form_submit_button. A form_submit_button
cannot exist outside a form.
For more information about forms, check out our
`blog post <https://blog.streamlit.io/introducing-submit-button-and-forms/>`_.
Parameters
----------
label : str
A short label explaining to the user what this button is for.
Defaults to "Submit".
help : str or None
A tooltip that gets displayed when the button is hovered over.
Defaults to None.
on_click : callable
An optional callback invoked when this button is clicked.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
Returns
-------
bool
True if the button was clicked.
"""
ctx = get_script_run_ctx()
return self._form_submit_button(
label=label,
help=help,
on_click=on_click,
args=args,
kwargs=kwargs,
ctx=ctx,
)
def _form_submit_button(
self,
label: str = "Submit",
help: Optional[str] = None,
on_click=None,
args=None,
kwargs=None,
ctx: Optional[ScriptRunContext] = None,
) -> bool:
form_id = current_form_id(self.dg)
submit_button_key = f"FormSubmitter:{form_id}-{label}"
return self.dg._button(
label=label,
key=submit_button_key,
help=help,
is_form_submitter=True,
on_click=on_click,
args=args,
kwargs=kwargs,
ctx=ctx,
)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,124 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streamlit support for GraphViz charts."""
import hashlib
from typing import cast
import streamlit
from streamlit import type_util
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.proto.GraphVizChart_pb2 import GraphVizChart as GraphVizChartProto
LOGGER = get_logger(__name__)
class GraphvizMixin:
def graphviz_chart(self, figure_or_dot, use_container_width=False):
"""Display a graph using the dagre-d3 library.
Parameters
----------
figure_or_dot : graphviz.dot.Graph, graphviz.dot.Digraph, str
The Graphlib graph object or dot string to display
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the figure's native `width` value.
Example
-------
>>> import streamlit as st
>>> import graphviz as graphviz
>>>
>>> # Create a graphlib graph object
>>> graph = graphviz.Digraph()
>>> graph.edge('run', 'intr')
>>> graph.edge('intr', 'runbl')
>>> graph.edge('runbl', 'run')
>>> graph.edge('run', 'kernel')
>>> graph.edge('kernel', 'zombie')
>>> graph.edge('kernel', 'sleep')
>>> graph.edge('kernel', 'runmem')
>>> graph.edge('sleep', 'swap')
>>> graph.edge('swap', 'runswap')
>>> graph.edge('runswap', 'new')
>>> graph.edge('runswap', 'runmem')
>>> graph.edge('new', 'runmem')
>>> graph.edge('sleep', 'runmem')
>>>
>>> st.graphviz_chart(graph)
Or you can render the chart from the graph using GraphViz's Dot
language:
>>> st.graphviz_chart('''
digraph {
run -> intr
intr -> runbl
runbl -> run
run -> kernel
kernel -> zombie
kernel -> sleep
kernel -> runmem
sleep -> swap
swap -> runswap
runswap -> new
runswap -> runmem
new -> runmem
sleep -> runmem
}
''')
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.graphviz_chart.py
height: 600px
"""
# Generate element ID from delta path
delta_path = self.dg._get_delta_path_str()
element_id = hashlib.md5(delta_path.encode()).hexdigest()
graphviz_chart_proto = GraphVizChartProto()
marshall(graphviz_chart_proto, figure_or_dot, use_container_width, element_id)
return self.dg._enqueue("graphviz_chart", graphviz_chart_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall(proto, figure_or_dot, use_container_width, element_id):
"""Construct a GraphViz chart object.
See DeltaGenerator.graphviz_chart for docs.
"""
if type_util.is_graphviz_chart(figure_or_dot):
dot = figure_or_dot.source
elif isinstance(figure_or_dot, str):
dot = figure_or_dot
else:
raise StreamlitAPIException(
"Unhandled type for graphviz chart: %s" % type(figure_or_dot)
)
proto.spec = dot
proto.use_container_width = use_container_width
proto.element_id = element_id

View File

@ -0,0 +1,142 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import cast
import streamlit
from streamlit.proto.IFrame_pb2 import IFrame as IFrameProto
class IframeMixin:
def _iframe(
self,
src,
width=None,
height=None,
scrolling=False,
):
"""Load a remote URL in an iframe.
Parameters
----------
src : str
The URL of the page to embed.
width : int
The width of the frame in CSS pixels. Defaults to the app's
default element width.
height : int
The height of the frame in CSS pixels. Defaults to 150.
scrolling : bool
If True, show a scrollbar when the content is larger than the iframe.
Otherwise, do not show a scrollbar. Defaults to False.
"""
iframe_proto = IFrameProto()
marshall(
iframe_proto,
src=src,
width=width,
height=height,
scrolling=scrolling,
)
return self.dg._enqueue("iframe", iframe_proto)
def _html(
self,
html,
width=None,
height=None,
scrolling=False,
):
"""Display an HTML string in an iframe.
Parameters
----------
html : str
The HTML string to embed in the iframe.
width : int
The width of the frame in CSS pixels. Defaults to the app's
default element width.
height : int
The height of the frame in CSS pixels. Defaults to 150.
scrolling : bool
If True, show a scrollbar when the content is larger than the iframe.
Otherwise, do not show a scrollbar. Defaults to False.
"""
iframe_proto = IFrameProto()
marshall(
iframe_proto,
srcdoc=html,
width=width,
height=height,
scrolling=scrolling,
)
return self.dg._enqueue("iframe", iframe_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall(
proto,
src: Optional[str] = None,
srcdoc: Optional[str] = None,
width: Optional[int] = None,
height: Optional[int] = None,
scrolling: bool = False,
) -> None:
"""Marshalls data into an IFrame proto.
These parameters correspond directly to <iframe> attributes, which are
described in more detail at
https://developer.mozilla.org/en-US/docs/Web/HTML/Element/iframe.
Parameters
----------
proto : IFrame protobuf
The protobuf object to marshall data into.
src : str
The URL of the page to embed.
srcdoc : str
Inline HTML to embed. Overrides src.
width : int
The width of the frame in CSS pixels. Defaults to the app's
default element width.
height : int
The height of the frame in CSS pixels. Defaults to 150.
scrolling : bool
If true, show a scrollbar when the content is larger than the iframe.
Otherwise, never show a scrollbar.
"""
if src is not None:
proto.src = src
if srcdoc is not None:
proto.srcdoc = srcdoc
if width is not None:
proto.width = width
proto.has_width = True
if height is not None:
proto.height = height
else:
proto.height = 150
proto.scrolling = scrolling

View File

@ -0,0 +1,374 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image marshalling."""
import imghdr
import io
import mimetypes
from typing import cast
from urllib.parse import urlparse
import re
import numpy as np
from PIL import Image, ImageFile
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.in_memory_file_manager import in_memory_file_manager
from streamlit.proto.Image_pb2 import ImageList as ImageListProto
LOGGER = get_logger(__name__)
# This constant is related to the frontend maximum content width specified
# in App.jsx main container
# 730 is the max width of element-container in the frontend, and 2x is for high
# DPI.
MAXIMUM_CONTENT_WIDTH = 2 * 730
class ImageMixin:
def image(
self,
image,
caption=None,
width=None,
use_column_width=None,
clamp=False,
channels="RGB",
output_format="auto",
):
"""Display an image or list of images.
Parameters
----------
image : numpy.ndarray, [numpy.ndarray], BytesIO, str, or [str]
Monochrome image of shape (w,h) or (w,h,1)
OR a color image of shape (w,h,3)
OR an RGBA image of shape (w,h,4)
OR a URL to fetch the image from
OR a path of a local image file
OR an SVG XML string like `<svg xmlns=...</svg>`
OR a list of one of the above, to display multiple images.
caption : str or list of str
Image caption. If displaying multiple images, caption should be a
list of captions (one for each image).
width : int or None
Image width. None means use the image width,
but do not exceed the width of the column.
Should be set for SVG images, as they have no default image width.
use_column_width : 'auto' or 'always' or 'never' or bool
If 'auto', set the image's width to its natural size,
but do not exceed the width of the column.
If 'always' or True, set the image's width to the column width.
If 'never' or False, set the image's width to its natural size.
Note: if set, `use_column_width` takes precedence over the `width` parameter.
clamp : bool
Clamp image pixel values to a valid range ([0-255] per channel).
This is only meaningful for byte array images; the parameter is
ignored for image URLs. If this is not set, and an image has an
out-of-range value, an error will be thrown.
channels : 'RGB' or 'BGR'
If image is an nd.array, this parameter denotes the format used to
represent color information. Defaults to 'RGB', meaning
`image[:, :, 0]` is the red channel, `image[:, :, 1]` is green, and
`image[:, :, 2]` is blue. For images coming from libraries like
OpenCV you should set this to 'BGR', instead.
output_format : 'JPEG', 'PNG', or 'auto'
This parameter specifies the format to use when transferring the
image data. Photos should use the JPEG format for lossy compression
while diagrams should use the PNG format for lossless compression.
Defaults to 'auto' which identifies the compression type based
on the type and format of the image argument.
Example
-------
>>> from PIL import Image
>>> image = Image.open('sunrise.jpg')
>>>
>>> st.image(image, caption='Sunrise by the mountains')
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.image.py
height: 710px
"""
if use_column_width == "auto" or (use_column_width is None and width is None):
width = -3
elif use_column_width == "always" or use_column_width == True:
width = -2
elif width is None:
width = -1
elif width <= 0:
raise StreamlitAPIException("Image width must be positive.")
image_list_proto = ImageListProto()
marshall_images(
self.dg._get_delta_path_str(),
image,
caption,
width,
image_list_proto,
clamp,
channels,
output_format,
)
return self.dg._enqueue("imgs", image_list_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def _image_may_have_alpha_channel(image):
if image.mode in ("RGBA", "LA", "P"):
return True
else:
return False
def _format_from_image_type(image, output_format):
output_format = output_format.upper()
if output_format == "JPEG" or output_format == "PNG":
return output_format
# We are forgiving on the spelling of JPEG
if output_format == "JPG":
return "JPEG"
if _image_may_have_alpha_channel(image):
return "PNG"
return "JPEG"
def _PIL_to_bytes(image, format="JPEG", quality=100):
tmp = io.BytesIO()
# User must have specified JPEG, so we must convert it
if format == "JPEG" and _image_may_have_alpha_channel(image):
image = image.convert("RGB")
image.save(tmp, format=format, quality=quality)
return tmp.getvalue()
def _BytesIO_to_bytes(data):
data.seek(0)
return data.getvalue()
def _np_array_to_bytes(array, output_format="JPEG"):
img = Image.fromarray(array.astype(np.uint8))
format = _format_from_image_type(img, output_format)
return _PIL_to_bytes(img, format)
def _4d_to_list_3d(array):
return [array[i, :, :, :] for i in range(0, array.shape[0])]
def _verify_np_shape(array):
if len(array.shape) not in (2, 3):
raise StreamlitAPIException("Numpy shape has to be of length 2 or 3.")
if len(array.shape) == 3 and array.shape[-1] not in (1, 3, 4):
raise StreamlitAPIException(
"Channel can only be 1, 3, or 4 got %d. Shape is %s"
% (array.shape[-1], str(array.shape))
)
# If there's only one channel, convert is to x, y
if len(array.shape) == 3 and array.shape[-1] == 1:
array = array[:, :, 0]
return array
def _normalize_to_bytes(data, width, output_format):
image = Image.open(io.BytesIO(data))
actual_width, actual_height = image.size
format = _format_from_image_type(image, output_format)
if output_format.lower() == "auto":
ext = imghdr.what(None, data)
mimetype = mimetypes.guess_type("image.%s" % ext)[0]
# if no other options, attempt to convert
if mimetype is None:
mimetype = "image/" + format.lower()
else:
mimetype = "image/" + format.lower()
if width < 0 and actual_width > MAXIMUM_CONTENT_WIDTH:
width = MAXIMUM_CONTENT_WIDTH
if width > 0 and actual_width > width:
new_height = int(1.0 * actual_height * width / actual_width)
image = image.resize((width, new_height), resample=Image.BILINEAR)
data = _PIL_to_bytes(image, format=format, quality=90)
mimetype = "image/" + format.lower()
return data, mimetype
def _clip_image(image, clamp):
data = image
if issubclass(image.dtype.type, np.floating):
if clamp:
data = np.clip(image, 0, 1.0)
else:
if np.amin(image) < 0.0 or np.amax(image) > 1.0:
raise RuntimeError("Data is outside [0.0, 1.0] and clamp is not set.")
data = data * 255
else:
if clamp:
data = np.clip(image, 0, 255)
else:
if np.amin(image) < 0 or np.amax(image) > 255:
raise RuntimeError("Data is outside [0, 255] and clamp is not set.")
return data
def image_to_url(
image, width, clamp, channels, output_format, image_id, allow_emoji=False
):
# PIL Images
if isinstance(image, ImageFile.ImageFile) or isinstance(image, Image.Image):
format = _format_from_image_type(image, output_format)
data = _PIL_to_bytes(image, format)
# BytesIO
# Note: This doesn't support SVG. We could convert to png (cairosvg.svg2png)
# or just decode BytesIO to string and handle that way.
elif isinstance(image, io.BytesIO):
data = _BytesIO_to_bytes(image)
# Numpy Arrays (ie opencv)
elif type(image) is np.ndarray:
data = _verify_np_shape(image)
data = _clip_image(data, clamp)
if channels == "BGR":
if len(data.shape) == 3:
data = data[:, :, [2, 1, 0]]
else:
raise StreamlitAPIException(
'When using `channels="BGR"`, the input image should '
"have exactly 3 color channels"
)
data = _np_array_to_bytes(data, output_format=output_format)
# Strings
elif isinstance(image, str):
# If it's a url, then set the protobuf and continue
try:
p = urlparse(image)
if p.scheme:
return image
except UnicodeDecodeError:
pass
# Finally, see if it's a file.
try:
with open(image, "rb") as f:
data = f.read()
except:
if allow_emoji:
# This might be an emoji string, so just pass it to the frontend
return image
else:
# Allow OS filesystem errors to raise
raise
# Assume input in bytes.
else:
data = image
(data, mimetype) = _normalize_to_bytes(data, width, output_format)
this_file = in_memory_file_manager.add(data, mimetype, image_id)
return this_file.url
def marshall_images(
coordinates,
image,
caption,
width,
proto_imgs,
clamp,
channels="RGB",
output_format="auto",
):
channels = channels.upper()
# Turn single image and caption into one element list.
if type(image) is list:
images = image
else:
if type(image) == np.ndarray and len(image.shape) == 4:
images = _4d_to_list_3d(image)
else:
images = [image]
if type(caption) is list:
captions = caption
else:
if isinstance(caption, str):
captions = [caption]
# You can pass in a 1-D Numpy array as captions.
elif type(caption) == np.ndarray and len(caption.shape) == 1:
captions = caption.tolist()
# If there are no captions then make the captions list the same size
# as the images list.
elif caption is None:
captions = [None] * len(images)
else:
captions = [str(caption)]
assert type(captions) == list, "If image is a list then caption should be as well"
assert len(captions) == len(images), "Cannot pair %d captions with %d images." % (
len(captions),
len(images),
)
proto_imgs.width = width
# Each image in an image list needs to be kept track of at its own coordinates.
for coord_suffix, (image, caption) in enumerate(zip(images, captions)):
proto_img = proto_imgs.imgs.add()
if caption is not None:
proto_img.caption = str(caption)
# We use the index of the image in the input image list to identify this image inside
# InMemoryFileManager. For this, we just add the index to the image's "coordinates".
image_id = "%s-%i" % (coordinates, coord_suffix)
is_svg = False
if isinstance(image, str):
# Unpack local SVG image file to an SVG string
if image.endswith(".svg") and not image.startswith(("http://", "https://")):
with open(image) as textfile:
image = textfile.read()
# Following regex allows svg image files to start either via a "<?xml...>" tag eventually followed by a "<svg...>" tag or directly starting with a "<svg>" tag
if re.search(r"(^\s?(<\?xml[\s\S]*<svg )|^\s?<svg )", image):
proto_img.markup = f"data:image/svg+xml,{image}"
is_svg = True
if not is_svg:
proto_img.url = image_to_url(
image, width, clamp, channels, output_format, image_id
)

View File

@ -0,0 +1,85 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import cast
import streamlit
from streamlit.proto.Json_pb2 import Json as JsonProto
from streamlit.state import SessionStateProxy
class JsonMixin:
def json(
self,
body,
*, # keyword-only arguments:
expanded=True,
):
"""Display object or string as a pretty-printed JSON string.
Parameters
----------
body : Object or str
The object to print as JSON. All referenced objects should be
serializable to JSON as well. If object is a string, we assume it
contains serialized JSON.
expanded : bool
An optional boolean that allows the user to set whether the initial
state of this json element should be expanded. Defaults to True.
This argument can only be supplied by keyword.
Example
-------
>>> st.json({
... 'foo': 'bar',
... 'baz': 'boz',
... 'stuff': [
... 'stuff 1',
... 'stuff 2',
... 'stuff 3',
... 'stuff 5',
... ],
... })
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/data.json.py
height: 385px
"""
import streamlit as st
if isinstance(body, SessionStateProxy):
body = body.to_dict()
if not isinstance(body, str):
try:
body = json.dumps(body, default=repr)
except TypeError as err:
st.warning(
"Warning: this data structure was not fully serializable as "
"JSON due to one or more unexpected keys. (Error was: %s)" % err
)
body = json.dumps(body, skipkeys=True, default=repr)
json_proto = JsonProto()
json_proto.body = body
json_proto.expanded = expanded
return self.dg._enqueue("json", json_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,252 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast, Sequence, Union
from streamlit.beta_util import function_beta_warning
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Block_pb2 import Block as BlockProto
import streamlit
SpecType = Union[int, Sequence[Union[int, float]]]
class LayoutsMixin:
def container(self):
"""Insert a multi-element container.
Inserts an invisible container into your app that can be used to hold
multiple elements. This allows you to, for example, insert multiple
elements into your app out of order.
To add elements to the returned container, you can use "with" notation
(preferred) or just call methods directly on the returned object. See
examples below.
Examples
--------
Inserting elements using "with" notation:
>>> with st.container():
... st.write("This is inside the container")
...
... # You can call any Streamlit command, including custom components:
... st.bar_chart(np.random.randn(50, 3))
...
>>> st.write("This is outside the container")
.. output ::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/layout.container1.py
height: 520px
Inserting elements out of order:
>>> container = st.container()
>>> container.write("This is inside the container")
>>> st.write("This is outside the container")
>>>
>>> # Now insert some more in the container
>>> container.write("This is inside too")
.. output ::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/layout.container2.py
height: 480px
"""
return self.dg._block()
# TODO: Enforce that columns are not nested or in Sidebar
def columns(self, spec: SpecType):
"""Insert containers laid out as side-by-side columns.
Inserts a number of multi-element containers laid out side-by-side and
returns a list of container objects.
To add elements to the returned containers, you can use "with" notation
(preferred) or just call methods directly on the returned object. See
examples below.
.. warning::
Currently, you may not put columns inside another column.
Parameters
----------
spec : int or list of numbers
If an int
Specifies the number of columns to insert, and all columns
have equal width.
If a list of numbers
Creates a column for each number, and each
column's width is proportional to the number provided. Numbers can
be ints or floats, but they must be positive.
For example, `st.columns([3, 1, 2])` creates 3 columns where
the first column is 3 times the width of the second, and the last
column is 2 times that width.
Returns
-------
list of containers
A list of container objects.
Examples
--------
You can use `with` notation to insert any element into a column:
>>> col1, col2, col3 = st.columns(3)
>>>
>>> with col1:
... st.header("A cat")
... st.image("https://static.streamlit.io/examples/cat.jpg")
...
>>> with col2:
... st.header("A dog")
... st.image("https://static.streamlit.io/examples/dog.jpg")
...
>>> with col3:
... st.header("An owl")
... st.image("https://static.streamlit.io/examples/owl.jpg")
.. output ::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/layout.columns1.py
height: 620px
Or you can just call methods directly in the returned objects:
>>> col1, col2 = st.columns([3, 1])
>>> data = np.random.randn(10, 1)
>>>
>>> col1.subheader("A wide column with a chart")
>>> col1.line_chart(data)
>>>
>>> col2.subheader("A narrow column with the data")
>>> col2.write(data)
.. output ::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/layout.columns2.py
height: 550px
"""
weights = spec
weights_exception = StreamlitAPIException(
"The input argument to st.columns must be either a "
+ "positive integer or a list of positive numeric weights. "
+ "See [documentation](https://docs.streamlit.io/library/api-reference/layout/st.columns) "
+ "for more information."
)
if isinstance(weights, int):
# If the user provided a single number, expand into equal weights.
# E.g. (1,) * 3 => (1, 1, 1)
# NOTE: A negative/zero spec will expand into an empty tuple.
weights = (1,) * weights
if len(weights) == 0 or any(weight <= 0 for weight in weights):
raise weights_exception
def column_proto(normalized_weight):
col_proto = BlockProto()
col_proto.column.weight = normalized_weight
col_proto.allow_empty = True
return col_proto
block_proto = BlockProto()
block_proto.horizontal.SetInParent()
row = self.dg._block(block_proto)
total_weight = sum(weights)
return [row._block(column_proto(w / total_weight)) for w in weights]
def expander(self, label: str, expanded: bool = False):
"""Insert a multi-element container that can be expanded/collapsed.
Inserts a container into your app that can be used to hold multiple elements
and can be expanded or collapsed by the user. When collapsed, all that is
visible is the provided label.
To add elements to the returned container, you can use "with" notation
(preferred) or just call methods directly on the returned object. See
examples below.
.. warning::
Currently, you may not put expanders inside another expander.
Parameters
----------
label : str
A string to use as the header for the expander.
expanded : bool
If True, initializes the expander in "expanded" state. Defaults to
False (collapsed).
Examples
--------
You can use `with` notation to insert any element into an expander
>>> st.bar_chart({"data": [1, 5, 2, 6, 2, 1]})
>>>
>>> with st.expander("See explanation"):
... st.write(\"\"\"
... The chart above shows some numbers I picked for you.
... I rolled actual dice for these, so they're *guaranteed* to
... be random.
... \"\"\")
... st.image("https://static.streamlit.io/examples/dice.jpg")
.. output ::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/layout.expander.py
height: 750px
Or you can just call methods directly in the returned objects:
>>> st.bar_chart({"data": [1, 5, 2, 6, 2, 1]})
>>>
>>> expander = st.expander("See explanation")
>>> expander.write(\"\"\"
... The chart above shows some numbers I picked for you.
... I rolled actual dice for these, so they're *guaranteed* to
... be random.
... \"\"\")
>>> expander.image("https://static.streamlit.io/examples/dice.jpg")
.. output ::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/layout.expander.py
height: 750px
"""
if label is None:
raise StreamlitAPIException("A label is required for an expander")
expandable_proto = BlockProto.Expandable()
expandable_proto.expanded = expanded
expandable_proto.label = label
block_proto = BlockProto()
block_proto.allow_empty = True
block_proto.expandable.CopyFrom(expandable_proto)
return self.dg._block(block_proto=block_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
# Deprecated beta_ functions
beta_container = function_beta_warning(container, "2021-11-02")
beta_expander = function_beta_warning(expander, "2021-11-02")
beta_columns = function_beta_warning(columns, "2021-11-02")

View File

@ -0,0 +1,350 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper around Altair.
Altair is a Python visualization library based on Vega-Lite,
a nice JSON schema for expressing graphs and charts."""
from datetime import date
from typing import cast
import streamlit
from streamlit import errors, type_util
from streamlit.proto.VegaLiteChart_pb2 import VegaLiteChart as VegaLiteChartProto
import streamlit.elements.legacy_vega_lite as vega_lite
import altair as alt
import pandas as pd
import pyarrow as pa
from .utils import last_index_for_melted_dataframes
class LegacyAltairMixin:
def _legacy_line_chart(
self, data=None, width=0, height=0, use_container_width=True
):
"""Display a line chart.
This is syntax-sugar around st._legacy_altair_chart. The main difference
is this command uses the data's own column and indices to figure out
the chart's spec. As a result this is easier to use for many "just plot
this" scenarios, while being less customizable.
If st._legacy_line_chart does not guess the data specification
correctly, try specifying your desired chart using st._legacy_altair_chart.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict
or None
Data to be plotted.
width : int
The chart width in pixels. If 0, selects the width automatically.
height : int
The chart width in pixels. If 0, selects the height automatically.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the width argument.
Example
-------
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
... columns=['a', 'b', 'c'])
...
>>> st._legacy_line_chart(chart_data)
.. output::
https://static.streamlit.io/0.50.0-td2L/index.html?id=BdxXG3MmrVBfJyqS2R2ki8
height: 220px
"""
vega_lite_chart_proto = VegaLiteChartProto()
chart = generate_chart("line", data, width, height)
marshall(vega_lite_chart_proto, chart, use_container_width)
last_index = last_index_for_melted_dataframes(data)
return self.dg._enqueue(
"line_chart", vega_lite_chart_proto, last_index=last_index
)
def _legacy_area_chart(
self, data=None, width=0, height=0, use_container_width=True
):
"""Display an area chart.
This is just syntax-sugar around st._legacy_altair_chart. The main difference
is this command uses the data's own column and indices to figure out
the chart's spec. As a result this is easier to use for many "just plot
this" scenarios, while being less customizable.
If st._legacy_area_chart does not guess the data specification
correctly, try specifying your desired chart using st._legacy_altair_chart.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, or dict
Data to be plotted.
width : int
The chart width in pixels. If 0, selects the width automatically.
height : int
The chart width in pixels. If 0, selects the height automatically.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the width argument.
Example
-------
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
... columns=['a', 'b', 'c'])
...
>>> st._legacy_area_chart(chart_data)
.. output::
https://static.streamlit.io/0.50.0-td2L/index.html?id=Pp65STuFj65cJRDfhGh4Jt
height: 220px
"""
vega_lite_chart_proto = VegaLiteChartProto()
chart = generate_chart("area", data, width, height)
marshall(vega_lite_chart_proto, chart, use_container_width)
last_index = last_index_for_melted_dataframes(data)
return self.dg._enqueue(
"area_chart", vega_lite_chart_proto, last_index=last_index
)
def _legacy_bar_chart(self, data=None, width=0, height=0, use_container_width=True):
"""Display a bar chart.
This is just syntax-sugar around st._legacy_altair_chart. The main difference
is this command uses the data's own column and indices to figure out
the chart's spec. As a result this is easier to use for many "just plot
this" scenarios, while being less customizable.
If st._legacy_bar_chart does not guess the data specification
correctly, try specifying your desired chart using st._legacy_altair_chart.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, or dict
Data to be plotted.
width : int
The chart width in pixels. If 0, selects the width automatically.
height : int
The chart width in pixels. If 0, selects the height automatically.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the width argument.
Example
-------
>>> chart_data = pd.DataFrame(
... np.random.randn(50, 3),
... columns=["a", "b", "c"])
...
>>> st._legacy_bar_chart(chart_data)
.. output::
https://static.streamlit.io/0.66.0-2BLtg/index.html?id=GaYDn6vxskvBUkBwsGVEaL
height: 220px
"""
vega_lite_chart_proto = VegaLiteChartProto()
chart = generate_chart("bar", data, width, height)
marshall(vega_lite_chart_proto, chart, use_container_width)
last_index = last_index_for_melted_dataframes(data)
return self.dg._enqueue(
"bar_chart", vega_lite_chart_proto, last_index=last_index
)
def _legacy_altair_chart(self, altair_chart, use_container_width=False):
"""Display a chart using the Altair library.
Parameters
----------
altair_chart : altair.vegalite.v2.api.Chart
The Altair chart object to display.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over Altair's native `width` value.
Example
-------
>>> import pandas as pd
>>> import numpy as np
>>> import altair as alt
>>>
>>> df = pd.DataFrame(
... np.random.randn(200, 3),
... columns=['a', 'b', 'c'])
...
>>> c = alt.Chart(df).mark_circle().encode(
... x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c'])
>>>
>>> st._legacy_altair_chart(c, use_container_width=True)
.. output::
https://static.streamlit.io/0.25.0-2JkNY/index.html?id=8jmmXR8iKoZGV4kXaKGYV5
height: 200px
Examples of Altair charts can be found at
https://altair-viz.github.io/gallery/.
"""
vega_lite_chart_proto = VegaLiteChartProto()
marshall(
vega_lite_chart_proto,
altair_chart,
use_container_width=use_container_width,
)
return self.dg._enqueue("vega_lite_chart", vega_lite_chart_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def _is_date_column(df, name):
"""True if the column with the given name stores datetime.date values.
This function just checks the first value in the given column, so
it's meaningful only for columns whose values all share the same type.
Parameters
----------
df : pd.DataFrame
name : str
The column name
Returns
-------
bool
"""
column = df[name]
if column.size == 0:
return False
return isinstance(column[0], date)
def generate_chart(chart_type, data, width=0, height=0):
if data is None:
# Use an empty-ish dict because if we use None the x axis labels rotate
# 90 degrees. No idea why. Need to debug.
data = {"": []}
if isinstance(data, pa.Table):
raise errors.StreamlitAPIException(
"""
pyarrow tables are not supported by Streamlit's legacy DataFrame serialization (i.e. with `config.dataFrameSerialization = "legacy"`).
To be able to use pyarrow tables, please enable pyarrow by changing the config setting,
`config.dataFrameSerialization = "arrow"`
"""
)
if not isinstance(data, pd.DataFrame):
data = type_util.convert_anything_to_df(data)
index_name = data.index.name
if index_name is None:
index_name = "index"
data = pd.melt(data.reset_index(), id_vars=[index_name])
if chart_type == "area":
opacity = {"value": 0.7}
else:
opacity = {"value": 1.0}
# Set the X and Y axes' scale to "utc" if they contain date values.
# This causes time data to be displayed in UTC, rather the user's local
# time zone. (By default, vega-lite displays time data in the browser's
# local time zone, regardless of which time zone the data specifies:
# https://vega.github.io/vega-lite/docs/timeunit.html#output).
x_scale = (
alt.Scale(type="utc") if _is_date_column(data, index_name) else alt.Undefined
)
y_scale = alt.Scale(type="utc") if _is_date_column(data, "value") else alt.Undefined
x_type = alt.Undefined
# Bar charts should have a discrete (ordinal) x-axis, UNLESS type is date/time
# https://github.com/streamlit/streamlit/pull/2097#issuecomment-714802475
if chart_type == "bar" and not _is_date_column(data, index_name):
x_type = "ordinal"
chart = (
getattr(alt.Chart(data, width=width, height=height), "mark_" + chart_type)()
.encode(
alt.X(index_name, title="", scale=x_scale, type=x_type),
alt.Y("value", title="", scale=y_scale),
alt.Color("variable", title="", type="nominal"),
alt.Tooltip([index_name, "value", "variable"]),
opacity=opacity,
)
.interactive()
)
return chart
def marshall(vega_lite_chart, altair_chart, use_container_width=False, **kwargs):
import altair as alt
# Normally altair_chart.to_dict() would transform the dataframe used by the
# chart into an array of dictionaries. To avoid that, we install a
# transformer that replaces datasets with a reference by the object id of
# the dataframe. We then fill in the dataset manually later on.
datasets = {}
def id_transform(data):
"""Altair data transformer that returns a fake named dataset with the
object id."""
datasets[id(data)] = data
return {"name": str(id(data))}
alt.data_transformers.register("id", id_transform)
with alt.data_transformers.enable("id"):
chart_dict = altair_chart.to_dict()
# Put datasets back into the chart dict but note how they weren't
# transformed.
chart_dict["datasets"] = datasets
vega_lite.marshall(
vega_lite_chart,
chart_dict,
use_container_width=use_container_width,
**kwargs,
)

View File

@ -0,0 +1,441 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions to marshall a pandas.DataFrame into a proto.DataFrame."""
import datetime
import re
from collections import namedtuple
from typing import cast, Dict, Any, Optional
import pyarrow as pa
import tzlocal
from pandas import DataFrame
from pandas.io.formats.style import Styler
import streamlit
from streamlit import errors, type_util
from streamlit.logger import get_logger
from streamlit.proto.DataFrame_pb2 import (
DataFrame as DataFrameProto,
TableStyle as TableStyleProto,
)
LOGGER = get_logger(__name__)
CSSStyle = namedtuple("CSSStyle", ["property", "value"])
class LegacyDataFrameMixin:
def _legacy_dataframe(self, data=None, width=None, height=None):
"""Display a dataframe as an interactive table.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict,
or None
The data to display.
If 'data' is a pandas.Styler, it will be used to style its
underyling DataFrame. Streamlit supports custom cell
values and colors. (It does not support some of the more exotic
pandas styling features, like bar charts, hovering, and captions.)
Styler support is experimental!
width : int or None
Desired width of the UI element expressed in pixels. If None, a
default width based on the page width is used.
height : int or None
Desired height of the UI element expressed in pixels. If None, a
default height is used.
Examples
--------
>>> df = pd.DataFrame(
... np.random.randn(50, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> st._legacy_dataframe(df)
.. output::
https://static.streamlit.io/0.25.0-2JkNY/index.html?id=165mJbzWdAC8Duf8a4tjyQ
height: 330px
>>> st._legacy_dataframe(df, 200, 100)
You can also pass a Pandas Styler object to change the style of
the rendered DataFrame:
>>> df = pd.DataFrame(
... np.random.randn(10, 20),
... columns=('col %d' % i for i in range(20)))
...
>>> st._legacy_dataframe(df.style.highlight_max(axis=0))
.. output::
https://static.streamlit.io/0.29.0-dV1Y/index.html?id=Hb6UymSNuZDzojUNybzPby
height: 285px
"""
data_frame_proto = DataFrameProto()
marshall_data_frame(data, data_frame_proto)
return self.dg._enqueue(
"data_frame",
data_frame_proto,
element_width=width,
element_height=height,
)
def _legacy_table(self, data=None):
"""Display a static table.
This differs from `st._legacy_dataframe` in that the table in this case is
static: its entire contents are laid out directly on the page.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict,
or None
The table data.
Example
-------
>>> df = pd.DataFrame(
... np.random.randn(10, 5),
... columns=('col %d' % i for i in range(5)))
...
>>> st._legacy_table(df)
.. output::
https://static.streamlit.io/0.25.0-2JkNY/index.html?id=KfZvDMprL4JFKXbpjD3fpq
height: 480px
"""
table_proto = DataFrameProto()
marshall_data_frame(data, table_proto)
return self.dg._enqueue("table", table_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall_data_frame(data: Any, proto_df: DataFrameProto) -> None:
"""Convert a pandas.DataFrame into a proto.DataFrame.
Parameters
----------
data : pandas.DataFrame, numpy.ndarray, Iterable, dict, DataFrame, Styler, or None
Something that is or can be converted to a dataframe.
proto_df : proto.DataFrame
Output. The protobuf for a Streamlit DataFrame proto.
"""
if isinstance(data, pa.Table):
raise errors.StreamlitAPIException(
"""
pyarrow tables are not supported by Streamlit's legacy DataFrame serialization (i.e. with `config.dataFrameSerialization = "legacy"`).
To be able to use pyarrow tables, please enable pyarrow by changing the config setting,
`config.dataFrameSerialization = "arrow"`
"""
)
df = type_util.convert_anything_to_df(data)
# Convert df into an iterable of columns (each of type Series).
df_data = (df.iloc[:, col] for col in range(len(df.columns)))
_marshall_table(df_data, proto_df.data)
_marshall_index(df.columns, proto_df.columns)
_marshall_index(df.index, proto_df.index)
styler = data if type_util.is_pandas_styler(data) else None
_marshall_styles(proto_df.style, df, styler)
def _marshall_styles(
proto_table_style: TableStyleProto, df: DataFrame, styler: Optional[Styler] = None
) -> None:
"""Adds pandas.Styler styling data to a proto.DataFrame
Parameters
----------
proto_table_style : proto.TableStyle
df : pandas.DataFrame
styler : pandas.Styler holding styling data for the data frame, or
None if there's no style data to marshall
"""
# NB: we're using protected members of Styler to get this data,
# which is non-ideal and could break if Styler's interface changes.
if styler is not None:
styler._compute()
# In Pandas 1.3.0, styler._translate() signature was changed.
# 2 arguments were added: sparse_index and sparse_columns.
# The functionality that they provide is not yet supported.
if type_util.is_pandas_version_less_than("1.3.0"):
translated_style = styler._translate()
else:
translated_style = styler._translate(False, False)
css_styles = _get_css_styles(translated_style)
display_values = _get_custom_display_values(translated_style)
else:
# If we have no Styler, we just make an empty CellStyle for each cell
css_styles = {}
display_values = {}
nrows, ncols = df.shape
for col in range(ncols):
proto_col = proto_table_style.cols.add()
for row in range(nrows):
proto_cell_style = proto_col.styles.add()
for css in css_styles.get((row, col), []):
proto_css = proto_cell_style.css.add()
proto_css.property = css.property
proto_css.value = css.value
display_value = display_values.get((row, col), None)
if display_value is not None:
proto_cell_style.display_value = display_value
proto_cell_style.has_display_value = True
def _get_css_styles(translated_style: Dict[Any, Any]) -> Dict[Any, Any]:
"""Parses pandas.Styler style dictionary into a
{(row, col): [CSSStyle]} dictionary
"""
# In pandas < 1.1.0
# translated_style["cellstyle"] has the following shape:
# [
# {
# "props": [["color", " black"], ["background-color", "orange"], ["", ""]],
# "selector": "row0_col0"
# }
# ...
# ]
#
# In pandas >= 1.1.0
# translated_style["cellstyle"] has the following shape:
# [
# {
# "props": [("color", " black"), ("background-color", "orange"), ("", "")],
# "selectors": ["row0_col0"]
# }
# ...
# ]
cell_selector_regex = re.compile(r"row(\d+)_col(\d+)")
css_styles = {}
for cell_style in translated_style["cellstyle"]:
if type_util.is_pandas_version_less_than("1.1.0"):
cell_selectors = [cell_style["selector"]]
else:
cell_selectors = cell_style["selectors"]
for cell_selector in cell_selectors:
match = cell_selector_regex.match(cell_selector)
if not match:
raise RuntimeError(
f'Failed to parse cellstyle selector "{cell_selector}"'
)
row = int(match.group(1))
col = int(match.group(2))
css_declarations = []
props = cell_style["props"]
for prop in props:
if not isinstance(prop, (tuple, list)) or len(prop) != 2:
raise RuntimeError(f'Unexpected cellstyle props "{prop}"')
name = str(prop[0]).strip()
value = str(prop[1]).strip()
if name and value:
css_declarations.append(CSSStyle(property=name, value=value))
css_styles[(row, col)] = css_declarations
return css_styles
def _get_custom_display_values(translated_style: Dict[Any, Any]) -> Dict[Any, Any]:
"""Parses pandas.Styler style dictionary into a
{(row, col): display_value} dictionary for cells whose display format
has been customized.
"""
# Create {(row, col): display_value} from translated_style['body']
# translated_style['body'] has the shape:
# [
# [ // row
# { // cell or header
# 'id': 'level0_row0' (for row header) | 'row0_col0' (for cells)
# 'value': 1.329212
# 'display_value': '132.92%'
# ...
# }
# ]
# ]
def has_custom_display_value(cell: Dict[Any, Any]) -> bool:
# We'd prefer to only pass `display_value` data to the frontend
# when a DataFrame cell has been custom-formatted by the user, to
# save on bandwidth. However:
#
# Panda's Styler's internals are private, and it doesn't give us a
# consistent way of testing whether a cell has a custom display_value
# or not. Prior to Pandas 1.4, we could test whether a cell's
# `display_value` differed from its `value`, and only stick the
# `display_value` in the protobuf when that was the case. In 1.4, an
# unmodified Styler will contain `display_value` strings for all
# cells, regardless of whether any formatting has been applied to
# that cell, so we no longer have this ability.
#
# So we're only testing that a cell's `display_value` is not None.
# In Pandas 1.4, it seems that `display_value` is never None, so this
# is purely a defense against future Styler changes.
return cell.get("display_value") is not None
cell_selector_regex = re.compile(r"row(\d+)_col(\d+)")
header_selector_regex = re.compile(r"level(\d+)_row(\d+)")
display_values = {}
for row in translated_style["body"]:
# row is a List[Dict], containing format data for each cell in the row,
# plus an extra first entry for the row header, which we skip
found_row_header = False
for cell in row:
cell_id = cell["id"] # a string in the form 'row0_col0'
if header_selector_regex.match(cell_id):
if not found_row_header:
# We don't care about processing row headers, but as
# a sanity check, ensure we only see one per row
found_row_header = True
continue
else:
raise RuntimeError('Found unexpected row header "%s"' % cell)
match = cell_selector_regex.match(cell_id)
if not match:
raise RuntimeError('Failed to parse cell selector "%s"' % cell_id)
if has_custom_display_value(cell):
row = int(match.group(1))
col = int(match.group(2))
display_values[(row, col)] = str(cell["display_value"])
return display_values
def _marshall_index(pandas_index, proto_index):
"""Convert an pandas.Index into a proto.Index.
pandas_index - Panda.Index or related (input)
proto_index - proto.Index (output)
"""
import pandas as pd
import numpy as np
if type(pandas_index) == pd.Index:
_marshall_any_array(np.array(pandas_index), proto_index.plain_index.data)
elif type(pandas_index) == pd.RangeIndex:
min = pandas_index.min()
max = pandas_index.max()
if pd.isna(min) or pd.isna(max):
proto_index.range_index.start = 0
proto_index.range_index.stop = 0
else:
proto_index.range_index.start = min
proto_index.range_index.stop = max + 1
elif type(pandas_index) == pd.MultiIndex:
for level in pandas_index.levels:
_marshall_index(level, proto_index.multi_index.levels.add())
if hasattr(pandas_index, "codes"):
index_codes = pandas_index.codes
else:
# Deprecated in Pandas 0.24, do don't bother covering.
index_codes = pandas_index.labels # pragma: no cover
for label in index_codes:
proto_index.multi_index.labels.add().data.extend(label)
elif type(pandas_index) == pd.DatetimeIndex:
if pandas_index.tz is None:
current_zone = tzlocal.get_localzone()
pandas_index = pandas_index.tz_localize(current_zone)
proto_index.datetime_index.data.data.extend(
pandas_index.map(datetime.datetime.isoformat)
)
elif type(pandas_index) == pd.TimedeltaIndex:
proto_index.timedelta_index.data.data.extend(pandas_index.astype(np.int64))
elif type(pandas_index) == pd.Int64Index:
proto_index.int_64_index.data.data.extend(pandas_index)
elif type(pandas_index) == pd.Float64Index:
proto_index.float_64_index.data.data.extend(pandas_index)
else:
raise NotImplementedError("Can't handle %s yet." % type(pandas_index))
def _marshall_table(pandas_table, proto_table):
"""Convert a sequence of 1D arrays into proto.Table.
pandas_table - Sequence of 1D arrays which are AnyArray compatible (input).
proto_table - proto.Table (output)
"""
for pandas_array in pandas_table:
if len(pandas_array) == 0:
continue
_marshall_any_array(pandas_array, proto_table.cols.add())
def _marshall_any_array(pandas_array, proto_array):
"""Convert a 1D numpy.Array into a proto.AnyArray.
pandas_array - 1D arrays which is AnyArray compatible (input).
proto_array - proto.AnyArray (output)
"""
import numpy as np
# Convert to np.array as necessary.
if not hasattr(pandas_array, "dtype"):
pandas_array = np.array(pandas_array)
# Only works on 1D arrays.
if len(pandas_array.shape) != 1:
raise ValueError("Array must be 1D.")
# Perform type-conversion based on the array dtype.
if issubclass(pandas_array.dtype.type, np.floating):
proto_array.doubles.data.extend(pandas_array)
elif issubclass(pandas_array.dtype.type, np.timedelta64):
proto_array.timedeltas.data.extend(pandas_array.astype(np.int64))
elif issubclass(pandas_array.dtype.type, np.integer):
proto_array.int64s.data.extend(pandas_array)
elif pandas_array.dtype == np.bool_:
proto_array.int64s.data.extend(pandas_array)
elif pandas_array.dtype == np.object_:
proto_array.strings.data.extend(map(str, pandas_array))
# dtype='string', <class 'pandas.core.arrays.string_.StringDtype'>
# NOTE: StringDtype is considered experimental.
# The implementation and parts of the API may change without warning.
elif pandas_array.dtype.name == "string":
proto_array.strings.data.extend(map(str, pandas_array))
# Setting a timezone changes (dtype, dtype.type) from
# 'datetime64[ns]', <class 'numpy.datetime64'>
# to
# datetime64[ns, UTC], <class 'pandas._libs.tslibs.timestamps.Timestamp'>
elif pandas_array.dtype.name.startswith("datetime64"):
# Just convert straight to ISO 8601, preserving timezone
# awareness/unawareness. The frontend will render it correctly.
proto_array.datetimes.data.extend(pandas_array.map(datetime.datetime.isoformat))
else:
raise NotImplementedError("Dtype %s not understood." % pandas_array.dtype)

View File

@ -0,0 +1,199 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper around Vega-Lite."""
import json
from typing import cast
import streamlit
import streamlit.elements.legacy_data_frame as data_frame
import streamlit.elements.lib.dicttools as dicttools
from streamlit.logger import get_logger
from streamlit.proto.VegaLiteChart_pb2 import VegaLiteChart as VegaLiteChartProto
LOGGER = get_logger(__name__)
class LegacyVegaLiteMixin:
def _legacy_vega_lite_chart(
self,
data=None,
spec=None,
use_container_width=False,
**kwargs,
):
"""Display a chart using the Vega-Lite library.
Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict,
or None
Either the data to be plotted or a Vega-Lite spec containing the
data (which more closely follows the Vega-Lite API).
spec : dict or None
The Vega-Lite spec for the chart. If the spec was already passed in
the previous argument, this must be set to None. See
https://vega.github.io/vega-lite/docs/ for more info.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over Vega-Lite's native `width` value.
**kwargs : any
Same as spec, but as keywords.
Example
-------
>>> import pandas as pd
>>> import numpy as np
>>>
>>> df = pd.DataFrame(
... np.random.randn(200, 3),
... columns=['a', 'b', 'c'])
>>>
>>> st._legacy_vega_lite_chart(df, {
... 'mark': {'type': 'circle', 'tooltip': True},
... 'encoding': {
... 'x': {'field': 'a', 'type': 'quantitative'},
... 'y': {'field': 'b', 'type': 'quantitative'},
... 'size': {'field': 'c', 'type': 'quantitative'},
... 'color': {'field': 'c', 'type': 'quantitative'},
... },
... })
.. output::
https://static.streamlit.io/0.25.0-2JkNY/index.html?id=8jmmXR8iKoZGV4kXaKGYV5
height: 200px
Examples of Vega-Lite usage without Streamlit can be found at
https://vega.github.io/vega-lite/examples/. Most of those can be easily
translated to the syntax shown above.
"""
vega_lite_chart_proto = VegaLiteChartProto()
marshall(
vega_lite_chart_proto,
data,
spec,
use_container_width=use_container_width,
**kwargs,
)
return self.dg._enqueue("vega_lite_chart", vega_lite_chart_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall(proto, data=None, spec=None, use_container_width=False, **kwargs):
"""Construct a Vega-Lite chart object.
See DeltaGenerator._legacy_vega_lite_chart for docs.
"""
# Support passing data inside spec['datasets'] and spec['data'].
# (The data gets pulled out of the spec dict later on.)
if isinstance(data, dict) and spec is None:
spec = data
data = None
# Support passing no spec arg, but filling it with kwargs.
# Example:
# marshall(proto, baz='boz')
if spec is None:
spec = dict()
else:
# Clone the spec dict, since we may be mutating it.
spec = dict(spec)
# Support passing in kwargs. Example:
# marshall(proto, {foo: 'bar'}, baz='boz')
if len(kwargs):
# Merge spec with unflattened kwargs, where kwargs take precedence.
# This only works for string keys, but kwarg keys are strings anyways.
spec = dict(spec, **dicttools.unflatten(kwargs, _CHANNELS))
if len(spec) == 0:
raise ValueError("Vega-Lite charts require a non-empty spec dict.")
if "autosize" not in spec:
spec["autosize"] = {"type": "fit", "contains": "padding"}
# Pull data out of spec dict when it's in a 'dataset' key:
# marshall(proto, {datasets: {foo: df1, bar: df2}, ...})
if "datasets" in spec:
for k, v in spec["datasets"].items():
dataset = proto.datasets.add()
dataset.name = str(k)
dataset.has_name = True
data_frame.marshall_data_frame(v, dataset.data)
del spec["datasets"]
# Pull data out of spec dict when it's in a top-level 'data' key:
# marshall(proto, {data: df})
# marshall(proto, {data: {values: df, ...}})
# marshall(proto, {data: {url: 'url'}})
# marshall(proto, {data: {name: 'foo'}})
if "data" in spec:
data_spec = spec["data"]
if isinstance(data_spec, dict):
if "values" in data_spec:
data = data_spec["values"]
del spec["data"]
else:
data = data_spec
del spec["data"]
proto.spec = json.dumps(spec)
proto.use_container_width = use_container_width
if data is not None:
data_frame.marshall_data_frame(data, proto.data)
# See https://vega.github.io/vega-lite/docs/encoding.html
_CHANNELS = set(
[
"x",
"y",
"x2",
"y2",
"xError",
"yError2",
"xError",
"yError2",
"longitude",
"latitude",
"color",
"opacity",
"fillOpacity",
"strokeOpacity",
"strokeWidth",
"size",
"shape",
"text",
"tooltip",
"href",
"key",
"order",
"detail",
"facet",
"row",
"column",
]
)

View File

@ -0,0 +1,13 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,135 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for working with dicts."""
from typing import Any, Dict, Optional
def _unflatten_single_dict(flat_dict):
"""Convert a flat dict of key-value pairs to dict tree.
Example
-------
_unflatten_single_dict({
foo_bar_baz: 123,
foo_bar_biz: 456,
x_bonks: 'hi',
})
# Returns:
# {
# foo: {
# bar: {
# baz: 123,
# biz: 456,
# },
# },
# x: {
# bonks: 'hi'
# }
# }
Parameters
----------
flat_dict : dict
A one-level dict where keys are fully-qualified paths separated by
underscores.
Returns
-------
dict
A tree made of dicts inside of dicts.
"""
out = dict() # type: Dict[str, Any]
for pathstr, v in flat_dict.items():
path = pathstr.split("_")
prev_dict = None # type: Optional[Dict[str, Any]]
curr_dict = out
for k in path:
if k not in curr_dict:
curr_dict[k] = dict()
prev_dict = curr_dict
curr_dict = curr_dict[k]
if prev_dict is not None:
prev_dict[k] = v
return out
def unflatten(flat_dict, encodings=None):
"""Converts a flat dict of key-value pairs to a spec tree.
Example:
unflatten({
foo_bar_baz: 123,
foo_bar_biz: 456,
x_bonks: 'hi',
}, ['x'])
# Returns:
# {
# foo: {
# bar: {
# baz: 123,
# biz: 456,
# },
# },
# encoding: { # This gets added automatically
# x: {
# bonks: 'hi'
# }
# }
# }
Args:
-----
flat_dict: dict
A flat dict where keys are fully-qualified paths separated by
underscores.
encodings: set
Key names that should be automatically moved into the 'encoding' key.
Returns:
--------
A tree made of dicts inside of dicts.
"""
if encodings is None:
encodings = set()
out_dict = _unflatten_single_dict(flat_dict)
for k, v in list(out_dict.items()):
# Unflatten child dicts:
if isinstance(v, dict):
v = unflatten(v, encodings)
elif hasattr(v, "__iter__"):
for i, child in enumerate(v):
if isinstance(child, dict):
v[i] = unflatten(child, encodings)
# Move items into 'encoding' if needed:
if k in encodings:
if "encoding" not in out_dict:
out_dict["encoding"] = dict()
out_dict["encoding"][k] = v
out_dict.pop(k)
return out_dict

View File

@ -0,0 +1,210 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A wrapper for simple PyDeck scatter charts."""
import copy
import json
from typing import Any, Dict
from typing import cast
import pandas as pd
import streamlit
import streamlit.elements.deck_gl_json_chart as deck_gl_json_chart
from streamlit.errors import StreamlitAPIException
from streamlit.proto.DeckGlJsonChart_pb2 import DeckGlJsonChart as DeckGlJsonChartProto
class MapMixin:
def map(self, data=None, zoom=None, use_container_width=True):
"""Display a map with points on it.
This is a wrapper around st.pydeck_chart to quickly create scatterplot
charts on top of a map, with auto-centering and auto-zoom.
When using this command, we advise all users to use a personal Mapbox
token. This ensures the map tiles used in this chart are more
robust. You can do this with the mapbox.token config option.
To get a token for yourself, create an account at
https://mapbox.com. It's free! (for moderate usage levels). For more
info on how to set config options, see
https://docs.streamlit.io/library/advanced-features/configuration#set-configuration-options
Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict,
or None
The data to be plotted. Must have columns called 'lat', 'lon',
'latitude', or 'longitude'.
zoom : int
Zoom level as specified in
https://wiki.openstreetmap.org/wiki/Zoom_levels
Example
-------
>>> import streamlit as st
>>> import pandas as pd
>>> import numpy as np
>>>
>>> df = pd.DataFrame(
... np.random.randn(1000, 2) / [50, 50] + [37.76, -122.4],
... columns=['lat', 'lon'])
>>>
>>> st.map(df)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.map.py
height: 650px
"""
map_proto = DeckGlJsonChartProto()
map_proto.json = to_deckgl_json(data, zoom)
map_proto.use_container_width = use_container_width
return self.dg._enqueue("deck_gl_json_chart", map_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
# Map used as the basis for st.map.
_DEFAULT_MAP = dict(deck_gl_json_chart.EMPTY_MAP) # type: Dict[str, Any]
_DEFAULT_MAP["mapStyle"] = "mapbox://styles/mapbox/light-v10"
# Other default parameters for st.map.
_DEFAULT_COLOR = [200, 30, 0, 160]
_DEFAULT_ZOOM_LEVEL = 12
_ZOOM_LEVELS = [
360,
180,
90,
45,
22.5,
11.25,
5.625,
2.813,
1.406,
0.703,
0.352,
0.176,
0.088,
0.044,
0.022,
0.011,
0.005,
0.003,
0.001,
0.0005,
0.00025,
]
def _get_zoom_level(distance):
"""Get the zoom level for a given distance in degrees.
See https://wiki.openstreetmap.org/wiki/Zoom_levels for reference.
Parameters
----------
distance : float
How many degrees of longitude should fit in the map.
Returns
-------
int
The zoom level, from 0 to 20.
"""
# For small number of points the default zoom level will be used.
if distance < _ZOOM_LEVELS[-1]:
return _DEFAULT_ZOOM_LEVEL
for i in range(len(_ZOOM_LEVELS) - 1):
if _ZOOM_LEVELS[i + 1] < distance <= _ZOOM_LEVELS[i]:
return i
def to_deckgl_json(data, zoom):
if data is None or data.empty:
return json.dumps(_DEFAULT_MAP)
if "lat" in data:
lat = "lat"
elif "latitude" in data:
lat = "latitude"
else:
raise StreamlitAPIException(
'Map data must contain a column named "latitude" or "lat".'
)
if "lon" in data:
lon = "lon"
elif "longitude" in data:
lon = "longitude"
else:
raise StreamlitAPIException(
'Map data must contain a column called "longitude" or "lon".'
)
if data[lon].isnull().values.any() or data[lat].isnull().values.any():
raise StreamlitAPIException("Latitude and longitude data must be numeric.")
data = pd.DataFrame(data)
min_lat = data[lat].min()
max_lat = data[lat].max()
min_lon = data[lon].min()
max_lon = data[lon].max()
center_lat = (max_lat + min_lat) / 2.0
center_lon = (max_lon + min_lon) / 2.0
range_lon = abs(max_lon - min_lon)
range_lat = abs(max_lat - min_lat)
if zoom == None:
if range_lon > range_lat:
longitude_distance = range_lon
else:
longitude_distance = range_lat
zoom = _get_zoom_level(longitude_distance)
# "+1" because itertuples includes the row index.
lon_col_index = data.columns.get_loc(lon) + 1
lat_col_index = data.columns.get_loc(lat) + 1
final_data = []
for row in data.itertuples():
final_data.append(
{"lon": float(row[lon_col_index]), "lat": float(row[lat_col_index])}
)
default = copy.deepcopy(_DEFAULT_MAP)
default["initialViewState"]["latitude"] = center_lat
default["initialViewState"]["longitude"] = center_lon
default["initialViewState"]["zoom"] = zoom
default["layers"] = [
{
"@@type": "ScatterplotLayer",
"getPosition": "@@=[lon, lat]",
"getRadius": 10,
"radiusScale": 10,
"radiusMinPixels": 3,
"getFillColor": _DEFAULT_COLOR,
"data": final_data,
}
]
return json.dumps(default)

View File

@ -0,0 +1,265 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import streamlit
from streamlit import type_util
from streamlit.proto.Markdown_pb2 import Markdown as MarkdownProto
from .utils import clean_text
class MarkdownMixin:
def markdown(self, body, unsafe_allow_html=False):
"""Display string formatted as Markdown.
Parameters
----------
body : str
The string to display as Github-flavored Markdown. Syntax
information can be found at: https://github.github.com/gfm.
This also supports:
* Emoji shortcodes, such as `:+1:` and `:sunglasses:`.
For a list of all supported codes,
see https://share.streamlit.io/streamlit/emoji-shortcodes.
* LaTeX expressions, by wrapping them in "$" or "$$" (the "$$"
must be on their own lines). Supported LaTeX functions are listed
at https://katex.org/docs/supported.html.
unsafe_allow_html : bool
By default, any HTML tags found in the body will be escaped and
therefore treated as pure text. This behavior may be turned off by
setting this argument to True.
That said, we *strongly advise against it*. It is hard to write
secure HTML, so by using this argument you may be compromising your
users' security. For more information, see:
https://github.com/streamlit/streamlit/issues/152
*Also note that `unsafe_allow_html` is a temporary measure and may
be removed from Streamlit at any time.*
If you decide to turn on HTML anyway, we ask you to please tell us
your exact use case here:
https://discuss.streamlit.io/t/96
This will help us come up with safe APIs that allow you to do what
you want.
Example
-------
>>> st.markdown('Streamlit is **_really_ cool**.')
"""
markdown_proto = MarkdownProto()
markdown_proto.body = clean_text(body)
markdown_proto.allow_html = unsafe_allow_html
return self.dg._enqueue("markdown", markdown_proto)
def header(self, body, anchor=None):
"""Display text in header formatting.
Parameters
----------
body : str
The text to display.
anchor : str
The anchor name of the header that can be accessed with #anchor
in the URL. If omitted, it generates an anchor using the body.
Example
-------
>>> st.header('This is a header')
"""
header_proto = MarkdownProto()
if anchor is None:
header_proto.body = f"## {clean_text(body)}"
else:
header_proto.body = f'<h2 data-anchor="{anchor}">{clean_text(body)}</h2>'
header_proto.allow_html = True
return self.dg._enqueue("markdown", header_proto)
def subheader(self, body, anchor=None):
"""Display text in subheader formatting.
Parameters
----------
body : str
The text to display.
anchor : str
The anchor name of the header that can be accessed with #anchor
in the URL. If omitted, it generates an anchor using the body.
Example
-------
>>> st.subheader('This is a subheader')
"""
subheader_proto = MarkdownProto()
if anchor is None:
subheader_proto.body = f"### {clean_text(body)}"
else:
subheader_proto.body = f'<h3 data-anchor="{anchor}">{clean_text(body)}</h3>'
subheader_proto.allow_html = True
return self.dg._enqueue("markdown", subheader_proto)
def code(self, body, language="python"):
"""Display a code block with optional syntax highlighting.
(This is a convenience wrapper around `st.markdown()`)
Parameters
----------
body : str
The string to display as code.
language : str
The language that the code is written in, for syntax highlighting.
If omitted, the code will be unstyled.
Example
-------
>>> code = '''def hello():
... print("Hello, Streamlit!")'''
>>> st.code(code, language='python')
"""
code_proto = MarkdownProto()
markdown = "```%(language)s\n%(body)s\n```" % {
"language": language or "",
"body": body,
}
code_proto.body = clean_text(markdown)
return self.dg._enqueue("markdown", code_proto)
def title(self, body, anchor=None):
"""Display text in title formatting.
Each document should have a single `st.title()`, although this is not
enforced.
Parameters
----------
body : str
The text to display.
anchor : str
The anchor name of the header that can be accessed with #anchor
in the URL. If omitted, it generates an anchor using the body.
Example
-------
>>> st.title('This is a title')
"""
title_proto = MarkdownProto()
if anchor is None:
title_proto.body = f"# {clean_text(body)}"
else:
title_proto.body = f'<h1 data-anchor="{anchor}">{clean_text(body)}</h1>'
title_proto.allow_html = True
return self.dg._enqueue("markdown", title_proto)
def caption(self, body, unsafe_allow_html=False):
"""Display text in small font.
This should be used for captions, asides, footnotes, sidenotes, and
other explanatory text.
Parameters
----------
body : str
The text to display.
unsafe_allow_html : bool
By default, any HTML tags found in strings will be escaped and
therefore treated as pure text. This behavior may be turned off by
setting this argument to True.
That said, *we strongly advise against it*. It is hard to write secure
HTML, so by using this argument you may be compromising your users'
security. For more information, see:
https://github.com/streamlit/streamlit/issues/152
**Also note that `unsafe_allow_html` is a temporary measure and may be
removed from Streamlit at any time.**
If you decide to turn on HTML anyway, we ask you to please tell us your
exact use case here:
https://discuss.streamlit.io/t/96 .
This will help us come up with safe APIs that allow you to do what you
want.
Example
-------
>>> st.caption('This is a string that explains something above.')
"""
caption_proto = MarkdownProto()
caption_proto.body = clean_text(body)
caption_proto.allow_html = unsafe_allow_html
caption_proto.is_caption = True
return self.dg._enqueue("markdown", caption_proto)
def latex(self, body):
# This docstring needs to be "raw" because of the backslashes in the
# example below.
r"""Display mathematical expressions formatted as LaTeX.
Supported LaTeX functions are listed at
https://katex.org/docs/supported.html.
Parameters
----------
body : str or SymPy expression
The string or SymPy expression to display as LaTeX. If str, it's
a good idea to use raw Python strings since LaTeX uses backslashes
a lot.
Example
-------
>>> st.latex(r'''
... a + ar + a r^2 + a r^3 + \cdots + a r^{n-1} =
... \sum_{k=0}^{n-1} ar^k =
... a \left(\frac{1-r^{n}}{1-r}\right)
... ''')
"""
if type_util.is_sympy_expession(body):
import sympy
body = sympy.latex(body)
latex_proto = MarkdownProto()
latex_proto.body = "$$\n%s\n$$" % clean_text(body)
return self.dg._enqueue("markdown", latex_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,246 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import re
from typing import cast
from validators import url
import streamlit
from streamlit import type_util
from streamlit.in_memory_file_manager import in_memory_file_manager
from streamlit.proto.Audio_pb2 import Audio as AudioProto
from streamlit.proto.Video_pb2 import Video as VideoProto
class MediaMixin:
def audio(self, data, format="audio/wav", start_time=0):
"""Display an audio player.
Parameters
----------
data : str, bytes, BytesIO, numpy.ndarray, or file opened with
io.open().
Raw audio data, filename, or a URL pointing to the file to load.
Numpy arrays and raw data formats must include all necessary file
headers to match specified file format.
start_time: int
The time from which this element should start playing.
format : str
The mime type for the audio file. Defaults to 'audio/wav'.
See https://tools.ietf.org/html/rfc4281 for more info.
Example
-------
>>> audio_file = open('myaudio.ogg', 'rb')
>>> audio_bytes = audio_file.read()
>>>
>>> st.audio(audio_bytes, format='audio/ogg')
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.audio.py
height: 465px
"""
audio_proto = AudioProto()
coordinates = self.dg._get_delta_path_str()
marshall_audio(coordinates, audio_proto, data, format, start_time)
return self.dg._enqueue("audio", audio_proto)
def video(self, data, format="video/mp4", start_time=0):
"""Display a video player.
Parameters
----------
data : str, bytes, BytesIO, numpy.ndarray, or file opened with
io.open().
Raw video data, filename, or URL pointing to a video to load.
Includes support for YouTube URLs.
Numpy arrays and raw data formats must include all necessary file
headers to match specified file format.
format : str
The mime type for the video file. Defaults to 'video/mp4'.
See https://tools.ietf.org/html/rfc4281 for more info.
start_time: int
The time from which this element should start playing.
Example
-------
>>> video_file = open('myvideo.mp4', 'rb')
>>> video_bytes = video_file.read()
>>>
>>> st.video(video_bytes)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.video.py
height: 700px
.. note::
Some videos may not display if they are encoded using MP4V (which is an export option in OpenCV), as this codec is
not widely supported by browsers. Converting your video to H.264 will allow the video to be displayed in Streamlit.
See this `StackOverflow post <https://stackoverflow.com/a/49535220/2394542>`_ or this
`Streamlit forum post <https://discuss.streamlit.io/t/st-video-doesnt-show-opencv-generated-mp4/3193/2>`_
for more information.
"""
video_proto = VideoProto()
coordinates = self.dg._get_delta_path_str()
marshall_video(coordinates, video_proto, data, format, start_time)
return self.dg._enqueue("video", video_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
# Regular expression explained at https://regexr.com/4n2l2 Covers any youtube
# URL (incl. shortlinks and embed links) and extracts its code.
YOUTUBE_RE = re.compile(
# Protocol
r"http(?:s?):\/\/"
# Domain
r"(?:www\.)?youtu(?:be\.com|\.be)\/"
# Path and query string
r"(?P<watch>(watch\?v=)|embed\/)?(?P<code>[\w\-\_]*)(&(amp;)?[\w\?=]*)?"
)
def _reshape_youtube_url(url):
"""Return whether URL is any kind of YouTube embed or watch link. If so,
reshape URL into an embed link suitable for use in an iframe.
If not a YouTube URL, return None.
Parameters
----------
url : str or bytes
Example
-------
>>> print(_reshape_youtube_url('https://youtu.be/_T8LGqJtuGc'))
.. output::
https://www.youtube.com/embed/_T8LGqJtuGc
"""
match = YOUTUBE_RE.match(url)
if match:
return "https://www.youtube.com/embed/{code}".format(**match.groupdict())
return None
def _marshall_av_media(coordinates, proto, data, mimetype):
"""Fill audio or video proto based on contents of data.
Given a string, check if it's a url; if so, send it out without modification.
Otherwise assume strings are filenames and let any OS errors raise.
Load data either from file or through bytes-processing methods into a
InMemoryFile object. Pack proto with generated Tornado-based URL.
"""
# Audio and Video methods have already checked if this is a URL by this point.
if isinstance(data, str):
# Assume it's a filename or blank. Allow OS-based file errors.
with open(data, "rb") as fh:
this_file = in_memory_file_manager.add(fh.read(), mimetype, coordinates)
proto.url = this_file.url
return
if data is None:
# Allow empty values so media players can be shown without media.
return
# Assume bytes; try methods until we run out.
if isinstance(data, bytes):
pass
elif isinstance(data, io.BytesIO):
data.seek(0)
data = data.getvalue()
elif isinstance(data, io.RawIOBase) or isinstance(data, io.BufferedReader):
data.seek(0)
data = data.read()
elif type_util.is_type(data, "numpy.ndarray"):
data = data.tobytes()
else:
raise RuntimeError("Invalid binary data format: %s" % type(data))
this_file = in_memory_file_manager.add(data, mimetype, coordinates)
proto.url = this_file.url
def marshall_video(coordinates, proto, data, mimetype="video/mp4", start_time=0):
"""Marshalls a video proto, using url processors as needed.
Parameters
----------
coordinates : str
proto : the proto to fill. Must have a string field called "data".
data : str, bytes, BytesIO, numpy.ndarray, or file opened with
io.open().
Raw video data or a string with a URL pointing to the video
to load. Includes support for YouTube URLs.
If passing the raw data, this must include headers and any other
bytes required in the actual file.
mimetype : str
The mime type for the video file. Defaults to 'video/mp4'.
See https://tools.ietf.org/html/rfc4281 for more info.
start_time : int
The time from which this element should start playing. (default: 0)
"""
proto.start_time = start_time
# "type" distinguishes between YouTube and non-YouTube links
proto.type = VideoProto.Type.NATIVE
if isinstance(data, str) and url(data):
youtube_url = _reshape_youtube_url(data)
if youtube_url:
proto.url = youtube_url
proto.type = VideoProto.Type.YOUTUBE_IFRAME
else:
proto.url = data
else:
_marshall_av_media(coordinates, proto, data, mimetype)
def marshall_audio(coordinates, proto, data, mimetype="audio/wav", start_time=0):
"""Marshalls an audio proto, using data and url processors as needed.
Parameters
----------
coordinates : str
proto : The proto to fill. Must have a string field called "url".
data : str, bytes, BytesIO, numpy.ndarray, or file opened with
io.open()
Raw audio data or a string with a URL pointing to the file to load.
If passing the raw data, this must include headers and any other bytes
required in the actual file.
mimetype : str
The mime type for the audio file. Defaults to "audio/wav".
See https://tools.ietf.org/html/rfc4281 for more info.
start_time : int
The time from which this element should start playing. (default: 0)
"""
proto.start_time = start_time
if isinstance(data, str) and url(data):
proto.url = data
else:
_marshall_av_media(coordinates, proto, data, mimetype)

View File

@ -0,0 +1,184 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from textwrap import dedent
from typing import Optional, cast
import attr
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Metric_pb2 import Metric as MetricProto
from .utils import clean_text
@attr.s(auto_attribs=True, slots=True)
class MetricColorAndDirection:
color: Optional[int]
direction: Optional[int]
class MetricMixin:
def metric(self, label, value, delta=None, delta_color="normal"):
"""Display a metric in big bold font, with an optional indicator of how the metric changed.
Tip: If you want to display a large number, it may be a good idea to
shorten it using packages like `millify <https://github.com/azaitsev/millify>`_
or `numerize <https://github.com/davidsa03/numerize>`_. E.g. ``1234`` can be
displayed as ``1.2k`` using ``st.metric("Short number", millify(1234))``.
Parameters
----------
label : str
The header or Title for the metric
value : int, float, str, or None
Value of the metric. None is rendered as a long dash.
delta : int, float, str, or None
Indicator of how the metric changed, rendered with an arrow below
the metric. If delta is negative (int/float) or starts with a minus
sign (str), the arrow points down and the text is red; else the
arrow points up and the text is green. If None (default), no delta
indicator is shown.
delta_color : str
If "normal" (default), the delta indicator is shown as described
above. If "inverse", it is red when positive and green when
negative. This is useful when a negative change is considered
good, e.g. if cost decreased. If "off", delta is shown in gray
regardless of its value.
Example
-------
>>> st.metric(label="Temperature", value="70 °F", delta="1.2 °F")
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/metric.example1.py
height: 210px
``st.metric`` looks especially nice in combination with ``st.columns``:
>>> col1, col2, col3 = st.columns(3)
>>> col1.metric("Temperature", "70 °F", "1.2 °F")
>>> col2.metric("Wind", "9 mph", "-8%")
>>> col3.metric("Humidity", "86%", "4%")
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/metric.example2.py
height: 210px
The delta indicator color can also be inverted or turned off:
>>> st.metric(label="Gas price", value=4, delta=-0.5,
... delta_color="inverse")
>>>
>>> st.metric(label="Active developers", value=123, delta=123,
... delta_color="off")
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/metric.example3.py
height: 320px
"""
metric_proto = MetricProto()
metric_proto.body = self.parse_value(value)
metric_proto.label = self.parse_label(label)
metric_proto.delta = self.parse_delta(delta)
color_and_direction = self.determine_delta_color_and_direction(
clean_text(delta_color), delta
)
metric_proto.color = color_and_direction.color
metric_proto.direction = color_and_direction.direction
return str(self.dg._enqueue("metric", metric_proto))
def parse_label(self, label):
if not isinstance(label, str):
raise TypeError(
f"'{str(label)}' is of type {str(type(label))}, which is not an accepted type."
" label only accepts: str. Please convert the label to an accepted type."
)
return label
def parse_value(self, value):
if value is None:
return ""
if isinstance(value, float) or isinstance(value, int) or isinstance(value, str):
return str(value)
elif hasattr(value, "item"):
# Add support for numpy values (e.g. int16, float64, etc.)
try:
# Item could also be just a variable, so we use try, except
if isinstance(value.item(), float) or isinstance(value.item(), int):
return str(value.item())
except Exception:
pass
raise TypeError(
f"'{str(value)}' is of type {str(type(value))}, which is not an accepted type."
" value only accepts: int, float, str, or None."
" Please convert the value to an accepted type."
)
def parse_delta(self, delta):
if delta is None or delta == "":
return ""
if isinstance(delta, str):
return dedent(delta)
elif isinstance(delta, int) or isinstance(delta, float):
return str(delta)
else:
raise TypeError(
f"'{str(delta)}' is of type {str(type(delta))}, which is not an accepted type."
" delta only accepts: int, float, str, or None."
" Please convert the value to an accepted type."
)
def determine_delta_color_and_direction(self, delta_color, delta):
cd = MetricColorAndDirection(color=None, direction=None)
if delta is None or delta == "":
cd.color = MetricProto.MetricColor.GRAY
cd.direction = MetricProto.MetricDirection.NONE
return cd
if self.is_negative(delta):
if delta_color == "normal":
cd.color = MetricProto.MetricColor.RED
elif delta_color == "inverse":
cd.color = MetricProto.MetricColor.GREEN
elif delta_color == "off":
cd.color = MetricProto.MetricColor.GRAY
cd.direction = MetricProto.MetricDirection.DOWN
else:
if delta_color == "normal":
cd.color = MetricProto.MetricColor.GREEN
elif delta_color == "inverse":
cd.color = MetricProto.MetricColor.RED
elif delta_color == "off":
cd.color = MetricProto.MetricColor.GRAY
cd.direction = MetricProto.MetricDirection.UP
if cd.color is None or cd.direction is None:
raise StreamlitAPIException(
f"'{str(delta_color)}' is not an accepted value. delta_color only accepts: "
"'normal', 'inverse', or 'off'"
)
return cd
def is_negative(self, delta):
return dedent(str(delta)).startswith("-")
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,218 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from textwrap import dedent
from typing import Any, Callable, Optional, cast, List
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.MultiSelect_pb2 import MultiSelect as MultiSelectProto
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.type_util import Key, OptionSequence, ensure_indexable, is_type, to_key
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
class MultiSelectMixin:
def multiselect(
self,
label: str,
options: OptionSequence,
default: Optional[Any] = None,
format_func: Callable[[Any], Any] = str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> List[Any]:
"""Display a multiselect widget.
The multiselect widget starts as empty.
Parameters
----------
label : str
A short label explaining to the user what this select widget is for.
options : Sequence[V], numpy.ndarray, pandas.Series, pandas.DataFrame, or pandas.Index
Labels for the select options. This will be cast to str internally
by default. For pandas.DataFrame, the first column is selected.
default: [V], V, or None
List of default values. Can also be a single value.
format_func : function
Function to modify the display of selectbox options. It receives
the raw option as an argument and should output the label to be
shown for that option. This has no impact on the return value of
the multiselect.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the multiselect.
on_change : callable
An optional callback invoked when this multiselect's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the multiselect widget if set
to True. The default is False. This argument can only be supplied
by keyword.
Returns
-------
list
A list with the selected options
Example
-------
>>> options = st.multiselect(
... 'What are your favorite colors',
... ['Green', 'Yellow', 'Red', 'Blue'],
... ['Yellow', 'Red'])
>>>
>>> st.write('You selected:', options)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.multiselect.py
height: 420px
.. note::
User experience can be degraded for large lists of `options` (100+), as this widget
is not designed to handle arbitrary text search efficiently. See this
`thread <https://discuss.streamlit.io/t/streamlit-loading-column-data-takes-too-much-time/1791>`_
on the Streamlit community forum for more information and
`GitHub issue #1059 <https://github.com/streamlit/streamlit/issues/1059>`_ for updates on the issue.
"""
ctx = get_script_run_ctx()
return self._multiselect(
label=label,
options=options,
default=default,
format_func=format_func,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _multiselect(
self,
label: str,
options: OptionSequence,
default: Optional[Any] = None,
format_func: Callable[[Any], Any] = str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> List[Any]:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=default, key=key)
opt = ensure_indexable(options)
# Perform validation checks and return indices base on the default values.
def _check_and_convert_to_indices(opt, default_values):
if default_values is None and None not in opt:
return None
if not isinstance(default_values, list):
# This if is done before others because calling if not x (done
# right below) when x is of type pd.Series() or np.array() throws a
# ValueError exception.
if is_type(default_values, "numpy.ndarray") or is_type(
default_values, "pandas.core.series.Series"
):
default_values = list(default_values)
elif not default_values or default_values in opt:
default_values = [default_values]
else:
default_values = list(default_values)
for value in default_values:
if value not in opt:
raise StreamlitAPIException(
"Every Multiselect default value must exist in options"
)
return [opt.index(value) for value in default_values]
indices = _check_and_convert_to_indices(opt, default)
multiselect_proto = MultiSelectProto()
multiselect_proto.label = label
default_value = [] if indices is None else indices
multiselect_proto.default[:] = default_value
multiselect_proto.options[:] = [str(format_func(option)) for option in opt]
multiselect_proto.form_id = current_form_id(self.dg)
if help is not None:
multiselect_proto.help = dedent(help)
def deserialize_multiselect(
ui_value: Optional[List[int]], widget_id: str = ""
) -> List[str]:
current_value = ui_value if ui_value is not None else default_value
return [opt[i] for i in current_value]
def serialize_multiselect(value):
return _check_and_convert_to_indices(opt, value)
current_value, set_frontend_value = register_widget(
"multiselect",
multiselect_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_multiselect,
serializer=serialize_multiselect,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
multiselect_proto.disabled = disabled
if set_frontend_value:
multiselect_proto.value[:] = _check_and_convert_to_indices(
opt, current_value
)
multiselect_proto.set_value = True
self.dg._enqueue("multiselect", multiselect_proto)
return cast(List[str], current_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,306 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.type_util import Key, to_key
from textwrap import dedent
from typing import Optional, Union, cast
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.js_number import JSNumber, JSNumberBoundsException
from streamlit.proto.NumberInput_pb2 import NumberInput as NumberInputProto
from streamlit.state import (
register_widget,
NoValue,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
Number = Union[int, float]
class NumberInputMixin:
def number_input(
self,
label: str,
min_value: Optional[Number] = None,
max_value: Optional[Number] = None,
value: Union[NoValue, Number, None] = NoValue(),
step: Optional[Number] = None,
format: Optional[str] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> Number:
"""Display a numeric input widget.
Parameters
----------
label : str
A short label explaining to the user what this input is for.
min_value : int or float or None
The minimum permitted value.
If None, there will be no minimum.
max_value : int or float or None
The maximum permitted value.
If None, there will be no maximum.
value : int or float or None
The value of this widget when it first renders.
Defaults to min_value, or 0.0 if min_value is None
step : int or float or None
The stepping interval.
Defaults to 1 if the value is an int, 0.01 otherwise.
If the value is not specified, the format parameter will be used.
format : str or None
A printf-style format string controlling how the interface should
display numbers. Output must be purely numeric. This does not impact
the return value. Valid formatters: %d %e %f %g %i %u
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the input.
on_change : callable
An optional callback invoked when this number_input's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the number input if set to
True. The default is False. This argument can only be supplied by
keyword.
Returns
-------
int or float
The current value of the numeric input widget. The return type
will match the data type of the value parameter.
Example
-------
>>> number = st.number_input('Insert a number')
>>> st.write('The current number is ', number)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.number_input.py
height: 260px
"""
ctx = get_script_run_ctx()
return self._number_input(
label=label,
min_value=min_value,
max_value=max_value,
value=value,
step=step,
format=format,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _number_input(
self,
label: str,
min_value: Optional[Number] = None,
max_value: Optional[Number] = None,
value: Union[NoValue, Number, None] = NoValue(),
step: Optional[Number] = None,
format: Optional[str] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> Number:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(
default_value=None if isinstance(value, NoValue) else value, key=key
)
# Ensure that all arguments are of the same type.
number_input_args = [min_value, max_value, value, step]
int_args = all(
isinstance(a, (numbers.Integral, type(None), NoValue))
for a in number_input_args
)
float_args = all(
isinstance(a, (float, type(None), NoValue)) for a in number_input_args
)
if not int_args and not float_args:
raise StreamlitAPIException(
"All numerical arguments must be of the same type."
f"\n`value` has {type(value).__name__} type."
f"\n`min_value` has {type(min_value).__name__} type."
f"\n`max_value` has {type(max_value).__name__} type."
f"\n`step` has {type(step).__name__} type."
)
if isinstance(value, NoValue):
if min_value is not None:
value = min_value
elif int_args and float_args:
value = 0.0 # if no values are provided, defaults to float
elif int_args:
value = 0
else:
value = 0.0
int_value = isinstance(value, numbers.Integral)
float_value = isinstance(value, float)
if value is None:
raise StreamlitAPIException(
"Default value for number_input should be an int or a float."
)
else:
if format is None:
format = "%d" if int_value else "%0.2f"
# Warn user if they format an int type as a float or vice versa.
if format in ["%d", "%u", "%i"] and float_value:
import streamlit as st
st.warning(
"Warning: NumberInput value below has type float,"
f" but format {format} displays as integer."
)
elif format[-1] == "f" and int_value:
import streamlit as st
st.warning(
"Warning: NumberInput value below has type int so is"
f" displayed as int despite format string {format}."
)
if step is None:
step = 1 if int_value else 0.01
try:
float(format % 2)
except (TypeError, ValueError):
raise StreamlitAPIException(
"Format string for st.number_input contains invalid characters: %s"
% format
)
# Ensure that the value matches arguments' types.
all_ints = int_value and int_args
if (min_value and min_value > value) or (max_value and max_value < value):
raise StreamlitAPIException(
"The default `value` of %(value)s "
"must lie between the `min_value` of %(min)s "
"and the `max_value` of %(max)s, inclusively."
% {"value": value, "min": min_value, "max": max_value}
)
# Bounds checks. JSNumber produces human-readable exceptions that
# we simply re-package as StreamlitAPIExceptions.
try:
if all_ints:
if min_value is not None:
JSNumber.validate_int_bounds(min_value, "`min_value`") # type: ignore
if max_value is not None:
JSNumber.validate_int_bounds(max_value, "`max_value`") # type: ignore
if step is not None:
JSNumber.validate_int_bounds(step, "`step`") # type: ignore
JSNumber.validate_int_bounds(value, "`value`") # type: ignore
else:
if min_value is not None:
JSNumber.validate_float_bounds(min_value, "`min_value`")
if max_value is not None:
JSNumber.validate_float_bounds(max_value, "`max_value`")
if step is not None:
JSNumber.validate_float_bounds(step, "`step`")
JSNumber.validate_float_bounds(value, "`value`")
except JSNumberBoundsException as e:
raise StreamlitAPIException(str(e))
number_input_proto = NumberInputProto()
number_input_proto.data_type = (
NumberInputProto.INT if all_ints else NumberInputProto.FLOAT
)
number_input_proto.label = label
number_input_proto.default = value
number_input_proto.form_id = current_form_id(self.dg)
if help is not None:
number_input_proto.help = dedent(help)
if min_value is not None:
number_input_proto.min = min_value
number_input_proto.has_min = True
if max_value is not None:
number_input_proto.max = max_value
number_input_proto.has_max = True
if step is not None:
number_input_proto.step = step
if format is not None:
number_input_proto.format = format
def deserialize_number_input(ui_value, widget_id=""):
return ui_value if ui_value is not None else value
current_value, set_frontend_value = register_widget(
"number_input",
number_input_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_number_input,
serializer=lambda x: x,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
number_input_proto.disabled = disabled
if set_frontend_value:
number_input_proto.value = current_value
number_input_proto.set_value = True
self.dg._enqueue("number_input", number_input_proto)
return cast(Number, current_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,193 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streamlit support for Plotly charts."""
import json
import urllib.parse
from typing import cast
import streamlit
from streamlit.legacy_caching import caching
from streamlit import type_util
from streamlit.logger import get_logger
from streamlit.proto.PlotlyChart_pb2 import PlotlyChart as PlotlyChartProto
LOGGER = get_logger(__name__)
SHARING_MODES = {
# This means the plot will be sent to the Streamlit app rather than to
# Plotly.
"streamlit",
# The three modes below are for plots that should be hosted in Plotly.
# These are the names Plotly uses for them.
"private",
"public",
"secret",
}
class PlotlyMixin:
def plotly_chart(
self,
figure_or_data,
use_container_width=False,
sharing="streamlit",
**kwargs,
):
"""Display an interactive Plotly chart.
Plotly is a charting library for Python. The arguments to this function
closely follow the ones for Plotly's `plot()` function. You can find
more about Plotly at https://plot.ly/python.
To show Plotly charts in Streamlit, call `st.plotly_chart` wherever you
would call Plotly's `py.plot` or `py.iplot`.
Parameters
----------
figure_or_data : plotly.graph_objs.Figure, plotly.graph_objs.Data,
dict/list of plotly.graph_objs.Figure/Data
See https://plot.ly/python/ for examples of graph descriptions.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the figure's native `width` value.
sharing : {'streamlit', 'private', 'secret', 'public'}
Use 'streamlit' to insert the plot and all its dependencies
directly in the Streamlit app using plotly's offline mode (default).
Use any other sharing mode to send the chart to Plotly chart studio, which
requires an account. See https://plotly.com/chart-studio/ for more information.
**kwargs
Any argument accepted by Plotly's `plot()` function.
Example
-------
The example below comes straight from the examples at
https://plot.ly/python:
>>> import streamlit as st
>>> import plotly.figure_factory as ff
>>> import numpy as np
>>>
>>> # Add histogram data
>>> x1 = np.random.randn(200) - 2
>>> x2 = np.random.randn(200)
>>> x3 = np.random.randn(200) + 2
>>>
>>> # Group data together
>>> hist_data = [x1, x2, x3]
>>>
>>> group_labels = ['Group 1', 'Group 2', 'Group 3']
>>>
>>> # Create distplot with custom bin_size
>>> fig = ff.create_distplot(
... hist_data, group_labels, bin_size=[.1, .25, .5])
>>>
>>> # Plot!
>>> st.plotly_chart(fig, use_container_width=True)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.plotly_chart.py
height: 400px
"""
# NOTE: "figure_or_data" is the name used in Plotly's .plot() method
# for their main parameter. I don't like the name, but it's best to
# keep it in sync with what Plotly calls it.
plotly_chart_proto = PlotlyChartProto()
marshall(
plotly_chart_proto, figure_or_data, use_container_width, sharing, **kwargs
)
return self.dg._enqueue("plotly_chart", plotly_chart_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall(proto, figure_or_data, use_container_width, sharing, **kwargs):
"""Marshall a proto with a Plotly spec.
See DeltaGenerator.plotly_chart for docs.
"""
# NOTE: "figure_or_data" is the name used in Plotly's .plot() method
# for their main parameter. I don't like the name, but its best to keep
# it in sync with what Plotly calls it.
import plotly.tools
if type_util.is_type(figure_or_data, "matplotlib.figure.Figure"):
figure = plotly.tools.mpl_to_plotly(figure_or_data)
else:
figure = plotly.tools.return_figure_from_figure_or_data(
figure_or_data, validate_figure=True
)
if not isinstance(sharing, str) or sharing.lower() not in SHARING_MODES:
raise ValueError("Invalid sharing mode for Plotly chart: %s" % sharing)
proto.use_container_width = use_container_width
if sharing == "streamlit":
import plotly.utils
config = dict(kwargs.get("config", {}))
# Copy over some kwargs to config dict. Plotly does the same in plot().
config.setdefault("showLink", kwargs.get("show_link", False))
config.setdefault("linkText", kwargs.get("link_text", False))
proto.figure.spec = json.dumps(figure, cls=plotly.utils.PlotlyJSONEncoder)
proto.figure.config = json.dumps(config)
else:
url = _plot_to_url_or_load_cached_url(
figure, sharing=sharing, auto_open=False, **kwargs
)
proto.url = _get_embed_url(url)
@caching.cache
def _plot_to_url_or_load_cached_url(*args, **kwargs):
"""Call plotly.plot wrapped in st.cache.
This is so we don't unecessarily upload data to Plotly's SASS if nothing
changed since the previous upload.
"""
try:
# Plotly 4 changed its main package.
import chart_studio.plotly as ply
except ImportError:
import plotly.plotly as ply
return ply.plot(*args, **kwargs)
def _get_embed_url(url):
parsed_url = urllib.parse.urlparse(url)
# Plotly's embed URL is the normal URL plus ".embed".
# (Note that our use namedtuple._replace is fine because that's not a
# private method! It just has an underscore to avoid clashing with the
# tuple field names)
parsed_embed_url = parsed_url._replace(path=parsed_url.path + ".embed")
return urllib.parse.urlunparse(parsed_embed_url)

View File

@ -0,0 +1,74 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Progress_pb2 import Progress as ProgressProto
class ProgressMixin:
def progress(self, value):
"""Display a progress bar.
Parameters
----------
value : int or float
0 <= value <= 100 for int
0.0 <= value <= 1.0 for float
Example
-------
Here is an example of a progress bar increasing over time:
>>> import time
>>>
>>> my_bar = st.progress(0)
>>>
>>> for percent_complete in range(100):
... time.sleep(0.1)
... my_bar.progress(percent_complete + 1)
"""
# TODO: standardize numerical type checking across st.* functions.
progress_proto = ProgressProto()
if isinstance(value, float):
if 0.0 <= value <= 1.0:
progress_proto.value = int(value * 100)
else:
raise StreamlitAPIException(
"Progress Value has invalid value [0.0, 1.0]: %f" % value
)
elif isinstance(value, int):
if 0 <= value <= 100:
progress_proto.value = value
else:
raise StreamlitAPIException(
"Progress Value has invalid value [0, 100]: %d" % value
)
else:
raise StreamlitAPIException(
"Progress Value has invalid type: %s" % type(value).__name__
)
return self.dg._enqueue("progress", progress_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,170 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streamlit support for Matplotlib PyPlot charts."""
import io
from typing import cast
import streamlit
import streamlit.elements.image as image_utils
from streamlit import config
from streamlit.errors import StreamlitDeprecationWarning
from streamlit.logger import get_logger
from streamlit.proto.Image_pb2 import ImageList as ImageListProto
LOGGER = get_logger(__name__)
class PyplotMixin:
def pyplot(self, fig=None, clear_figure=None, **kwargs):
"""Display a matplotlib.pyplot figure.
Parameters
----------
fig : Matplotlib Figure
The figure to plot. When this argument isn't specified, this
function will render the global figure (but this is deprecated,
as described below)
clear_figure : bool
If True, the figure will be cleared after being rendered.
If False, the figure will not be cleared after being rendered.
If left unspecified, we pick a default based on the value of `fig`.
* If `fig` is set, defaults to `False`.
* If `fig` is not set, defaults to `True`. This simulates Jupyter's
approach to matplotlib rendering.
**kwargs : any
Arguments to pass to Matplotlib's savefig function.
Example
-------
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>>
>>> arr = np.random.normal(1, 1, size=100)
>>> fig, ax = plt.subplots()
>>> ax.hist(arr, bins=20)
>>>
>>> st.pyplot(fig)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.pyplot.py
height: 630px
Notes
-----
.. note::
Deprecation warning. After December 1st, 2020, we will remove the ability
to specify no arguments in `st.pyplot()`, as that requires the use of
Matplotlib's global figure object, which is not thread-safe. So
please always pass a figure object as shown in the example section
above.
Matplotlib support several different types of "backends". If you're
getting an error using Matplotlib with Streamlit, try setting your
backend to "TkAgg"::
echo "backend: TkAgg" >> ~/.matplotlib/matplotlibrc
For more information, see https://matplotlib.org/faq/usage_faq.html.
"""
if not fig and config.get_option("deprecation.showPyplotGlobalUse"):
self.dg.exception(PyplotGlobalUseWarning())
image_list_proto = ImageListProto()
marshall(
self.dg._get_delta_path_str(), image_list_proto, fig, clear_figure, **kwargs
)
return self.dg._enqueue("imgs", image_list_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
def marshall(coordinates, image_list_proto, fig=None, clear_figure=True, **kwargs):
try:
import matplotlib
import matplotlib.pyplot as plt
plt.ioff()
except ImportError:
raise ImportError("pyplot() command requires matplotlib")
# You can call .savefig() on a Figure object or directly on the pyplot
# module, in which case you're doing it to the latest Figure.
if not fig:
if clear_figure is None:
clear_figure = True
fig = plt
# Normally, dpi is set to 'figure', and the figure's dpi is set to 100.
# So here we pick double of that to make things look good in a high
# DPI display.
options = {"bbox_inches": "tight", "dpi": 200, "format": "png"}
# If some of the options are passed in from kwargs then replace
# the values in options with the ones from kwargs
options = {a: kwargs.get(a, b) for a, b in options.items()}
# Merge options back into kwargs.
kwargs.update(options)
image = io.BytesIO()
fig.savefig(image, **kwargs)
image_utils.marshall_images(
coordinates,
image,
None,
-2,
image_list_proto,
False,
channels="RGB",
output_format="PNG",
)
# Clear the figure after rendering it. This means that subsequent
# plt calls will be starting fresh.
if clear_figure:
fig.clf()
class PyplotGlobalUseWarning(StreamlitDeprecationWarning):
def __init__(self):
super(PyplotGlobalUseWarning, self).__init__(
msg=self._get_message(), config_option="deprecation.showPyplotGlobalUse"
)
def _get_message(self):
return """
You are calling `st.pyplot()` without any arguments. After December 1st, 2020,
we will remove the ability to do this as it requires the use of Matplotlib's global
figure object, which is not thread-safe.
To future-proof this code, you should pass in a figure as shown below:
```python
>>> fig, ax = plt.subplots()
>>> ax.scatter([1, 2, 3], [1, 2, 3])
>>> ... other plotting actions ...
>>> st.pyplot(fig)
```
"""

View File

@ -0,0 +1,193 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from textwrap import dedent
from typing import Any, Callable, Optional, cast
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Radio_pb2 import Radio as RadioProto
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from streamlit.type_util import Key, OptionSequence, ensure_indexable, to_key
from streamlit.util import index_
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
class RadioMixin:
def radio(
self,
label: str,
options: OptionSequence,
index: int = 0,
format_func: Callable[[Any], Any] = str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only args:
disabled: bool = False,
) -> Any:
"""Display a radio button widget.
Parameters
----------
label : str
A short label explaining to the user what this radio group is for.
options : Sequence, numpy.ndarray, pandas.Series, pandas.DataFrame, or pandas.Index
Labels for the radio options. This will be cast to str internally
by default. For pandas.DataFrame, the first column is selected.
index : int
The index of the preselected option on first render.
format_func : function
Function to modify the display of radio options. It receives
the raw option as an argument and should output the label to be
shown for that option. This has no impact on the return value of
the radio.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the radio.
on_change : callable
An optional callback invoked when this radio's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the radio button if set to
True. The default is False. This argument can only be supplied by
keyword.
Returns
-------
any
The selected option.
Example
-------
>>> genre = st.radio(
... "What\'s your favorite movie genre",
... ('Comedy', 'Drama', 'Documentary'))
>>>
>>> if genre == 'Comedy':
... st.write('You selected comedy.')
... else:
... st.write("You didn\'t select comedy.")
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.radio.py
height: 260px
"""
ctx = get_script_run_ctx()
return self._radio(
label=label,
options=options,
index=index,
format_func=format_func,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _radio(
self,
label: str,
options: OptionSequence,
index: int = 0,
format_func: Callable[[Any], Any] = str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only args:
disabled: bool = False,
ctx: Optional[ScriptRunContext],
) -> Any:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=None if index == 0 else index, key=key)
opt = ensure_indexable(options)
if not isinstance(index, int):
raise StreamlitAPIException(
"Radio Value has invalid type: %s" % type(index).__name__
)
if len(opt) > 0 and not 0 <= index < len(opt):
raise StreamlitAPIException(
"Radio index must be between 0 and length of options"
)
radio_proto = RadioProto()
radio_proto.label = label
radio_proto.default = index
radio_proto.options[:] = [str(format_func(option)) for option in opt]
radio_proto.form_id = current_form_id(self.dg)
if help is not None:
radio_proto.help = dedent(help)
def deserialize_radio(ui_value, widget_id=""):
idx = ui_value if ui_value is not None else index
return opt[idx] if len(opt) > 0 and opt[idx] is not None else None
def serialize_radio(v):
if len(options) == 0:
return 0
return index_(options, v)
current_value, set_frontend_value = register_widget(
"radio",
radio_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_radio,
serializer=serialize_radio,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
radio_proto.disabled = disabled
if set_frontend_value:
radio_proto.value = serialize_radio(current_value)
radio_proto.set_value = True
self.dg._enqueue("radio", radio_proto)
return cast(str, current_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,234 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from textwrap import dedent
from typing import Any, Callable, Optional, cast
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Slider_pb2 import Slider as SliderProto
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from streamlit.type_util import Key, OptionSequence, ensure_indexable, to_key
from streamlit.util import index_
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
class SelectSliderMixin:
def select_slider(
self,
label: str,
options: OptionSequence = [],
value: Any = None,
format_func: Callable[[Any], Any] = str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> Any:
"""
Display a slider widget to select items from a list.
This also allows you to render a range slider by passing a two-element
tuple or list as the `value`.
The difference between `st.select_slider` and `st.slider` is that
`select_slider` accepts any datatype and takes an iterable set of
options, while `slider` only accepts numerical or date/time data and
takes a range as input.
Parameters
----------
label : str
A short label explaining to the user what this slider is for.
options : Sequence, numpy.ndarray, pandas.Series, pandas.DataFrame, or pandas.Index
Labels for the slider options. All options will be cast to str
internally by default. For pandas.DataFrame, the first column is
selected.
value : a supported type or a tuple/list of supported types or None
The value of the slider when it first renders. If a tuple/list
of two values is passed here, then a range slider with those lower
and upper bounds is rendered. For example, if set to `(1, 10)` the
slider will have a selectable range between 1 and 10.
Defaults to first option.
format_func : function
Function to modify the display of the labels from the options.
argument. It receives the option as an argument and its output
will be cast to str.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the select slider.
on_change : callable
An optional callback invoked when this select_slider's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the select slider if set to True.
The default is False. This argument can only be supplied by keyword.
Returns
-------
any value or tuple of any value
The current value of the slider widget. The return type will match
the data type of the value parameter.
Examples
--------
>>> color = st.select_slider(
... 'Select a color of the rainbow',
... options=['red', 'orange', 'yellow', 'green', 'blue', 'indigo', 'violet'])
>>> st.write('My favorite color is', color)
And here's an example of a range select slider:
>>> start_color, end_color = st.select_slider(
... 'Select a range of color wavelength',
... options=['red', 'orange', 'yellow', 'green', 'blue', 'indigo', 'violet'],
... value=('red', 'blue'))
>>> st.write('You selected wavelengths between', start_color, 'and', end_color)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.select_slider.py
height: 450px
"""
ctx = get_script_run_ctx()
return self._select_slider(
label=label,
options=options,
value=value,
format_func=format_func,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _select_slider(
self,
label: str,
options: OptionSequence = [],
value: Any = None,
format_func: Callable[[Any], Any] = str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> Any:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=value, key=key)
opt = ensure_indexable(options)
if len(opt) == 0:
raise StreamlitAPIException("The `options` argument needs to be non-empty")
is_range_value = isinstance(value, (list, tuple))
def as_index_list(v):
is_range_value = isinstance(v, (list, tuple))
if is_range_value:
slider_value = [index_(opt, val) for val in v]
start, end = slider_value
if start > end:
slider_value = [end, start]
return slider_value
else:
# Simplify future logic by always making value a list
try:
return [index_(opt, v)]
except ValueError:
if value is not None:
raise
return [0]
# Convert element to index of the elements
slider_value = as_index_list(value)
slider_proto = SliderProto()
slider_proto.label = label
slider_proto.format = "%s"
slider_proto.default[:] = slider_value
slider_proto.min = 0
slider_proto.max = len(opt) - 1
slider_proto.step = 1 # default for index changes
slider_proto.data_type = SliderProto.INT
slider_proto.options[:] = [str(format_func(option)) for option in opt]
slider_proto.form_id = current_form_id(self.dg)
if help is not None:
slider_proto.help = dedent(help)
def deserialize_select_slider(ui_value, widget_id=""):
if not ui_value:
# Widget has not been used; fallback to the original value,
ui_value = slider_value
# The widget always returns floats, so convert to ints before indexing
return_value = list(map(lambda x: opt[int(x)], ui_value)) # type: ignore[no-any-return]
# If the original value was a list/tuple, so will be the output (and vice versa)
return tuple(return_value) if is_range_value else return_value[0]
def serialize_select_slider(v):
return as_index_list(v)
current_value, set_frontend_value = register_widget(
"slider",
slider_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_select_slider,
serializer=serialize_select_slider,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
slider_proto.disabled = disabled
if set_frontend_value:
slider_proto.value[:] = serialize_select_slider(current_value)
slider_proto.set_value = True
self.dg._enqueue("slider", slider_proto)
return current_value
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,187 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from textwrap import dedent
from typing import Any, Callable, Optional, cast
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Selectbox_pb2 import Selectbox as SelectboxProto
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from streamlit.type_util import Key, OptionSequence, ensure_indexable, to_key
from streamlit.util import index_
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
class SelectboxMixin:
def selectbox(
self,
label: str,
options: OptionSequence,
index: int = 0,
format_func: Callable[[Any], Any] = str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> Any:
"""Display a select widget.
Parameters
----------
label : str
A short label explaining to the user what this select widget is for.
options : Sequence, numpy.ndarray, pandas.Series, pandas.DataFrame, or pandas.Index
Labels for the select options. This will be cast to str internally
by default. For pandas.DataFrame, the first column is selected.
index : int
The index of the preselected option on first render.
format_func : function
Function to modify the display of the labels. It receives the option
as an argument and its output will be cast to str.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the selectbox.
on_change : callable
An optional callback invoked when this selectbox's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the selectbox if set to True.
The default is False. This argument can only be supplied by keyword.
Returns
-------
any
The selected option
Example
-------
>>> option = st.selectbox(
... 'How would you like to be contacted?',
... ('Email', 'Home phone', 'Mobile phone'))
>>>
>>> st.write('You selected:', option)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.selectbox.py
height: 320px
"""
ctx = get_script_run_ctx()
return self._selectbox(
label=label,
options=options,
index=index,
format_func=format_func,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _selectbox(
self,
label: str,
options: OptionSequence,
index: int = 0,
format_func: Callable[[Any], Any] = str,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> Any:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=None if index == 0 else index, key=key)
opt = ensure_indexable(options)
if not isinstance(index, int):
raise StreamlitAPIException(
"Selectbox Value has invalid type: %s" % type(index).__name__
)
if len(opt) > 0 and not 0 <= index < len(opt):
raise StreamlitAPIException(
"Selectbox index must be between 0 and length of options"
)
selectbox_proto = SelectboxProto()
selectbox_proto.label = label
selectbox_proto.default = index
selectbox_proto.options[:] = [str(format_func(option)) for option in opt]
selectbox_proto.form_id = current_form_id(self.dg)
if help is not None:
selectbox_proto.help = dedent(help)
def deserialize_select_box(ui_value, widget_id=""):
idx = ui_value if ui_value is not None else index
return opt[idx] if len(opt) > 0 and opt[idx] is not None else None
def serialize_select_box(v):
if len(opt) == 0:
return 0
return index_(opt, v)
current_value, set_frontend_value = register_widget(
"selectbox",
selectbox_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_select_box,
serializer=serialize_select_box,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
selectbox_proto.disabled = disabled
if set_frontend_value:
selectbox_proto.value = serialize_select_box(current_value)
selectbox_proto.set_value = True
self.dg._enqueue("selectbox", selectbox_proto)
return cast(str, current_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,508 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import date, time, datetime, timedelta, timezone
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.type_util import Key, to_key
from typing import Any, List, cast, Optional
from textwrap import dedent
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.js_number import JSNumber
from streamlit.js_number import JSNumberBoundsException
from streamlit.proto.Slider_pb2 import Slider as SliderProto
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
class SliderMixin:
def slider(
self,
label: str,
min_value=None,
max_value=None,
value=None,
step=None,
format=None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
):
"""Display a slider widget.
This supports int, float, date, time, and datetime types.
This also allows you to render a range slider by passing a two-element
tuple or list as the `value`.
The difference between `st.slider` and `st.select_slider` is that
`slider` only accepts numerical or date/time data and takes a range as
input, while `select_slider` accepts any datatype and takes an iterable
set of options.
Parameters
----------
label : str
A short label explaining to the user what this slider is for.
min_value : a supported type or None
The minimum permitted value.
Defaults to 0 if the value is an int, 0.0 if a float,
value - timedelta(days=14) if a date/datetime, time.min if a time
max_value : a supported type or None
The maximum permitted value.
Defaults to 100 if the value is an int, 1.0 if a float,
value + timedelta(days=14) if a date/datetime, time.max if a time
value : a supported type or a tuple/list of supported types or None
The value of the slider when it first renders. If a tuple/list
of two values is passed here, then a range slider with those lower
and upper bounds is rendered. For example, if set to `(1, 10)` the
slider will have a selectable range between 1 and 10.
Defaults to min_value.
step : int/float/timedelta or None
The stepping interval.
Defaults to 1 if the value is an int, 0.01 if a float,
timedelta(days=1) if a date/datetime, timedelta(minutes=15) if a time
(or if max_value - min_value < 1 day)
format : str or None
A printf-style format string controlling how the interface should
display numbers. This does not impact the return value.
Formatter for int/float supports: %d %e %f %g %i
Formatter for date/time/datetime uses Moment.js notation:
https://momentjs.com/docs/#/displaying/format/
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the slider.
on_change : callable
An optional callback invoked when this slider's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the slider if set to True. The
default is False. This argument can only be supplied by keyword.
Returns
-------
int/float/date/time/datetime or tuple of int/float/date/time/datetime
The current value of the slider widget. The return type will match
the data type of the value parameter.
Examples
--------
>>> age = st.slider('How old are you?', 0, 130, 25)
>>> st.write("I'm ", age, 'years old')
And here's an example of a range slider:
>>> values = st.slider(
... 'Select a range of values',
... 0.0, 100.0, (25.0, 75.0))
>>> st.write('Values:', values)
This is a range time slider:
>>> from datetime import time
>>> appointment = st.slider(
... "Schedule your appointment:",
... value=(time(11, 30), time(12, 45)))
>>> st.write("You're scheduled for:", appointment)
Finally, a datetime slider:
>>> from datetime import datetime
>>> start_time = st.slider(
... "When do you start?",
... value=datetime(2020, 1, 1, 9, 30),
... format="MM/DD/YY - hh:mm")
>>> st.write("Start time:", start_time)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.slider.py
height: 300px
"""
ctx = get_script_run_ctx()
return self._slider(
label=label,
min_value=min_value,
max_value=max_value,
value=value,
step=step,
format=format,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _slider(
self,
label: str,
min_value=None,
max_value=None,
value=None,
step=None,
format=None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
):
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=value, key=key)
# Set value default.
if value is None:
value = min_value if min_value is not None else 0
SUPPORTED_TYPES = {
int: SliderProto.INT,
float: SliderProto.FLOAT,
datetime: SliderProto.DATETIME,
date: SliderProto.DATE,
time: SliderProto.TIME,
}
TIMELIKE_TYPES = (SliderProto.DATETIME, SliderProto.TIME, SliderProto.DATE)
# Ensure that the value is either a single value or a range of values.
single_value = isinstance(value, tuple(SUPPORTED_TYPES.keys()))
range_value = isinstance(value, (list, tuple)) and len(value) in (0, 1, 2)
if not single_value and not range_value:
raise StreamlitAPIException(
"Slider value should either be an int/float/datetime or a list/tuple of "
"0 to 2 ints/floats/datetimes"
)
# Simplify future logic by always making value a list
if single_value:
value = [value]
def all_same_type(items):
return len(set(map(type, items))) < 2
if not all_same_type(value):
raise StreamlitAPIException(
"Slider tuple/list components must be of the same type.\n"
f"But were: {list(map(type, value))}"
)
if len(value) == 0:
data_type = SliderProto.INT
else:
data_type = SUPPORTED_TYPES[type(value[0])]
datetime_min = time.min
datetime_max = time.max
if data_type == SliderProto.TIME:
datetime_min = time.min.replace(tzinfo=value[0].tzinfo)
datetime_max = time.max.replace(tzinfo=value[0].tzinfo)
if data_type in (SliderProto.DATETIME, SliderProto.DATE):
datetime_min = value[0] - timedelta(days=14)
datetime_max = value[0] + timedelta(days=14)
DEFAULTS = {
SliderProto.INT: {
"min_value": 0,
"max_value": 100,
"step": 1,
"format": "%d",
},
SliderProto.FLOAT: {
"min_value": 0.0,
"max_value": 1.0,
"step": 0.01,
"format": "%0.2f",
},
SliderProto.DATETIME: {
"min_value": datetime_min,
"max_value": datetime_max,
"step": timedelta(days=1),
"format": "YYYY-MM-DD",
},
SliderProto.DATE: {
"min_value": datetime_min,
"max_value": datetime_max,
"step": timedelta(days=1),
"format": "YYYY-MM-DD",
},
SliderProto.TIME: {
"min_value": datetime_min,
"max_value": datetime_max,
"step": timedelta(minutes=15),
"format": "HH:mm",
},
}
if min_value is None:
min_value = DEFAULTS[data_type]["min_value"]
if max_value is None:
max_value = DEFAULTS[data_type]["max_value"]
if step is None:
step = DEFAULTS[data_type]["step"]
if (
data_type
in (
SliderProto.DATETIME,
SliderProto.DATE,
)
and max_value - min_value < timedelta(days=1)
):
step = timedelta(minutes=15)
if format is None:
format = DEFAULTS[data_type]["format"]
if step == 0:
raise StreamlitAPIException(
"Slider components cannot be passed a `step` of 0."
)
# Ensure that all arguments are of the same type.
slider_args = [min_value, max_value, step]
int_args = all(map(lambda a: isinstance(a, int), slider_args))
float_args = all(map(lambda a: isinstance(a, float), slider_args))
# When min and max_value are the same timelike, step should be a timedelta
timelike_args = (
data_type in TIMELIKE_TYPES
and isinstance(step, timedelta)
and type(min_value) == type(max_value)
)
if not int_args and not float_args and not timelike_args:
raise StreamlitAPIException(
"Slider value arguments must be of matching types."
"\n`min_value` has %(min_type)s type."
"\n`max_value` has %(max_type)s type."
"\n`step` has %(step)s type."
% {
"min_type": type(min_value).__name__,
"max_type": type(max_value).__name__,
"step": type(step).__name__,
}
)
# Ensure that the value matches arguments' types.
all_ints = data_type == SliderProto.INT and int_args
all_floats = data_type == SliderProto.FLOAT and float_args
all_timelikes = data_type in TIMELIKE_TYPES and timelike_args
if not all_ints and not all_floats and not all_timelikes:
raise StreamlitAPIException(
"Both value and arguments must be of the same type."
"\n`value` has %(value_type)s type."
"\n`min_value` has %(min_type)s type."
"\n`max_value` has %(max_type)s type."
% {
"value_type": type(value).__name__,
"min_type": type(min_value).__name__,
"max_type": type(max_value).__name__,
}
)
# Ensure that min <= value(s) <= max, adjusting the bounds as necessary.
min_value = min(min_value, max_value)
max_value = max(min_value, max_value)
if len(value) == 1:
min_value = min(value[0], min_value)
max_value = max(value[0], max_value)
elif len(value) == 2:
start, end = value
if start > end:
# Swap start and end, since they seem reversed
start, end = end, start
value = start, end
min_value = min(start, min_value)
max_value = max(end, max_value)
else:
# Empty list, so let's just use the outer bounds
value = [min_value, max_value]
# Bounds checks. JSNumber produces human-readable exceptions that
# we simply re-package as StreamlitAPIExceptions.
# (We check `min_value` and `max_value` here; `value` and `step` are
# already known to be in the [min_value, max_value] range.)
try:
if all_ints:
JSNumber.validate_int_bounds(min_value, "`min_value`")
JSNumber.validate_int_bounds(max_value, "`max_value`")
elif all_floats:
JSNumber.validate_float_bounds(min_value, "`min_value`")
JSNumber.validate_float_bounds(max_value, "`max_value`")
elif all_timelikes:
# No validation yet. TODO: check between 0001-01-01 to 9999-12-31
pass
except JSNumberBoundsException as e:
raise StreamlitAPIException(str(e))
# Convert dates or times into datetimes
if data_type == SliderProto.TIME:
def _time_to_datetime(time):
# Note, here we pick an arbitrary date well after Unix epoch.
# This prevents pre-epoch timezone issues (https://bugs.python.org/issue36759)
# We're dropping the date from datetime laters, anyways.
return datetime.combine(date(2000, 1, 1), time)
value = list(map(_time_to_datetime, value))
min_value = _time_to_datetime(min_value)
max_value = _time_to_datetime(max_value)
if data_type == SliderProto.DATE:
def _date_to_datetime(date):
return datetime.combine(date, time())
value = list(map(_date_to_datetime, value))
min_value = _date_to_datetime(min_value)
max_value = _date_to_datetime(max_value)
# Now, convert to microseconds (so we can serialize datetime to a long)
if data_type in TIMELIKE_TYPES:
SECONDS_TO_MICROS = 1000 * 1000
DAYS_TO_MICROS = 24 * 60 * 60 * SECONDS_TO_MICROS
def _delta_to_micros(delta):
return (
delta.microseconds
+ delta.seconds * SECONDS_TO_MICROS
+ delta.days * DAYS_TO_MICROS
)
UTC_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
def _datetime_to_micros(dt):
# The frontend is not aware of timezones and only expects a UTC-based timestamp (in microseconds).
# Since we want to show the date/time exactly as it is in the given datetime object,
# we just set the tzinfo to UTC and do not do any timezone conversions.
# Only the backend knows about original timezone and will replace the UTC timestamp in the deserialization.
utc_dt = dt.replace(tzinfo=timezone.utc)
return _delta_to_micros(utc_dt - UTC_EPOCH)
# Restore times/datetimes to original timezone (dates are always naive)
orig_tz = (
value[0].tzinfo
if data_type in (SliderProto.TIME, SliderProto.DATETIME)
else None
)
def _micros_to_datetime(micros):
utc_dt = UTC_EPOCH + timedelta(microseconds=micros)
# Add the original timezone. No conversion is required here,
# since in the serialization, we also just replace the timestamp with UTC.
return utc_dt.replace(tzinfo=orig_tz)
value = list(map(_datetime_to_micros, value))
min_value = _datetime_to_micros(min_value)
max_value = _datetime_to_micros(max_value)
step = _delta_to_micros(step)
# It would be great if we could guess the number of decimal places from
# the `step` argument, but this would only be meaningful if step were a
# decimal. As a possible improvement we could make this function accept
# decimals and/or use some heuristics for floats.
slider_proto = SliderProto()
slider_proto.label = label
slider_proto.format = format
slider_proto.default[:] = value
slider_proto.min = min_value
slider_proto.max = max_value
slider_proto.step = step
slider_proto.data_type = data_type
slider_proto.options[:] = []
slider_proto.form_id = current_form_id(self.dg)
if help is not None:
slider_proto.help = dedent(help)
def deserialize_slider(ui_value: Optional[List[float]], widget_id=""):
if ui_value is not None:
val = ui_value
else:
# Widget has not been used; fallback to the original value,
val = cast(List[float], value)
# The widget always returns a float array, so fix the return type if necessary
if data_type == SliderProto.INT:
val = [int(v) for v in val]
if data_type == SliderProto.DATETIME:
val = [_micros_to_datetime(int(v)) for v in val]
if data_type == SliderProto.DATE:
val = [_micros_to_datetime(int(v)).date() for v in val]
if data_type == SliderProto.TIME:
val = [
_micros_to_datetime(int(v)).time().replace(tzinfo=orig_tz)
for v in val
]
return val[0] if single_value else tuple(val)
def serialize_slider(v: Any) -> List[Any]:
range_value = isinstance(v, (list, tuple))
value = list(v) if range_value else [v]
if data_type == SliderProto.DATE:
value = [_datetime_to_micros(_date_to_datetime(v)) for v in value]
if data_type == SliderProto.TIME:
value = [_datetime_to_micros(_time_to_datetime(v)) for v in value]
if data_type == SliderProto.DATETIME:
value = [_datetime_to_micros(v) for v in value]
return value
current_value, set_frontend_value = register_widget(
"slider",
slider_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_slider,
serializer=serialize_slider,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
slider_proto.disabled = disabled
if set_frontend_value:
slider_proto.value[:] = serialize_slider(current_value)
slider_proto.set_value = True
self.dg._enqueue("slider", slider_proto)
return current_value
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,39 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import streamlit
from streamlit.proto.Snow_pb2 import Snow as SnowProto
class SnowMixin:
def snow(self):
"""Draw celebratory snowfall.
Example
-------
>>> st.snow()
...then watch your app and get ready for a cool celebration!
"""
snow_proto = SnowProto()
snow_proto.show = True
return self.dg._enqueue("snow", snow_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,43 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import streamlit
from streamlit.proto.Text_pb2 import Text as TextProto
from .utils import clean_text
class TextMixin:
def text(self, body):
"""Write fixed-width and preformatted text.
Parameters
----------
body : str
The string to display.
Example
-------
>>> st.text('This is some text.')
"""
text_proto = TextProto()
text_proto.body = clean_text(body)
return self.dg._enqueue("text", text_proto)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,347 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.type_util import Key, to_key
from textwrap import dedent
from typing import Optional, cast
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.TextArea_pb2 import TextArea as TextAreaProto
from streamlit.proto.TextInput_pb2 import TextInput as TextInputProto
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
class TextWidgetsMixin:
def text_input(
self,
label: str,
value: str = "",
max_chars: Optional[int] = None,
key: Optional[Key] = None,
type: str = "default",
help: Optional[str] = None,
autocomplete: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
placeholder: Optional[str] = None,
disabled: bool = False,
) -> str:
"""Display a single-line text input widget.
Parameters
----------
label : str
A short label explaining to the user what this input is for.
value : any
The text value of this widget when it first renders. This will be
cast to str internally.
max_chars : int or None
Max number of characters allowed in text input.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
type : str
The type of the text input. This can be either "default" (for
a regular text input), or "password" (for a text input that
masks the user's typed value). Defaults to "default".
help : str
An optional tooltip that gets displayed next to the input.
autocomplete : str
An optional value that will be passed to the <input> element's
autocomplete property. If unspecified, this value will be set to
"new-password" for "password" inputs, and the empty string for
"default" inputs. For more details, see https://developer.mozilla.org/en-US/docs/Web/HTML/Attributes/autocomplete
on_change : callable
An optional callback invoked when this text_input's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
placeholder : str or None
An optional string displayed when the text input is empty. If None,
no text is displayed. This argument can only be supplied by keyword.
disabled : bool
An optional boolean, which disables the text input if set to True.
The default is False. This argument can only be supplied by keyword.
Returns
-------
str
The current value of the text input widget.
Example
-------
>>> title = st.text_input('Movie title', 'Life of Brian')
>>> st.write('The current movie title is', title)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.text_input.py
height: 260px
"""
ctx = get_script_run_ctx()
return self._text_input(
label=label,
value=value,
max_chars=max_chars,
key=key,
type=type,
help=help,
autocomplete=autocomplete,
on_change=on_change,
args=args,
kwargs=kwargs,
placeholder=placeholder,
disabled=disabled,
ctx=ctx,
)
def _text_input(
self,
label: str,
value: str = "",
max_chars: Optional[int] = None,
key: Optional[Key] = None,
type: str = "default",
help: Optional[str] = None,
autocomplete: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
placeholder: Optional[str] = None,
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> str:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=None if value == "" else value, key=key)
text_input_proto = TextInputProto()
text_input_proto.label = label
text_input_proto.default = str(value)
text_input_proto.form_id = current_form_id(self.dg)
if help is not None:
text_input_proto.help = dedent(help)
if max_chars is not None:
text_input_proto.max_chars = max_chars
if placeholder is not None:
text_input_proto.placeholder = str(placeholder)
if type == "default":
text_input_proto.type = TextInputProto.DEFAULT
elif type == "password":
text_input_proto.type = TextInputProto.PASSWORD
else:
raise StreamlitAPIException(
"'%s' is not a valid text_input type. Valid types are 'default' and 'password'."
% type
)
# Marshall the autocomplete param. If unspecified, this will be
# set to "new-password" for password inputs.
if autocomplete is None:
autocomplete = "new-password" if type == "password" else ""
text_input_proto.autocomplete = autocomplete
def deserialize_text_input(ui_value, widget_id="") -> str:
return str(ui_value if ui_value is not None else value)
current_value, set_frontend_value = register_widget(
"text_input",
text_input_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_text_input,
serializer=lambda x: x,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
text_input_proto.disabled = disabled
if set_frontend_value:
text_input_proto.value = current_value
text_input_proto.set_value = True
self.dg._enqueue("text_input", text_input_proto)
return cast(str, current_value)
def text_area(
self,
label: str,
value: str = "",
height: Optional[int] = None,
max_chars: Optional[int] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
placeholder: Optional[str] = None,
disabled: bool = False,
) -> str:
"""Display a multi-line text input widget.
Parameters
----------
label : str
A short label explaining to the user what this input is for.
value : any
The text value of this widget when it first renders. This will be
cast to str internally.
height : int or None
Desired height of the UI element expressed in pixels. If None, a
default height is used.
max_chars : int or None
Maximum number of characters allowed in text area.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the textarea.
on_change : callable
An optional callback invoked when this text_area's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
placeholder : str or None
An optional string displayed when the text area is empty. If None,
no text is displayed. This argument can only be supplied by keyword.
disabled : bool
An optional boolean, which disables the text area if set to True.
The default is False. This argument can only be supplied by keyword.
Returns
-------
str
The current value of the text input widget.
Example
-------
>>> txt = st.text_area('Text to analyze', '''
... It was the best of times, it was the worst of times, it was
... the age of wisdom, it was the age of foolishness, it was
... the epoch of belief, it was the epoch of incredulity, it
... was the season of Light, it was the season of Darkness, it
... was the spring of hope, it was the winter of despair, (...)
... ''')
>>> st.write('Sentiment:', run_sentiment_analysis(txt))
"""
ctx = get_script_run_ctx()
return self._text_area(
label=label,
value=value,
height=height,
max_chars=max_chars,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
placeholder=placeholder,
disabled=disabled,
ctx=ctx,
)
def _text_area(
self,
label: str,
value: str = "",
height: Optional[int] = None,
max_chars: Optional[int] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
placeholder: Optional[str] = None,
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> str:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=None if value == "" else value, key=key)
text_area_proto = TextAreaProto()
text_area_proto.label = label
text_area_proto.default = str(value)
text_area_proto.form_id = current_form_id(self.dg)
if help is not None:
text_area_proto.help = dedent(help)
if height is not None:
text_area_proto.height = height
if max_chars is not None:
text_area_proto.max_chars = max_chars
if placeholder is not None:
text_area_proto.placeholder = str(placeholder)
def deserialize_text_area(ui_value, widget_id="") -> str:
return str(ui_value if ui_value is not None else value)
current_value, set_frontend_value = register_widget(
"text_area",
text_area_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_text_area,
serializer=lambda x: x,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
text_area_proto.disabled = disabled
if set_frontend_value:
text_area_proto.value = current_value
text_area_proto.set_value = True
self.dg._enqueue("text_area", text_area_proto)
return cast(str, current_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,370 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime, date, time
from streamlit.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.type_util import Key, to_key
from typing import cast, Optional, Union, Tuple
from textwrap import dedent
from dateutil import relativedelta
import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.proto.DateInput_pb2 import DateInput as DateInputProto
from streamlit.proto.TimeInput_pb2 import TimeInput as TimeInputProto
from streamlit.state import (
register_widget,
WidgetArgs,
WidgetCallback,
WidgetKwargs,
)
from .form import current_form_id
from .utils import check_callback_rules, check_session_state_rules
class TimeWidgetsMixin:
def time_input(
self,
label: str,
value=None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> time:
"""Display a time input widget.
Parameters
----------
label : str
A short label explaining to the user what this time input is for.
value : datetime.time/datetime.datetime
The value of this widget when it first renders. This will be
cast to str internally. Defaults to the current time.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the input.
on_change : callable
An optional callback invoked when this time_input's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the time input if set to True.
The default is False. This argument can only be supplied by keyword.
Returns
-------
datetime.time
The current value of the time input widget.
Example
-------
>>> t = st.time_input('Set an alarm for', datetime.time(8, 45))
>>> st.write('Alarm is set for', t)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.time_input.py
height: 260px
"""
ctx = get_script_run_ctx()
return self._time_input(
label=label,
value=value,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _time_input(
self,
label: str,
value=None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> time:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=value, key=key)
# Set value default.
if value is None:
value = datetime.now().time().replace(second=0, microsecond=0)
# Ensure that the value is either datetime/time
if not isinstance(value, datetime) and not isinstance(value, time):
raise StreamlitAPIException(
"The type of the value should be either datetime or time."
)
# Convert datetime to time
if isinstance(value, datetime):
value = value.time().replace(second=0, microsecond=0)
time_input_proto = TimeInputProto()
time_input_proto.label = label
time_input_proto.default = time.strftime(value, "%H:%M")
time_input_proto.form_id = current_form_id(self.dg)
if help is not None:
time_input_proto.help = dedent(help)
def deserialize_time_input(ui_value, widget_id=""):
return (
datetime.strptime(ui_value, "%H:%M").time()
if ui_value is not None
else value
)
def serialize_time_input(v):
if isinstance(v, datetime):
v = v.time()
return time.strftime(v, "%H:%M")
current_value, set_frontend_value = register_widget(
"time_input",
time_input_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_time_input,
serializer=serialize_time_input,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
time_input_proto.disabled = disabled
if set_frontend_value:
time_input_proto.value = serialize_time_input(current_value)
time_input_proto.set_value = True
self.dg._enqueue("time_input", time_input_proto)
return cast(time, current_value)
def date_input(
self,
label: str,
value=None,
min_value=None,
max_value=None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
) -> Union[date, Tuple[date, ...]]:
"""Display a date input widget.
Parameters
----------
label : str
A short label explaining to the user what this date input is for.
value : datetime.date or datetime.datetime or list/tuple of datetime.date or datetime.datetime or None
The value of this widget when it first renders. If a list/tuple with
0 to 2 date/datetime values is provided, the datepicker will allow
users to provide a range. Defaults to today as a single-date picker.
min_value : datetime.date or datetime.datetime
The minimum selectable date. If value is a date, defaults to value - 10 years.
If value is the interval [start, end], defaults to start - 10 years.
max_value : datetime.date or datetime.datetime
The maximum selectable date. If value is a date, defaults to value + 10 years.
If value is the interval [start, end], defaults to end + 10 years.
key : str or int
An optional string or integer to use as the unique key for the widget.
If this is omitted, a key will be generated for the widget
based on its content. Multiple widgets of the same type may
not share the same key.
help : str
An optional tooltip that gets displayed next to the input.
on_change : callable
An optional callback invoked when this date_input's value changes.
args : tuple
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
disabled : bool
An optional boolean, which disables the date input if set to True.
The default is False. This argument can only be supplied by keyword.
Returns
-------
datetime.date or a tuple with 0-2 dates
The current value of the date input widget.
Example
-------
>>> d = st.date_input(
... "When\'s your birthday",
... datetime.date(2019, 7, 6))
>>> st.write('Your birthday is:', d)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/widget.date_input.py
height: 260px
"""
ctx = get_script_run_ctx()
return self._date_input(
label=label,
value=value,
min_value=min_value,
max_value=max_value,
key=key,
help=help,
on_change=on_change,
args=args,
kwargs=kwargs,
disabled=disabled,
ctx=ctx,
)
def _date_input(
self,
label: str,
value=None,
min_value=None,
max_value=None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> Union[date, Tuple[date, ...]]:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=value, key=key)
# Set value default.
if value is None:
value = datetime.now().date()
single_value = isinstance(value, (date, datetime))
range_value = isinstance(value, (list, tuple)) and len(value) in (0, 1, 2)
if not single_value and not range_value:
raise StreamlitAPIException(
"DateInput value should either be an date/datetime or a list/tuple of "
"0 - 2 date/datetime values"
)
if single_value:
value = [value]
value = [v.date() if isinstance(v, datetime) else v for v in value]
if isinstance(min_value, datetime):
min_value = min_value.date()
elif min_value is None:
if value:
min_value = value[0] - relativedelta.relativedelta(years=10)
else:
min_value = date.today() - relativedelta.relativedelta(years=10)
if isinstance(max_value, datetime):
max_value = max_value.date()
elif max_value is None:
if value:
max_value = value[-1] + relativedelta.relativedelta(years=10)
else:
max_value = date.today() + relativedelta.relativedelta(years=10)
if value:
start_value = value[0]
end_value = value[-1]
if (start_value < min_value) or (end_value > max_value):
raise StreamlitAPIException(
f"The default `value` of {value} "
f"must lie between the `min_value` of {min_value} "
f"and the `max_value` of {max_value}, inclusively."
)
date_input_proto = DateInputProto()
date_input_proto.is_range = range_value
if help is not None:
date_input_proto.help = dedent(help)
date_input_proto.label = label
date_input_proto.default[:] = [date.strftime(v, "%Y/%m/%d") for v in value]
date_input_proto.min = date.strftime(min_value, "%Y/%m/%d")
date_input_proto.max = date.strftime(max_value, "%Y/%m/%d")
date_input_proto.form_id = current_form_id(self.dg)
def deserialize_date_input(ui_value, widget_id=""):
if ui_value is not None:
return_value = [
datetime.strptime(v, "%Y/%m/%d").date() for v in ui_value
]
else:
return_value = value
return return_value[0] if single_value else tuple(return_value)
def serialize_date_input(v):
range_value = isinstance(v, (list, tuple))
to_serialize = list(v) if range_value else [v]
return [date.strftime(v, "%Y/%m/%d") for v in to_serialize]
current_value, set_frontend_value = register_widget(
"date_input",
date_input_proto,
user_key=key,
on_change_handler=on_change,
args=args,
kwargs=kwargs,
deserializer=deserialize_date_input,
serializer=serialize_date_input,
ctx=ctx,
)
# This needs to be done after register_widget because we don't want
# the following proto fields to affect a widget's ID.
date_input_proto.disabled = disabled
if set_frontend_value:
date_input_proto.value[:] = serialize_date_input(current_value)
date_input_proto.set_value = True
self.dg._enqueue("date_input", date_input_proto)
return cast(date, current_value)
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,84 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap
from typing import Any, Optional, TYPE_CHECKING
import streamlit
from streamlit import type_util
from streamlit.elements.form import is_in_form
from streamlit.errors import StreamlitAPIException
from streamlit.state import get_session_state, WidgetCallback
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
def clean_text(text: Any) -> str:
"""Convert an object to text, dedent it, and strip whitespace."""
return textwrap.dedent(str(text)).strip()
def last_index_for_melted_dataframes(data):
if type_util.is_dataframe_compatible(data):
data = type_util.convert_anything_to_df(data)
if data.index.size > 0:
return data.index[-1]
return None
def check_callback_rules(
dg: "DeltaGenerator", on_change: Optional[WidgetCallback]
) -> None:
if (
streamlit._is_running_with_streamlit
and is_in_form(dg)
and on_change is not None
):
raise StreamlitAPIException(
"With forms, callbacks can only be defined on the `st.form_submit_button`."
" Defining callbacks on other widgets inside a form is not allowed."
)
_shown_default_value_warning = False
def check_session_state_rules(
default_value: Any, key: Optional[str], writes_allowed: bool = True
) -> None:
global _shown_default_value_warning
if key is None or not streamlit._is_running_with_streamlit:
return
session_state = get_session_state()
if not session_state.is_new_state_value(key):
return
if not writes_allowed:
raise StreamlitAPIException(
"Values for st.button, st.download_button, st.file_uploader, and "
"st.form cannot be set using st.session_state."
)
if default_value is not None and not _shown_default_value_warning:
streamlit.warning(
f'The widget with key "{key}" was created with a default value but'
" also had its value set via the Session State API."
)
_shown_default_value_warning = True

View File

@ -0,0 +1,239 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json as json
import types
from typing import cast, Any, List, Tuple, Type
import numpy as np
import streamlit
from streamlit import type_util
from streamlit.errors import StreamlitAPIException
from streamlit.state import SessionStateProxy
# Special methods:
HELP_TYPES = (
types.BuiltinFunctionType,
types.BuiltinMethodType,
types.FunctionType,
types.MethodType,
types.ModuleType,
) # type: Tuple[Type[Any], ...]
class WriteMixin:
def write(self, *args, **kwargs):
"""Write arguments to the app.
This is the Swiss Army knife of Streamlit commands: it does different
things depending on what you throw at it. Unlike other Streamlit commands,
write() has some unique properties:
1. You can pass in multiple arguments, all of which will be written.
2. Its behavior depends on the input types as follows.
3. It returns None, so its "slot" in the App cannot be reused.
Parameters
----------
*args : any
One or many objects to print to the App.
Arguments are handled as follows:
- write(string) : Prints the formatted Markdown string, with
support for LaTeX expression and emoji shortcodes.
See docs for st.markdown for more.
- write(data_frame) : Displays the DataFrame as a table.
- write(error) : Prints an exception specially.
- write(func) : Displays information about a function.
- write(module) : Displays information about the module.
- write(dict) : Displays dict in an interactive widget.
- write(mpl_fig) : Displays a Matplotlib figure.
- write(altair) : Displays an Altair chart.
- write(keras) : Displays a Keras model.
- write(graphviz) : Displays a Graphviz graph.
- write(plotly_fig) : Displays a Plotly figure.
- write(bokeh_fig) : Displays a Bokeh figure.
- write(sympy_expr) : Prints SymPy expression using LaTeX.
- write(htmlable) : Prints _repr_html_() for the object if available.
- write(obj) : Prints str(obj) if otherwise unknown.
unsafe_allow_html : bool
This is a keyword-only argument that defaults to False.
By default, any HTML tags found in strings will be escaped and
therefore treated as pure text. This behavior may be turned off by
setting this argument to True.
That said, *we strongly advise against it*. It is hard to write secure
HTML, so by using this argument you may be compromising your users'
security. For more information, see:
https://github.com/streamlit/streamlit/issues/152
**Also note that `unsafe_allow_html` is a temporary measure and may be
removed from Streamlit at any time.**
If you decide to turn on HTML anyway, we ask you to please tell us your
exact use case here:
https://discuss.streamlit.io/t/96 .
This will help us come up with safe APIs that allow you to do what you
want.
Example
-------
Its basic use case is to draw Markdown-formatted text, whenever the
input is a string:
>>> write('Hello, *World!* :sunglasses:')
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/text.write1.py
height: 150px
As mentioned earlier, `st.write()` also accepts other data formats, such as
numbers, data frames, styled data frames, and assorted objects:
>>> st.write(1234)
>>> st.write(pd.DataFrame({
... 'first column': [1, 2, 3, 4],
... 'second column': [10, 20, 30, 40],
... }))
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/text.write2.py
height: 350px
Finally, you can pass in multiple arguments to do things like:
>>> st.write('1 + 1 = ', 2)
>>> st.write('Below is a DataFrame:', data_frame, 'Above is a dataframe.')
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/text.write3.py
height: 410px
Oh, one more thing: `st.write` accepts chart objects too! For example:
>>> import pandas as pd
>>> import numpy as np
>>> import altair as alt
>>>
>>> df = pd.DataFrame(
... np.random.randn(200, 3),
... columns=['a', 'b', 'c'])
...
>>> c = alt.Chart(df).mark_circle().encode(
... x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c'])
>>>
>>> st.write(c)
.. output::
https://share.streamlit.io/streamlit/docs/main/python/api-examples-source/charts.vega_lite_chart.py
height: 300px
"""
string_buffer = [] # type: List[str]
unsafe_allow_html = kwargs.get("unsafe_allow_html", False)
# This bans some valid cases like: e = st.empty(); e.write("a", "b").
# BUT: 1) such cases are rare, 2) this rule is easy to understand,
# and 3) this rule should be removed once we have st.container()
if not self.dg._is_top_level and len(args) > 1:
raise StreamlitAPIException(
"Cannot replace a single element with multiple elements.\n\n"
"The `write()` method only supports multiple elements when "
"inserting elements rather than replacing. That is, only "
"when called as `st.write()` or `st.sidebar.write()`."
)
def flush_buffer():
if string_buffer:
self.dg.markdown(
" ".join(string_buffer),
unsafe_allow_html=unsafe_allow_html,
)
string_buffer[:] = []
for arg in args:
# Order matters!
if isinstance(arg, str):
string_buffer.append(arg)
elif type_util.is_dataframe_like(arg):
flush_buffer()
if len(np.shape(arg)) > 2:
self.dg.text(arg)
else:
self.dg.dataframe(arg)
elif isinstance(arg, Exception):
flush_buffer()
self.dg.exception(arg)
elif isinstance(arg, HELP_TYPES):
flush_buffer()
self.dg.help(arg)
elif type_util.is_altair_chart(arg):
flush_buffer()
self.dg.altair_chart(arg)
elif type_util.is_type(arg, "matplotlib.figure.Figure"):
flush_buffer()
self.dg.pyplot(arg)
elif type_util.is_plotly_chart(arg):
flush_buffer()
self.dg.plotly_chart(arg)
elif type_util.is_type(arg, "bokeh.plotting.figure.Figure"):
flush_buffer()
self.dg.bokeh_chart(arg)
elif type_util.is_graphviz_chart(arg):
flush_buffer()
self.dg.graphviz_chart(arg)
elif type_util.is_sympy_expession(arg):
flush_buffer()
self.dg.latex(arg)
elif type_util.is_keras_model(arg):
from tensorflow.python.keras.utils import vis_utils
flush_buffer()
dot = vis_utils.model_to_dot(arg)
self.dg.graphviz_chart(dot.to_string())
elif isinstance(arg, (dict, list, SessionStateProxy)):
flush_buffer()
self.dg.json(arg)
elif type_util.is_namedtuple(arg):
flush_buffer()
self.dg.json(json.dumps(arg._asdict()))
elif type_util.is_pydeck(arg):
flush_buffer()
self.dg.pydeck_chart(arg)
elif inspect.isclass(arg):
flush_buffer()
self.dg.text(arg)
elif hasattr(arg, "_repr_html_"):
self.dg.markdown(
arg._repr_html_(),
unsafe_allow_html=True,
)
else:
string_buffer.append("`%s`" % str(arg).replace("`", "\\`"))
flush_buffer()
@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)

View File

@ -0,0 +1,60 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
import re
import sys
_system = platform.system()
IS_WINDOWS = _system == "Windows"
IS_DARWIN = _system == "Darwin"
IS_LINUX_OR_BSD = (_system == "Linux") or ("BSD" in _system)
def is_pex():
"""Return if streamlit running in pex.
Pex modifies sys.path so the pex file is the first path and that's
how we determine we're running in the pex file.
"""
if re.match(r".*pex$", sys.path[0]):
return True
return False
def is_repl():
"""Return True if running in the Python REPL."""
import inspect
root_frame = inspect.stack()[-1]
filename = root_frame[1] # 1 is the filename field in this tuple.
if filename.endswith(os.path.join("bin", "ipython")):
return True
# <stdin> is what the basic Python REPL calls the root frame's
# filename, and <string> is what iPython sometimes calls it.
if filename in ("<stdin>", "<string>"):
return True
return False
def is_executable_in_path(name):
"""Check if executable is in OS path."""
from distutils.spawn import find_executable
return find_executable(name) is not None

View File

@ -0,0 +1,39 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
import streamlit as st
from streamlit import config
from streamlit.logger import get_logger
from streamlit.errors import UncaughtAppException
LOGGER = get_logger(__name__)
def handle_uncaught_app_exception(e: BaseException) -> None:
"""Handle an exception that originated from a user app.
By default, we show exceptions directly in the browser. However,
if the user has disabled client error details, we display a generic
warning in the frontend instead.
"""
if config.get_option("client.showErrorDetails"):
LOGGER.warning(traceback.format_exc())
st.exception(e)
# TODO: Clean up the stack trace, so it doesn't include ScriptRunner.
else:
# Use LOGGER.error, rather than LOGGER.debug, since we don't
# show debug logs by default.
LOGGER.error("Uncaught app exception", exc_info=e)
st.exception(UncaughtAppException(e))

View File

@ -0,0 +1,130 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit import util
class Error(Exception):
pass
class DeprecationError(Error):
pass
class NoStaticFiles(Exception):
pass
class NoSessionContext(Exception):
pass
class MarkdownFormattedException(Exception):
"""Exceptions with Markdown in their description.
Instances of this class can use markdown in their messages, which will get
nicely formatted on the frontend.
"""
pass
class UncaughtAppException(Exception):
"""This will be used for Uncaught Exception within Streamlit Apps in order
to say that the Streamlit app has an error"""
def __init__(self, exc):
self.exc = exc
class StreamlitAPIException(MarkdownFormattedException):
"""Base class for Streamlit API exceptions.
An API exception should be thrown when user code interacts with the
Streamlit API incorrectly. (That is, when we throw an exception as a
result of a user's malformed `st.foo` call, it should be a
StreamlitAPIException or subclass.)
When displaying these exceptions on the frontend, we strip Streamlit
entries from the stack trace so that the user doesn't see a bunch of
noise related to Streamlit internals.
"""
pass
def __repr__(self) -> str:
return util.repr_(self)
class DuplicateWidgetID(StreamlitAPIException):
pass
class NumpyDtypeException(StreamlitAPIException):
pass
class StreamlitAPIWarning(StreamlitAPIException, Warning):
"""Used to display a warning.
Note that this should not be "raised", but passed to st.exception
instead.
"""
def __init__(self, *args):
super(StreamlitAPIWarning, self).__init__(*args)
import inspect
import traceback
f = inspect.currentframe()
self.tacked_on_stack = traceback.extract_stack(f)
def __repr__(self) -> str:
return util.repr_(self)
class StreamlitDeprecationWarning(StreamlitAPIWarning):
"""Used to display a warning.
Note that this should not be "raised", but passed to st.exception
instead.
"""
def __init__(self, config_option, msg, *args):
message = """
{0}
You can disable this warning by disabling the config option:
`{1}`
```
st.set_option('{1}', False)
```
or in your `.streamlit/config.toml`
```
[deprecation]
{2} = False
```
""".format(
msg, config_option, config_option.split(".")[1]
)
# TODO: create a deprecation docs page to add to deprecation msg #1669
# For more details, please see: https://docs.streamlit.io/path/to/deprecation/docs.html
super(StreamlitAPIWarning, self).__init__(message, *args)
def __repr__(self) -> str:
return util.repr_(self)

View File

@ -0,0 +1,198 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import errno
import io
import os
import fnmatch
from streamlit import env_util
from streamlit import util
from streamlit.string_util import is_binary_string
# Configuration and credentials are stored inside the ~/.streamlit folder
CONFIG_FOLDER_NAME = ".streamlit"
def get_encoded_file_data(data, encoding="auto"):
"""Coerce bytes to a BytesIO or a StringIO.
Parameters
----------
data : bytes
encoding : str
Returns
-------
BytesIO or StringIO
If the file's data is in a well-known textual format (or if the encoding
parameter is set), return a StringIO. Otherwise, return BytesIO.
"""
if encoding == "auto":
if is_binary_string(data):
encoding = None
else:
# If the file does not look like a pure binary file, assume
# it's utf-8. It would be great if we could guess it a little
# more smartly here, but it is what it is!
encoding = "utf-8"
if encoding:
return io.StringIO(data.decode(encoding))
return io.BytesIO(data)
@contextlib.contextmanager
def streamlit_read(path, binary=False):
"""Opens a context to read this file relative to the streamlit path.
For example:
with streamlit_read('foo.txt') as foo:
...
opens the file `%s/foo.txt`
path - the path to write to (within the streamlit directory)
binary - set to True for binary IO
""" % CONFIG_FOLDER_NAME
filename = get_streamlit_file_path(path)
if os.stat(filename).st_size == 0:
raise util.Error('Read zero byte file: "%s"' % filename)
mode = "r"
if binary:
mode += "b"
with open(os.path.join(CONFIG_FOLDER_NAME, path), mode) as handle:
yield handle
@contextlib.contextmanager
def streamlit_write(path, binary=False):
"""
Opens a file for writing within the streamlit path, and
ensuring that the path exists. For example:
with streamlit_write('foo/bar.txt') as bar:
...
opens the file %s/foo/bar.txt for writing,
creating any necessary directories along the way.
path - the path to write to (within the streamlit directory)
binary - set to True for binary IO
""" % CONFIG_FOLDER_NAME
mode = "w"
if binary:
mode += "b"
path = get_streamlit_file_path(path)
os.makedirs(os.path.dirname(path), exist_ok=True)
try:
with open(path, mode) as handle:
yield handle
except OSError as e:
msg = ["Unable to write file: %s" % os.path.abspath(path)]
if e.errno == errno.EINVAL and env_util.IS_DARWIN:
msg.append(
"Python is limited to files below 2GB on OSX. "
"See https://bugs.python.org/issue24658"
)
raise util.Error("\n".join(msg))
def get_static_dir():
"""Get the folder where static HTML/JS/CSS files live."""
dirname = os.path.dirname(os.path.normpath(__file__))
return os.path.normpath(os.path.join(dirname, "static"))
def get_assets_dir():
"""Get the folder where static assets live."""
dirname = os.path.dirname(os.path.normpath(__file__))
return os.path.normpath(os.path.join(dirname, "static/assets"))
def get_streamlit_file_path(*filepath) -> str:
"""Return the full path to a file in ~/.streamlit.
This doesn't guarantee that the file (or its directory) exists.
"""
# os.path.expanduser works on OSX, Linux and Windows
home = os.path.expanduser("~")
if home is None:
raise RuntimeError("No home directory.")
return os.path.join(home, CONFIG_FOLDER_NAME, *filepath)
def get_project_streamlit_file_path(*filepath):
"""Return the full path to a filepath in ${CWD}/.streamlit.
This doesn't guarantee that the file (or its directory) exists.
"""
return os.path.join(os.getcwd(), CONFIG_FOLDER_NAME, *filepath)
def file_is_in_folder_glob(filepath, folderpath_glob) -> bool:
"""Test whether a file is in some folder with globbing support.
Parameters
----------
filepath : str
A file path.
folderpath_glob: str
A path to a folder that may include globbing.
"""
# Make the glob always end with "/*" so we match files inside subfolders of
# folderpath_glob.
if not folderpath_glob.endswith("*"):
if folderpath_glob.endswith("/"):
folderpath_glob += "*"
else:
folderpath_glob += "/*"
file_dir = os.path.dirname(filepath) + "/"
return fnmatch.fnmatch(file_dir, folderpath_glob)
def file_in_pythonpath(filepath) -> bool:
"""Test whether a filepath is in the same folder of a path specified in the PYTHONPATH env variable.
Parameters
----------
filepath : str
An absolute file path.
Returns
-------
boolean
True if contained in PYTHONPATH, False otherwise. False if PYTHONPATH is not defined or empty.
"""
pythonpath = os.environ.get("PYTHONPATH", "")
if len(pythonpath) == 0:
return False
absolute_paths = [os.path.abspath(path) for path in pythonpath.split(os.pathsep)]
return any(
file_is_in_folder_glob(os.path.normpath(filepath), path)
for path in absolute_paths
)

View File

@ -0,0 +1,80 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from streamlit import util
from streamlit import file_util
from streamlit import config
# The files in the folders below should always be blacklisted.
DEFAULT_FOLDER_BLACKLIST = [
"**/.*",
"**/anaconda",
"**/anaconda2",
"**/anaconda3",
"**/dist-packages",
"**/miniconda",
"**/miniconda2",
"**/miniconda3",
"**/node_modules",
"**/pyenv",
"**/site-packages",
"**/venv",
"**/virtualenv",
]
class FolderBlackList(object):
"""Implement a black list object with globbing.
Note
----
Blacklist any path that matches a glob in `DEFAULT_FOLDER_BLACKLIST`.
"""
def __init__(self, folder_blacklist):
"""Constructor.
Parameters
----------
folder_blacklist : list of str
list of folder names with globbing to blacklist.
"""
self._folder_blacklist = list(folder_blacklist)
self._folder_blacklist.extend(DEFAULT_FOLDER_BLACKLIST)
# Add the Streamlit lib folder when in dev mode, since otherwise we end
# up with weird situations where the ID of a class in one run is not
# the same as in another run.
if config.get_option("global.developmentMode"):
self._folder_blacklist.append(os.path.dirname(__file__))
def __repr__(self) -> str:
return util.repr_(self)
def is_blacklisted(self, filepath):
"""Test if filepath is in the blacklist.
Parameters
----------
filepath : str
File path that we intend to test.
"""
return any(
file_util.file_is_in_folder_glob(filepath, blacklisted_folder)
for blacklisted_folder in self._folder_blacklist
)

View File

@ -0,0 +1,271 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from typing import MutableMapping, Dict, Optional, TYPE_CHECKING, List
from weakref import WeakKeyDictionary
from streamlit import config
from streamlit import util
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.stats import CacheStatsProvider, CacheStat
if TYPE_CHECKING:
from streamlit.app_session import AppSession
LOGGER = get_logger(__name__)
def populate_hash_if_needed(msg: ForwardMsg) -> str:
"""Computes and assigns the unique hash for a ForwardMsg.
If the ForwardMsg already has a hash, this is a no-op.
Parameters
----------
msg : ForwardMsg
Returns
-------
string
The message's hash, returned here for convenience. (The hash
will also be assigned to the ForwardMsg; callers do not need
to do this.)
"""
if msg.hash == "":
# Move the message's metadata aside. It's not part of the
# hash calculation.
metadata = msg.metadata
msg.ClearField("metadata")
# MD5 is good enough for what we need, which is uniqueness.
hasher = hashlib.md5()
hasher.update(msg.SerializeToString())
msg.hash = hasher.hexdigest()
# Restore metadata.
msg.metadata.CopyFrom(metadata)
return msg.hash
def create_reference_msg(msg: ForwardMsg) -> ForwardMsg:
"""Create a ForwardMsg that refers to the given message via its hash.
The reference message will also get a copy of the source message's
metadata.
Parameters
----------
msg : ForwardMsg
The ForwardMsg to create the reference to.
Returns
-------
ForwardMsg
A new ForwardMsg that "points" to the original message via the
ref_hash field.
"""
ref_msg = ForwardMsg()
ref_msg.ref_hash = populate_hash_if_needed(msg)
ref_msg.metadata.CopyFrom(msg.metadata)
return ref_msg
class ForwardMsgCache(CacheStatsProvider):
"""A cache of ForwardMsgs.
Large ForwardMsgs (e.g. those containing big DataFrame payloads) are
stored in this cache. The server can choose to send a ForwardMsg's hash,
rather than the message itself, to a client. Clients can then
request messages from this cache via another endpoint.
This cache is *not* thread safe. It's intended to only be accessed by
the server thread.
"""
class Entry:
"""Cache entry.
Stores the cached message, and the set of AppSessions
that we've sent the cached message to.
"""
def __init__(self, msg: ForwardMsg):
self.msg = msg
self._session_script_run_counts: MutableMapping[
"AppSession", int
] = WeakKeyDictionary()
def __repr__(self) -> str:
return util.repr_(self)
def add_session_ref(self, session: "AppSession", script_run_count: int) -> None:
"""Adds a reference to a AppSession that has referenced
this Entry's message.
Parameters
----------
session : AppSession
script_run_count : int
The session's run count at the time of the call
"""
prev_run_count = self._session_script_run_counts.get(session, 0)
if script_run_count < prev_run_count:
LOGGER.error(
"New script_run_count (%s) is < prev_run_count (%s). "
"This should never happen!" % (script_run_count, prev_run_count)
)
script_run_count = prev_run_count
self._session_script_run_counts[session] = script_run_count
def has_session_ref(self, session: "AppSession") -> bool:
return session in self._session_script_run_counts
def get_session_ref_age(
self, session: "AppSession", script_run_count: int
) -> int:
"""The age of the given session's reference to the Entry,
given a new script_run_count.
"""
return script_run_count - self._session_script_run_counts[session]
def remove_session_ref(self, session: "AppSession") -> None:
del self._session_script_run_counts[session]
def has_refs(self) -> bool:
"""True if this Entry has references from any AppSession.
If not, it can be removed from the cache.
"""
return len(self._session_script_run_counts) > 0
def __init__(self):
self._entries: Dict[str, "ForwardMsgCache.Entry"] = {}
def __repr__(self) -> str:
return util.repr_(self)
def add_message(
self, msg: ForwardMsg, session: "AppSession", script_run_count: int
) -> None:
"""Add a ForwardMsg to the cache.
The cache will also record a reference to the given AppSession,
so that it can track which sessions have already received
each given ForwardMsg.
Parameters
----------
msg : ForwardMsg
session : AppSession
script_run_count : int
The number of times the session's script has run
"""
populate_hash_if_needed(msg)
entry = self._entries.get(msg.hash, None)
if entry is None:
entry = ForwardMsgCache.Entry(msg)
self._entries[msg.hash] = entry
entry.add_session_ref(session, script_run_count)
def get_message(self, hash: str) -> Optional[ForwardMsg]:
"""Return the message with the given ID if it exists in the cache.
Parameters
----------
hash : string
The id of the message to retrieve.
Returns
-------
ForwardMsg | None
"""
entry = self._entries.get(hash, None)
return entry.msg if entry else None
def has_message_reference(
self, msg: ForwardMsg, session: "AppSession", script_run_count: int
) -> bool:
"""Return True if a session has a reference to a message."""
populate_hash_if_needed(msg)
entry = self._entries.get(msg.hash, None)
if entry is None or not entry.has_session_ref(session):
return False
# Ensure we're not expired
age = entry.get_session_ref_age(session, script_run_count)
return age <= int(config.get_option("global.maxCachedMessageAge"))
def remove_expired_session_entries(
self, session: "AppSession", script_run_count: int
) -> None:
"""Remove any cached messages that have expired from the given session.
This should be called each time a AppSession finishes executing.
Parameters
----------
session : AppSession
script_run_count : int
The number of times the session's script has run
"""
max_age = config.get_option("global.maxCachedMessageAge")
# Operate on a copy of our entries dict.
# We may be deleting from it.
for msg_hash, entry in self._entries.copy().items():
if not entry.has_session_ref(session):
continue
age = entry.get_session_ref_age(session, script_run_count)
if age > max_age:
LOGGER.debug(
"Removing expired entry [session=%s, hash=%s, age=%s]",
id(session),
msg_hash,
age,
)
entry.remove_session_ref(session)
if not entry.has_refs():
# The entry has no more references. Remove it from
# the cache completely.
del self._entries[msg_hash]
def clear(self) -> None:
"""Remove all entries from the cache"""
self._entries.clear()
def get_stats(self) -> List[CacheStat]:
stats: List[CacheStat] = []
for entry_hash, entry in self._entries.items():
stats.append(
CacheStat(
category_name="ForwardMessageCache",
cache_name="",
byte_length=entry.msg.ByteSize(),
)
)
return stats

View File

@ -0,0 +1,142 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, List, Dict, Any, Tuple
from streamlit.logger import get_logger
from streamlit.proto.Delta_pb2 import Delta
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
LOGGER = get_logger(__name__)
class ForwardMsgQueue:
"""Accumulates a session's outgoing ForwardMsgs.
Each AppSession adds messages to its queue, and the Server periodically
flushes all session queues and delivers their messages to the appropriate
clients.
ForwardMsgQueue is not thread-safe - a queue should only be used from
a single thread.
"""
def __init__(self):
self._queue: List[ForwardMsg] = []
# A mapping of (delta_path -> _queue.indexof(msg)) for each
# Delta message in the queue. We use this for coalescing
# redundant outgoing Deltas (where a newer Delta supercedes
# an older Delta, with the same delta_path, that's still in the
# queue).
self._delta_index_map: Dict[Tuple[int, ...], int] = dict()
def get_debug(self) -> Dict[str, Any]:
from google.protobuf.json_format import MessageToDict
return {
"queue": [MessageToDict(m) for m in self._queue],
"ids": list(self._delta_index_map.keys()),
}
def is_empty(self) -> bool:
return len(self._queue) == 0
def enqueue(self, msg: ForwardMsg) -> None:
"""Add message into queue, possibly composing it with another message."""
if not _is_composable_message(msg):
self._queue.append(msg)
return
# If there's a Delta message with the same delta_path already in
# the queue - meaning that it refers to the same location in
# the app - we attempt to combine this new Delta into the old
# one. This is an optimization that prevents redundant Deltas
# from being sent to the frontend.
delta_key = tuple(msg.metadata.delta_path)
if delta_key in self._delta_index_map:
index = self._delta_index_map[delta_key]
old_msg = self._queue[index]
composed_delta = _maybe_compose_deltas(old_msg.delta, msg.delta)
if composed_delta is not None:
new_msg = ForwardMsg()
new_msg.delta.CopyFrom(composed_delta)
new_msg.metadata.CopyFrom(msg.metadata)
self._queue[index] = new_msg
return
# No composition occured. Append this message to the queue, and
# store its index for potential future composition.
self._delta_index_map[delta_key] = len(self._queue)
self._queue.append(msg)
def clear(self) -> None:
"""Clear the queue."""
self._queue = []
self._delta_index_map = dict()
def flush(self) -> List[ForwardMsg]:
"""Clear the queue and return a list of the messages it contained
before being cleared."""
queue = self._queue
self.clear()
return queue
def _is_composable_message(msg: ForwardMsg) -> bool:
"""True if the ForwardMsg is potentially composable with other ForwardMsgs."""
if not msg.HasField("delta"):
# Non-delta messages are never composable.
return False
# We never compose add_rows messages in Python, because the add_rows
# operation can raise errors, and we don't have a good way of handling
# those errors in the message queue.
delta_type = msg.delta.WhichOneof("type")
return delta_type != "add_rows" and delta_type != "arrow_add_rows"
def _maybe_compose_deltas(old_delta: Delta, new_delta: Delta) -> Optional[Delta]:
"""Combines new_delta onto old_delta if possible.
If the combination takes place, the function returns a new Delta that
should replace old_delta in the queue.
If the new_delta is incompatible with old_delta, the function returns None.
In this case, the new_delta should just be appended to the queue as normal.
"""
old_delta_type = old_delta.WhichOneof("type")
if old_delta_type == "add_block":
# We never replace add_block deltas, because blocks can have
# other dependent deltas later in the queue. For example:
#
# placeholder = st.empty()
# placeholder.columns(1)
# placeholder.empty()
#
# The call to "placeholder.columns(1)" creates two blocks, a parent
# container with delta_path (0, 0), and a column child with
# delta_path (0, 0, 0). If the final "placeholder.empty()" Delta
# is composed with the parent container Delta, the frontend will
# throw an error when it tries to add that column child to what is
# now just an element, and not a block.
return None
new_delta_type = new_delta.WhichOneof("type")
if new_delta_type == "new_element":
return new_delta
if new_delta_type == "add_block":
return new_delta
return None

View File

@ -0,0 +1,170 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from typing import Optional, Tuple, Any
from streamlit import util
# Github has two URLs, one that is https and one that is ssh
GITHUB_HTTP_URL = r"^https://(www\.)?github.com/(.+)/(.+)(?:.git)?$"
GITHUB_SSH_URL = r"^git@github.com:(.+)/(.+)(?:.git)?$"
# We don't support git < 2.7, because we can't get repo info without
# talking to the remote server, which results in the user being prompted
# for credentials.
MIN_GIT_VERSION = (2, 7, 0)
class GitRepo:
def __init__(self, path):
# If we have a valid repo, git_version will be a tuple of 3+ ints:
# (major, minor, patch, possible_additional_patch_number)
self.git_version = None # type: Optional[Tuple[int, ...]]
try:
import git
# GitPython is not fully typed, and mypy is outputting inconsistent
# type errors on Mac and Linux. We bypass type checking entirely
# by re-declaring the `git` import as an "Any".
git_package: Any = git
self.repo = git_package.Repo(path, search_parent_directories=True)
self.git_version = self.repo.git.version_info
if self.git_version >= MIN_GIT_VERSION:
git_root = self.repo.git.rev_parse("--show-toplevel")
self.module = os.path.relpath(path, git_root)
except:
# The git repo must be invalid for the following reasons:
# * git binary or GitPython not installed
# * No .git folder
# * Corrupted .git folder
# * Path is invalid
self.repo = None
def __repr__(self) -> str:
return util.repr_(self)
def is_valid(self) -> bool:
"""True if there's a git repo here, and git.version >= MIN_GIT_VERSION."""
return (
self.repo is not None
and self.git_version is not None
and self.git_version >= MIN_GIT_VERSION
)
@property
def tracking_branch(self):
if not self.is_valid():
return None
if self.is_head_detached:
return None
return self.repo.active_branch.tracking_branch()
@property
def untracked_files(self):
if not self.is_valid():
return None
return self.repo.untracked_files
@property
def is_head_detached(self):
if not self.is_valid():
return False
return self.repo.head.is_detached
@property
def uncommitted_files(self):
if not self.is_valid():
return None
return [item.a_path for item in self.repo.index.diff(None)]
@property
def ahead_commits(self):
if not self.is_valid():
return None
try:
remote, branch_name = self.get_tracking_branch_remote()
remote_branch = "/".join([remote.name, branch_name])
return list(self.repo.iter_commits(f"{remote_branch}..{branch_name}"))
except:
return list()
def get_tracking_branch_remote(self):
if not self.is_valid():
return None
tracking_branch = self.tracking_branch
if tracking_branch is None:
return None
remote_name, *branch = tracking_branch.name.split("/")
branch_name = "/".join(branch)
return self.repo.remote(remote_name), branch_name
def is_github_repo(self):
if not self.is_valid():
return False
remote_info = self.get_tracking_branch_remote()
if remote_info is None:
return False
remote, _branch = remote_info
for url in remote.urls:
if (
re.match(GITHUB_HTTP_URL, url) is not None
or re.match(GITHUB_SSH_URL, url) is not None
):
return True
return False
def get_repo_info(self):
if not self.is_valid():
return None
remote_info = self.get_tracking_branch_remote()
if remote_info is None:
return None
remote, branch = remote_info
repo = None
for url in remote.urls:
https_matches = re.match(GITHUB_HTTP_URL, url)
ssh_matches = re.match(GITHUB_SSH_URL, url)
if https_matches is not None:
repo = f"{https_matches.group(2)}/{https_matches.group(3)}"
break
if ssh_matches is not None:
repo = f"{ssh_matches.group(1)}/{ssh_matches.group(2)}"
break
if repo is None:
return None
return repo, branch, self.module

View File

@ -0,0 +1,13 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,271 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
def intro():
import streamlit as st
st.sidebar.success("Select a demo above.")
st.markdown(
"""
Streamlit is an open-source app framework built specifically for
Machine Learning and Data Science projects.
**👈 Select a demo from the dropdown on the left** to see some examples
of what Streamlit can do!
### Want to learn more?
- Check out [streamlit.io](https://streamlit.io)
- Jump into our [documentation](https://docs.streamlit.io)
- Ask a question in our [community
forums](https://discuss.streamlit.io)
### See more complex demos
- Use a neural net to [analyze the Udacity Self-driving Car Image
Dataset](https://github.com/streamlit/demo-self-driving)
- Explore a [New York City rideshare dataset](https://github.com/streamlit/demo-uber-nyc-pickups)
"""
)
# Turn off black formatting for this function to present the user with more
# compact code.
# fmt: off
def mapping_demo():
import streamlit as st
import pandas as pd
import pydeck as pdk
from urllib.error import URLError
@st.cache
def from_data_file(filename):
url = (
"http://raw.githubusercontent.com/streamlit/"
"example-data/master/hello/v1/%s" % filename)
return pd.read_json(url)
try:
ALL_LAYERS = {
"Bike Rentals": pdk.Layer(
"HexagonLayer",
data=from_data_file("bike_rental_stats.json"),
get_position=["lon", "lat"],
radius=200,
elevation_scale=4,
elevation_range=[0, 1000],
extruded=True,
),
"Bart Stop Exits": pdk.Layer(
"ScatterplotLayer",
data=from_data_file("bart_stop_stats.json"),
get_position=["lon", "lat"],
get_color=[200, 30, 0, 160],
get_radius="[exits]",
radius_scale=0.05,
),
"Bart Stop Names": pdk.Layer(
"TextLayer",
data=from_data_file("bart_stop_stats.json"),
get_position=["lon", "lat"],
get_text="name",
get_color=[0, 0, 0, 200],
get_size=15,
get_alignment_baseline="'bottom'",
),
"Outbound Flow": pdk.Layer(
"ArcLayer",
data=from_data_file("bart_path_stats.json"),
get_source_position=["lon", "lat"],
get_target_position=["lon2", "lat2"],
get_source_color=[200, 30, 0, 160],
get_target_color=[200, 30, 0, 160],
auto_highlight=True,
width_scale=0.0001,
get_width="outbound",
width_min_pixels=3,
width_max_pixels=30,
),
}
st.sidebar.markdown('### Map Layers')
selected_layers = [
layer for layer_name, layer in ALL_LAYERS.items()
if st.sidebar.checkbox(layer_name, True)]
if selected_layers:
st.pydeck_chart(pdk.Deck(
map_style="mapbox://styles/mapbox/light-v9",
initial_view_state={"latitude": 37.76,
"longitude": -122.4, "zoom": 11, "pitch": 50},
layers=selected_layers,
))
else:
st.error("Please choose at least one layer above.")
except URLError as e:
st.error("""
**This demo requires internet access.**
Connection error: %s
""" % e.reason)
# fmt: on
# Turn off black formatting for this function to present the user with more
# compact code.
# fmt: off
def fractal_demo():
import streamlit as st
import numpy as np
# Interactive Streamlit elements, like these sliders, return their value.
# This gives you an extremely simple interaction model.
iterations = st.sidebar.slider("Level of detail", 2, 20, 10, 1)
separation = st.sidebar.slider("Separation", 0.7, 2.0, 0.7885)
# Non-interactive elements return a placeholder to their location
# in the app. Here we're storing progress_bar to update it later.
progress_bar = st.sidebar.progress(0)
# These two elements will be filled in later, so we create a placeholder
# for them using st.empty()
frame_text = st.sidebar.empty()
image = st.empty()
m, n, s = 960, 640, 400
x = np.linspace(-m / s, m / s, num=m).reshape((1, m))
y = np.linspace(-n / s, n / s, num=n).reshape((n, 1))
for frame_num, a in enumerate(np.linspace(0.0, 4 * np.pi, 100)):
# Here were setting value for these two elements.
progress_bar.progress(frame_num)
frame_text.text("Frame %i/100" % (frame_num + 1))
# Performing some fractal wizardry.
c = separation * np.exp(1j * a)
Z = np.tile(x, (n, 1)) + 1j * np.tile(y, (1, m))
C = np.full((n, m), c)
M: Any = np.full((n, m), True, dtype=bool)
N = np.zeros((n, m))
for i in range(iterations):
Z[M] = Z[M] * Z[M] + C[M]
M[np.abs(Z) > 2] = False
N[M] = i
# Update the image placeholder by calling the image() function on it.
image.image(1.0 - (N / N.max()), use_column_width=True)
# We clear elements by calling empty on them.
progress_bar.empty()
frame_text.empty()
# Streamlit widgets automatically run the script from top to bottom. Since
# this button is not connected to any other logic, it just causes a plain
# rerun.
st.button("Re-run")
# fmt: on
# Turn off black formatting for this function to present the user with more
# compact code.
# fmt: off
def plotting_demo():
import streamlit as st
import time
import numpy as np
progress_bar = st.sidebar.progress(0)
status_text = st.sidebar.empty()
last_rows = np.random.randn(1, 1)
chart = st.line_chart(last_rows)
for i in range(1, 101):
new_rows = last_rows[-1, :] + np.random.randn(5, 1).cumsum(axis=0)
status_text.text("%i%% Complete" % i)
chart.add_rows(new_rows)
progress_bar.progress(i)
last_rows = new_rows
time.sleep(0.05)
progress_bar.empty()
# Streamlit widgets automatically run the script from top to bottom. Since
# this button is not connected to any other logic, it just causes a plain
# rerun.
st.button("Re-run")
# fmt: on
# Turn off black formatting for this function to present the user with more
# compact code.
# fmt: off
def data_frame_demo():
import streamlit as st
import pandas as pd
import altair as alt
from urllib.error import URLError
@st.cache
def get_UN_data():
AWS_BUCKET_URL = "http://streamlit-demo-data.s3-us-west-2.amazonaws.com"
df = pd.read_csv(AWS_BUCKET_URL + "/agri.csv.gz")
return df.set_index("Region")
try:
df = get_UN_data()
countries = st.multiselect(
"Choose countries", list(df.index), ["China", "United States of America"]
)
if not countries:
st.error("Please select at least one country.")
else:
data = df.loc[countries]
data /= 1000000.0
st.write("### Gross Agricultural Production ($B)", data.sort_index())
data = data.T.reset_index()
data = pd.melt(data, id_vars=["index"]).rename(
columns={"index": "year", "value": "Gross Agricultural Product ($B)"}
)
chart = (
alt.Chart(data)
.mark_area(opacity=0.3)
.encode(
x="year:T",
y=alt.Y("Gross Agricultural Product ($B):Q", stack=None),
color="Region:N",
)
)
st.altair_chart(chart, use_container_width=True)
except URLError as e:
st.error(
"""
**This demo requires internet access.**
Connection error: %s
"""
% e.reason
)
# fmt: on

View File

@ -0,0 +1,105 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import textwrap
from collections import OrderedDict
import streamlit as st
from streamlit.logger import get_logger
from streamlit.hello import demos
LOGGER = get_logger(__name__)
# Dictionary of
# demo_name -> (demo_function, demo_description)
DEMOS = OrderedDict(
[
("", (demos.intro, None)),
(
"Animation Demo",
(
demos.fractal_demo,
"""
This app shows how you can use Streamlit to build cool animations.
It displays an animated fractal based on the the Julia Set. Use the slider
to tune different parameters.
""",
),
),
(
"Plotting Demo",
(
demos.plotting_demo,
"""
This demo illustrates a combination of plotting and animation with
Streamlit. We're generating a bunch of random numbers in a loop for around
5 seconds. Enjoy!
""",
),
),
(
"Mapping Demo",
(
demos.mapping_demo,
"""
This demo shows how to use
[`st.pydeck_chart`](https://docs.streamlit.io/library/api-reference/charts/st.pydeck_chart)
to display geospatial data.
""",
),
),
(
"DataFrame Demo",
(
demos.data_frame_demo,
"""
This demo shows how to use `st.write` to visualize Pandas DataFrames.
(Data courtesy of the [UN Data Explorer](http://data.un.org/Explorer.aspx).)
""",
),
),
]
)
def run():
demo_name = st.sidebar.selectbox("Choose a demo", list(DEMOS.keys()), 0)
demo = DEMOS[demo_name][0]
if demo_name == "":
show_code = False
st.write("# Welcome to Streamlit! 👋")
else:
show_code = st.sidebar.checkbox("Show code", True)
st.markdown("# %s" % demo_name)
description = DEMOS[demo_name][1]
if description:
st.write(description)
# Clear everything from the intro page.
# We only have 4 elements in the page so this is intentional overkill.
for i in range(10):
st.empty()
demo()
if show_code:
st.markdown("## Code")
sourcelines, _ = inspect.getsourcelines(demo)
st.code(textwrap.dedent("".join(sourcelines[1:])))
if __name__ == "__main__":
run()

View File

@ -0,0 +1,328 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides global InMemoryFileManager object as `in_memory_file_manager`."""
from typing import Dict, Set, Optional, List
import collections
import hashlib
import mimetypes
from streamlit.logger import get_logger
from streamlit import util
from streamlit.stats import CacheStatsProvider, CacheStat
LOGGER = get_logger(__name__)
STATIC_MEDIA_ENDPOINT = "/media"
PREFERRED_MIMETYPE_EXTENSION_MAP = {
"image/jpeg": ".jpeg",
"audio/wav": ".wav",
}
# used for images and videos in st.image() and st.video()
FILE_TYPE_MEDIA = "media_file"
# used for st.download_button files
FILE_TYPE_DOWNLOADABLE = "downloadable_file"
def _get_session_id() -> str:
"""Semantic wrapper to retrieve current AppSession ID."""
from streamlit.scriptrunner import get_script_run_ctx
ctx = get_script_run_ctx()
if ctx is None:
# This is only None when running "python myscript.py" rather than
# "streamlit run myscript.py". In which case the session ID doesn't
# matter and can just be a constant, as there's only ever "session".
return "dontcare"
else:
return ctx.session_id
def _calculate_file_id(
data: bytes, mimetype: str, file_name: Optional[str] = None
) -> str:
"""Return an ID by hashing the data and mime.
Parameters
----------
data : bytes
Content of in-memory file in bytes. Other types will throw TypeError.
mimetype : str
Any string. Will be converted to bytes and used to compute a hash.
None will be converted to empty string. [default: None]
file_name : str
Any string. Will be converted to bytes and used to compute a hash.
None will be converted to empty string. [default: None]
"""
filehash = hashlib.new("sha224")
filehash.update(data)
filehash.update(bytes(mimetype.encode()))
if file_name is not None:
filehash.update(bytes(file_name.encode()))
return filehash.hexdigest()
def _get_extension_for_mimetype(mimetype: str) -> str:
# Python mimetypes preference was changed in Python versions, so we specify
# a preference first and let Python's mimetypes library guess the rest.
# See https://bugs.python.org/issue4963
#
# Note: Removing Python 3.6 support would likely eliminate this code
if mimetype in PREFERRED_MIMETYPE_EXTENSION_MAP:
return PREFERRED_MIMETYPE_EXTENSION_MAP[mimetype]
extension = mimetypes.guess_extension(mimetype)
if extension is None:
return ""
return extension
class InMemoryFile:
"""Abstraction for file objects."""
def __init__(
self,
file_id: str,
content: bytes,
mimetype: str,
file_name: Optional[str] = None,
file_type: str = FILE_TYPE_MEDIA,
):
self._file_id = file_id
self._content = content
self._mimetype = mimetype
self._file_name = file_name
self._file_type = file_type
self._is_marked_for_delete = False
def __repr__(self) -> str:
return util.repr_(self)
@property
def url(self) -> str:
extension = _get_extension_for_mimetype(self._mimetype)
return f"{STATIC_MEDIA_ENDPOINT}/{self.id}{extension}"
@property
def id(self) -> str:
return self._file_id
@property
def content(self) -> bytes:
return self._content
@property
def mimetype(self) -> str:
return self._mimetype
@property
def content_size(self) -> int:
return len(self._content)
@property
def file_type(self) -> str:
return self._file_type
@property
def file_name(self) -> Optional[str]:
return self._file_name
def _mark_for_delete(self) -> None:
self._is_marked_for_delete = True
class InMemoryFileManager(CacheStatsProvider):
"""In-memory file manager for InMemoryFile objects.
This keeps track of:
- Which files exist, and what their IDs are. This is important so we can
serve files by ID -- that's the whole point of this class!
- Which files are being used by which AppSession (by ID). This is
important so we can remove files from memory when no more sessions need
them.
- The exact location in the app where each file is being used (i.e. the
file's "coordinates"). This is is important so we can mark a file as "not
being used by a certain session" if it gets replaced by another file at
the same coordinates. For example, when doing an animation where the same
image is constantly replace with new frames. (This doesn't solve the case
where the file's coordinates keep changing for some reason, though! e.g.
if new elements keep being prepended to the app. Unlikely to happen, but
we should address it at some point.)
"""
def __init__(self):
# Dict of file ID to InMemoryFile.
self._files_by_id: Dict[str, InMemoryFile] = dict()
# Dict[session ID][coordinates] -> InMemoryFile.
self._files_by_session_and_coord: Dict[
str, Dict[str, InMemoryFile]
] = collections.defaultdict(dict)
def __repr__(self) -> str:
return util.repr_(self)
def del_expired_files(self) -> None:
LOGGER.debug("Deleting expired files...")
# Get a flat set of every file ID in the session ID map.
active_file_ids: Set[str] = set()
for files_by_coord in self._files_by_session_and_coord.values():
file_ids = map(lambda imf: imf.id, files_by_coord.values())
active_file_ids = active_file_ids.union(file_ids)
for file_id, imf in list(self._files_by_id.items()):
if imf.id not in active_file_ids:
if imf.file_type == FILE_TYPE_MEDIA:
LOGGER.debug(f"Deleting File: {file_id}")
del self._files_by_id[file_id]
elif imf.file_type == FILE_TYPE_DOWNLOADABLE:
if imf._is_marked_for_delete:
LOGGER.debug(f"Deleting File: {file_id}")
del self._files_by_id[file_id]
else:
imf._mark_for_delete()
def clear_session_files(self, session_id: Optional[str] = None) -> None:
"""Removes AppSession-coordinate mapping immediately, and id-file mapping later.
Should be called whenever ScriptRunner completes and when a session ends.
"""
if session_id is None:
session_id = _get_session_id()
LOGGER.debug("Disconnecting files for session with ID %s", session_id)
if session_id in self._files_by_session_and_coord:
del self._files_by_session_and_coord[session_id]
LOGGER.debug(
"Sessions still active: %r", self._files_by_session_and_coord.keys()
)
LOGGER.debug(
"Files: %s; Sessions with files: %s",
len(self._files_by_id),
len(self._files_by_session_and_coord),
)
def add(
self,
content: bytes,
mimetype: str,
coordinates: str,
file_name: Optional[str] = None,
is_for_static_download: bool = False,
) -> InMemoryFile:
"""Adds new InMemoryFile with given parameters; returns the object.
If an identical file already exists, returns the existing object
and registers the current session as a user.
mimetype must be set, as this string will be used in the
"Content-Type" header when the file is sent via HTTP GET.
coordinates should look like this: "1.(3.-14).5"
Parameters
----------
content : bytes
Raw data to store in file object.
mimetype : str
The mime type for the in-memory file. E.g. "audio/mpeg"
coordinates : str
Unique string identifying an element's location.
Prevents memory leak of "forgotten" file IDs when element media
is being replaced-in-place (e.g. an st.image stream).
file_name : str
Optional file_name. Used to set filename in response header. [default: None]
is_for_static_download: bool
Indicate that data stored for downloading as a file,
not as a media for rendering at page. [default: None]
"""
file_id = _calculate_file_id(content, mimetype, file_name=file_name)
imf = self._files_by_id.get(file_id, None)
if imf is None:
LOGGER.debug("Adding media file %s", file_id)
if is_for_static_download:
file_type = FILE_TYPE_DOWNLOADABLE
else:
file_type = FILE_TYPE_MEDIA
imf = InMemoryFile(
file_id=file_id,
content=content,
mimetype=mimetype,
file_name=file_name,
file_type=file_type,
)
else:
LOGGER.debug("Overwriting media file %s", file_id)
session_id = _get_session_id()
self._files_by_id[imf.id] = imf
self._files_by_session_and_coord[session_id][coordinates] = imf
LOGGER.debug(
"Files: %s; Sessions with files: %s",
len(self._files_by_id),
len(self._files_by_session_and_coord),
)
return imf
def get(self, inmemory_filename: str) -> InMemoryFile:
"""Returns InMemoryFile object for given file_id or InMemoryFile object.
Raises KeyError if not found.
"""
# Filename is {requested_hash}.{extension} but InMemoryFileManager
# is indexed by requested_hash.
hash = inmemory_filename.split(".")[0]
return self._files_by_id[hash]
def get_stats(self) -> List[CacheStat]:
# We operate on a copy of our dict, to avoid race conditions
# with other threads that may be manipulating the cache.
files_by_id = self._files_by_id.copy()
stats: List[CacheStat] = []
for file_id, file in files_by_id.items():
stats.append(
CacheStat(
category_name="st_in_memory_file_manager",
cache_name="",
byte_length=file.content_size,
)
)
return stats
def __contains__(self, inmemory_file_or_id):
return inmemory_file_or_id in self._files_by_id
def __len__(self):
return len(self._files_by_id)
in_memory_file_manager = InMemoryFileManager()

View File

@ -0,0 +1,110 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from typing import Optional, Union
class JSNumberBoundsException(Exception):
pass
class JSNumber(object):
"""
Utility class. Exposes JavaScript Number constants.
"""
# The largest int that can be represented with perfect precision
# in JavaScript.
MAX_SAFE_INTEGER = (1 << 53) - 1
# The smallest int that can be represented with perfect precision
# in JavaScript.
MIN_SAFE_INTEGER = -((1 << 53) - 1)
# The largest float that can be represented in JavaScript.
MAX_VALUE = 1.7976931348623157e308
# The closest number to zero that can be represented in JavaScript.
MIN_VALUE = 5e-324
# The largest negative float that can be represented in JavaScript.
MIN_NEGATIVE_VALUE = -MAX_VALUE
@classmethod
def validate_int_bounds(cls, value: int, value_name: Optional[str] = None) -> None:
"""Validate that an int value can be represented with perfect precision
by a JavaScript Number.
Parameters
----------
value : int
value_name : str or None
The name of the value parameter. If specified, this will be used
in any exception that is thrown.
Raises
-------
JSNumberBoundsException
Raised with a human-readable explanation if the value falls outside
JavaScript int bounds.
"""
if value_name is None:
value_name = "value"
if value < cls.MIN_SAFE_INTEGER:
raise JSNumberBoundsException(
"%s (%s) must be >= -((1 << 53) - 1)" % (value_name, value)
)
elif value > cls.MAX_SAFE_INTEGER:
raise JSNumberBoundsException(
"%s (%s) must be <= (1 << 53) - 1" % (value_name, value)
)
@classmethod
def validate_float_bounds(
cls, value: Union[int, float], value_name: Optional[str]
) -> None:
"""Validate that a float value can be represented by a JavaScript Number.
Parameters
----------
value : float
value_name : str or None
The name of the value parameter. If specified, this will be used
in any exception that is thrown.
Raises
-------
JSNumberBoundsException
Raised with a human-readable explanation if the value falls outside
JavaScript float bounds.
"""
if value_name is None:
value_name = "value"
if not isinstance(value, (numbers.Integral, float)):
raise JSNumberBoundsException(
"%s (%s) is not a float" % (value_name, value)
)
elif value < cls.MIN_NEGATIVE_VALUE:
raise JSNumberBoundsException(
"%s (%s) must be >= -1.797e+308" % (value_name, value)
)
elif value > cls.MAX_VALUE:
raise JSNumberBoundsException(
"%s (%s) must be <= 1.797e+308" % (value_name, value)
)

View File

@ -0,0 +1,20 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .caching import cache as cache
from .caching import clear_cache as clear_cache
from .caching import get_cache_path as get_cache_path
from .caching import (
maybe_show_cached_st_function_warning as maybe_show_cached_st_function_warning,
)

View File

@ -0,0 +1,754 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A library of caching utilities."""
import contextlib
import functools
import hashlib
import inspect
import math
import os
import pickle
import shutil
import threading
import time
import types
from collections import namedtuple
from typing import Dict, Optional, List, Iterator, Any, Callable
import attr
from cachetools import TTLCache
from pympler.asizeof import asizeof
from streamlit import config
from streamlit import file_util
from streamlit import util
from streamlit.error_util import handle_uncaught_app_exception
from streamlit.errors import StreamlitAPIWarning
from streamlit.legacy_caching.hashing import update_hash, HashFuncsDict
from streamlit.legacy_caching.hashing import HashReason
from streamlit.logger import get_logger
import streamlit as st
from streamlit.stats import CacheStat, CacheStatsProvider
_LOGGER = get_logger(__name__)
# The timer function we use with TTLCache. This is the default timer func, but
# is exposed here as a constant so that it can be patched in unit tests.
_TTLCACHE_TIMER = time.monotonic
_CacheEntry = namedtuple("_CacheEntry", ["value", "hash"])
_DiskCacheEntry = namedtuple("_DiskCacheEntry", ["value"])
@attr.s(auto_attribs=True, slots=True)
class MemCache:
cache: TTLCache
display_name: str
class _MemCaches(CacheStatsProvider):
"""Manages all in-memory st.cache caches"""
def __init__(self):
# Contains a cache object for each st.cache'd function
self._lock = threading.RLock()
self._function_caches: Dict[str, MemCache] = {}
def __repr__(self) -> str:
return util.repr_(self)
def get_cache(
self,
key: str,
max_entries: Optional[float],
ttl: Optional[float],
display_name: str = "",
) -> MemCache:
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
if max_entries is None:
max_entries = math.inf
if ttl is None:
ttl = math.inf
if not isinstance(max_entries, (int, float)):
raise RuntimeError("max_entries must be an int")
if not isinstance(ttl, (int, float)):
raise RuntimeError("ttl must be a float")
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._lock:
mem_cache = self._function_caches.get(key)
if (
mem_cache is not None
and mem_cache.cache.ttl == ttl
and mem_cache.cache.maxsize == max_entries
):
return mem_cache
# Create a new cache object and put it in our dict
_LOGGER.debug(
"Creating new mem_cache (key=%s, max_entries=%s, ttl=%s)",
key,
max_entries,
ttl,
)
ttl_cache = TTLCache(maxsize=max_entries, ttl=ttl, timer=_TTLCACHE_TIMER)
mem_cache = MemCache(ttl_cache, display_name)
self._function_caches[key] = mem_cache
return mem_cache
def clear(self) -> None:
"""Clear all caches"""
with self._lock:
self._function_caches = {}
def get_stats(self) -> List[CacheStat]:
with self._lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats = [
CacheStat("st_cache", cache.display_name, asizeof(c))
for cache in function_caches.values()
for c in cache.cache
]
return stats
# Our singleton _MemCaches instance
_mem_caches = _MemCaches()
# A thread-local counter that's incremented when we enter @st.cache
# and decremented when we exit.
class ThreadLocalCacheInfo(threading.local):
def __init__(self):
self.cached_func_stack: List[types.FunctionType] = []
self.suppress_st_function_warning = 0
def __repr__(self) -> str:
return util.repr_(self)
_cache_info = ThreadLocalCacheInfo()
@contextlib.contextmanager
def _calling_cached_function(func: types.FunctionType) -> Iterator[None]:
_cache_info.cached_func_stack.append(func)
try:
yield
finally:
_cache_info.cached_func_stack.pop()
@contextlib.contextmanager
def suppress_cached_st_function_warning() -> Iterator[None]:
_cache_info.suppress_st_function_warning += 1
try:
yield
finally:
_cache_info.suppress_st_function_warning -= 1
assert _cache_info.suppress_st_function_warning >= 0
def _show_cached_st_function_warning(
dg: "st.delta_generator.DeltaGenerator",
st_func_name: str,
cached_func: types.FunctionType,
) -> None:
# Avoid infinite recursion by suppressing additional cached
# function warnings from within the cached function warning.
with suppress_cached_st_function_warning():
e = CachedStFunctionWarning(st_func_name, cached_func)
dg.exception(e)
def maybe_show_cached_st_function_warning(
dg: "st.delta_generator.DeltaGenerator", st_func_name: str
) -> None:
"""If appropriate, warn about calling st.foo inside @cache.
DeltaGenerator's @_with_element and @_widget wrappers use this to warn
the user when they're calling st.foo() from within a function that is
wrapped in @st.cache.
Parameters
----------
dg : DeltaGenerator
The DeltaGenerator to publish the warning to.
st_func_name : str
The name of the Streamlit function that was called.
"""
if (
len(_cache_info.cached_func_stack) > 0
and _cache_info.suppress_st_function_warning <= 0
):
cached_func = _cache_info.cached_func_stack[-1]
_show_cached_st_function_warning(dg, st_func_name, cached_func)
def _read_from_mem_cache(
mem_cache: MemCache,
key: str,
allow_output_mutation: bool,
func_or_code: Callable[..., Any],
hash_funcs: Optional[HashFuncsDict],
) -> Any:
cache = mem_cache.cache
if key in cache:
entry = cache[key]
if not allow_output_mutation:
computed_output_hash = _get_output_hash(
entry.value, func_or_code, hash_funcs
)
stored_output_hash = entry.hash
if computed_output_hash != stored_output_hash:
_LOGGER.debug("Cached object was mutated: %s", key)
raise CachedObjectMutationError(entry.value, func_or_code)
_LOGGER.debug("Memory cache HIT: %s", type(entry.value))
return entry.value
else:
_LOGGER.debug("Memory cache MISS: %s", key)
raise CacheKeyNotFoundError("Key not found in mem cache")
def _write_to_mem_cache(
mem_cache: MemCache,
key: str,
value: Any,
allow_output_mutation: bool,
func_or_code: Callable[..., Any],
hash_funcs: Optional[HashFuncsDict],
) -> None:
if allow_output_mutation:
hash = None
else:
hash = _get_output_hash(value, func_or_code, hash_funcs)
mem_cache.display_name = f"{func_or_code.__module__}.{func_or_code.__qualname__}"
mem_cache.cache[key] = _CacheEntry(value=value, hash=hash)
def _get_output_hash(
value: Any, func_or_code: Callable[..., Any], hash_funcs: Optional[HashFuncsDict]
) -> bytes:
hasher = hashlib.new("md5")
update_hash(
value,
hasher=hasher,
hash_funcs=hash_funcs,
hash_reason=HashReason.CACHING_FUNC_OUTPUT,
hash_source=func_or_code,
)
return hasher.digest()
def _read_from_disk_cache(key: str) -> Any:
path = file_util.get_streamlit_file_path("cache", "%s.pickle" % key)
try:
with file_util.streamlit_read(path, binary=True) as input:
entry = pickle.load(input)
value = entry.value
_LOGGER.debug("Disk cache HIT: %s", type(value))
except util.Error as e:
_LOGGER.error(e)
raise CacheError("Unable to read from cache: %s" % e)
except FileNotFoundError:
raise CacheKeyNotFoundError("Key not found in disk cache")
return value
def _write_to_disk_cache(key: str, value: Any) -> None:
path = file_util.get_streamlit_file_path("cache", "%s.pickle" % key)
try:
with file_util.streamlit_write(path, binary=True) as output:
entry = _DiskCacheEntry(value=value)
pickle.dump(entry, output, pickle.HIGHEST_PROTOCOL)
except util.Error as e:
_LOGGER.debug(e)
# Clean up file so we don't leave zero byte files.
try:
os.remove(path)
except (FileNotFoundError, IOError, OSError):
pass
raise CacheError("Unable to write to cache: %s" % e)
def _read_from_cache(
mem_cache: MemCache,
key: str,
persist: bool,
allow_output_mutation: bool,
func_or_code: Callable[..., Any],
hash_funcs: Optional[HashFuncsDict] = None,
) -> Any:
"""Read a value from the cache.
Our goal is to read from memory if possible. If the data was mutated (hash
changed), we show a warning. If reading from memory fails, we either read
from disk or rerun the code.
"""
try:
return _read_from_mem_cache(
mem_cache, key, allow_output_mutation, func_or_code, hash_funcs
)
except CachedObjectMutationError as e:
handle_uncaught_app_exception(CachedObjectMutationWarning(e))
return e.cached_value
except CacheKeyNotFoundError as e:
if persist:
value = _read_from_disk_cache(key)
_write_to_mem_cache(
mem_cache, key, value, allow_output_mutation, func_or_code, hash_funcs
)
return value
raise e
def _write_to_cache(
mem_cache: MemCache,
key: str,
value: Any,
persist: bool,
allow_output_mutation: bool,
func_or_code: Callable[..., Any],
hash_funcs: Optional[HashFuncsDict] = None,
):
_write_to_mem_cache(
mem_cache, key, value, allow_output_mutation, func_or_code, hash_funcs
)
if persist:
_write_to_disk_cache(key, value)
def cache(
func=None,
persist=False,
allow_output_mutation=False,
show_spinner=True,
suppress_st_warning=False,
hash_funcs=None,
max_entries=None,
ttl=None,
):
"""Function decorator to memoize function executions.
Parameters
----------
func : callable
The function to cache. Streamlit hashes the function and dependent code.
persist : boolean
Whether to persist the cache on disk.
allow_output_mutation : boolean
Streamlit shows a warning when return values are mutated, as that
can have unintended consequences. This is done by hashing the return value internally.
If you know what you're doing and would like to override this warning, set this to True.
show_spinner : boolean
Enable the spinner. Default is True to show a spinner when there is
a cache miss.
suppress_st_warning : boolean
Suppress warnings about calling Streamlit functions from within
the cached function.
hash_funcs : dict or None
Mapping of types or fully qualified names to hash functions. This is used to override
the behavior of the hasher inside Streamlit's caching mechanism: when the hasher
encounters an object, it will first check to see if its type matches a key in this
dict and, if so, will use the provided function to generate a hash for it. See below
for an example of how this can be used.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. (When a new entry is added to a full cache,
the oldest cached entry will be removed.) The default is None.
ttl : float or None
The maximum number of seconds to keep an entry in the cache, or
None if cache entries should not expire. The default is None.
Example
-------
>>> @st.cache
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
...
>>> d1 = fetch_and_clean_data(DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> d2 = fetch_and_clean_data(DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the data in d1 is the same as in d2.
>>>
>>> d3 = fetch_and_clean_data(DATA_URL_2)
>>> # This is a different URL, so the function executes.
To set the ``persist`` parameter, use this command as follows:
>>> @st.cache(persist=True)
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
To disable hashing return values, set the ``allow_output_mutation`` parameter to ``True``:
>>> @st.cache(allow_output_mutation=True)
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
To override the default hashing behavior, pass a custom hash function.
You can do that by mapping a type (e.g. ``MongoClient``) to a hash function (``id``) like this:
>>> @st.cache(hash_funcs={MongoClient: id})
... def connect_to_database(url):
... return MongoClient(url)
Alternatively, you can map the type's fully-qualified name
(e.g. ``"pymongo.mongo_client.MongoClient"``) to the hash function instead:
>>> @st.cache(hash_funcs={"pymongo.mongo_client.MongoClient": id})
... def connect_to_database(url):
... return MongoClient(url)
"""
_LOGGER.debug("Entering st.cache: %s", func)
# Support passing the params via function decorator, e.g.
# @st.cache(persist=True, allow_output_mutation=True)
if func is None:
return lambda f: cache(
func=f,
persist=persist,
allow_output_mutation=allow_output_mutation,
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
hash_funcs=hash_funcs,
max_entries=max_entries,
ttl=ttl,
)
cache_key = None
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
"""This function wrapper will only call the underlying function in
the case of a cache miss. Cached objects are stored in the cache/
directory."""
if not config.get_option("client.caching"):
_LOGGER.debug("Purposefully skipping cache")
return func(*args, **kwargs)
name = func.__qualname__
if len(args) == 0 and len(kwargs) == 0:
message = "Running `%s()`." % name
else:
message = "Running `%s(...)`." % name
def get_or_create_cached_value():
nonlocal cache_key
if cache_key is None:
# Delay generating the cache key until the first call.
# This way we can see values of globals, including functions
# defined after this one.
# If we generated the key earlier we would only hash those
# globals by name, and miss changes in their code or value.
cache_key = _hash_func(func, hash_funcs)
# First, get the cache that's attached to this function.
# This cache's key is generated (above) from the function's code.
mem_cache = _mem_caches.get_cache(cache_key, max_entries, ttl)
# Next, calculate the key for the value we'll be searching for
# within that cache. This key is generated from both the function's
# code and the arguments that are passed into it. (Even though this
# key is used to index into a per-function cache, it must be
# globally unique, because it is *also* used for a global on-disk
# cache that is *not* per-function.)
value_hasher = hashlib.new("md5")
if args:
update_hash(
args,
hasher=value_hasher,
hash_funcs=hash_funcs,
hash_reason=HashReason.CACHING_FUNC_ARGS,
hash_source=func,
)
if kwargs:
update_hash(
kwargs,
hasher=value_hasher,
hash_funcs=hash_funcs,
hash_reason=HashReason.CACHING_FUNC_ARGS,
hash_source=func,
)
value_key = value_hasher.hexdigest()
# Avoid recomputing the body's hash by just appending the
# previously-computed hash to the arg hash.
value_key = "%s-%s" % (value_key, cache_key)
_LOGGER.debug("Cache key: %s", value_key)
try:
return_value = _read_from_cache(
mem_cache=mem_cache,
key=value_key,
persist=persist,
allow_output_mutation=allow_output_mutation,
func_or_code=func,
hash_funcs=hash_funcs,
)
_LOGGER.debug("Cache hit: %s", func)
except CacheKeyNotFoundError:
_LOGGER.debug("Cache miss: %s", func)
with _calling_cached_function(func):
if suppress_st_warning:
with suppress_cached_st_function_warning():
return_value = func(*args, **kwargs)
else:
return_value = func(*args, **kwargs)
_write_to_cache(
mem_cache=mem_cache,
key=value_key,
value=return_value,
persist=persist,
allow_output_mutation=allow_output_mutation,
func_or_code=func,
hash_funcs=hash_funcs,
)
return return_value
if show_spinner:
with st.spinner(message):
return get_or_create_cached_value()
else:
return get_or_create_cached_value()
# Make this a well-behaved decorator by preserving important function
# attributes.
try:
wrapped_func.__dict__.update(func.__dict__)
except AttributeError:
pass
return wrapped_func
def _hash_func(func: types.FunctionType, hash_funcs: HashFuncsDict) -> str:
# Create the unique key for a function's cache. The cache will be retrieved
# from inside the wrapped function.
#
# A naive implementation would involve simply creating the cache object
# right in the wrapper, which in a normal Python script would be executed
# only once. But in Streamlit, we reload all modules related to a user's
# app when the app is re-run, which means that - among other things - all
# function decorators in the app will be re-run, and so any decorator-local
# objects will be recreated.
#
# Furthermore, our caches can be destroyed and recreated (in response to
# cache clearing, for example), which means that retrieving the function's
# cache in the decorator (so that the wrapped function can save a lookup)
# is incorrect: the cache itself may be recreated between
# decorator-evaluation time and decorated-function-execution time. So we
# must retrieve the cache object *and* perform the cached-value lookup
# inside the decorated function.
func_hasher = hashlib.new("md5")
# Include the function's __module__ and __qualname__ strings in the hash.
# This means that two identical functions in different modules
# will not share a hash; it also means that two identical *nested*
# functions in the same module will not share a hash.
# We do not pass `hash_funcs` here, because we don't want our function's
# name to get an unexpected hash.
update_hash(
(func.__module__, func.__qualname__),
hasher=func_hasher,
hash_funcs=None,
hash_reason=HashReason.CACHING_FUNC_BODY,
hash_source=func,
)
# Include the function's body in the hash. We *do* pass hash_funcs here,
# because this step will be hashing any objects referenced in the function
# body.
update_hash(
func,
hasher=func_hasher,
hash_funcs=hash_funcs,
hash_reason=HashReason.CACHING_FUNC_BODY,
hash_source=func,
)
cache_key = func_hasher.hexdigest()
_LOGGER.debug(
"mem_cache key for %s.%s: %s", func.__module__, func.__qualname__, cache_key
)
return cache_key
def clear_cache() -> bool:
"""Clear the memoization cache.
Returns
-------
boolean
True if the disk cache was cleared. False otherwise (e.g. cache file
doesn't exist on disk).
"""
_clear_mem_cache()
return _clear_disk_cache()
def get_cache_path() -> str:
return file_util.get_streamlit_file_path("cache")
def _clear_disk_cache() -> bool:
# TODO: Only delete disk cache for functions related to the user's current
# script.
cache_path = get_cache_path()
if os.path.isdir(cache_path):
shutil.rmtree(cache_path)
return True
return False
def _clear_mem_cache() -> None:
_mem_caches.clear()
class CacheError(Exception):
pass
class CacheKeyNotFoundError(Exception):
pass
class CachedObjectMutationError(ValueError):
"""This is used internally, but never shown to the user.
Users see CachedObjectMutationWarning instead.
"""
def __init__(self, cached_value, func_or_code):
self.cached_value = cached_value
if inspect.iscode(func_or_code):
self.cached_func_name = "a code block"
else:
self.cached_func_name = _get_cached_func_name_md(func_or_code)
def __repr__(self) -> str:
return util.repr_(self)
class CachedStFunctionWarning(StreamlitAPIWarning):
def __init__(self, st_func_name, cached_func):
msg = self._get_message(st_func_name, cached_func)
super(CachedStFunctionWarning, self).__init__(msg)
def _get_message(self, st_func_name, cached_func):
args = {
"st_func_name": "`st.%s()` or `st.write()`" % st_func_name,
"func_name": _get_cached_func_name_md(cached_func),
}
return (
"""
Your script uses %(st_func_name)s to write to your Streamlit app from within
some cached code at %(func_name)s. This code will only be called when we detect
a cache "miss", which can lead to unexpected results.
How to fix this:
* Move the %(st_func_name)s call outside %(func_name)s.
* Or, if you know what you're doing, use `@st.cache(suppress_st_warning=True)`
to suppress the warning.
"""
% args
).strip("\n")
class CachedObjectMutationWarning(StreamlitAPIWarning):
def __init__(self, orig_exc):
msg = self._get_message(orig_exc)
super(CachedObjectMutationWarning, self).__init__(msg)
def _get_message(self, orig_exc):
return (
"""
Return value of %(func_name)s was mutated between runs.
By default, Streamlit's cache should be treated as immutable, or it may behave
in unexpected ways. You received this warning because Streamlit detected
that an object returned by %(func_name)s was mutated outside of %(func_name)s.
How to fix this:
* If you did not mean to mutate that return value:
- If possible, inspect your code to find and remove that mutation.
- Otherwise, you could also clone the returned value so you can freely
mutate it.
* If you actually meant to mutate the return value and know the consequences of
doing so, annotate the function with `@st.cache(allow_output_mutation=True)`.
For more information and detailed solutions check out [our documentation.]
(https://docs.streamlit.io/library/advanced-features/caching)
"""
% {"func_name": orig_exc.cached_func_name}
).strip("\n")
def _get_cached_func_name_md(func: types.FunctionType) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return "`%s()`" % func.__name__
else:
return "a cached function"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,130 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Logging module."""
import logging
import sys
from typing import Dict, Union
from streamlit import config
# Loggers for each name are saved here.
LOGGERS: Dict[str, logging.Logger] = {}
# The global log level is set here across all names.
LOG_LEVEL = logging.INFO
DEFAULT_LOG_MESSAGE = "%(asctime)s %(levelname) -7s " "%(name)s: %(message)s"
def set_log_level(level: Union[str, int]) -> None:
"""Set log level."""
logger = get_logger(__name__)
if isinstance(level, str):
level = level.upper()
if level == "CRITICAL" or level == logging.CRITICAL:
log_level = logging.CRITICAL
elif level == "ERROR" or level == logging.ERROR:
log_level = logging.ERROR
elif level == "WARNING" or level == logging.WARNING:
log_level = logging.WARNING
elif level == "INFO" or level == logging.INFO:
log_level = logging.INFO
elif level == "DEBUG" or level == logging.DEBUG:
log_level = logging.DEBUG
else:
msg = 'undefined log level "%s"' % level
logger.critical(msg)
sys.exit(1)
for log in LOGGERS.values():
log.setLevel(log_level)
global LOG_LEVEL
LOG_LEVEL = log_level
def setup_formatter(logger: logging.Logger) -> None:
"""Set up the console formatter for a given logger."""
# Deregister any previous console loggers.
if hasattr(logger, "streamlit_console_handler"):
logger.removeHandler(logger.streamlit_console_handler) # type: ignore[attr-defined]
logger.streamlit_console_handler = logging.StreamHandler() # type: ignore[attr-defined]
if config._config_options:
# logger is required in ConfigOption.set_value
# Getting the config option before the config file has been parsed
# can create an infinite loop
message_format = config.get_option("logger.messageFormat")
else:
message_format = DEFAULT_LOG_MESSAGE
formatter = logging.Formatter(fmt=message_format)
formatter.default_msec_format = "%s.%03d"
logger.streamlit_console_handler.setFormatter(formatter) # type: ignore[attr-defined]
# Register the new console logger.
logger.addHandler(logger.streamlit_console_handler) # type: ignore[attr-defined]
def update_formatter() -> None:
for log in LOGGERS.values():
setup_formatter(log)
def init_tornado_logs() -> None:
"""Initialize tornado logs."""
global LOGGER
# http://www.tornadoweb.org/en/stable/log.html
logs = ["access", "application", "general"]
for log in logs:
name = "tornado.%s" % log
get_logger(name)
logger = get_logger(__name__)
logger.debug("Initialized tornado logs")
def get_logger(name: str) -> logging.Logger:
"""Return a logger.
Parameters
----------
name : str
The name of the logger to use. You should just pass in __name__.
Returns
-------
Logger
"""
if name in LOGGERS.keys():
return LOGGERS[name]
if name == "root":
logger = logging.getLogger()
else:
logger = logging.getLogger(name)
logger.setLevel(LOG_LEVEL)
logger.propagate = False
setup_formatter(logger)
LOGGERS[name] = logger
return logger

View File

@ -0,0 +1,179 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import sys
def add_magic(code, script_path):
"""Modifies the code to support magic Streamlit commands.
Parameters
----------
code : str
The Python code.
script_path : str
The path to the script file.
Returns
-------
ast.Module
The syntax tree for the code.
"""
# Pass script_path so we get pretty exceptions.
tree = ast.parse(code, script_path, "exec")
return _modify_ast_subtree(tree, is_root=True)
def _modify_ast_subtree(tree, body_attr="body", is_root=False):
"""Parses magic commands and modifies the given AST (sub)tree."""
body = getattr(tree, body_attr)
for i, node in enumerate(body):
node_type = type(node)
# Parse the contents of functions, With statements, and for statements
if (
node_type is ast.FunctionDef
or node_type is ast.With
or node_type is ast.For
or node_type is ast.While
or node_type is ast.AsyncFunctionDef
or node_type is ast.AsyncWith
or node_type is ast.AsyncFor
):
_modify_ast_subtree(node)
# Parse the contents of try statements
elif node_type is ast.Try:
for j, inner_node in enumerate(node.handlers):
node.handlers[j] = _modify_ast_subtree(inner_node)
finally_node = _modify_ast_subtree(node, body_attr="finalbody")
node.finalbody = finally_node.finalbody
_modify_ast_subtree(node)
# Convert if expressions to st.write
elif node_type is ast.If:
_modify_ast_subtree(node)
_modify_ast_subtree(node, "orelse")
# Convert standalone expression nodes to st.write
elif node_type is ast.Expr:
value = _get_st_write_from_expr(node, i, parent_type=type(tree))
if value is not None:
node.value = value
if is_root:
# Import Streamlit so we can use it in the new_value above.
_insert_import_statement(tree)
ast.fix_missing_locations(tree)
return tree
def _insert_import_statement(tree):
"""Insert Streamlit import statement at the top(ish) of the tree."""
st_import = _build_st_import_statement()
# If the 0th node is already an import statement, put the Streamlit
# import below that, so we don't break "from __future__ import".
if tree.body and type(tree.body[0]) in (ast.ImportFrom, ast.Import):
tree.body.insert(1, st_import)
# If the 0th node is a docstring and the 1st is an import statement,
# put the Streamlit import below those, so we don't break "from
# __future__ import".
elif (
len(tree.body) > 1
and (type(tree.body[0]) is ast.Expr and _is_docstring_node(tree.body[0].value))
and type(tree.body[1]) in (ast.ImportFrom, ast.Import)
):
tree.body.insert(2, st_import)
else:
tree.body.insert(0, st_import)
def _build_st_import_statement():
"""Build AST node for `import streamlit as __streamlit__`."""
return ast.Import(names=[ast.alias(name="streamlit", asname="__streamlit__")])
def _build_st_write_call(nodes):
"""Build AST node for `__streamlit__._transparent_write(*nodes)`."""
return ast.Call(
func=ast.Attribute(
attr="_transparent_write",
value=ast.Name(id="__streamlit__", ctx=ast.Load()),
ctx=ast.Load(),
),
args=nodes,
keywords=[],
kwargs=None,
starargs=None,
)
def _get_st_write_from_expr(node, i, parent_type):
# Don't change function calls
if type(node.value) is ast.Call:
return None
# Don't change Docstring nodes
if (
i == 0
and _is_docstring_node(node.value)
and parent_type in (ast.FunctionDef, ast.Module)
):
return None
# Don't change yield nodes
if type(node.value) is ast.Yield or type(node.value) is ast.YieldFrom:
return None
# If tuple, call st.write on the 0th element (rather than the
# whole tuple). This allows us to add a comma at the end of a statement
# to turn it into an expression that should be st-written. Ex:
# "np.random.randn(1000, 2),"
if type(node.value) is ast.Tuple:
args = node.value.elts
st_write = _build_st_write_call(args)
# st.write all strings.
elif type(node.value) is ast.Str:
args = [node.value]
st_write = _build_st_write_call(args)
# st.write all variables.
elif type(node.value) is ast.Name:
args = [node.value]
st_write = _build_st_write_call(args)
# st.write everything else
else:
args = [node.value]
st_write = _build_st_write_call(args)
return st_write
def _is_docstring_node(node):
if sys.version_info >= (3, 8, 0):
return type(node) is ast.Constant and type(node.value) is str
else:
return type(node) is ast.Str

View File

@ -0,0 +1,74 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import threading
import uuid
from typing import Optional
from streamlit import util
_ETC_MACHINE_ID_PATH = "/etc/machine-id"
_DBUS_MACHINE_ID_PATH = "/var/lib/dbus/machine-id"
def _get_machine_id_v3() -> str:
"""Get the machine ID
This is a unique identifier for a user for tracking metrics in Segment,
that is broken in different ways in some Linux distros and Docker images.
- at times just a hash of '', which means many machines map to the same ID
- at times a hash of the same string, when running in a Docker container
"""
machine_id = str(uuid.getnode())
if os.path.isfile(_ETC_MACHINE_ID_PATH):
with open(_ETC_MACHINE_ID_PATH, "r") as f:
machine_id = f.read()
elif os.path.isfile(_DBUS_MACHINE_ID_PATH):
with open(_DBUS_MACHINE_ID_PATH, "r") as f:
machine_id = f.read()
return machine_id
class Installation:
_instance_lock = threading.Lock()
_instance = None # type: Optional[Installation]
@classmethod
def instance(cls) -> "Installation":
"""Returns the singleton Installation"""
# We use a double-checked locking optimization to avoid the overhead
# of acquiring the lock in the common case:
# https://en.wikipedia.org/wiki/Double-checked_locking
if cls._instance is None:
with cls._instance_lock:
if cls._instance is None:
cls._instance = Installation()
return cls._instance
def __init__(self):
self.installation_id_v3 = str(
uuid.uuid5(uuid.NAMESPACE_DNS, _get_machine_id_v3())
)
def __repr__(self) -> str:
return util.repr_(self)
@property
def installation_id(self):
return self.installation_id_v3

View File

@ -0,0 +1,122 @@
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
from typing import Optional
import requests
from streamlit import util
from streamlit.logger import get_logger
LOGGER = get_logger(__name__)
# URL for checking the current machine's external IP address.
_AWS_CHECK_IP = "http://checkip.amazonaws.com"
_external_ip = None # type: Optional[str]
def get_external_ip():
"""Get the *external* IP address of the current machine.
Returns
-------
string
The external IPv4 address of the current machine.
"""
global _external_ip
if _external_ip is not None:
return _external_ip
response = _make_blocking_http_get(_AWS_CHECK_IP, timeout=5)
if _looks_like_an_ip_adress(response):
_external_ip = response
else:
LOGGER.warning(
# fmt: off
"Did not auto detect external IP.\n"
"Please go to %s for debugging hints.",
# fmt: on
util.HELP_DOC
)
_external_ip = None
return _external_ip
_internal_ip = None # type: Optional[str]
def get_internal_ip():
"""Get the *local* IP address of the current machine.
From: https://stackoverflow.com/a/28950776
Returns
-------
string
The local IPv4 address of the current machine.
"""
global _internal_ip
if _internal_ip is not None:
return _internal_ip
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# Doesn't even have to be reachable
s.connect(("8.8.8.8", 1))
_internal_ip = s.getsockname()[0]
except Exception:
_internal_ip = "127.0.0.1"
finally:
s.close()
return _internal_ip
def _make_blocking_http_get(url, timeout=5):
try:
text = requests.get(url, timeout=timeout).text
if isinstance(text, str):
text = text.strip()
return text
except Exception as e:
return None
def _looks_like_an_ip_adress(address):
if address is None:
return False
try:
socket.inet_pton(socket.AF_INET, address)
return True # Yup, this is an IPv4 address!
except (AttributeError, OSError):
pass
try:
socket.inet_pton(socket.AF_INET6, address)
return True # Yup, this is an IPv6 address!
except (AttributeError, OSError):
pass
# Nope, this is not an IP address.
return False

View File

@ -0,0 +1,120 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: streamlit/proto/Alert.proto
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='streamlit/proto/Alert.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x1bstreamlit/proto/Alert.proto\"y\n\x05\x41lert\x12\x0c\n\x04\x62ody\x18\x01 \x01(\t\x12\x1d\n\x06\x66ormat\x18\x02 \x01(\x0e\x32\r.Alert.Format\"C\n\x06\x46ormat\x12\n\n\x06UNUSED\x10\x00\x12\t\n\x05\x45RROR\x10\x01\x12\x0b\n\x07WARNING\x10\x02\x12\x08\n\x04INFO\x10\x03\x12\x0b\n\x07SUCCESS\x10\x04\x62\x06proto3'
)
_ALERT_FORMAT = _descriptor.EnumDescriptor(
name='Format',
full_name='Alert.Format',
filename=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
values=[
_descriptor.EnumValueDescriptor(
name='UNUSED', index=0, number=0,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='ERROR', index=1, number=1,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='WARNING', index=2, number=2,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='INFO', index=3, number=3,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='SUCCESS', index=4, number=4,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
],
containing_type=None,
serialized_options=None,
serialized_start=85,
serialized_end=152,
)
_sym_db.RegisterEnumDescriptor(_ALERT_FORMAT)
_ALERT = _descriptor.Descriptor(
name='Alert',
full_name='Alert',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='body', full_name='Alert.body', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='format', full_name='Alert.format', index=1,
number=2, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
_ALERT_FORMAT,
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=31,
serialized_end=152,
)
_ALERT.fields_by_name['format'].enum_type = _ALERT_FORMAT
_ALERT_FORMAT.containing_type = _ALERT
DESCRIPTOR.message_types_by_name['Alert'] = _ALERT
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Alert = _reflection.GeneratedProtocolMessageType('Alert', (_message.Message,), {
'DESCRIPTOR' : _ALERT,
'__module__' : 'streamlit.proto.Alert_pb2'
# @@protoc_insertion_point(class_scope:Alert)
})
_sym_db.RegisterMessage(Alert)
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: streamlit/proto/ArrowNamedDataSet.proto
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from streamlit.proto import Arrow_pb2 as streamlit_dot_proto_dot_Arrow__pb2
DESCRIPTOR = _descriptor.FileDescriptor(
name='streamlit/proto/ArrowNamedDataSet.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\'streamlit/proto/ArrowNamedDataSet.proto\x1a\x1bstreamlit/proto/Arrow.proto\"I\n\x11\x41rrowNamedDataSet\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08has_name\x18\x03 \x01(\x08\x12\x14\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x06.Arrowb\x06proto3'
,
dependencies=[streamlit_dot_proto_dot_Arrow__pb2.DESCRIPTOR,])
_ARROWNAMEDDATASET = _descriptor.Descriptor(
name='ArrowNamedDataSet',
full_name='ArrowNamedDataSet',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='ArrowNamedDataSet.name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='has_name', full_name='ArrowNamedDataSet.has_name', index=1,
number=3, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='data', full_name='ArrowNamedDataSet.data', index=2,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=72,
serialized_end=145,
)
_ARROWNAMEDDATASET.fields_by_name['data'].message_type = streamlit_dot_proto_dot_Arrow__pb2._ARROW
DESCRIPTOR.message_types_by_name['ArrowNamedDataSet'] = _ARROWNAMEDDATASET
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
ArrowNamedDataSet = _reflection.GeneratedProtocolMessageType('ArrowNamedDataSet', (_message.Message,), {
'DESCRIPTOR' : _ARROWNAMEDDATASET,
'__module__' : 'streamlit.proto.ArrowNamedDataSet_pb2'
# @@protoc_insertion_point(class_scope:ArrowNamedDataSet)
})
_sym_db.RegisterMessage(ArrowNamedDataSet)
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: streamlit/proto/ArrowVegaLiteChart.proto
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from streamlit.proto import Arrow_pb2 as streamlit_dot_proto_dot_Arrow__pb2
from streamlit.proto import ArrowNamedDataSet_pb2 as streamlit_dot_proto_dot_ArrowNamedDataSet__pb2
DESCRIPTOR = _descriptor.FileDescriptor(
name='streamlit/proto/ArrowVegaLiteChart.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n(streamlit/proto/ArrowVegaLiteChart.proto\x1a\x1bstreamlit/proto/Arrow.proto\x1a\'streamlit/proto/ArrowNamedDataSet.proto\"\x81\x01\n\x12\x41rrowVegaLiteChart\x12\x0c\n\x04spec\x18\x01 \x01(\t\x12\x14\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x06.Arrow\x12$\n\x08\x64\x61tasets\x18\x04 \x03(\x0b\x32\x12.ArrowNamedDataSet\x12\x1b\n\x13use_container_width\x18\x05 \x01(\x08J\x04\x08\x03\x10\x04\x62\x06proto3'
,
dependencies=[streamlit_dot_proto_dot_Arrow__pb2.DESCRIPTOR,streamlit_dot_proto_dot_ArrowNamedDataSet__pb2.DESCRIPTOR,])
_ARROWVEGALITECHART = _descriptor.Descriptor(
name='ArrowVegaLiteChart',
full_name='ArrowVegaLiteChart',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='spec', full_name='ArrowVegaLiteChart.spec', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='data', full_name='ArrowVegaLiteChart.data', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='datasets', full_name='ArrowVegaLiteChart.datasets', index=2,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='use_container_width', full_name='ArrowVegaLiteChart.use_container_width', index=3,
number=5, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=115,
serialized_end=244,
)
_ARROWVEGALITECHART.fields_by_name['data'].message_type = streamlit_dot_proto_dot_Arrow__pb2._ARROW
_ARROWVEGALITECHART.fields_by_name['datasets'].message_type = streamlit_dot_proto_dot_ArrowNamedDataSet__pb2._ARROWNAMEDDATASET
DESCRIPTOR.message_types_by_name['ArrowVegaLiteChart'] = _ARROWVEGALITECHART
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
ArrowVegaLiteChart = _reflection.GeneratedProtocolMessageType('ArrowVegaLiteChart', (_message.Message,), {
'DESCRIPTOR' : _ARROWVEGALITECHART,
'__module__' : 'streamlit.proto.ArrowVegaLiteChart_pb2'
# @@protoc_insertion_point(class_scope:ArrowVegaLiteChart)
})
_sym_db.RegisterMessage(ArrowVegaLiteChart)
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: streamlit/proto/Arrow.proto
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='streamlit/proto/Arrow.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x1bstreamlit/proto/Arrow.proto\".\n\x05\x41rrow\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x06styler\x18\x02 \x01(\x0b\x32\x07.Styler\"O\n\x06Styler\x12\x0c\n\x04uuid\x18\x01 \x01(\t\x12\x0f\n\x07\x63\x61ption\x18\x02 \x01(\t\x12\x0e\n\x06styles\x18\x03 \x01(\t\x12\x16\n\x0e\x64isplay_values\x18\x04 \x01(\x0c\x62\x06proto3'
)
_ARROW = _descriptor.Descriptor(
name='Arrow',
full_name='Arrow',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='data', full_name='Arrow.data', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='styler', full_name='Arrow.styler', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=31,
serialized_end=77,
)
_STYLER = _descriptor.Descriptor(
name='Styler',
full_name='Styler',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='uuid', full_name='Styler.uuid', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='caption', full_name='Styler.caption', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='styles', full_name='Styler.styles', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='display_values', full_name='Styler.display_values', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=79,
serialized_end=158,
)
_ARROW.fields_by_name['styler'].message_type = _STYLER
DESCRIPTOR.message_types_by_name['Arrow'] = _ARROW
DESCRIPTOR.message_types_by_name['Styler'] = _STYLER
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Arrow = _reflection.GeneratedProtocolMessageType('Arrow', (_message.Message,), {
'DESCRIPTOR' : _ARROW,
'__module__' : 'streamlit.proto.Arrow_pb2'
# @@protoc_insertion_point(class_scope:Arrow)
})
_sym_db.RegisterMessage(Arrow)
Styler = _reflection.GeneratedProtocolMessageType('Styler', (_message.Message,), {
'DESCRIPTOR' : _STYLER,
'__module__' : 'streamlit.proto.Arrow_pb2'
# @@protoc_insertion_point(class_scope:Styler)
})
_sym_db.RegisterMessage(Styler)
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: streamlit/proto/Audio.proto
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='streamlit/proto/Audio.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x1bstreamlit/proto/Audio.proto\"H\n\x05\x41udio\x12\x0b\n\x03url\x18\x05 \x01(\t\x12\x12\n\nstart_time\x18\x03 \x01(\x05J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x04\x10\x05R\x04\x64\x61taR\x06\x66ormatb\x06proto3'
)
_AUDIO = _descriptor.Descriptor(
name='Audio',
full_name='Audio',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='url', full_name='Audio.url', index=0,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='start_time', full_name='Audio.start_time', index=1,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=31,
serialized_end=103,
)
DESCRIPTOR.message_types_by_name['Audio'] = _AUDIO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Audio = _reflection.GeneratedProtocolMessageType('Audio', (_message.Message,), {
'DESCRIPTOR' : _AUDIO,
'__module__' : 'streamlit.proto.Audio_pb2'
# @@protoc_insertion_point(class_scope:Audio)
})
_sym_db.RegisterMessage(Audio)
# @@protoc_insertion_point(module_scope)

Some files were not shown because too many files have changed in this diff Show More