2022-05-23 00:16:32 +04:00

180 lines
5.4 KiB
Python

# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import sys
def add_magic(code, script_path):
"""Modifies the code to support magic Streamlit commands.
Parameters
----------
code : str
The Python code.
script_path : str
The path to the script file.
Returns
-------
ast.Module
The syntax tree for the code.
"""
# Pass script_path so we get pretty exceptions.
tree = ast.parse(code, script_path, "exec")
return _modify_ast_subtree(tree, is_root=True)
def _modify_ast_subtree(tree, body_attr="body", is_root=False):
"""Parses magic commands and modifies the given AST (sub)tree."""
body = getattr(tree, body_attr)
for i, node in enumerate(body):
node_type = type(node)
# Parse the contents of functions, With statements, and for statements
if (
node_type is ast.FunctionDef
or node_type is ast.With
or node_type is ast.For
or node_type is ast.While
or node_type is ast.AsyncFunctionDef
or node_type is ast.AsyncWith
or node_type is ast.AsyncFor
):
_modify_ast_subtree(node)
# Parse the contents of try statements
elif node_type is ast.Try:
for j, inner_node in enumerate(node.handlers):
node.handlers[j] = _modify_ast_subtree(inner_node)
finally_node = _modify_ast_subtree(node, body_attr="finalbody")
node.finalbody = finally_node.finalbody
_modify_ast_subtree(node)
# Convert if expressions to st.write
elif node_type is ast.If:
_modify_ast_subtree(node)
_modify_ast_subtree(node, "orelse")
# Convert standalone expression nodes to st.write
elif node_type is ast.Expr:
value = _get_st_write_from_expr(node, i, parent_type=type(tree))
if value is not None:
node.value = value
if is_root:
# Import Streamlit so we can use it in the new_value above.
_insert_import_statement(tree)
ast.fix_missing_locations(tree)
return tree
def _insert_import_statement(tree):
"""Insert Streamlit import statement at the top(ish) of the tree."""
st_import = _build_st_import_statement()
# If the 0th node is already an import statement, put the Streamlit
# import below that, so we don't break "from __future__ import".
if tree.body and type(tree.body[0]) in (ast.ImportFrom, ast.Import):
tree.body.insert(1, st_import)
# If the 0th node is a docstring and the 1st is an import statement,
# put the Streamlit import below those, so we don't break "from
# __future__ import".
elif (
len(tree.body) > 1
and (type(tree.body[0]) is ast.Expr and _is_docstring_node(tree.body[0].value))
and type(tree.body[1]) in (ast.ImportFrom, ast.Import)
):
tree.body.insert(2, st_import)
else:
tree.body.insert(0, st_import)
def _build_st_import_statement():
"""Build AST node for `import streamlit as __streamlit__`."""
return ast.Import(names=[ast.alias(name="streamlit", asname="__streamlit__")])
def _build_st_write_call(nodes):
"""Build AST node for `__streamlit__._transparent_write(*nodes)`."""
return ast.Call(
func=ast.Attribute(
attr="_transparent_write",
value=ast.Name(id="__streamlit__", ctx=ast.Load()),
ctx=ast.Load(),
),
args=nodes,
keywords=[],
kwargs=None,
starargs=None,
)
def _get_st_write_from_expr(node, i, parent_type):
# Don't change function calls
if type(node.value) is ast.Call:
return None
# Don't change Docstring nodes
if (
i == 0
and _is_docstring_node(node.value)
and parent_type in (ast.FunctionDef, ast.Module)
):
return None
# Don't change yield nodes
if type(node.value) is ast.Yield or type(node.value) is ast.YieldFrom:
return None
# If tuple, call st.write on the 0th element (rather than the
# whole tuple). This allows us to add a comma at the end of a statement
# to turn it into an expression that should be st-written. Ex:
# "np.random.randn(1000, 2),"
if type(node.value) is ast.Tuple:
args = node.value.elts
st_write = _build_st_write_call(args)
# st.write all strings.
elif type(node.value) is ast.Str:
args = [node.value]
st_write = _build_st_write_call(args)
# st.write all variables.
elif type(node.value) is ast.Name:
args = [node.value]
st_write = _build_st_write_call(args)
# st.write everything else
else:
args = [node.value]
st_write = _build_st_write_call(args)
return st_write
def _is_docstring_node(node):
if sys.version_info >= (3, 8, 0):
return type(node) is ast.Constant and type(node.value) is str
else:
return type(node) is ast.Str