diff --git a/cmd/cli/main.go b/cmd/cli/main.go index b5703f7..18687a2 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -53,7 +53,12 @@ func main() { combinedConfig.Cookies, combinedConfig.Bodies, combinedConfig.Proxies, combinedConfig.Values, *combinedConfig.Output != config.ConfigOutputTypeNone, *combinedConfig.DryRun, + combinedConfig.Lua, combinedConfig.Js, ) + if err != nil { + fmt.Fprintln(os.Stderr, config.StyleRed.Render("[ERROR] ")+err.Error()) + os.Exit(1) + } _ = utilsErr.MustHandle(err, utilsErr.OnType(func(err types.ProxyDialError) error { fmt.Fprintln(os.Stderr, config.StyleRed.Render("[PROXY] ")+err.Error()) diff --git a/go.mod b/go.mod index d6a5c29..8f0299e 100644 --- a/go.mod +++ b/go.mod @@ -9,8 +9,10 @@ require ( github.com/charmbracelet/glamour v0.10.0 github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 github.com/charmbracelet/x/term v0.2.2 + github.com/dop251/goja v0.0.0-20260106131823-651366fbe6e3 github.com/joho/godotenv v1.5.1 github.com/valyala/fasthttp v1.69.0 + github.com/yuin/gopher-lua v1.1.1 go.aykhans.me/utils v1.0.7 go.yaml.in/yaml/v4 v4.0.0-rc.3 golang.org/x/net v0.49.0 @@ -32,6 +34,8 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect github.com/gorilla/css v1.0.1 // indirect github.com/klauspost/compress v1.18.2 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect diff --git a/go.sum b/go.sum index 11d4273..2b0e427 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/chroma/v2 v2.21.1 h1:FaSDrp6N+3pphkNKU6HPCiYLgm8dbe5UXIXcoBhZSWA= @@ -46,8 +48,14 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20260106131823-651366fbe6e3 h1:bVp3yUzvSAJzu9GqID+Z96P+eu5TKnIMJSV4QaZMauM= +github.com/dop251/goja v0.0.0-20260106131823-651366fbe6e3/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= +github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= +github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= @@ -95,6 +103,8 @@ github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs= github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.aykhans.me/utils v1.0.7 h1:ClHXHlWmkjfFlD7+w5BQY29lKCEztxY/yCf543x4hZw= go.aykhans.me/utils v1.0.7/go.mod h1:0Jz8GlZLN35cCHLOLx39sazWwEe33bF6SYlSeqzEXoI= go.yaml.in/yaml/v4 v4.0.0-rc.3 h1:3h1fjsh1CTAPjW7q/EMe+C8shx5d8ctzZTrLcs/j8Go= @@ -111,5 +121,7 @@ golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/config/cli.go b/internal/config/cli.go index 5e0ed60..5ae615f 100644 --- a/internal/config/cli.go +++ b/internal/config/cli.go @@ -17,20 +17,7 @@ const cliUsageText = `Usage: sarin [flags] Simple usage: - sarin -U https://example.com -d 1m - -Usage with all flags: - sarin -s -q -z -o json -f ./config.yaml -c 50 -r 100_000 -d 2m30s \ - -U https://example.com \ - -M POST \ - -V "sharedUUID={{ fakeit_UUID }}" \ - -B '{"product": "car"}' \ - -P "id={{ .Values.sharedUUID }}" \ - -H "User-Agent: {{ fakeit_UserAgent }}" -H "Accept: */*" \ - -C "token={{ .Values.sharedUUID }}" \ - -X "http://proxy.example.com" \ - -T 3s \ - -I + sarin -U https://example.com -r 1 Flags: General Config: @@ -55,7 +42,9 @@ Flags: -X, -proxy []string Proxy for the request (e.g. "http://proxy.example.com:8080") -V, -values []string List of values for templating (e.g. "key1=value1") -T, -timeout time Timeout for the request (e.g. 400ms, 3s, 1m10s) (default %v) - -I, -insecure bool Skip SSL/TLS certificate verification (default %v)` + -I, -insecure bool Skip SSL/TLS certificate verification (default %v) + -lua []string Lua script for request transformation (inline or @file/@url) + -js []string JavaScript script for request transformation (inline or @file/@url)` var _ IParser = ConfigCLIParser{} @@ -106,16 +95,18 @@ func (parser ConfigCLIParser) Parse() (*Config, error) { dryRun bool // Request config - urlInput string - methods = stringSliceArg{} - bodies = stringSliceArg{} - params = stringSliceArg{} - headers = stringSliceArg{} - cookies = stringSliceArg{} - proxies = stringSliceArg{} - values = stringSliceArg{} - timeout time.Duration - insecure bool + urlInput string + methods = stringSliceArg{} + bodies = stringSliceArg{} + params = stringSliceArg{} + headers = stringSliceArg{} + cookies = stringSliceArg{} + proxies = stringSliceArg{} + values = stringSliceArg{} + timeout time.Duration + insecure bool + luaScripts = stringSliceArg{} + jsScripts = stringSliceArg{} ) { @@ -177,6 +168,10 @@ func (parser ConfigCLIParser) Parse() (*Config, error) { flagSet.BoolVar(&insecure, "insecure", false, "Skip SSL/TLS certificate verification") flagSet.BoolVar(&insecure, "I", false, "Skip SSL/TLS certificate verification") + + flagSet.Var(&luaScripts, "lua", "Lua script for request transformation (inline or @file/@url)") + + flagSet.Var(&jsScripts, "js", "JavaScript script for request transformation (inline or @file/@url)") } // Parse the specific arguments provided to the parser, skipping the program name. @@ -259,6 +254,10 @@ func (parser ConfigCLIParser) Parse() (*Config, error) { config.Timeout = common.ToPtr(timeout) case "insecure", "I": config.Insecure = common.ToPtr(insecure) + case "lua": + config.Lua = append(config.Lua, luaScripts...) + case "js": + config.Js = append(config.Js, jsScripts...) } }) diff --git a/internal/config/config.go b/internal/config/config.go index cbaa374..4609c78 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "context" "errors" "fmt" "net/url" @@ -16,6 +17,7 @@ import ( "github.com/charmbracelet/glamour/styles" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/term" + "go.aykhans.me/sarin/internal/script" "go.aykhans.me/sarin/internal/types" "go.aykhans.me/sarin/internal/version" "go.aykhans.me/utils/common" @@ -87,6 +89,8 @@ type Config struct { Bodies []string `yaml:"bodies,omitempty"` Proxies types.Proxies `yaml:"proxies,omitempty"` Values []string `yaml:"values,omitempty"` + Lua []string `yaml:"lua,omitempty"` + Js []string `yaml:"js,omitempty"` } func NewConfig() *Config { @@ -219,6 +223,8 @@ func (config Config) MarshalYAML() (any, error) { } addStringSlice(content, "values", config.Values, false) + addStringSlice(content, "lua", config.Lua, false) + addStringSlice(content, "js", config.Js, false) return root, nil } @@ -323,6 +329,12 @@ func (config *Config) Merge(newConfig *Config) { if len(newConfig.Values) != 0 { config.Values = append(config.Values, newConfig.Values...) } + if len(newConfig.Lua) != 0 { + config.Lua = append(config.Lua, newConfig.Lua...) + } + if len(newConfig.Js) != 0 { + config.Js = append(config.Js, newConfig.Js...) + } } func (config *Config) SetDefaults() { @@ -465,6 +477,44 @@ func (config Config) Validate() error { } } + // Create a context with timeout for script validation (loading from URLs) + scriptCtx, scriptCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer scriptCancel() + + for i, scriptSrc := range config.Lua { + if err := validateScriptSource(scriptSrc); err != nil { + validationErrors = append( + validationErrors, + types.NewFieldValidationError(fmt.Sprintf("Lua[%d]", i), scriptSrc, err), + ) + continue + } + // Validate script syntax + if err := script.ValidateScript(scriptCtx, scriptSrc, script.EngineTypeLua); err != nil { + validationErrors = append( + validationErrors, + types.NewFieldValidationError(fmt.Sprintf("Lua[%d]", i), scriptSrc, err), + ) + } + } + + for i, scriptSrc := range config.Js { + if err := validateScriptSource(scriptSrc); err != nil { + validationErrors = append( + validationErrors, + types.NewFieldValidationError(fmt.Sprintf("Js[%d]", i), scriptSrc, err), + ) + continue + } + // Validate script syntax + if err := script.ValidateScript(scriptCtx, scriptSrc, script.EngineTypeJavaScript); err != nil { + validationErrors = append( + validationErrors, + types.NewFieldValidationError(fmt.Sprintf("Js[%d]", i), scriptSrc, err), + ) + } + } + templateErrors := ValidateTemplates(&config) validationErrors = append(validationErrors, templateErrors...) @@ -582,6 +632,51 @@ func parseConfigFile(configFile types.ConfigFile, maxDepth int) (*Config, error) return fileConfig, nil } +// validateScriptSource validates a script source string. +// Scripts can be: +// - Inline script: any string not starting with "@" +// - Escaped "@": strings starting with "@@" (literal "@" at start) +// - File reference: "@/path/to/file" or "@./relative/path" +// - URL reference: "@http://..." or "@https://..." +func validateScriptSource(script string) error { + // Empty script is invalid + if script == "" { + return errors.New("script cannot be empty") + } + + // Not a file/URL reference - it's an inline script + if !strings.HasPrefix(script, "@") { + return nil + } + + // Escaped @ - it's an inline script starting with literal @ + if strings.HasPrefix(script, "@@") { + return nil + } + + // It's a file or URL reference - validate the source + source := script[1:] // Remove the @ prefix + + if source == "" { + return errors.New("script source cannot be empty after @") + } + + // Check if it's a URL + if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { + parsedURL, err := url.Parse(source) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + if parsedURL.Host == "" { + return errors.New("URL must have a host") + } + return nil + } + + // It's a file path - basic validation (not empty, checked above) + return nil +} + func printParseErrors(parserName string, errors ...types.FieldParseError) { for _, fieldErr := range errors { if fieldErr.Value == "" { diff --git a/internal/config/env.go b/internal/config/env.go index fbdd0e9..6736b9a 100644 --- a/internal/config/env.go +++ b/internal/config/env.go @@ -216,6 +216,14 @@ func (parser ConfigENVParser) Parse() (*Config, error) { config.Values = []string{values} } + if lua := parser.getEnv("LUA"); lua != "" { + config.Lua = []string{lua} + } + + if js := parser.getEnv("JS"); js != "" { + config.Js = []string{js} + } + if len(fieldParseErrors) > 0 { return nil, types.NewFieldParseErrors(fieldParseErrors) } diff --git a/internal/config/file.go b/internal/config/file.go index 5df450d..59f94db 100644 --- a/internal/config/file.go +++ b/internal/config/file.go @@ -202,6 +202,8 @@ type configYAML struct { Bodies stringOrSliceField `yaml:"body"` Proxies stringOrSliceField `yaml:"proxy"` Values stringOrSliceField `yaml:"values"` + Lua stringOrSliceField `yaml:"lua"` + Js stringOrSliceField `yaml:"js"` } // ParseYAML parses YAML config file arguments into a Config object. @@ -246,6 +248,8 @@ func (parser ConfigFileParser) ParseYAML(data []byte) (*Config, error) { } config.Bodies = append(config.Bodies, parsedData.Bodies...) config.Values = append(config.Values, parsedData.Values...) + config.Lua = append(config.Lua, parsedData.Lua...) + config.Js = append(config.Js, parsedData.Js...) if len(parsedData.ConfigFiles) > 0 { for _, configFile := range parsedData.ConfigFiles { diff --git a/internal/sarin/request.go b/internal/sarin/request.go index aeddcb4..06503fd 100644 --- a/internal/sarin/request.go +++ b/internal/sarin/request.go @@ -11,6 +11,7 @@ import ( "github.com/joho/godotenv" "github.com/valyala/fasthttp" + "go.aykhans.me/sarin/internal/script" "go.aykhans.me/sarin/internal/types" utilsSlice "go.aykhans.me/utils/slice" ) @@ -26,6 +27,9 @@ type valuesData struct { // NewRequestGenerator creates a new RequestGenerator function that generates HTTP requests // with the specified configuration. The returned RequestGenerator is NOT safe for concurrent // use by multiple goroutines. +// +// Note: Scripts must be validated before calling this function (e.g., in NewSarin). +// The caller is responsible for managing the scriptTransformer lifecycle. func NewRequestGenerator( methods []string, requestURL *url.URL, @@ -35,6 +39,7 @@ func NewRequestGenerator( bodies []string, values []string, fileCache *FileCache, + scriptTransformer *script.Transformer, ) (RequestGenerator, bool) { randSource := NewDefaultRandSource() //nolint:gosec // G404: Using non-cryptographic rand for load testing, not security @@ -53,6 +58,8 @@ func NewRequestGenerator( valuesGenerator := NewValuesGeneratorFunc(values, templateFuncMap) + hasScripts := scriptTransformer != nil && !scriptTransformer.IsEmpty() + var ( data valuesData path string @@ -98,13 +105,24 @@ func NewRequestGenerator( if requestURL.Scheme == "https" { req.URI().SetScheme("https") } + + // Apply script transformations if any + if hasScripts { + reqData := script.RequestDataFromFastHTTP(req) + if err = scriptTransformer.Transform(reqData); err != nil { + return err + } + script.ApplyToFastHTTP(reqData, req) + } + return nil }, isPathGeneratorDynamic || isMethodGeneratorDynamic || isParamsGeneratorDynamic || isHeadersGeneratorDynamic || isCookiesGeneratorDynamic || - isBodyGeneratorDynamic + isBodyGeneratorDynamic || + hasScripts } func NewMethodGeneratorFunc(localRand *rand.Rand, methods []string, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) { diff --git a/internal/sarin/sarin.go b/internal/sarin/sarin.go index 664e9b7..f18b048 100644 --- a/internal/sarin/sarin.go +++ b/internal/sarin/sarin.go @@ -14,6 +14,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/valyala/fasthttp" + "go.aykhans.me/sarin/internal/script" "go.aykhans.me/sarin/internal/types" ) @@ -52,11 +53,13 @@ type sarin struct { hostClients []*fasthttp.HostClient responses *SarinResponseData fileCache *FileCache + scriptChain *script.Chain } // NewSarin creates a new sarin instance for load testing. // It can return the following errors: // - types.ProxyDialError +// - script loading errors func NewSarin( ctx context.Context, methods []string, @@ -75,6 +78,8 @@ func NewSarin( values []string, collectStats bool, dryRun bool, + luaScripts []string, + jsScripts []string, ) (*sarin, error) { if workers == 0 { workers = 1 @@ -85,6 +90,19 @@ func NewSarin( return nil, err } + // Load script sources + luaSources, err := script.LoadSources(ctx, luaScripts, script.EngineTypeLua) + if err != nil { + return nil, err + } + + jsSources, err := script.LoadSources(ctx, jsScripts, script.EngineTypeJavaScript) + if err != nil { + return nil, err + } + + scriptChain := script.NewChain(luaSources, jsSources) + srn := &sarin{ workers: workers, requestURL: requestURL, @@ -103,6 +121,7 @@ func NewSarin( dryRun: dryRun, hostClients: hostClients, fileCache: NewFileCache(time.Second * 10), + scriptChain: scriptChain, } if collectStats { @@ -193,7 +212,20 @@ func (q sarin) Worker( defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) - requestGenerator, isDynamic := NewRequestGenerator(q.methods, q.requestURL, q.params, q.headers, q.cookies, q.bodies, q.values, q.fileCache) + // Create script transformer for this worker (engines are not thread-safe) + // Scripts are pre-validated in NewSarin, so this should not fail + var scriptTransformer *script.Transformer + if !q.scriptChain.IsEmpty() { + scriptTransformer, err := q.scriptChain.NewTransformer() + if err != nil { + panic(err) + } + defer scriptTransformer.Close() + } + + requestGenerator, isDynamic := NewRequestGenerator( + q.methods, q.requestURL, q.params, q.headers, q.cookies, q.bodies, q.values, q.fileCache, scriptTransformer, + ) if q.dryRun { switch { diff --git a/internal/script/chain.go b/internal/script/chain.go new file mode 100644 index 0000000..c4158ba --- /dev/null +++ b/internal/script/chain.go @@ -0,0 +1,185 @@ +package script + +import ( + "fmt" + + "github.com/valyala/fasthttp" +) + +// Chain holds the loaded script sources and can create engine instances. +// The sources are loaded once, but engines are created per-worker since they're not thread-safe. +type Chain struct { + luaSources []*Source + jsSources []*Source +} + +// NewChain creates a new script chain from loaded sources. +// Lua scripts run first, then JavaScript scripts, in the order provided. +func NewChain(luaSources, jsSources []*Source) *Chain { + return &Chain{ + luaSources: luaSources, + jsSources: jsSources, + } +} + +// IsEmpty returns true if there are no scripts to execute. +func (c *Chain) IsEmpty() bool { + return len(c.luaSources) == 0 && len(c.jsSources) == 0 +} + +// Transformer holds instantiated script engines for a single worker. +// It is NOT safe for concurrent use. +type Transformer struct { + luaEngines []*LuaEngine + jsEngines []*JsEngine +} + +// NewTransformer creates engine instances from the chain's sources. +// Call this once per worker goroutine. +func (c *Chain) NewTransformer() (*Transformer, error) { + if c.IsEmpty() { + return &Transformer{}, nil + } + + t := &Transformer{ + luaEngines: make([]*LuaEngine, 0, len(c.luaSources)), + jsEngines: make([]*JsEngine, 0, len(c.jsSources)), + } + + // Create Lua engines + for i, src := range c.luaSources { + engine, err := NewLuaEngine(src.Content) + if err != nil { + t.Close() // Clean up already created engines + return nil, fmt.Errorf("lua script[%d]: %w", i, err) + } + t.luaEngines = append(t.luaEngines, engine) + } + + // Create JS engines + for i, src := range c.jsSources { + engine, err := NewJsEngine(src.Content) + if err != nil { + t.Close() // Clean up already created engines + return nil, fmt.Errorf("js script[%d]: %w", i, err) + } + t.jsEngines = append(t.jsEngines, engine) + } + + return t, nil +} + +// Transform applies all scripts to the request data. +// Lua scripts run first, then JavaScript scripts. +func (t *Transformer) Transform(req *RequestData) error { + // Run Lua scripts + for i, engine := range t.luaEngines { + if err := engine.Transform(req); err != nil { + return fmt.Errorf("lua script[%d]: %w", i, err) + } + } + + // Run JS scripts + for i, engine := range t.jsEngines { + if err := engine.Transform(req); err != nil { + return fmt.Errorf("js script[%d]: %w", i, err) + } + } + + return nil +} + +// Close releases all engine resources. +func (t *Transformer) Close() { + for _, engine := range t.luaEngines { + engine.Close() + } + for _, engine := range t.jsEngines { + engine.Close() + } +} + +// IsEmpty returns true if there are no engines. +func (t *Transformer) IsEmpty() bool { + return len(t.luaEngines) == 0 && len(t.jsEngines) == 0 +} + +// RequestDataFromFastHTTP extracts RequestData from a fasthttp.Request. +func RequestDataFromFastHTTP(req *fasthttp.Request) *RequestData { + data := &RequestData{ + Method: string(req.Header.Method()), + URL: string(req.URI().FullURI()), + Path: string(req.URI().Path()), + Body: string(req.Body()), + Headers: make(map[string][]string), + Params: make(map[string][]string), + Cookies: make(map[string][]string), + } + + // Extract headers (supports multiple values per key) + req.Header.All()(func(key, value []byte) bool { + k := string(key) + data.Headers[k] = append(data.Headers[k], string(value)) + return true + }) + + // Extract query params (supports multiple values per key) + req.URI().QueryArgs().All()(func(key, value []byte) bool { + k := string(key) + data.Params[k] = append(data.Params[k], string(value)) + return true + }) + + // Extract cookies (supports multiple values per key) + req.Header.Cookies()(func(key, value []byte) bool { + k := string(key) + data.Cookies[k] = append(data.Cookies[k], string(value)) + return true + }) + + return data +} + +// ApplyToFastHTTP applies the modified RequestData back to a fasthttp.Request. +func ApplyToFastHTTP(data *RequestData, req *fasthttp.Request) { + // Method + req.Header.SetMethod(data.Method) + + // Path (preserve scheme and host) + req.URI().SetPath(data.Path) + + // Body + req.SetBody([]byte(data.Body)) + + // Clear and set headers (supports multiple values per key) + req.Header.All()(func(key, _ []byte) bool { + keyStr := string(key) + if keyStr != "Host" { + req.Header.Del(keyStr) + } + return true + }) + for k, values := range data.Headers { + if k != "Host" { // Don't overwrite Host + for _, v := range values { + req.Header.Add(k, v) + } + } + } + + // Clear and set query params (supports multiple values per key) + req.URI().QueryArgs().Reset() + for k, values := range data.Params { + for _, v := range values { + req.URI().QueryArgs().Add(k, v) + } + } + + // Clear and set cookies (supports multiple values per key) + req.Header.DelAllCookies() + for k, values := range data.Cookies { + for _, v := range values { + req.Header.SetCookie(k, v) + } + } +} diff --git a/internal/script/js.go b/internal/script/js.go new file mode 100644 index 0000000..4e22f51 --- /dev/null +++ b/internal/script/js.go @@ -0,0 +1,198 @@ +package script + +import ( + "errors" + "fmt" + + "github.com/dop251/goja" +) + +// JsEngine implements the Engine interface using goja (JavaScript). +type JsEngine struct { + runtime *goja.Runtime + transform goja.Callable +} + +// NewJsEngine creates a new JavaScript script engine with the given script content. +// The script must define a global `transform` function that takes a request object +// and returns the modified request object. +// +// Example JavaScript script: +// +// function transform(req) { +// req.headers["X-Custom"] = "value"; +// return req; +// } +func NewJsEngine(scriptContent string) (*JsEngine, error) { + vm := goja.New() + + // Execute the script to define the transform function + _, err := vm.RunString(scriptContent) + if err != nil { + return nil, fmt.Errorf("failed to execute JavaScript script: %w", err) + } + + // Get the transform function + transformVal := vm.Get("transform") + if transformVal == nil || goja.IsUndefined(transformVal) || goja.IsNull(transformVal) { + return nil, errors.New("script must define a global 'transform' function") + } + + transform, ok := goja.AssertFunction(transformVal) + if !ok { + return nil, errors.New("'transform' must be a function") + } + + return &JsEngine{ + runtime: vm, + transform: transform, + }, nil +} + +// Transform executes the JavaScript transform function with the given request data. +func (e *JsEngine) Transform(req *RequestData) error { + // Convert RequestData to JavaScript object + reqObj := e.requestDataToObject(req) + + // Call transform(req) + result, err := e.transform(goja.Undefined(), reqObj) + if err != nil { + return fmt.Errorf("JavaScript transform error: %w", err) + } + + // Update RequestData from the returned object + if err := e.objectToRequestData(result, req); err != nil { + return fmt.Errorf("failed to parse transform result: %w", err) + } + + return nil +} + +// Close releases the JavaScript runtime resources. +func (e *JsEngine) Close() { + // goja doesn't have an explicit close method, but we can help GC + e.runtime = nil + e.transform = nil +} + +// requestDataToObject converts RequestData to a goja Value (JavaScript object). +func (e *JsEngine) requestDataToObject(req *RequestData) goja.Value { + obj := e.runtime.NewObject() + + _ = obj.Set("method", req.Method) + _ = obj.Set("url", req.URL) + _ = obj.Set("path", req.Path) + _ = obj.Set("body", req.Body) + + // Headers (map[string][]string -> object of arrays) + headers := e.runtime.NewObject() + for k, values := range req.Headers { + _ = headers.Set(k, e.stringSliceToArray(values)) + } + _ = obj.Set("headers", headers) + + // Params (map[string][]string -> object of arrays) + params := e.runtime.NewObject() + for k, values := range req.Params { + _ = params.Set(k, e.stringSliceToArray(values)) + } + _ = obj.Set("params", params) + + // Cookies (map[string][]string -> object of arrays) + cookies := e.runtime.NewObject() + for k, values := range req.Cookies { + _ = cookies.Set(k, e.stringSliceToArray(values)) + } + _ = obj.Set("cookies", cookies) + + return obj +} + +// objectToRequestData updates RequestData from a JavaScript object. +func (e *JsEngine) objectToRequestData(val goja.Value, req *RequestData) error { + if val == nil || goja.IsUndefined(val) || goja.IsNull(val) { + return errors.New("transform function must return an object") + } + + obj := val.ToObject(e.runtime) + if obj == nil { + return errors.New("transform function must return an object") + } + + // Method + if v := obj.Get("method"); v != nil && !goja.IsUndefined(v) { + req.Method = v.String() + } + + // URL + if v := obj.Get("url"); v != nil && !goja.IsUndefined(v) { + req.URL = v.String() + } + + // Path + if v := obj.Get("path"); v != nil && !goja.IsUndefined(v) { + req.Path = v.String() + } + + // Body + if v := obj.Get("body"); v != nil && !goja.IsUndefined(v) { + req.Body = v.String() + } + + // Headers + if v := obj.Get("headers"); v != nil && !goja.IsUndefined(v) && !goja.IsNull(v) { + req.Headers = e.objectToStringSliceMap(v.ToObject(e.runtime)) + } + + // Params + if v := obj.Get("params"); v != nil && !goja.IsUndefined(v) && !goja.IsNull(v) { + req.Params = e.objectToStringSliceMap(v.ToObject(e.runtime)) + } + + // Cookies + if v := obj.Get("cookies"); v != nil && !goja.IsUndefined(v) && !goja.IsNull(v) { + req.Cookies = e.objectToStringSliceMap(v.ToObject(e.runtime)) + } + + return nil +} + +// stringSliceToArray converts a Go []string to a JavaScript array. +func (e *JsEngine) stringSliceToArray(values []string) *goja.Object { + ifaces := make([]interface{}, len(values)) + for i, v := range values { + ifaces[i] = v + } + return e.runtime.NewArray(ifaces...) +} + +// objectToStringSliceMap converts a JavaScript object to a Go map[string][]string. +// Supports both single string values and array values. +func (e *JsEngine) objectToStringSliceMap(obj *goja.Object) map[string][]string { + if obj == nil { + return make(map[string][]string) + } + + result := make(map[string][]string) + for _, key := range obj.Keys() { + v := obj.Get(key) + if v == nil || goja.IsUndefined(v) || goja.IsNull(v) { + continue + } + + // Check if it's an array + if arr, ok := v.Export().([]interface{}); ok { + var values []string + for _, item := range arr { + if s, ok := item.(string); ok { + values = append(values, s) + } + } + result[key] = values + } else { + // Single value - wrap in slice + result[key] = []string{v.String()} + } + } + return result +} diff --git a/internal/script/lua.go b/internal/script/lua.go new file mode 100644 index 0000000..013f5c7 --- /dev/null +++ b/internal/script/lua.go @@ -0,0 +1,191 @@ +package script + +import ( + "errors" + "fmt" + + lua "github.com/yuin/gopher-lua" +) + +// LuaEngine implements the Engine interface using gopher-lua. +type LuaEngine struct { + state *lua.LState + transform *lua.LFunction +} + +// NewLuaEngine creates a new Lua script engine with the given script content. +// The script must define a global `transform` function that takes a request table +// and returns the modified request table. +// +// Example Lua script: +// +// function transform(req) +// req.headers["X-Custom"] = "value" +// return req +// end +func NewLuaEngine(scriptContent string) (*LuaEngine, error) { + L := lua.NewState() + + // Execute the script to define the transform function + if err := L.DoString(scriptContent); err != nil { + L.Close() + return nil, fmt.Errorf("failed to execute Lua script: %w", err) + } + + // Get the transform function + transform := L.GetGlobal("transform") + if transform.Type() != lua.LTFunction { + L.Close() + return nil, errors.New("script must define a global 'transform' function") + } + + return &LuaEngine{ + state: L, + transform: transform.(*lua.LFunction), + }, nil +} + +// Transform executes the Lua transform function with the given request data. +func (e *LuaEngine) Transform(req *RequestData) error { + // Convert RequestData to Lua table + reqTable := e.requestDataToTable(req) + + // Call transform(req) + e.state.Push(e.transform) + e.state.Push(reqTable) + if err := e.state.PCall(1, 1, nil); err != nil { + return fmt.Errorf("lua transform error: %w", err) + } + + // Get the result + result := e.state.Get(-1) + e.state.Pop(1) + + if result.Type() != lua.LTTable { + return fmt.Errorf("transform function must return a table, got %s", result.Type()) + } + + // Update RequestData from the returned table + e.tableToRequestData(result.(*lua.LTable), req) + + return nil +} + +// Close releases the Lua state resources. +func (e *LuaEngine) Close() { + if e.state != nil { + e.state.Close() + } +} + +// requestDataToTable converts RequestData to a Lua table. +func (e *LuaEngine) requestDataToTable(req *RequestData) *lua.LTable { + L := e.state + t := L.NewTable() + + t.RawSetString("method", lua.LString(req.Method)) + t.RawSetString("url", lua.LString(req.URL)) + t.RawSetString("path", lua.LString(req.Path)) + t.RawSetString("body", lua.LString(req.Body)) + + // Headers (map[string][]string -> table of arrays) + headers := L.NewTable() + for k, values := range req.Headers { + arr := L.NewTable() + for _, v := range values { + arr.Append(lua.LString(v)) + } + headers.RawSetString(k, arr) + } + t.RawSetString("headers", headers) + + // Params (map[string][]string -> table of arrays) + params := L.NewTable() + for k, values := range req.Params { + arr := L.NewTable() + for _, v := range values { + arr.Append(lua.LString(v)) + } + params.RawSetString(k, arr) + } + t.RawSetString("params", params) + + // Cookies (map[string][]string -> table of arrays) + cookies := L.NewTable() + for k, values := range req.Cookies { + arr := L.NewTable() + for _, v := range values { + arr.Append(lua.LString(v)) + } + cookies.RawSetString(k, arr) + } + t.RawSetString("cookies", cookies) + + return t +} + +// tableToRequestData updates RequestData from a Lua table. +func (e *LuaEngine) tableToRequestData(t *lua.LTable, req *RequestData) { + // Method + if v := t.RawGetString("method"); v.Type() == lua.LTString { + req.Method = string(v.(lua.LString)) + } + + // URL + if v := t.RawGetString("url"); v.Type() == lua.LTString { + req.URL = string(v.(lua.LString)) + } + + // Path + if v := t.RawGetString("path"); v.Type() == lua.LTString { + req.Path = string(v.(lua.LString)) + } + + // Body + if v := t.RawGetString("body"); v.Type() == lua.LTString { + req.Body = string(v.(lua.LString)) + } + + // Headers + if v := t.RawGetString("headers"); v.Type() == lua.LTTable { + req.Headers = e.tableToStringSliceMap(v.(*lua.LTable)) + } + + // Params + if v := t.RawGetString("params"); v.Type() == lua.LTTable { + req.Params = e.tableToStringSliceMap(v.(*lua.LTable)) + } + + // Cookies + if v := t.RawGetString("cookies"); v.Type() == lua.LTTable { + req.Cookies = e.tableToStringSliceMap(v.(*lua.LTable)) + } +} + +// tableToStringSliceMap converts a Lua table to a Go map[string][]string. +// Supports both single string values and array values. +func (e *LuaEngine) tableToStringSliceMap(t *lua.LTable) map[string][]string { + result := make(map[string][]string) + t.ForEach(func(k, v lua.LValue) { + if k.Type() != lua.LTString { + return + } + key := string(k.(lua.LString)) + + switch v.Type() { + case lua.LTString: + // Single string value + result[key] = []string{string(v.(lua.LString))} + case lua.LTTable: + // Array of strings + var values []string + v.(*lua.LTable).ForEach(func(_, item lua.LValue) { + if item.Type() == lua.LTString { + values = append(values, string(item.(lua.LString))) + } + }) + result[key] = values + } + }) + return result +} diff --git a/internal/script/script.go b/internal/script/script.go new file mode 100644 index 0000000..2c3dc7e --- /dev/null +++ b/internal/script/script.go @@ -0,0 +1,190 @@ +package script + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +// RequestData represents the request data passed to scripts for transformation. +// Scripts can modify any field and the changes will be applied to the actual request. +// Headers, Params, and Cookies use []string values to support multiple values per key. +type RequestData struct { + Method string `json:"method"` + URL string `json:"url"` + Path string `json:"path"` + Headers map[string][]string `json:"headers"` + Params map[string][]string `json:"params"` + Cookies map[string][]string `json:"cookies"` + Body string `json:"body"` +} + +// Engine defines the interface for script engines (Lua, JavaScript). +// Each engine must be able to transform request data using a user-provided script. +type Engine interface { + // Transform executes the script's transform function with the given request data. + // The script should modify the RequestData and return it. + Transform(req *RequestData) error + + // Close releases any resources held by the engine. + Close() +} + +// EngineType represents the type of script engine. +type EngineType string + +const ( + EngineTypeLua EngineType = "lua" + EngineTypeJavaScript EngineType = "js" +) + +// Source represents a loaded script source. +type Source struct { + Content string + EngineType EngineType +} + +// LoadSource loads a script from the given source string. +// The source can be: +// - Inline script: any string not starting with "@" +// - Escaped "@": strings starting with "@@" (literal "@" at start, returns string without first @) +// - File reference: "@/path/to/file" or "@./relative/path" +// - URL reference: "@http://..." or "@https://..." +func LoadSource(ctx context.Context, source string, engineType EngineType) (*Source, error) { + if source == "" { + return nil, errors.New("script source cannot be empty") + } + + var content string + var err error + + switch { + case strings.HasPrefix(source, "@@"): + // Escaped @ - it's an inline script starting with literal @ + content = source[1:] // Remove first @, keep the rest + case strings.HasPrefix(source, "@"): + // File or URL reference + ref := source[1:] + if strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://") { + content, err = fetchURL(ctx, ref) + } else { + content, err = readFile(ref) + } + if err != nil { + return nil, fmt.Errorf("failed to load script from %q: %w", ref, err) + } + default: + // Inline script + content = source + } + + return &Source{ + Content: content, + EngineType: engineType, + }, nil +} + +// LoadSources loads multiple script sources. +func LoadSources(ctx context.Context, sources []string, engineType EngineType) ([]*Source, error) { + loaded := make([]*Source, 0, len(sources)) + for i, src := range sources { + source, err := LoadSource(ctx, src, engineType) + if err != nil { + return nil, fmt.Errorf("script[%d]: %w", i, err) + } + loaded = append(loaded, source) + } + return loaded, nil +} + +// ValidateScript validates a script source by loading it and checking syntax. +// It loads the script (from file/URL/inline), parses it, and verifies +// that a 'transform' function is defined. +func ValidateScript(ctx context.Context, source string, engineType EngineType) error { + // Load the script source + src, err := LoadSource(ctx, source, engineType) + if err != nil { + return err + } + + // Try to create an engine - this validates syntax and transform function + var engine Engine + switch engineType { + case EngineTypeLua: + engine, err = NewLuaEngine(src.Content) + case EngineTypeJavaScript: + engine, err = NewJsEngine(src.Content) + default: + return fmt.Errorf("unknown engine type: %s", engineType) + } + + if err != nil { + return err + } + + // Clean up the engine - we only needed it for validation + engine.Close() + return nil +} + +// ValidateScripts validates multiple script sources. +func ValidateScripts(ctx context.Context, sources []string, engineType EngineType) error { + for i, src := range sources { + if err := ValidateScript(ctx, src, engineType); err != nil { + return fmt.Errorf("script[%d]: %w", i, err) + } + } + return nil +} + +// fetchURL downloads content from an HTTP/HTTPS URL. +func fetchURL(ctx context.Context, url string) (string, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to fetch: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP %d %s", resp.StatusCode, resp.Status) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + return string(data), nil +} + +// readFile reads content from a local file. +func readFile(path string) (string, error) { + if !filepath.IsAbs(path) { + pwd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("failed to get working directory: %w", err) + } + path = filepath.Join(pwd, path) + } + + data, err := os.ReadFile(path) //nolint:gosec + if err != nil { + return "", fmt.Errorf("failed to read file: %w", err) + } + + return string(data), nil +}