mirror of
https://github.com/aykhans/sarin.git
synced 2026-01-14 04:21:21 +00:00
758 lines
23 KiB
Go
758 lines
23 KiB
Go
package config
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/charmbracelet/bubbles/viewport"
|
|
tea "github.com/charmbracelet/bubbletea"
|
|
"github.com/charmbracelet/glamour"
|
|
"github.com/charmbracelet/glamour/styles"
|
|
"github.com/charmbracelet/lipgloss"
|
|
"github.com/charmbracelet/x/term"
|
|
"go.aykhans.me/sarin/internal/types"
|
|
"go.aykhans.me/sarin/internal/version"
|
|
"go.aykhans.me/utils/common"
|
|
utilsErr "go.aykhans.me/utils/errors"
|
|
"go.yaml.in/yaml/v4"
|
|
)
|
|
|
|
var Defaults = struct {
|
|
UserAgent string
|
|
Method string
|
|
RequestTimeout time.Duration
|
|
Concurrency uint
|
|
ShowConfig bool
|
|
Quiet bool
|
|
Insecure bool
|
|
Output ConfigOutputType
|
|
DryRun bool
|
|
}{
|
|
UserAgent: "Sarin/" + version.Version,
|
|
Method: "GET",
|
|
RequestTimeout: time.Second * 10,
|
|
Concurrency: 1,
|
|
ShowConfig: false,
|
|
Quiet: false,
|
|
Insecure: false,
|
|
Output: ConfigOutputTypeTable,
|
|
DryRun: false,
|
|
}
|
|
|
|
var (
|
|
ValidProxySchemes = []string{"http", "https", "socks5", "socks5h"}
|
|
ValidRequestURLSchemes = []string{"http", "https"}
|
|
)
|
|
|
|
var (
|
|
StyleYellow = lipgloss.NewStyle().Foreground(lipgloss.Color("220"))
|
|
StyleRed = lipgloss.NewStyle().Foreground(lipgloss.Color("196"))
|
|
)
|
|
|
|
type IParser interface {
|
|
Parse() (*Config, error)
|
|
}
|
|
|
|
type ConfigOutputType string
|
|
|
|
var (
|
|
ConfigOutputTypeTable ConfigOutputType = "table"
|
|
ConfigOutputTypeJSON ConfigOutputType = "json"
|
|
ConfigOutputTypeYAML ConfigOutputType = "yaml"
|
|
ConfigOutputTypeNone ConfigOutputType = "none"
|
|
)
|
|
|
|
type Config struct {
|
|
ShowConfig *bool `yaml:"showConfig,omitempty"`
|
|
Files []types.ConfigFile `yaml:"files,omitempty"`
|
|
Methods []string `yaml:"methods,omitempty"`
|
|
URL *url.URL `yaml:"url,omitempty"`
|
|
Timeout *time.Duration `yaml:"timeout,omitempty"`
|
|
Concurrency *uint `yaml:"concurrency,omitempty"`
|
|
Requests *uint64 `yaml:"requests,omitempty"`
|
|
Duration *time.Duration `yaml:"duration,omitempty"`
|
|
Quiet *bool `yaml:"quiet,omitempty"`
|
|
Output *ConfigOutputType `yaml:"output,omitempty"`
|
|
Insecure *bool `yaml:"insecure,omitempty"`
|
|
DryRun *bool `yaml:"dryRun,omitempty"`
|
|
Params types.Params `yaml:"params,omitempty"`
|
|
Headers types.Headers `yaml:"headers,omitempty"`
|
|
Cookies types.Cookies `yaml:"cookies,omitempty"`
|
|
Bodies []string `yaml:"bodies,omitempty"`
|
|
Proxies types.Proxies `yaml:"proxies,omitempty"`
|
|
Values []string `yaml:"values,omitempty"`
|
|
}
|
|
|
|
func NewConfig() *Config {
|
|
return &Config{}
|
|
}
|
|
|
|
func (config Config) MarshalYAML() (any, error) {
|
|
const randomValueComment = "Cycles through all values, with a new random start each round"
|
|
|
|
toNode := func(v any) *yaml.Node {
|
|
node := &yaml.Node{}
|
|
_ = node.Encode(v)
|
|
return node
|
|
}
|
|
|
|
addField := func(content *[]*yaml.Node, key string, value *yaml.Node, comment string) {
|
|
if value.Kind == 0 || (value.Kind == yaml.ScalarNode && value.Value == "") ||
|
|
(value.Kind == yaml.SequenceNode && len(value.Content) == 0) {
|
|
return
|
|
}
|
|
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Value: key, LineComment: comment}
|
|
*content = append(*content, keyNode, value)
|
|
}
|
|
|
|
addStringSlice := func(content *[]*yaml.Node, key string, items []string, withComment bool) {
|
|
comment := ""
|
|
if withComment && len(items) > 1 {
|
|
comment = randomValueComment
|
|
}
|
|
switch len(items) {
|
|
case 1:
|
|
addField(content, key, toNode(items[0]), "")
|
|
default:
|
|
addField(content, key, toNode(items), comment)
|
|
}
|
|
}
|
|
|
|
marshalKeyValues := func(items []types.KeyValue[string, []string]) *yaml.Node {
|
|
seqNode := &yaml.Node{Kind: yaml.SequenceNode}
|
|
for _, item := range items {
|
|
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Value: item.Key}
|
|
var valueNode *yaml.Node
|
|
|
|
switch len(item.Value) {
|
|
case 1:
|
|
valueNode = &yaml.Node{Kind: yaml.ScalarNode, Value: item.Value[0]}
|
|
default:
|
|
valueNode = &yaml.Node{Kind: yaml.SequenceNode}
|
|
for _, v := range item.Value {
|
|
valueNode.Content = append(valueNode.Content, &yaml.Node{Kind: yaml.ScalarNode, Value: v})
|
|
}
|
|
if len(item.Value) > 1 {
|
|
keyNode.LineComment = randomValueComment
|
|
}
|
|
}
|
|
|
|
mapNode := &yaml.Node{Kind: yaml.MappingNode, Content: []*yaml.Node{keyNode, valueNode}}
|
|
seqNode.Content = append(seqNode.Content, mapNode)
|
|
}
|
|
return seqNode
|
|
}
|
|
|
|
root := &yaml.Node{Kind: yaml.MappingNode}
|
|
content := &root.Content
|
|
|
|
if config.ShowConfig != nil {
|
|
addField(content, "showConfig", toNode(*config.ShowConfig), "")
|
|
}
|
|
|
|
addStringSlice(content, "method", config.Methods, true)
|
|
|
|
if config.URL != nil {
|
|
addField(content, "url", toNode(config.URL.String()), "")
|
|
}
|
|
if config.Timeout != nil {
|
|
addField(content, "timeout", toNode(*config.Timeout), "")
|
|
}
|
|
if config.Concurrency != nil {
|
|
addField(content, "concurrency", toNode(*config.Concurrency), "")
|
|
}
|
|
if config.Requests != nil {
|
|
addField(content, "requests", toNode(*config.Requests), "")
|
|
}
|
|
if config.Duration != nil {
|
|
addField(content, "duration", toNode(*config.Duration), "")
|
|
}
|
|
if config.Quiet != nil {
|
|
addField(content, "quiet", toNode(*config.Quiet), "")
|
|
}
|
|
if config.Output != nil {
|
|
addField(content, "output", toNode(string(*config.Output)), "")
|
|
}
|
|
if config.Insecure != nil {
|
|
addField(content, "insecure", toNode(*config.Insecure), "")
|
|
}
|
|
if config.DryRun != nil {
|
|
addField(content, "dryRun", toNode(*config.DryRun), "")
|
|
}
|
|
|
|
if len(config.Params) > 0 {
|
|
items := make([]types.KeyValue[string, []string], len(config.Params))
|
|
for i, p := range config.Params {
|
|
items[i] = types.KeyValue[string, []string](p)
|
|
}
|
|
addField(content, "params", marshalKeyValues(items), "")
|
|
}
|
|
if len(config.Headers) > 0 {
|
|
items := make([]types.KeyValue[string, []string], len(config.Headers))
|
|
for i, h := range config.Headers {
|
|
items[i] = types.KeyValue[string, []string](h)
|
|
}
|
|
addField(content, "headers", marshalKeyValues(items), "")
|
|
}
|
|
if len(config.Cookies) > 0 {
|
|
items := make([]types.KeyValue[string, []string], len(config.Cookies))
|
|
for i, c := range config.Cookies {
|
|
items[i] = types.KeyValue[string, []string](c)
|
|
}
|
|
addField(content, "cookies", marshalKeyValues(items), "")
|
|
}
|
|
|
|
addStringSlice(content, "body", config.Bodies, true)
|
|
|
|
if len(config.Proxies) > 0 {
|
|
proxyStrings := make([]string, len(config.Proxies))
|
|
for i, p := range config.Proxies {
|
|
proxyStrings[i] = p.String()
|
|
}
|
|
addStringSlice(content, "proxy", proxyStrings, true)
|
|
}
|
|
|
|
addStringSlice(content, "values", config.Values, false)
|
|
|
|
return root, nil
|
|
}
|
|
|
|
func (config Config) Print() bool {
|
|
configYAML, err := yaml.Marshal(config)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, StyleRed.Render("Error marshaling config to yaml: "+err.Error()))
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Pipe mode: output raw content directly
|
|
if !term.IsTerminal(os.Stdout.Fd()) {
|
|
fmt.Println(string(configYAML))
|
|
os.Exit(0)
|
|
}
|
|
|
|
style := styles.TokyoNightStyleConfig
|
|
style.Document.Margin = common.ToPtr[uint](0)
|
|
style.CodeBlock.Margin = common.ToPtr[uint](0)
|
|
|
|
renderer, err := glamour.NewTermRenderer(
|
|
glamour.WithStyles(style),
|
|
glamour.WithWordWrap(0),
|
|
)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, StyleRed.Render(err.Error()))
|
|
os.Exit(1)
|
|
}
|
|
|
|
content, err := renderer.Render("```yaml\n" + string(configYAML) + "```")
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, StyleRed.Render(err.Error()))
|
|
os.Exit(1)
|
|
}
|
|
|
|
p := tea.NewProgram(
|
|
printConfigModel{content: strings.Trim(content, "\n"), rawContent: configYAML},
|
|
tea.WithAltScreen(),
|
|
tea.WithMouseCellMotion(),
|
|
)
|
|
|
|
m, err := p.Run()
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, StyleRed.Render(err.Error()))
|
|
os.Exit(1)
|
|
}
|
|
|
|
return m.(printConfigModel).start //nolint:forcetypeassert // m is guaranteed to be of type printConfigModel as it was the only model passed to tea.NewProgram
|
|
}
|
|
|
|
func (config *Config) Merge(newConfig *Config) {
|
|
config.Files = append(config.Files, newConfig.Files...)
|
|
if len(newConfig.Methods) > 0 {
|
|
config.Methods = append(config.Methods, newConfig.Methods...)
|
|
}
|
|
if newConfig.URL != nil {
|
|
config.URL = newConfig.URL
|
|
}
|
|
if newConfig.Timeout != nil {
|
|
config.Timeout = newConfig.Timeout
|
|
}
|
|
if newConfig.Concurrency != nil {
|
|
config.Concurrency = newConfig.Concurrency
|
|
}
|
|
if newConfig.Requests != nil {
|
|
config.Requests = newConfig.Requests
|
|
}
|
|
if newConfig.Duration != nil {
|
|
config.Duration = newConfig.Duration
|
|
}
|
|
if newConfig.ShowConfig != nil {
|
|
config.ShowConfig = newConfig.ShowConfig
|
|
}
|
|
if newConfig.Quiet != nil {
|
|
config.Quiet = newConfig.Quiet
|
|
}
|
|
if newConfig.Output != nil {
|
|
config.Output = newConfig.Output
|
|
}
|
|
if newConfig.Insecure != nil {
|
|
config.Insecure = newConfig.Insecure
|
|
}
|
|
if newConfig.DryRun != nil {
|
|
config.DryRun = newConfig.DryRun
|
|
}
|
|
if len(newConfig.Params) != 0 {
|
|
config.Params = append(config.Params, newConfig.Params...)
|
|
}
|
|
if len(newConfig.Headers) != 0 {
|
|
config.Headers = append(config.Headers, newConfig.Headers...)
|
|
}
|
|
if len(newConfig.Cookies) != 0 {
|
|
config.Cookies = append(config.Cookies, newConfig.Cookies...)
|
|
}
|
|
if len(newConfig.Bodies) != 0 {
|
|
config.Bodies = append(config.Bodies, newConfig.Bodies...)
|
|
}
|
|
if len(newConfig.Proxies) != 0 {
|
|
config.Proxies.Append(newConfig.Proxies...)
|
|
}
|
|
if len(newConfig.Values) != 0 {
|
|
config.Values = append(config.Values, newConfig.Values...)
|
|
}
|
|
}
|
|
|
|
func (config *Config) SetDefaults() {
|
|
if config.URL != nil && len(config.URL.Query()) > 0 {
|
|
urlParams := types.Params{}
|
|
for key, values := range config.URL.Query() {
|
|
for _, value := range values {
|
|
urlParams = append(urlParams, types.Param{
|
|
Key: key,
|
|
Value: []string{value},
|
|
})
|
|
}
|
|
}
|
|
|
|
config.Params = append(urlParams, config.Params...)
|
|
config.URL.RawQuery = ""
|
|
}
|
|
|
|
if len(config.Methods) == 0 {
|
|
config.Methods = []string{Defaults.Method}
|
|
}
|
|
if config.Timeout == nil {
|
|
config.Timeout = &Defaults.RequestTimeout
|
|
}
|
|
if config.Concurrency == nil {
|
|
config.Concurrency = common.ToPtr(Defaults.Concurrency)
|
|
}
|
|
if config.ShowConfig == nil {
|
|
config.ShowConfig = common.ToPtr(Defaults.ShowConfig)
|
|
}
|
|
if config.Quiet == nil {
|
|
config.Quiet = common.ToPtr(Defaults.Quiet)
|
|
}
|
|
if config.Insecure == nil {
|
|
config.Insecure = common.ToPtr(Defaults.Insecure)
|
|
}
|
|
if config.DryRun == nil {
|
|
config.DryRun = common.ToPtr(Defaults.DryRun)
|
|
}
|
|
if !config.Headers.Has("User-Agent") {
|
|
config.Headers = append(config.Headers, types.Header{Key: "User-Agent", Value: []string{Defaults.UserAgent}})
|
|
}
|
|
|
|
if config.Output == nil {
|
|
config.Output = common.ToPtr(Defaults.Output)
|
|
}
|
|
}
|
|
|
|
// Validate validates the config fields.
|
|
// It can return the following errors:
|
|
// - types.FieldValidationErrors
|
|
func (config Config) Validate() error {
|
|
validationErrors := make([]types.FieldValidationError, 0)
|
|
|
|
if len(config.Methods) == 0 {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Method", "", errors.New("method is required")))
|
|
}
|
|
|
|
switch {
|
|
case config.URL == nil:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("URL", "", errors.New("URL is required")))
|
|
case !slices.Contains(ValidRequestURLSchemes, config.URL.Scheme):
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("URL", config.URL.String(), fmt.Errorf("URL scheme must be one of: %s", strings.Join(ValidRequestURLSchemes, ", "))))
|
|
case config.URL.Host == "":
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("URL", config.URL.String(), errors.New("URL must have a host")))
|
|
}
|
|
|
|
switch {
|
|
case config.Concurrency == nil:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Concurrency", "", errors.New("concurrency count is required")))
|
|
case *config.Concurrency == 0:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Concurrency", "0", errors.New("concurrency must be greater than 0")))
|
|
case *config.Concurrency > 100_000_000:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Concurrency", strconv.FormatUint(uint64(*config.Concurrency), 10), errors.New("concurrency must not exceed 100,000,000")))
|
|
}
|
|
|
|
switch {
|
|
case config.Requests == nil && config.Duration == nil:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Requests / Duration", "", errors.New("either request count or duration must be specified")))
|
|
case (config.Requests != nil && config.Duration != nil) && (*config.Requests == 0 && *config.Duration == 0):
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Requests / Duration", "0", errors.New("both request count and duration cannot be zero")))
|
|
case config.Requests != nil && config.Duration == nil && *config.Requests == 0:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Requests", "0", errors.New("request count must be greater than 0")))
|
|
case config.Requests == nil && config.Duration != nil && *config.Duration == 0:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Duration", "0", errors.New("duration must be greater than 0")))
|
|
}
|
|
|
|
if *config.Timeout < 1 {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Timeout", "0", errors.New("timeout must be greater than 0")))
|
|
}
|
|
|
|
if config.ShowConfig == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("ShowConfig", "", errors.New("showConfig field is required")))
|
|
}
|
|
|
|
if config.Quiet == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Quiet", "", errors.New("quiet field is required")))
|
|
}
|
|
|
|
if config.Output == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Output", "", errors.New("output field is required")))
|
|
} else {
|
|
switch *config.Output {
|
|
case "":
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Output", "", errors.New("output field is required")))
|
|
case ConfigOutputTypeTable, ConfigOutputTypeJSON, ConfigOutputTypeYAML, ConfigOutputTypeNone:
|
|
default:
|
|
validOutputs := []string{string(ConfigOutputTypeTable), string(ConfigOutputTypeJSON), string(ConfigOutputTypeYAML), string(ConfigOutputTypeNone)}
|
|
validationErrors = append(validationErrors,
|
|
types.NewFieldValidationError(
|
|
"Output",
|
|
string(*config.Output),
|
|
fmt.Errorf(
|
|
"output type must be one of: %s",
|
|
strings.Join(validOutputs, ", "),
|
|
),
|
|
),
|
|
)
|
|
}
|
|
}
|
|
|
|
if config.Insecure == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Insecure", "", errors.New("insecure field is required")))
|
|
}
|
|
|
|
if config.DryRun == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("DryRun", "", errors.New("dryRun field is required")))
|
|
}
|
|
|
|
for i, proxy := range config.Proxies {
|
|
if !slices.Contains(ValidProxySchemes, proxy.Scheme) {
|
|
validationErrors = append(
|
|
validationErrors,
|
|
types.NewFieldValidationError(
|
|
fmt.Sprintf("Proxy[%d]", i),
|
|
proxy.String(),
|
|
fmt.Errorf("proxy scheme must be one of: %v", ValidProxySchemes),
|
|
),
|
|
)
|
|
}
|
|
}
|
|
|
|
templateErrors := ValidateTemplates(&config)
|
|
validationErrors = append(validationErrors, templateErrors...)
|
|
|
|
if len(validationErrors) > 0 {
|
|
return types.NewFieldValidationErrors(validationErrors)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ReadAllConfigs() *Config {
|
|
envParser := NewConfigENVParser("SARIN")
|
|
envConfig, err := envParser.Parse()
|
|
_ = utilsErr.MustHandle(err,
|
|
utilsErr.OnType(func(err types.FieldParseErrors) error {
|
|
printParseErrors("ENV", err.Errors...)
|
|
fmt.Println()
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
)
|
|
|
|
cliParser := NewConfigCLIParser(os.Args)
|
|
cliConf, err := cliParser.Parse()
|
|
_ = utilsErr.MustHandle(err,
|
|
utilsErr.OnSentinel(types.ErrCLINoArgs, func(err error) error {
|
|
cliParser.PrintHelp()
|
|
fmt.Fprintln(os.Stderr, StyleYellow.Render("\nNo arguments provided."))
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
utilsErr.OnType(func(err types.CLIUnexpectedArgsError) error {
|
|
cliParser.PrintHelp()
|
|
fmt.Fprintln(os.Stderr,
|
|
StyleYellow.Render(
|
|
"\nUnexpected CLI arguments provided: ",
|
|
)+strings.Join(err.Args, ", "),
|
|
)
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
utilsErr.OnType(func(err types.FieldParseErrors) error {
|
|
cliParser.PrintHelp()
|
|
fmt.Println()
|
|
printParseErrors("CLI", err.Errors...)
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
)
|
|
|
|
for _, configFile := range append(envConfig.Files, cliConf.Files...) {
|
|
fileConfig, err := parseConfigFile(configFile, 10)
|
|
_ = utilsErr.MustHandle(err,
|
|
utilsErr.OnType(func(err types.ConfigFileReadError) error {
|
|
cliParser.PrintHelp()
|
|
fmt.Fprintln(os.Stderr,
|
|
StyleYellow.Render(
|
|
fmt.Sprintf("\nFailed to read config file (%s): ", configFile.Path())+err.Error(),
|
|
),
|
|
)
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
utilsErr.OnType(func(err types.UnmarshalError) error {
|
|
fmt.Fprintln(os.Stderr,
|
|
StyleYellow.Render(
|
|
fmt.Sprintf("\nFailed to parse config file (%s): ", configFile.Path())+err.Error(),
|
|
),
|
|
)
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
utilsErr.OnType(func(err types.FieldParseErrors) error {
|
|
printParseErrors(fmt.Sprintf("CONFIG FILE '%s'", configFile.Path()), err.Errors...)
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
)
|
|
|
|
envConfig.Merge(fileConfig)
|
|
}
|
|
|
|
envConfig.Merge(cliConf)
|
|
|
|
return envConfig
|
|
}
|
|
|
|
// parseConfigFile recursively parses a config file and its nested files up to maxDepth levels.
|
|
// Returns the merged configuration or an error if parsing fails.
|
|
// It can return the following errors:
|
|
// - types.ConfigFileReadError
|
|
// - types.UnmarshalError
|
|
// - types.FieldParseErrors
|
|
func parseConfigFile(configFile types.ConfigFile, maxDepth int) (*Config, error) {
|
|
configFileParser := NewConfigFileParser(configFile)
|
|
fileConfig, err := configFileParser.Parse()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if maxDepth <= 0 {
|
|
return fileConfig, nil
|
|
}
|
|
|
|
for _, c := range fileConfig.Files {
|
|
innerFileConfig, err := parseConfigFile(c, maxDepth-1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
innerFileConfig.Merge(fileConfig)
|
|
fileConfig = innerFileConfig
|
|
}
|
|
|
|
return fileConfig, nil
|
|
}
|
|
|
|
func printParseErrors(parserName string, errors ...types.FieldParseError) {
|
|
for _, fieldErr := range errors {
|
|
if fieldErr.Value == "" {
|
|
fmt.Fprintln(os.Stderr,
|
|
StyleYellow.Render(fmt.Sprintf("[%s] Field '%s': ", parserName, fieldErr.Field))+fieldErr.Err.Error(),
|
|
)
|
|
} else {
|
|
fmt.Fprintln(os.Stderr,
|
|
StyleYellow.Render(fmt.Sprintf("[%s] Field '%s' (%s): ", parserName, fieldErr.Field, fieldErr.Value))+fieldErr.Err.Error(),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
const (
|
|
scrollbarWidth = 1
|
|
scrollbarBottomSpace = 1
|
|
statusDisplayTime = 3 * time.Second
|
|
)
|
|
|
|
var (
|
|
printConfigBorderStyle = func() lipgloss.Border {
|
|
b := lipgloss.RoundedBorder()
|
|
return b
|
|
}()
|
|
|
|
printConfigHelpStyle = lipgloss.NewStyle().BorderStyle(printConfigBorderStyle).Padding(0, 1)
|
|
printConfigSuccessStatusStyle = lipgloss.NewStyle().BorderStyle(printConfigBorderStyle).Padding(0, 1).Foreground(lipgloss.Color("10"))
|
|
printConfigErrorStatusStyle = lipgloss.NewStyle().BorderStyle(printConfigBorderStyle).Padding(0, 1).Foreground(lipgloss.Color("9"))
|
|
printConfigKeyStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")).Bold(true)
|
|
printConfigDescStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
|
|
)
|
|
|
|
type printConfigClearStatusMsg struct{}
|
|
|
|
type printConfigModel struct {
|
|
viewport viewport.Model
|
|
content string
|
|
rawContent []byte
|
|
statusMsg string
|
|
ready bool
|
|
start bool
|
|
}
|
|
|
|
func (m printConfigModel) Init() tea.Cmd { return nil }
|
|
|
|
func (m printConfigModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|
var cmd tea.Cmd
|
|
|
|
switch msg := msg.(type) {
|
|
case tea.KeyMsg:
|
|
switch msg.String() {
|
|
case "ctrl+c", "esc":
|
|
return m, tea.Quit
|
|
case "ctrl+s":
|
|
return m.saveContent()
|
|
case "enter":
|
|
m.start = true
|
|
return m, tea.Quit
|
|
}
|
|
|
|
case printConfigClearStatusMsg:
|
|
m.statusMsg = ""
|
|
return m, nil
|
|
|
|
case tea.WindowSizeMsg:
|
|
m.handleResize(msg)
|
|
}
|
|
|
|
m.viewport, cmd = m.viewport.Update(msg)
|
|
return m, cmd
|
|
}
|
|
|
|
func (m printConfigModel) View() string {
|
|
if !m.ready {
|
|
return "\n Initializing..."
|
|
}
|
|
|
|
content := lipgloss.JoinHorizontal(lipgloss.Top, m.viewport.View(), m.scrollbar())
|
|
return fmt.Sprintf("%s\n%s\n%s", m.headerView(), content, m.footerView())
|
|
}
|
|
|
|
func (m *printConfigModel) saveContent() (printConfigModel, tea.Cmd) {
|
|
filename := fmt.Sprintf("sarin_config_%s.yaml", time.Now().Format("2006-01-02_15-04-05"))
|
|
if err := os.WriteFile(filename, m.rawContent, 0600); err != nil {
|
|
m.statusMsg = printConfigErrorStatusStyle.Render("✗ Error saving file: " + err.Error())
|
|
} else {
|
|
m.statusMsg = printConfigSuccessStatusStyle.Render("✓ Saved to " + filename)
|
|
}
|
|
return *m, tea.Tick(statusDisplayTime, func(time.Time) tea.Msg { return printConfigClearStatusMsg{} })
|
|
}
|
|
|
|
func (m *printConfigModel) handleResize(msg tea.WindowSizeMsg) {
|
|
headerHeight := lipgloss.Height(m.headerView())
|
|
footerHeight := lipgloss.Height(m.footerView())
|
|
height := msg.Height - headerHeight - footerHeight
|
|
width := msg.Width - scrollbarWidth
|
|
|
|
if !m.ready {
|
|
m.viewport = viewport.New(width, height)
|
|
m.viewport.SetContent(m.contentWithLineNumbers())
|
|
m.ready = true
|
|
} else {
|
|
m.viewport.Width = width
|
|
m.viewport.Height = height
|
|
}
|
|
}
|
|
|
|
func (m printConfigModel) headerView() string {
|
|
var title string
|
|
if m.statusMsg != "" {
|
|
title = ("" + m.statusMsg)
|
|
} else {
|
|
sep := printConfigDescStyle.Render(" / ")
|
|
help := printConfigKeyStyle.Render("ENTER") + printConfigDescStyle.Render(" start") + sep +
|
|
printConfigKeyStyle.Render("CTRL+S") + printConfigDescStyle.Render(" save") + sep +
|
|
printConfigKeyStyle.Render("ESC") + printConfigDescStyle.Render(" exit")
|
|
title = printConfigHelpStyle.Render(help)
|
|
}
|
|
line := strings.Repeat("─", max(0, m.viewport.Width+scrollbarWidth-lipgloss.Width(title)))
|
|
return lipgloss.JoinHorizontal(lipgloss.Center, title, line)
|
|
}
|
|
|
|
func (m printConfigModel) footerView() string {
|
|
return strings.Repeat("─", m.viewport.Width+scrollbarWidth)
|
|
}
|
|
|
|
func (m printConfigModel) contentWithLineNumbers() string {
|
|
lines := strings.Split(m.content, "\n")
|
|
width := len(strconv.Itoa(len(lines)))
|
|
lineNumStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("246"))
|
|
|
|
var sb strings.Builder
|
|
for i, line := range lines {
|
|
lineNum := lineNumStyle.Render(fmt.Sprintf("%*d", width, i+1))
|
|
sb.WriteString(lineNum)
|
|
sb.WriteString(" ")
|
|
sb.WriteString(line)
|
|
if i < len(lines)-1 {
|
|
sb.WriteByte('\n')
|
|
}
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
func (m printConfigModel) scrollbar() string {
|
|
height := m.viewport.Height
|
|
trackHeight := height - scrollbarBottomSpace
|
|
totalLines := m.viewport.TotalLineCount()
|
|
|
|
if totalLines <= height {
|
|
return strings.Repeat(" \n", trackHeight) + " "
|
|
}
|
|
|
|
thumbSize := max(1, (height*trackHeight)/totalLines)
|
|
thumbPos := int(m.viewport.ScrollPercent() * float64(trackHeight-thumbSize))
|
|
|
|
var sb strings.Builder
|
|
for i := range trackHeight {
|
|
if i >= thumbPos && i < thumbPos+thumbSize {
|
|
sb.WriteByte('\xe2') // █ (U+2588)
|
|
sb.WriteByte('\x96')
|
|
sb.WriteByte('\x88')
|
|
} else {
|
|
sb.WriteByte('\xe2') // ░ (U+2591)
|
|
sb.WriteByte('\x96')
|
|
sb.WriteByte('\x91')
|
|
}
|
|
sb.WriteByte('\n')
|
|
}
|
|
sb.WriteByte(' ')
|
|
return sb.String()
|
|
}
|