Merge pull request #169 from aykhans/feat/add-scripting

Add scripting
This commit is contained in:
2026-02-15 00:19:58 +04:00
committed by GitHub
27 changed files with 1656 additions and 176 deletions

View File

@@ -16,8 +16,12 @@ jobs:
- uses: actions/checkout@v5
- uses: actions/setup-go@v6
with:
go-version: 1.25.5
go-version: 1.26.0
- name: go fix
run: |
go fix ./...
git diff --exit-code
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.7.2
version: v2.9.0

View File

@@ -35,7 +35,7 @@ jobs:
run: |
echo "VERSION=$(git describe --tags --always)" >> $GITHUB_ENV
echo "GIT_COMMIT=$(git rev-parse HEAD)" >> $GITHUB_ENV
echo "GO_VERSION=1.25.5" >> $GITHUB_ENV
echo "GO_VERSION=1.26.0" >> $GITHUB_ENV
- name: Set up Go
if: github.event_name == 'release' || inputs.build_binaries
@@ -53,12 +53,12 @@ jobs:
-X 'go.aykhans.me/sarin/internal/version.GoVersion=$(go version)' \
-s -w"
CGO_ENABLED=0 GOEXPERIMENT=greenteagc GOOS=linux GOARCH=amd64 go build -ldflags "$LDFLAGS" -o ./sarin-linux-amd64 ./cmd/cli/main.go
CGO_ENABLED=0 GOEXPERIMENT=greenteagc GOOS=linux GOARCH=arm64 go build -ldflags "$LDFLAGS" -o ./sarin-linux-arm64 ./cmd/cli/main.go
CGO_ENABLED=0 GOEXPERIMENT=greenteagc GOOS=darwin GOARCH=amd64 go build -ldflags "$LDFLAGS" -o ./sarin-darwin-amd64 ./cmd/cli/main.go
CGO_ENABLED=0 GOEXPERIMENT=greenteagc GOOS=darwin GOARCH=arm64 go build -ldflags "$LDFLAGS" -o ./sarin-darwin-arm64 ./cmd/cli/main.go
CGO_ENABLED=0 GOEXPERIMENT=greenteagc GOOS=windows GOARCH=amd64 go build -ldflags "$LDFLAGS" -o ./sarin-windows-amd64.exe ./cmd/cli/main.go
CGO_ENABLED=0 GOEXPERIMENT=greenteagc GOOS=windows GOARCH=arm64 go build -ldflags "$LDFLAGS" -o ./sarin-windows-arm64.exe ./cmd/cli/main.go
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "$LDFLAGS" -o ./sarin-linux-amd64 ./cmd/cli/main.go
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags "$LDFLAGS" -o ./sarin-linux-arm64 ./cmd/cli/main.go
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags "$LDFLAGS" -o ./sarin-darwin-amd64 ./cmd/cli/main.go
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags "$LDFLAGS" -o ./sarin-darwin-arm64 ./cmd/cli/main.go
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -ldflags "$LDFLAGS" -o ./sarin-windows-amd64.exe ./cmd/cli/main.go
CGO_ENABLED=0 GOOS=windows GOARCH=arm64 go build -ldflags "$LDFLAGS" -o ./sarin-windows-arm64.exe ./cmd/cli/main.go
- name: Upload Release Assets
if: github.event_name == 'release' || inputs.build_binaries

View File

@@ -1,7 +1,7 @@
version: "2"
run:
go: "1.25"
go: "1.26"
concurrency: 12
linters:

View File

@@ -1,4 +1,4 @@
ARG GO_VERSION=1.25.5
ARG GO_VERSION=1.26.0
FROM docker.io/library/golang:${GO_VERSION}-alpine AS builder
@@ -12,7 +12,7 @@ RUN --mount=type=bind,source=./go.mod,target=./go.mod \
go mod download
RUN --mount=type=bind,source=./,target=./ \
CGO_ENABLED=0 GOEXPERIMENT=greenteagc go build \
CGO_ENABLED=0 go build \
-ldflags "-X 'go.aykhans.me/sarin/internal/version.Version=${VERSION}' \
-X 'go.aykhans.me/sarin/internal/version.GitCommit=${GIT_COMMIT}' \
-X 'go.aykhans.me/sarin/internal/version.BuildDate=$(date -u +%Y-%m-%dT%H:%M:%SZ)' \

View File

@@ -22,13 +22,14 @@
Sarin is designed for efficient HTTP load testing with minimal resource consumption. It prioritizes simplicity—features like templating add zero overhead when unused.
| ✅ Supported | ❌ Not Supported |
| ---------------------------------------------------------- | --------------------------------- |
| High-performance with low memory footprint | Detailed response body analysis |
| Long-running duration/count based tests | Extensive response statistics |
| Dynamic requests via 320+ template functions | Web UI or complex TUI |
| Multiple proxy protocols<br>(HTTP, HTTPS, SOCKS5, SOCKS5H) | Scripting or multi-step scenarios |
| Flexible config (CLI, ENV, YAML) | HTTP/2, HTTP/3, WebSocket, gRPC |
| ✅ Supported | ❌ Not Supported |
| ---------------------------------------------------------- | ------------------------------- |
| High-performance with low memory footprint | Detailed response body analysis |
| Long-running duration/count based tests | Extensive response statistics |
| Dynamic requests via 320+ template functions | Web UI or complex TUI |
| Request scripting with Lua and JavaScript | Distributed load testing |
| Multiple proxy protocols<br>(HTTP, HTTPS, SOCKS5, SOCKS5H) | HTTP/2, HTTP/3, WebSocket, gRPC |
| Flexible config (CLI, ENV, YAML) | Plugins / extensions ecosystem |
## Installation
@@ -56,12 +57,12 @@ Download the latest binaries from the [releases](https://github.com/aykhans/sari
### Building from Source
Requires [Go 1.25+](https://golang.org/dl/).
Requires [Go 1.26+](https://golang.org/dl/).
```sh
git clone https://github.com/aykhans/sarin.git && cd sarin
CGO_ENABLED=0 GOEXPERIMENT=greenteagc go build \
CGO_ENABLED=0 go build \
-ldflags "-X 'go.aykhans.me/sarin/internal/version.Version=dev' \
-X 'go.aykhans.me/sarin/internal/version.GitCommit=$(git rev-parse HEAD)' \
-X 'go.aykhans.me/sarin/internal/version.BuildDate=$(date -u +%Y-%m-%dT%H:%M:%SZ)' \

View File

@@ -3,7 +3,7 @@ version: "3"
vars:
BIN_DIR: ./bin
GOLANGCI_LINT_VERSION: v2.7.2
GOLANGCI_LINT_VERSION: v2.9.0
GOLANGCI: "{{.BIN_DIR}}/golangci-lint-{{.GOLANGCI_LINT_VERSION}}"
tasks:
@@ -11,16 +11,22 @@ tasks:
desc: Run fmt, tidy, and lint.
cmds:
- task: fmt
- task: fix
- task: tidy
- task: lint
fmt:
desc: Run linters
desc: Run format
deps:
- install-golangci-lint
cmds:
- "{{.GOLANGCI}} fmt"
fix:
desc: Run go fix
cmds:
- go fix ./...
tidy:
desc: Run go mod tidy.
cmds:
@@ -52,7 +58,7 @@ tasks:
cmds:
- rm -f {{.OUTPUT}}
- >-
CGO_ENABLED=0 GOEXPERIMENT=greenteagc go build
CGO_ENABLED=0 go build
-ldflags "-X 'go.aykhans.me/sarin/internal/version.Version=$(git describe --tags --always)'
-X 'go.aykhans.me/sarin/internal/version.GitCommit=$(git rev-parse HEAD)'
-X 'go.aykhans.me/sarin/internal/version.BuildDate=$(date -u +%Y-%m-%dT%H:%M:%SZ)'

View File

@@ -53,6 +53,7 @@ func main() {
combinedConfig.Cookies, combinedConfig.Bodies, combinedConfig.Proxies, combinedConfig.Values,
*combinedConfig.Output != config.ConfigOutputTypeNone,
*combinedConfig.DryRun,
combinedConfig.Lua, combinedConfig.Js,
)
_ = utilsErr.MustHandle(err,
utilsErr.OnType(func(err types.ProxyDialError) error {
@@ -60,6 +61,16 @@ func main() {
os.Exit(1)
return nil
}),
utilsErr.OnSentinel(types.ErrScriptEmpty, func(err error) error {
fmt.Fprintln(os.Stderr, config.StyleRed.Render("[SCRIPT] ")+err.Error())
os.Exit(1)
return nil
}),
utilsErr.OnType(func(err types.ScriptLoadError) error {
fmt.Fprintln(os.Stderr, config.StyleRed.Render("[SCRIPT] ")+err.Error())
os.Exit(1)
return nil
}),
)
srn.Start(ctx)

View File

@@ -36,6 +36,8 @@ Use `-s` or `--show-config` to see the final merged configuration before sending
| [Cookies](#cookies) | `cookies`<br>(object) | `-cookie` / `-C`<br>(string / []string) | `SARIN_COOKIE`<br>(string) | - | HTTP cookies |
| [Proxy](#proxy) | `proxy`<br>(string / []string) | `-proxy` / `-X`<br>(string / []string) | `SARIN_PROXY`<br>(string) | - | Proxy URL(s) |
| [Values](#values) | `values`<br>(string / []string) | `-values` / `-V`<br>(string / []string) | `SARIN_VALUES`<br>(string) | - | Template values (key=value) |
| [Lua](#lua) | `lua`<br>(string / []string) | `-lua`<br>(string / []string) | `SARIN_LUA`<br>(string) | - | Lua script(s) |
| [Js](#js) | `js`<br>(string / []string) | `-js`<br>(string / []string) | `SARIN_JS`<br>(string) | - | JavaScript script(s) |
---
@@ -374,3 +376,133 @@ values: |
```sh
SARIN_VALUES="key1=value1"
```
## Lua
Lua script(s) for request transformation. Each script must define a global `transform` function that receives a request object and returns the modified request object. Scripts run after template rendering, before the request is sent.
If multiple Lua scripts are provided, they are chained in order—the output of one becomes the input to the next. When both Lua and JavaScript scripts are specified, all Lua scripts run first, then all JavaScript scripts.
**Script sources:**
Scripts can be provided as:
- **Inline script:** Direct script code
- **File reference:** `@/path/to/script.lua` or `@./relative/path.lua`
- **URL reference:** `@http://...` or `@https://...`
- **Escaped `@`:** `@@...` for inline scripts that start with a literal `@`
**The `transform` function:**
```lua
function transform(req)
-- req.method (string) - HTTP method (e.g. "GET", "POST")
-- req.path (string) - URL path (e.g. "/api/users")
-- req.body (string) - Request body
-- req.headers (table of string/arrays) - HTTP headers (e.g. {["X-Key"] = "value"})
-- req.params (table of string/arrays) - Query parameters (e.g. {["id"] = "123"})
-- req.cookies (table of string/arrays) - Cookies (e.g. {["session"] = "abc"})
req.headers["X-Custom"] = "my-value"
return req
end
```
> **Note:** Header, parameter, and cookie values can be a single string or a table (array) for multiple values per key (e.g. `{"val1", "val2"}`).
**YAML example:**
```yaml
lua: |
function transform(req)
req.headers["X-Custom"] = "my-value"
return req
end
# OR
lua:
- "@/path/to/script1.lua"
- "@/path/to/script2.lua"
```
**CLI example:**
```sh
-lua 'function transform(req) req.headers["X-Custom"] = "my-value" return req end'
# OR
-lua @/path/to/script1.lua -lua @/path/to/script2.lua
```
**ENV example:**
```sh
SARIN_LUA='function transform(req) req.headers["X-Custom"] = "my-value" return req end'
```
## Js
JavaScript script(s) for request transformation. Each script must define a global `transform` function that receives a request object and returns the modified request object. Scripts run after template rendering, before the request is sent.
If multiple JavaScript scripts are provided, they are chained in order—the output of one becomes the input to the next. When both Lua and JavaScript scripts are specified, all Lua scripts run first, then all JavaScript scripts.
**Script sources:**
Scripts can be provided as:
- **Inline script:** Direct script code
- **File reference:** `@/path/to/script.js` or `@./relative/path.js`
- **URL reference:** `@http://...` or `@https://...`
- **Escaped `@`:** `@@...` for inline scripts that start with a literal `@`
**The `transform` function:**
```javascript
function transform(req) {
// req.method (string) - HTTP method (e.g. "GET", "POST")
// req.path (string) - URL path (e.g. "/api/users")
// req.body (string) - Request body
// req.headers (object of string/arrays) - HTTP headers (e.g. {"X-Key": "value"})
// req.params (object of string/arrays) - Query parameters (e.g. {"id": "123"})
// req.cookies (object of string/arrays) - Cookies (e.g. {"session": "abc"})
req.headers["X-Custom"] = "my-value";
return req;
}
```
> **Note:** Header, parameter, and cookie values can be a single string or an array for multiple values per key (e.g. `["val1", "val2"]`).
**YAML example:**
```yaml
js: |
function transform(req) {
req.headers["X-Custom"] = "my-value";
return req;
}
# OR
js:
- "@/path/to/script1.js"
- "@/path/to/script2.js"
```
**CLI example:**
```sh
-js 'function transform(req) { req.headers["X-Custom"] = "my-value"; return req; }'
# OR
-js @/path/to/script1.js -js @/path/to/script2.js
```
**ENV example:**
```sh
SARIN_JS='function transform(req) { req.headers["X-Custom"] = "my-value"; return req; }'
```

View File

@@ -15,6 +15,7 @@ This guide provides practical examples for common Sarin use cases.
- [Docker Usage](#docker-usage)
- [Dry Run Mode](#dry-run-mode)
- [Show Configuration](#show-configuration)
- [Scripting](#scripting)
---
@@ -894,3 +895,124 @@ headers:
```
</details>
## Scripting
Transform requests using Lua or JavaScript scripts. Scripts run after template rendering, before the request is sent.
**Add a custom header with Lua:**
```sh
sarin -U http://example.com/api -r 1000 -c 10 \
-lua 'function transform(req) req.headers["X-Custom"] = "my-value" return req end'
```
<details>
<summary>YAML equivalent</summary>
```yaml
url: http://example.com/api
requests: 1000
concurrency: 10
lua: |
function transform(req)
req.headers["X-Custom"] = "my-value"
return req
end
```
</details>
**Modify request body with JavaScript:**
```sh
sarin -U http://example.com/api/data -r 1000 -c 10 \
-M POST \
-H "Content-Type: application/json" \
-B '{"name": "test"}' \
-js 'function transform(req) { var body = JSON.parse(req.body); body.timestamp = Date.now(); req.body = JSON.stringify(body); return req; }'
```
<details>
<summary>YAML equivalent</summary>
```yaml
url: http://example.com/api/data
requests: 1000
concurrency: 10
method: POST
headers:
Content-Type: application/json
body: '{"name": "test"}'
js: |
function transform(req) {
var body = JSON.parse(req.body);
body.timestamp = Date.now();
req.body = JSON.stringify(body);
return req;
}
```
</details>
**Load script from a file:**
```sh
sarin -U http://example.com/api -r 1000 -c 10 \
-lua @./scripts/transform.lua
```
<details>
<summary>YAML equivalent</summary>
```yaml
url: http://example.com/api
requests: 1000
concurrency: 10
lua: "@./scripts/transform.lua"
```
</details>
**Load script from a URL:**
```sh
sarin -U http://example.com/api -r 1000 -c 10 \
-js @https://example.com/scripts/transform.js
```
<details>
<summary>YAML equivalent</summary>
```yaml
url: http://example.com/api
requests: 1000
concurrency: 10
js: "@https://example.com/scripts/transform.js"
```
</details>
**Chain multiple scripts (Lua runs first, then JavaScript):**
```sh
sarin -U http://example.com/api -r 1000 -c 10 \
-lua @./scripts/auth.lua \
-lua @./scripts/headers.lua \
-js @./scripts/body.js
```
<details>
<summary>YAML equivalent</summary>
```yaml
url: http://example.com/api
requests: 1000
concurrency: 10
lua:
- "@./scripts/auth.lua"
- "@./scripts/headers.lua"
js: "@./scripts/body.js"
```
</details>

6
go.mod
View File

@@ -1,6 +1,6 @@
module go.aykhans.me/sarin
go 1.25.5
go 1.26.0
require (
github.com/brianvoe/gofakeit/v7 v7.14.0
@@ -9,8 +9,10 @@ require (
github.com/charmbracelet/glamour v0.10.0
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834
github.com/charmbracelet/x/term v0.2.2
github.com/dop251/goja v0.0.0-20260106131823-651366fbe6e3
github.com/joho/godotenv v1.5.1
github.com/valyala/fasthttp v1.69.0
github.com/yuin/gopher-lua v1.1.1
go.aykhans.me/utils v1.0.7
go.yaml.in/yaml/v4 v4.0.0-rc.3
golang.org/x/net v0.49.0
@@ -32,6 +34,8 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/klauspost/compress v1.18.2 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect

12
go.sum
View File

@@ -1,3 +1,5 @@
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.21.1 h1:FaSDrp6N+3pphkNKU6HPCiYLgm8dbe5UXIXcoBhZSWA=
@@ -46,8 +48,14 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dop251/goja v0.0.0-20260106131823-651366fbe6e3 h1:bVp3yUzvSAJzu9GqID+Z96P+eu5TKnIMJSV4QaZMauM=
github.com/dop251/goja v0.0.0-20260106131823-651366fbe6e3/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU=
github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
@@ -95,6 +103,8 @@ github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE=
github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
go.aykhans.me/utils v1.0.7 h1:ClHXHlWmkjfFlD7+w5BQY29lKCEztxY/yCf543x4hZw=
go.aykhans.me/utils v1.0.7/go.mod h1:0Jz8GlZLN35cCHLOLx39sazWwEe33bF6SYlSeqzEXoI=
go.yaml.in/yaml/v4 v4.0.0-rc.3 h1:3h1fjsh1CTAPjW7q/EMe+C8shx5d8ctzZTrLcs/j8Go=
@@ -111,5 +121,7 @@ golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -10,27 +10,13 @@ import (
"go.aykhans.me/sarin/internal/types"
versionpkg "go.aykhans.me/sarin/internal/version"
"go.aykhans.me/utils/common"
)
const cliUsageText = `Usage:
sarin [flags]
Simple usage:
sarin -U https://example.com -d 1m
Usage with all flags:
sarin -s -q -z -o json -f ./config.yaml -c 50 -r 100_000 -d 2m30s \
-U https://example.com \
-M POST \
-V "sharedUUID={{ fakeit_UUID }}" \
-B '{"product": "car"}' \
-P "id={{ .Values.sharedUUID }}" \
-H "User-Agent: {{ fakeit_UserAgent }}" -H "Accept: */*" \
-C "token={{ .Values.sharedUUID }}" \
-X "http://proxy.example.com" \
-T 3s \
-I
sarin -U https://example.com -r 1
Flags:
General Config:
@@ -55,7 +41,9 @@ Flags:
-X, -proxy []string Proxy for the request (e.g. "http://proxy.example.com:8080")
-V, -values []string List of values for templating (e.g. "key1=value1")
-T, -timeout time Timeout for the request (e.g. 400ms, 3s, 1m10s) (default %v)
-I, -insecure bool Skip SSL/TLS certificate verification (default %v)`
-I, -insecure bool Skip SSL/TLS certificate verification (default %v)
-lua []string Lua script for request transformation (inline or @file/@url)
-js []string JavaScript script for request transformation (inline or @file/@url)`
var _ IParser = ConfigCLIParser{}
@@ -106,16 +94,18 @@ func (parser ConfigCLIParser) Parse() (*Config, error) {
dryRun bool
// Request config
urlInput string
methods = stringSliceArg{}
bodies = stringSliceArg{}
params = stringSliceArg{}
headers = stringSliceArg{}
cookies = stringSliceArg{}
proxies = stringSliceArg{}
values = stringSliceArg{}
timeout time.Duration
insecure bool
urlInput string
methods = stringSliceArg{}
bodies = stringSliceArg{}
params = stringSliceArg{}
headers = stringSliceArg{}
cookies = stringSliceArg{}
proxies = stringSliceArg{}
values = stringSliceArg{}
timeout time.Duration
insecure bool
luaScripts = stringSliceArg{}
jsScripts = stringSliceArg{}
)
{
@@ -177,6 +167,10 @@ func (parser ConfigCLIParser) Parse() (*Config, error) {
flagSet.BoolVar(&insecure, "insecure", false, "Skip SSL/TLS certificate verification")
flagSet.BoolVar(&insecure, "I", false, "Skip SSL/TLS certificate verification")
flagSet.Var(&luaScripts, "lua", "Lua script for request transformation (inline or @file/@url)")
flagSet.Var(&jsScripts, "js", "JavaScript script for request transformation (inline or @file/@url)")
}
// Parse the specific arguments provided to the parser, skipping the program name.
@@ -207,23 +201,23 @@ func (parser ConfigCLIParser) Parse() (*Config, error) {
switch flagVar.Name {
// General config
case "show-config", "s":
config.ShowConfig = common.ToPtr(showConfig)
config.ShowConfig = new(showConfig)
case "config-file", "f":
for _, configFile := range configFiles {
config.Files = append(config.Files, *types.ParseConfigFile(configFile))
}
case "concurrency", "c":
config.Concurrency = common.ToPtr(concurrency)
config.Concurrency = new(concurrency)
case "requests", "r":
config.Requests = common.ToPtr(requestCount)
config.Requests = new(requestCount)
case "duration", "d":
config.Duration = common.ToPtr(duration)
config.Duration = new(duration)
case "quiet", "q":
config.Quiet = common.ToPtr(quiet)
config.Quiet = new(quiet)
case "output", "o":
config.Output = common.ToPtr(ConfigOutputType(output))
config.Output = new(ConfigOutputType(output))
case "dry-run", "z":
config.DryRun = common.ToPtr(dryRun)
config.DryRun = new(dryRun)
// Request config
case "url", "U":
@@ -256,9 +250,13 @@ func (parser ConfigCLIParser) Parse() (*Config, error) {
case "values", "V":
config.Values = append(config.Values, values...)
case "timeout", "T":
config.Timeout = common.ToPtr(timeout)
config.Timeout = new(timeout)
case "insecure", "I":
config.Insecure = common.ToPtr(insecure)
config.Insecure = new(insecure)
case "lua":
config.Lua = append(config.Lua, luaScripts...)
case "js":
config.Js = append(config.Js, jsScripts...)
}
})

View File

@@ -1,6 +1,7 @@
package config
import (
"context"
"errors"
"fmt"
"net/url"
@@ -16,6 +17,7 @@ import (
"github.com/charmbracelet/glamour/styles"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/x/term"
"go.aykhans.me/sarin/internal/script"
"go.aykhans.me/sarin/internal/types"
"go.aykhans.me/sarin/internal/version"
"go.aykhans.me/utils/common"
@@ -87,10 +89,8 @@ type Config struct {
Bodies []string `yaml:"bodies,omitempty"`
Proxies types.Proxies `yaml:"proxies,omitempty"`
Values []string `yaml:"values,omitempty"`
}
func NewConfig() *Config {
return &Config{}
Lua []string `yaml:"lua,omitempty"`
Js []string `yaml:"js,omitempty"`
}
func (config Config) MarshalYAML() (any, error) {
@@ -219,6 +219,8 @@ func (config Config) MarshalYAML() (any, error) {
}
addStringSlice(content, "values", config.Values, false)
addStringSlice(content, "lua", config.Lua, false)
addStringSlice(content, "js", config.Js, false)
return root, nil
}
@@ -323,6 +325,12 @@ func (config *Config) Merge(newConfig *Config) {
if len(newConfig.Values) != 0 {
config.Values = append(config.Values, newConfig.Values...)
}
if len(newConfig.Lua) != 0 {
config.Lua = append(config.Lua, newConfig.Lua...)
}
if len(newConfig.Js) != 0 {
config.Js = append(config.Js, newConfig.Js...)
}
}
func (config *Config) SetDefaults() {
@@ -348,26 +356,26 @@ func (config *Config) SetDefaults() {
config.Timeout = &Defaults.RequestTimeout
}
if config.Concurrency == nil {
config.Concurrency = common.ToPtr(Defaults.Concurrency)
config.Concurrency = new(Defaults.Concurrency)
}
if config.ShowConfig == nil {
config.ShowConfig = common.ToPtr(Defaults.ShowConfig)
config.ShowConfig = new(Defaults.ShowConfig)
}
if config.Quiet == nil {
config.Quiet = common.ToPtr(Defaults.Quiet)
config.Quiet = new(Defaults.Quiet)
}
if config.Insecure == nil {
config.Insecure = common.ToPtr(Defaults.Insecure)
config.Insecure = new(Defaults.Insecure)
}
if config.DryRun == nil {
config.DryRun = common.ToPtr(Defaults.DryRun)
config.DryRun = new(Defaults.DryRun)
}
if !config.Headers.Has("User-Agent") {
config.Headers = append(config.Headers, types.Header{Key: "User-Agent", Value: []string{Defaults.UserAgent}})
}
if config.Output == nil {
config.Output = common.ToPtr(Defaults.Output)
config.Output = new(Defaults.Output)
}
}
@@ -465,6 +473,44 @@ func (config Config) Validate() error {
}
}
// Create a context with timeout for script validation (loading from URLs)
scriptCtx, scriptCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer scriptCancel()
for i, scriptSrc := range config.Lua {
if err := validateScriptSource(scriptSrc); err != nil {
validationErrors = append(
validationErrors,
types.NewFieldValidationError(fmt.Sprintf("Lua[%d]", i), scriptSrc, err),
)
continue
}
// Validate script syntax
if err := script.ValidateScript(scriptCtx, scriptSrc, script.EngineTypeLua); err != nil {
validationErrors = append(
validationErrors,
types.NewFieldValidationError(fmt.Sprintf("Lua[%d]", i), scriptSrc, err),
)
}
}
for i, scriptSrc := range config.Js {
if err := validateScriptSource(scriptSrc); err != nil {
validationErrors = append(
validationErrors,
types.NewFieldValidationError(fmt.Sprintf("Js[%d]", i), scriptSrc, err),
)
continue
}
// Validate script syntax
if err := script.ValidateScript(scriptCtx, scriptSrc, script.EngineTypeJavaScript); err != nil {
validationErrors = append(
validationErrors,
types.NewFieldValidationError(fmt.Sprintf("Js[%d]", i), scriptSrc, err),
)
}
}
templateErrors := ValidateTemplates(&config)
validationErrors = append(validationErrors, templateErrors...)
@@ -582,6 +628,57 @@ func parseConfigFile(configFile types.ConfigFile, maxDepth int) (*Config, error)
return fileConfig, nil
}
// validateScriptSource validates a script source string.
// Scripts can be:
// - Inline script: any string not starting with "@"
// - Escaped "@": strings starting with "@@" (literal "@" at start)
// - File reference: "@/path/to/file" or "@./relative/path"
// - URL reference: "@http://..." or "@https://..."
//
// It can return the following errors:
// - types.ErrScriptEmpty
// - types.ErrScriptSourceEmpty
// - types.ErrScriptURLNoHost
// - types.URLParseError
func validateScriptSource(script string) error {
// Empty script is invalid
if script == "" {
return types.ErrScriptEmpty
}
// Not a file/URL reference - it's an inline script
if !strings.HasPrefix(script, "@") {
return nil
}
// Escaped @ - it's an inline script starting with literal @
if strings.HasPrefix(script, "@@") {
return nil
}
// It's a file or URL reference - validate the source
source := script[1:] // Remove the @ prefix
if source == "" {
return types.ErrScriptSourceEmpty
}
// Check if it's a http(s) URL
if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") {
parsedURL, err := url.Parse(source)
if err != nil {
return types.NewURLParseError(source, err)
}
if parsedURL.Host == "" {
return types.ErrScriptURLNoHost
}
return nil
}
// It's a file path - basic validation (not empty, checked above)
return nil
}
func printParseErrors(parserName string, errors ...types.FieldParseError) {
for _, fieldErr := range errors {
if fieldErr.Value == "" {

View File

@@ -7,7 +7,6 @@ import (
"time"
"go.aykhans.me/sarin/internal/types"
"go.aykhans.me/utils/common"
utilsParse "go.aykhans.me/utils/parser"
)
@@ -67,7 +66,7 @@ func (parser ConfigENVParser) Parse() (*Config, error) {
}
if output := parser.getEnv("OUTPUT"); output != "" {
config.Output = common.ToPtr(ConfigOutputType(output))
config.Output = new(ConfigOutputType(output))
}
if insecure := parser.getEnv("INSECURE"); insecure != "" {
@@ -216,6 +215,14 @@ func (parser ConfigENVParser) Parse() (*Config, error) {
config.Values = []string{values}
}
if lua := parser.getEnv("LUA"); lua != "" {
config.Lua = []string{lua}
}
if js := parser.getEnv("JS"); js != "" {
config.Js = []string{js}
}
if len(fieldParseErrors) > 0 {
return nil, types.NewFieldParseErrors(fieldParseErrors)
}

View File

@@ -12,7 +12,6 @@ import (
"time"
"go.aykhans.me/sarin/internal/types"
"go.aykhans.me/utils/common"
"go.yaml.in/yaml/v4"
)
@@ -49,6 +48,10 @@ func (parser ConfigFileParser) Parse() (*Config, error) {
}
// fetchFile retrieves file contents from a local path or HTTP/HTTPS URL.
// It can return the following errors:
// - types.FileReadError
// - types.HTTPFetchError
// - types.HTTPStatusError
func fetchFile(ctx context.Context, src string) ([]byte, error) {
if strings.HasPrefix(src, "http://") || strings.HasPrefix(src, "https://") {
return fetchHTTP(ctx, src)
@@ -57,25 +60,28 @@ func fetchFile(ctx context.Context, src string) ([]byte, error) {
}
// fetchHTTP downloads file contents from an HTTP/HTTPS URL.
// It can return the following errors:
// - types.HTTPFetchError
// - types.HTTPStatusError
func fetchHTTP(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, types.NewHTTPFetchError(url, err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch file: %w", err)
return nil, types.NewHTTPFetchError(url, err)
}
defer resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch file: HTTP %d %s", resp.StatusCode, resp.Status)
return nil, types.NewHTTPStatusError(url, resp.StatusCode, resp.Status)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
return nil, types.NewHTTPFetchError(url, err)
}
return data, nil
@@ -83,19 +89,21 @@ func fetchHTTP(ctx context.Context, url string) ([]byte, error) {
// fetchLocal reads file contents from the local filesystem.
// It resolves relative paths from the current working directory.
// It can return the following errors:
// - types.FileReadError
func fetchLocal(src string) ([]byte, error) {
path := src
if !filepath.IsAbs(src) {
pwd, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("failed to get working directory: %w", err)
return nil, types.NewFileReadError(src, err)
}
path = filepath.Join(pwd, src)
}
data, err := os.ReadFile(path) //nolint:gosec
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
return nil, types.NewFileReadError(path, err)
}
return data, nil
@@ -202,6 +210,8 @@ type configYAML struct {
Bodies stringOrSliceField `yaml:"body"`
Proxies stringOrSliceField `yaml:"proxy"`
Values stringOrSliceField `yaml:"values"`
Lua stringOrSliceField `yaml:"lua"`
Js stringOrSliceField `yaml:"js"`
}
// ParseYAML parses YAML config file arguments into a Config object.
@@ -230,7 +240,7 @@ func (parser ConfigFileParser) ParseYAML(data []byte) (*Config, error) {
config.Quiet = parsedData.Quiet
if parsedData.Output != nil {
config.Output = common.ToPtr(ConfigOutputType(*parsedData.Output))
config.Output = new(ConfigOutputType(*parsedData.Output))
}
config.Insecure = parsedData.Insecure
@@ -246,6 +256,8 @@ func (parser ConfigFileParser) ParseYAML(data []byte) (*Config, error) {
}
config.Bodies = append(config.Bodies, parsedData.Bodies...)
config.Values = append(config.Values, parsedData.Values...)
config.Lua = append(config.Lua, parsedData.Lua...)
config.Js = append(config.Js, parsedData.Js...)
if len(parsedData.ConfigFiles) > 0 {
for _, configFile := range parsedData.ConfigFiles {

View File

@@ -8,6 +8,8 @@ import (
"go.aykhans.me/sarin/internal/types"
)
// It can return the following errors:
// - types.TemplateParseError
func validateTemplateString(value string, funcMap template.FuncMap) error {
if value == "" {
return nil
@@ -15,7 +17,7 @@ func validateTemplateString(value string, funcMap template.FuncMap) error {
_, err := template.New("").Funcs(funcMap).Parse(value)
if err != nil {
return fmt.Errorf("template parse error: %w", err)
return types.NewTemplateParseError(err)
}
return nil

View File

@@ -5,7 +5,6 @@ import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"math"
"net"
"net/http"
@@ -95,6 +94,9 @@ func NewHostClients(
return []*fasthttp.HostClient{client}, nil
}
// NewProxyDialFunc creates a dial function for the given proxy URL.
// It can return the following errors:
// - types.ProxyUnsupportedSchemeError
func NewProxyDialFunc(ctx context.Context, proxyURL *url.URL, timeout time.Duration) (fasthttp.DialFunc, error) {
var (
dialer fasthttp.DialFunc
@@ -117,16 +119,14 @@ func NewProxyDialFunc(ctx context.Context, proxyURL *url.URL, timeout time.Durat
case "https":
dialer = fasthttpHTTPSDialerDualStackTimeout(proxyURL, timeout)
default:
return nil, errors.New("unsupported proxy scheme")
}
if dialer == nil {
return nil, errors.New("internal error: proxy dialer is nil")
return nil, types.NewProxyUnsupportedSchemeError(proxyURL.Scheme)
}
return dialer, nil
}
// The returned dial function can return the following errors:
// - types.ProxyDialError
func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL, timeout time.Duration, resolveLocally bool) (fasthttp.DialFunc, error) {
netDialer := &net.Dialer{}
@@ -147,12 +147,18 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
return nil, err
}
proxyStr := proxyURL.String()
// Assert to ContextDialer for timeout support
contextDialer, ok := socksDialer.(proxy.ContextDialer)
if !ok {
// Fallback without timeout (should not happen with net.Dialer)
return func(addr string) (net.Conn, error) {
return socksDialer.Dial("tcp", addr)
conn, err := socksDialer.Dial("tcp", addr)
if err != nil {
return nil, types.NewProxyDialError(proxyStr, err)
}
return conn, nil
}, nil
}
@@ -163,7 +169,7 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
if resolveLocally {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
return nil, types.NewProxyDialError(proxyStr, err)
}
// Cap DNS resolution to half the timeout to reserve time for dial
@@ -171,10 +177,10 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
ips, err := net.DefaultResolver.LookupIP(dnsCtx, "ip", host)
dnsCancel()
if err != nil {
return nil, err
return nil, types.NewProxyDialError(proxyStr, err)
}
if len(ips) == 0 {
return nil, errors.New("no IP addresses found for host: " + host)
return nil, types.NewProxyDialError(proxyStr, types.NewProxyResolveError(host))
}
// Use the first resolved IP
@@ -184,16 +190,22 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
// Use remaining time for dial
remaining := time.Until(deadline)
if remaining <= 0 {
return nil, context.DeadlineExceeded
return nil, types.NewProxyDialError(proxyStr, context.DeadlineExceeded)
}
dialCtx, dialCancel := context.WithTimeout(ctx, remaining)
defer dialCancel()
return contextDialer.DialContext(dialCtx, "tcp", addr)
conn, err := contextDialer.DialContext(dialCtx, "tcp", addr)
if err != nil {
return nil, types.NewProxyDialError(proxyStr, err)
}
return conn, nil
}, nil
}
// The returned dial function can return the following errors:
// - types.ProxyDialError
func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duration) fasthttp.DialFunc {
proxyAddr := proxyURL.Host
if proxyURL.Port() == "" {
@@ -209,24 +221,26 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
}
proxyStr := proxyURL.String()
return func(addr string) (net.Conn, error) {
// Establish TCP connection to proxy with timeout
start := time.Now()
conn, err := fasthttp.DialDualStackTimeout(proxyAddr, timeout)
if err != nil {
return nil, err
return nil, types.NewProxyDialError(proxyStr, err)
}
remaining := timeout - time.Since(start)
if remaining <= 0 {
conn.Close() //nolint:errcheck,gosec
return nil, context.DeadlineExceeded
return nil, types.NewProxyDialError(proxyStr, context.DeadlineExceeded)
}
// Set deadline for the TLS handshake and CONNECT request
if err := conn.SetDeadline(time.Now().Add(remaining)); err != nil {
conn.Close() //nolint:errcheck,gosec
return nil, err
return nil, types.NewProxyDialError(proxyStr, err)
}
// Upgrade to TLS
@@ -235,7 +249,7 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
})
if err := tlsConn.Handshake(); err != nil {
tlsConn.Close() //nolint:errcheck,gosec
return nil, err
return nil, types.NewProxyDialError(proxyStr, err)
}
// Build and send CONNECT request
@@ -251,7 +265,7 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
if err := connectReq.Write(tlsConn); err != nil {
tlsConn.Close() //nolint:errcheck,gosec
return nil, err
return nil, types.NewProxyDialError(proxyStr, err)
}
// Read response using buffered reader, but return wrapped connection
@@ -260,19 +274,19 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
resp, err := http.ReadResponse(bufReader, connectReq)
if err != nil {
tlsConn.Close() //nolint:errcheck,gosec
return nil, err
return nil, types.NewProxyDialError(proxyStr, err)
}
resp.Body.Close() //nolint:errcheck,gosec
if resp.StatusCode != http.StatusOK {
tlsConn.Close() //nolint:errcheck,gosec
return nil, errors.New("proxy CONNECT failed: " + resp.Status)
return nil, types.NewProxyDialError(proxyStr, types.NewProxyConnectError(resp.Status))
}
// Clear deadline for the tunneled connection
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
tlsConn.Close() //nolint:errcheck,gosec
return nil, err
return nil, types.NewProxyDialError(proxyStr, err)
}
// Return wrapped connection that uses the buffered reader

View File

@@ -1,7 +1,6 @@
package sarin
import (
"fmt"
"io"
"net/http"
"os"
@@ -10,6 +9,8 @@ import (
"strings"
"sync"
"time"
"go.aykhans.me/sarin/internal/types"
)
// CachedFile holds the cached content and metadata of a file.
@@ -31,6 +32,10 @@ func NewFileCache(requestTimeout time.Duration) *FileCache {
// GetOrLoad retrieves a file from cache or loads it using the provided source.
// The source can be a local file path or an HTTP/HTTPS URL.
// It can return the following errors:
// - types.FileReadError
// - types.HTTPFetchError
// - types.HTTPStatusError
func (fc *FileCache) GetOrLoad(source string) (*CachedFile, error) {
if val, ok := fc.cache.Load(source); ok {
return val.(*CachedFile), nil
@@ -59,14 +64,21 @@ func (fc *FileCache) GetOrLoad(source string) (*CachedFile, error) {
return actual.(*CachedFile), nil
}
// readLocalFile reads a file from the local filesystem and returns its content and filename.
// It can return the following errors:
// - types.FileReadError
func (fc *FileCache) readLocalFile(filePath string) ([]byte, string, error) {
content, err := os.ReadFile(filePath) //nolint:gosec
if err != nil {
return nil, "", fmt.Errorf("failed to read file %s: %w", filePath, err)
return nil, "", types.NewFileReadError(filePath, err)
}
return content, filepath.Base(filePath), nil
}
// fetchURL downloads file contents from an HTTP/HTTPS URL.
// It can return the following errors:
// - types.HTTPFetchError
// - types.HTTPStatusError
func (fc *FileCache) fetchURL(url string) ([]byte, string, error) {
client := &http.Client{
Timeout: fc.requestTimeout,
@@ -74,17 +86,17 @@ func (fc *FileCache) fetchURL(url string) ([]byte, string, error) {
resp, err := client.Get(url)
if err != nil {
return nil, "", fmt.Errorf("failed to fetch URL %s: %w", url, err)
return nil, "", types.NewHTTPFetchError(url, err)
}
defer resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("failed to fetch URL %s: HTTP %d", url, resp.StatusCode)
return nil, "", types.NewHTTPStatusError(url, resp.StatusCode, resp.Status)
}
content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", fmt.Errorf("failed to read response body from %s: %w", url, err)
return nil, "", types.NewHTTPFetchError(url, err)
}
// Extract filename from URL path

View File

@@ -2,7 +2,6 @@ package sarin
import (
"bytes"
"fmt"
"maps"
"math/rand/v2"
"net/url"
@@ -11,13 +10,14 @@ import (
"github.com/joho/godotenv"
"github.com/valyala/fasthttp"
"go.aykhans.me/sarin/internal/script"
"go.aykhans.me/sarin/internal/types"
utilsSlice "go.aykhans.me/utils/slice"
)
type RequestGenerator func(*fasthttp.Request) error
type RequestGeneratorWithData func(*fasthttp.Request, any) error
type requestDataGenerator func(*script.RequestData, any) error
type valuesData struct {
Values map[string]string
@@ -26,6 +26,9 @@ type valuesData struct {
// NewRequestGenerator creates a new RequestGenerator function that generates HTTP requests
// with the specified configuration. The returned RequestGenerator is NOT safe for concurrent
// use by multiple goroutines.
//
// Note: Scripts must be validated before calling this function (e.g., in NewSarin).
// The caller is responsible for managing the scriptTransformer lifecycle.
func NewRequestGenerator(
methods []string,
requestURL *url.URL,
@@ -35,6 +38,7 @@ func NewRequestGenerator(
bodies []string,
values []string,
fileCache *FileCache,
scriptTransformer *script.Transformer,
) (RequestGenerator, bool) {
randSource := NewDefaultRandSource()
//nolint:gosec // G404: Using non-cryptographic rand for load testing, not security
@@ -53,13 +57,24 @@ func NewRequestGenerator(
valuesGenerator := NewValuesGeneratorFunc(values, templateFuncMap)
hasScripts := scriptTransformer != nil && !scriptTransformer.IsEmpty()
host := requestURL.Host
scheme := requestURL.Scheme
reqData := &script.RequestData{
Headers: make(map[string][]string),
Params: make(map[string][]string),
Cookies: make(map[string][]string),
}
var (
data valuesData
path string
err error
)
return func(req *fasthttp.Request) error {
req.Header.SetHost(requestURL.Host)
resetRequestData(reqData)
data, err = valuesGenerator()
if err != nil {
@@ -70,87 +85,135 @@ func NewRequestGenerator(
if err != nil {
return err
}
req.SetRequestURI(path)
reqData.Path = path
if err = methodGenerator(req, data); err != nil {
if err = methodGenerator(reqData, data); err != nil {
return err
}
bodyTemplateFuncMapData.ClearFormDataContenType()
if err = bodyGenerator(req, data); err != nil {
if err = bodyGenerator(reqData, data); err != nil {
return err
}
if err = headersGenerator(req, data); err != nil {
if err = headersGenerator(reqData, data); err != nil {
return err
}
if bodyTemplateFuncMapData.GetFormDataContenType() != "" {
req.Header.Add("Content-Type", bodyTemplateFuncMapData.GetFormDataContenType())
reqData.Headers["Content-Type"] = append(reqData.Headers["Content-Type"], bodyTemplateFuncMapData.GetFormDataContenType())
}
if err = paramsGenerator(req, data); err != nil {
if err = paramsGenerator(reqData, data); err != nil {
return err
}
if err = cookiesGenerator(req, data); err != nil {
if err = cookiesGenerator(reqData, data); err != nil {
return err
}
if requestURL.Scheme == "https" {
req.URI().SetScheme("https")
if hasScripts {
if err = scriptTransformer.Transform(reqData); err != nil {
return err
}
}
applyRequestDataToFastHTTP(reqData, req, host, scheme)
return nil
}, isPathGeneratorDynamic ||
isMethodGeneratorDynamic ||
isParamsGeneratorDynamic ||
isHeadersGeneratorDynamic ||
isCookiesGeneratorDynamic ||
isBodyGeneratorDynamic
isBodyGeneratorDynamic ||
hasScripts
}
func NewMethodGeneratorFunc(localRand *rand.Rand, methods []string, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
func resetRequestData(reqData *script.RequestData) {
reqData.Method = ""
reqData.Path = ""
reqData.Body = ""
clear(reqData.Headers)
clear(reqData.Params)
clear(reqData.Cookies)
}
func applyRequestDataToFastHTTP(reqData *script.RequestData, req *fasthttp.Request, host, scheme string) {
req.Header.SetHost(host)
req.SetRequestURI(reqData.Path)
req.Header.SetMethod(reqData.Method)
req.SetBody([]byte(reqData.Body))
for k, values := range reqData.Headers {
for _, v := range values {
req.Header.Add(k, v)
}
}
for k, values := range reqData.Params {
for _, v := range values {
req.URI().QueryArgs().Add(k, v)
}
}
if len(reqData.Cookies) > 0 {
cookieStrings := make([]string, 0, len(reqData.Cookies))
for k, values := range reqData.Cookies {
for _, v := range values {
cookieStrings = append(cookieStrings, k+"="+v)
}
}
req.Header.Add("Cookie", strings.Join(cookieStrings, "; "))
}
if scheme == "https" {
req.URI().SetScheme("https")
}
}
func NewMethodGeneratorFunc(localRand *rand.Rand, methods []string, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
methodGenerator, isDynamic := buildStringSliceGenerator(localRand, methods, templateFunctions)
var (
method string
err error
)
return func(req *fasthttp.Request, data any) error {
return func(reqData *script.RequestData, data any) error {
method, err = methodGenerator()(data)
if err != nil {
return err
}
req.Header.SetMethod(method)
reqData.Method = method
return nil
}, isDynamic
}
func NewBodyGeneratorFunc(localRand *rand.Rand, bodies []string, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
func NewBodyGeneratorFunc(localRand *rand.Rand, bodies []string, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
bodyGenerator, isDynamic := buildStringSliceGenerator(localRand, bodies, templateFunctions)
var (
body string
err error
)
return func(req *fasthttp.Request, data any) error {
return func(reqData *script.RequestData, data any) error {
body, err = bodyGenerator()(data)
if err != nil {
return err
}
req.SetBody([]byte(body))
reqData.Body = body
return nil
}, isDynamic
}
func NewParamsGeneratorFunc(localRand *rand.Rand, params types.Params, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
func NewParamsGeneratorFunc(localRand *rand.Rand, params types.Params, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
generators, isDynamic := buildKeyValueGenerators(localRand, params, templateFunctions)
var (
key, value string
err error
)
return func(req *fasthttp.Request, data any) error {
return func(reqData *script.RequestData, data any) error {
for _, gen := range generators {
key, err = gen.Key(data)
if err != nil {
@@ -162,20 +225,20 @@ func NewParamsGeneratorFunc(localRand *rand.Rand, params types.Params, templateF
return err
}
req.URI().QueryArgs().Add(key, value)
reqData.Params[key] = append(reqData.Params[key], value)
}
return nil
}, isDynamic
}
func NewHeadersGeneratorFunc(localRand *rand.Rand, headers types.Headers, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
func NewHeadersGeneratorFunc(localRand *rand.Rand, headers types.Headers, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
generators, isDynamic := buildKeyValueGenerators(localRand, headers, templateFunctions)
var (
key, value string
err error
)
return func(req *fasthttp.Request, data any) error {
return func(reqData *script.RequestData, data any) error {
for _, gen := range generators {
key, err = gen.Key(data)
if err != nil {
@@ -187,41 +250,33 @@ func NewHeadersGeneratorFunc(localRand *rand.Rand, headers types.Headers, templa
return err
}
req.Header.Add(key, value)
reqData.Headers[key] = append(reqData.Headers[key], value)
}
return nil
}, isDynamic
}
func NewCookiesGeneratorFunc(localRand *rand.Rand, cookies types.Cookies, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
func NewCookiesGeneratorFunc(localRand *rand.Rand, cookies types.Cookies, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
generators, isDynamic := buildKeyValueGenerators(localRand, cookies, templateFunctions)
var (
key, value string
err error
)
if len(generators) > 0 {
return func(req *fasthttp.Request, data any) error {
cookieStrings := make([]string, 0, len(generators))
for _, gen := range generators {
key, err = gen.Key(data)
if err != nil {
return err
}
value, err = gen.Value()(data)
if err != nil {
return err
}
cookieStrings = append(cookieStrings, key+"="+value)
return func(reqData *script.RequestData, data any) error {
for _, gen := range generators {
key, err = gen.Key(data)
if err != nil {
return err
}
req.Header.Add("Cookie", strings.Join(cookieStrings, "; "))
return nil
}, isDynamic
}
return func(req *fasthttp.Request, data any) error {
value, err = gen.Value()(data)
if err != nil {
return err
}
reqData.Cookies[key] = append(reqData.Cookies[key], value)
}
return nil
}, isDynamic
}
@@ -243,12 +298,12 @@ func NewValuesGeneratorFunc(values []string, templateFunctions template.FuncMap)
for _, generator := range generators {
rendered, err = generator(nil)
if err != nil {
return valuesData{}, fmt.Errorf("values rendering: %w", err)
return valuesData{}, types.NewTemplateRenderError(err)
}
data, err = godotenv.Unmarshal(rendered)
if err != nil {
return valuesData{}, fmt.Errorf("values rendering: %w", err)
return valuesData{}, types.NewTemplateRenderError(err)
}
maps.Copy(result, data)
@@ -265,7 +320,7 @@ func createTemplateFunc(value string, templateFunctions template.FuncMap) (func(
return func(data any) (string, error) {
var buf bytes.Buffer
if err = tmpl.Execute(&buf, data); err != nil {
return "", fmt.Errorf("template rendering: %w", err)
return "", types.NewTemplateRenderError(err)
}
return buf.String(), nil
}, true

View File

@@ -14,6 +14,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/valyala/fasthttp"
"go.aykhans.me/sarin/internal/script"
"go.aykhans.me/sarin/internal/types"
)
@@ -52,11 +53,14 @@ type sarin struct {
hostClients []*fasthttp.HostClient
responses *SarinResponseData
fileCache *FileCache
scriptChain *script.Chain
}
// NewSarin creates a new sarin instance for load testing.
// It can return the following errors:
// - types.ProxyDialError
// - types.ProxyDialError
// - types.ErrScriptEmpty
// - types.ScriptLoadError
func NewSarin(
ctx context.Context,
methods []string,
@@ -75,6 +79,8 @@ func NewSarin(
values []string,
collectStats bool,
dryRun bool,
luaScripts []string,
jsScripts []string,
) (*sarin, error) {
if workers == 0 {
workers = 1
@@ -85,6 +91,19 @@ func NewSarin(
return nil, err
}
// Load script sources
luaSources, err := script.LoadSources(ctx, luaScripts, script.EngineTypeLua)
if err != nil {
return nil, err
}
jsSources, err := script.LoadSources(ctx, jsScripts, script.EngineTypeJavaScript)
if err != nil {
return nil, err
}
scriptChain := script.NewChain(luaSources, jsSources)
srn := &sarin{
workers: workers,
requestURL: requestURL,
@@ -103,6 +122,7 @@ func NewSarin(
dryRun: dryRun,
hostClients: hostClients,
fileCache: NewFileCache(time.Second * 10),
scriptChain: scriptChain,
}
if collectStats {
@@ -193,7 +213,21 @@ func (q sarin) Worker(
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
requestGenerator, isDynamic := NewRequestGenerator(q.methods, q.requestURL, q.params, q.headers, q.cookies, q.bodies, q.values, q.fileCache)
// Create script transformer for this worker (engines are not thread-safe)
// Scripts are pre-validated in NewSarin, so this should not fail
var scriptTransformer *script.Transformer
if !q.scriptChain.IsEmpty() {
var err error
scriptTransformer, err = q.scriptChain.NewTransformer()
if err != nil {
panic(err)
}
defer scriptTransformer.Close()
}
requestGenerator, isDynamic := NewRequestGenerator(
q.methods, q.requestURL, q.params, q.headers, q.cookies, q.bodies, q.values, q.fileCache, scriptTransformer,
)
if q.dryRun {
switch {

View File

@@ -3,7 +3,6 @@ package sarin
import (
"bytes"
"encoding/base64"
"errors"
"math/rand/v2"
"mime/multipart"
"strings"
@@ -12,6 +11,7 @@ import (
"time"
"github.com/brianvoe/gofakeit/v7"
"go.aykhans.me/sarin/internal/types"
)
func NewDefaultTemplateFuncMap(randSource rand.Source, fileCache *FileCache) template.FuncMap {
@@ -90,7 +90,7 @@ func NewDefaultTemplateFuncMap(randSource rand.Source, fileCache *FileCache) tem
// {{ file_Base64 "https://example.com/image.png" }}
"file_Base64": func(source string) (string, error) {
if fileCache == nil {
return "", errors.New("file cache is not initialized")
return "", types.ErrFileCacheNotInitialized
}
cached, err := fileCache.GetOrLoad(source)
if err != nil {
@@ -582,7 +582,7 @@ func NewDefaultBodyTemplateFuncMap(
// {{ body_FormData "name" "John" "avatar" "@/path/to/photo.jpg" "doc" "@https://example.com/file.pdf" }}
funcMap["body_FormData"] = func(pairs ...string) (string, error) {
if len(pairs)%2 != 0 {
return "", errors.New("body_FormData requires an even number of arguments (key-value pairs)")
return "", types.ErrFormDataOddArgs
}
var multipartData bytes.Buffer
@@ -602,7 +602,7 @@ func NewDefaultBodyTemplateFuncMap(
case strings.HasPrefix(val, "@"):
// File (local path or remote URL)
if fileCache == nil {
return "", errors.New("file cache is not initialized")
return "", types.ErrFileCacheNotInitialized
}
source := val[1:]
cached, err := fileCache.GetOrLoad(source)

107
internal/script/chain.go Normal file
View File

@@ -0,0 +1,107 @@
package script
import (
"go.aykhans.me/sarin/internal/types"
)
// Chain holds the loaded script sources and can create engine instances.
// The sources are loaded once, but engines are created per-worker since they're not thread-safe.
type Chain struct {
luaSources []*Source
jsSources []*Source
}
// NewChain creates a new script chain from loaded sources.
// Lua scripts run first, then JavaScript scripts, in the order provided.
func NewChain(luaSources, jsSources []*Source) *Chain {
return &Chain{
luaSources: luaSources,
jsSources: jsSources,
}
}
// IsEmpty returns true if there are no scripts to execute.
func (c *Chain) IsEmpty() bool {
return len(c.luaSources) == 0 && len(c.jsSources) == 0
}
// Transformer holds instantiated script engines for a single worker.
// It is NOT safe for concurrent use.
type Transformer struct {
luaEngines []*LuaEngine
jsEngines []*JsEngine
}
// NewTransformer creates engine instances from the chain's sources.
// Call this once per worker goroutine.
// It can return the following errors:
// - types.ScriptChainError
func (c *Chain) NewTransformer() (*Transformer, error) {
if c.IsEmpty() {
return &Transformer{}, nil
}
t := &Transformer{
luaEngines: make([]*LuaEngine, 0, len(c.luaSources)),
jsEngines: make([]*JsEngine, 0, len(c.jsSources)),
}
// Create Lua engines
for i, src := range c.luaSources {
engine, err := NewLuaEngine(src.Content)
if err != nil {
t.Close() // Clean up already created engines
return nil, types.NewScriptChainError("lua", i, err)
}
t.luaEngines = append(t.luaEngines, engine)
}
// Create JS engines
for i, src := range c.jsSources {
engine, err := NewJsEngine(src.Content)
if err != nil {
t.Close() // Clean up already created engines
return nil, types.NewScriptChainError("js", i, err)
}
t.jsEngines = append(t.jsEngines, engine)
}
return t, nil
}
// Transform applies all scripts to the request data.
// Lua scripts run first, then JavaScript scripts.
// It can return the following errors:
// - types.ScriptChainError
func (t *Transformer) Transform(req *RequestData) error {
// Run Lua scripts
for i, engine := range t.luaEngines {
if err := engine.Transform(req); err != nil {
return types.NewScriptChainError("lua", i, err)
}
}
// Run JS scripts
for i, engine := range t.jsEngines {
if err := engine.Transform(req); err != nil {
return types.NewScriptChainError("js", i, err)
}
}
return nil
}
// Close releases all engine resources.
func (t *Transformer) Close() {
for _, engine := range t.luaEngines {
engine.Close()
}
for _, engine := range t.jsEngines {
engine.Close()
}
}
// IsEmpty returns true if there are no engines.
func (t *Transformer) IsEmpty() bool {
return len(t.luaEngines) == 0 && len(t.jsEngines) == 0
}

198
internal/script/js.go Normal file
View File

@@ -0,0 +1,198 @@
package script
import (
"errors"
"github.com/dop251/goja"
"go.aykhans.me/sarin/internal/types"
)
// JsEngine implements the Engine interface using goja (JavaScript).
type JsEngine struct {
runtime *goja.Runtime
transform goja.Callable
}
// NewJsEngine creates a new JavaScript script engine with the given script content.
// The script must define a global `transform` function that takes a request object
// and returns the modified request object.
//
// Example JavaScript script:
//
// function transform(req) {
// req.headers["X-Custom"] = ["value"];
// return req;
// }
//
// It can return the following errors:
// - types.ErrScriptTransformMissing
// - types.ScriptExecutionError
func NewJsEngine(scriptContent string) (*JsEngine, error) {
vm := goja.New()
// Execute the script to define the transform function
_, err := vm.RunString(scriptContent)
if err != nil {
return nil, types.NewScriptExecutionError("JavaScript", err)
}
// Get the transform function
transformVal := vm.Get("transform")
if transformVal == nil || goja.IsUndefined(transformVal) || goja.IsNull(transformVal) {
return nil, types.ErrScriptTransformMissing
}
transform, ok := goja.AssertFunction(transformVal)
if !ok {
return nil, types.NewScriptExecutionError("JavaScript", errors.New("'transform' must be a function"))
}
return &JsEngine{
runtime: vm,
transform: transform,
}, nil
}
// Transform executes the JavaScript transform function with the given request data.
// It can return the following errors:
// - types.ScriptExecutionError
func (e *JsEngine) Transform(req *RequestData) error {
// Convert RequestData to JavaScript object
reqObj := e.requestDataToObject(req)
// Call transform(req)
result, err := e.transform(goja.Undefined(), reqObj)
if err != nil {
return types.NewScriptExecutionError("JavaScript", err)
}
// Update RequestData from the returned object
if err := e.objectToRequestData(result, req); err != nil {
return types.NewScriptExecutionError("JavaScript", err)
}
return nil
}
// Close releases the JavaScript runtime resources.
func (e *JsEngine) Close() {
// goja doesn't have an explicit close method, but we can help GC
e.runtime = nil
e.transform = nil
}
// requestDataToObject converts RequestData to a goja Value (JavaScript object).
func (e *JsEngine) requestDataToObject(req *RequestData) goja.Value {
obj := e.runtime.NewObject()
_ = obj.Set("method", req.Method)
_ = obj.Set("path", req.Path)
_ = obj.Set("body", req.Body)
// Headers (map[string][]string -> object of arrays)
headers := e.runtime.NewObject()
for k, values := range req.Headers {
_ = headers.Set(k, e.stringSliceToArray(values))
}
_ = obj.Set("headers", headers)
// Params (map[string][]string -> object of arrays)
params := e.runtime.NewObject()
for k, values := range req.Params {
_ = params.Set(k, e.stringSliceToArray(values))
}
_ = obj.Set("params", params)
// Cookies (map[string][]string -> object of arrays)
cookies := e.runtime.NewObject()
for k, values := range req.Cookies {
_ = cookies.Set(k, e.stringSliceToArray(values))
}
_ = obj.Set("cookies", cookies)
return obj
}
// objectToRequestData updates RequestData from a JavaScript object.
func (e *JsEngine) objectToRequestData(val goja.Value, req *RequestData) error {
if val == nil || goja.IsUndefined(val) || goja.IsNull(val) {
return types.ErrScriptTransformReturnObject
}
obj := val.ToObject(e.runtime)
if obj == nil {
return types.ErrScriptTransformReturnObject
}
// Method
if v := obj.Get("method"); v != nil && !goja.IsUndefined(v) {
req.Method = v.String()
}
// Path
if v := obj.Get("path"); v != nil && !goja.IsUndefined(v) {
req.Path = v.String()
}
// Body
if v := obj.Get("body"); v != nil && !goja.IsUndefined(v) {
req.Body = v.String()
}
// Headers
if v := obj.Get("headers"); v != nil && !goja.IsUndefined(v) && !goja.IsNull(v) {
req.Headers = e.objectToStringSliceMap(v.ToObject(e.runtime))
}
// Params
if v := obj.Get("params"); v != nil && !goja.IsUndefined(v) && !goja.IsNull(v) {
req.Params = e.objectToStringSliceMap(v.ToObject(e.runtime))
}
// Cookies
if v := obj.Get("cookies"); v != nil && !goja.IsUndefined(v) && !goja.IsNull(v) {
req.Cookies = e.objectToStringSliceMap(v.ToObject(e.runtime))
}
return nil
}
// stringSliceToArray converts a Go []string to a JavaScript array.
func (e *JsEngine) stringSliceToArray(values []string) *goja.Object {
ifaces := make([]any, len(values))
for i, v := range values {
ifaces[i] = v
}
return e.runtime.NewArray(ifaces...)
}
// objectToStringSliceMap converts a JavaScript object to a Go map[string][]string.
// Supports both single string values and array values.
func (e *JsEngine) objectToStringSliceMap(obj *goja.Object) map[string][]string {
if obj == nil {
return make(map[string][]string)
}
result := make(map[string][]string)
for _, key := range obj.Keys() {
v := obj.Get(key)
if v == nil || goja.IsUndefined(v) || goja.IsNull(v) {
continue
}
// Check if it's an array
if arr, ok := v.Export().([]any); ok {
var values []string
for _, item := range arr {
if s, ok := item.(string); ok {
values = append(values, s)
}
}
result[key] = values
} else {
// Single value - wrap in slice
result[key] = []string{v.String()}
}
}
return result
}

191
internal/script/lua.go Normal file
View File

@@ -0,0 +1,191 @@
package script
import (
"fmt"
lua "github.com/yuin/gopher-lua"
"go.aykhans.me/sarin/internal/types"
)
// LuaEngine implements the Engine interface using gopher-lua.
type LuaEngine struct {
state *lua.LState
transform *lua.LFunction
}
// NewLuaEngine creates a new Lua script engine with the given script content.
// The script must define a global `transform` function that takes a request table
// and returns the modified request table.
//
// Example Lua script:
//
// function transform(req)
// req.headers["X-Custom"] = {"value"}
// return req
// end
//
// It can return the following errors:
// - types.ErrScriptTransformMissing
// - types.ScriptExecutionError
func NewLuaEngine(scriptContent string) (*LuaEngine, error) {
L := lua.NewState()
// Execute the script to define the transform function
if err := L.DoString(scriptContent); err != nil {
L.Close()
return nil, types.NewScriptExecutionError("Lua", err)
}
// Get the transform function
transform := L.GetGlobal("transform")
if transform.Type() != lua.LTFunction {
L.Close()
return nil, types.ErrScriptTransformMissing
}
return &LuaEngine{
state: L,
transform: transform.(*lua.LFunction),
}, nil
}
// Transform executes the Lua transform function with the given request data.
// It can return the following errors:
// - types.ScriptExecutionError
func (e *LuaEngine) Transform(req *RequestData) error {
// Convert RequestData to Lua table
reqTable := e.requestDataToTable(req)
// Call transform(req)
e.state.Push(e.transform)
e.state.Push(reqTable)
if err := e.state.PCall(1, 1, nil); err != nil {
return types.NewScriptExecutionError("Lua", err)
}
// Get the result
result := e.state.Get(-1)
e.state.Pop(1)
if result.Type() != lua.LTTable {
return types.NewScriptExecutionError("Lua", fmt.Errorf("transform function must return a table, got %s", result.Type()))
}
// Update RequestData from the returned table
e.tableToRequestData(result.(*lua.LTable), req)
return nil
}
// Close releases the Lua state resources.
func (e *LuaEngine) Close() {
if e.state != nil {
e.state.Close()
}
}
// requestDataToTable converts RequestData to a Lua table.
func (e *LuaEngine) requestDataToTable(req *RequestData) *lua.LTable {
L := e.state
t := L.NewTable()
t.RawSetString("method", lua.LString(req.Method))
t.RawSetString("path", lua.LString(req.Path))
t.RawSetString("body", lua.LString(req.Body))
// Headers (map[string][]string -> table of arrays)
headers := L.NewTable()
for k, values := range req.Headers {
arr := L.NewTable()
for _, v := range values {
arr.Append(lua.LString(v))
}
headers.RawSetString(k, arr)
}
t.RawSetString("headers", headers)
// Params (map[string][]string -> table of arrays)
params := L.NewTable()
for k, values := range req.Params {
arr := L.NewTable()
for _, v := range values {
arr.Append(lua.LString(v))
}
params.RawSetString(k, arr)
}
t.RawSetString("params", params)
// Cookies (map[string][]string -> table of arrays)
cookies := L.NewTable()
for k, values := range req.Cookies {
arr := L.NewTable()
for _, v := range values {
arr.Append(lua.LString(v))
}
cookies.RawSetString(k, arr)
}
t.RawSetString("cookies", cookies)
return t
}
// tableToRequestData updates RequestData from a Lua table.
func (e *LuaEngine) tableToRequestData(t *lua.LTable, req *RequestData) {
// Method
if v := t.RawGetString("method"); v.Type() == lua.LTString {
req.Method = string(v.(lua.LString))
}
// Path
if v := t.RawGetString("path"); v.Type() == lua.LTString {
req.Path = string(v.(lua.LString))
}
// Body
if v := t.RawGetString("body"); v.Type() == lua.LTString {
req.Body = string(v.(lua.LString))
}
// Headers
if v := t.RawGetString("headers"); v.Type() == lua.LTTable {
req.Headers = e.tableToStringSliceMap(v.(*lua.LTable))
}
// Params
if v := t.RawGetString("params"); v.Type() == lua.LTTable {
req.Params = e.tableToStringSliceMap(v.(*lua.LTable))
}
// Cookies
if v := t.RawGetString("cookies"); v.Type() == lua.LTTable {
req.Cookies = e.tableToStringSliceMap(v.(*lua.LTable))
}
}
// tableToStringSliceMap converts a Lua table to a Go map[string][]string.
// Supports both single string values and array values.
func (e *LuaEngine) tableToStringSliceMap(t *lua.LTable) map[string][]string {
result := make(map[string][]string)
t.ForEach(func(k, v lua.LValue) {
if k.Type() != lua.LTString {
return
}
key := string(k.(lua.LString))
switch v.Type() {
case lua.LTString:
// Single string value
result[key] = []string{string(v.(lua.LString))}
case lua.LTTable:
// Array of strings
var values []string
v.(*lua.LTable).ForEach(func(_, item lua.LValue) {
if item.Type() == lua.LTString {
values = append(values, string(item.(lua.LString)))
}
})
result[key] = values
}
})
return result
}

197
internal/script/script.go Normal file
View File

@@ -0,0 +1,197 @@
package script
import (
"context"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"go.aykhans.me/sarin/internal/types"
)
// RequestData represents the request data passed to scripts for transformation.
// Scripts can modify any field and the changes will be applied to the actual request.
// Headers, Params, and Cookies use []string values to support multiple values per key.
type RequestData struct {
Method string `json:"method"`
Path string `json:"path"`
Headers map[string][]string `json:"headers"`
Params map[string][]string `json:"params"`
Cookies map[string][]string `json:"cookies"`
Body string `json:"body"`
}
// Engine defines the interface for script engines (Lua, JavaScript).
// Each engine must be able to transform request data using a user-provided script.
type Engine interface {
// Transform executes the script's transform function with the given request data.
// The script should modify the RequestData and return it.
Transform(req *RequestData) error
// Close releases any resources held by the engine.
Close()
}
// EngineType represents the type of script engine.
type EngineType string
const (
EngineTypeLua EngineType = "lua"
EngineTypeJavaScript EngineType = "js"
)
// Source represents a loaded script source.
type Source struct {
Content string
EngineType EngineType
}
// LoadSource loads a script from the given source string.
// The source can be:
// - Inline script: any string not starting with "@"
// - Escaped "@": strings starting with "@@" (literal "@" at start, returns string without first @)
// - File reference: "@/path/to/file" or "@./relative/path"
// - URL reference: "@http://..." or "@https://..."
//
// It can return the following errors:
// - types.ErrScriptEmpty
// - types.ScriptLoadError
func LoadSource(ctx context.Context, source string, engineType EngineType) (*Source, error) {
if source == "" {
return nil, types.ErrScriptEmpty
}
var content string
var err error
switch {
case strings.HasPrefix(source, "@@"):
// Escaped @ - it's an inline script starting with literal @
content = source[1:] // Remove first @, keep the rest
case strings.HasPrefix(source, "@"):
// File or URL reference
ref := source[1:]
if strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://") {
content, err = fetchURL(ctx, ref)
} else {
content, err = readFile(ref)
}
if err != nil {
return nil, types.NewScriptLoadError(ref, err)
}
default:
// Inline script
content = source
}
return &Source{
Content: content,
EngineType: engineType,
}, nil
}
// LoadSources loads multiple script sources.
// It can return the following errors:
// - types.ErrScriptEmpty
// - types.ScriptLoadError
func LoadSources(ctx context.Context, sources []string, engineType EngineType) ([]*Source, error) {
loaded := make([]*Source, 0, len(sources))
for _, src := range sources {
source, err := LoadSource(ctx, src, engineType)
if err != nil {
return nil, err
}
loaded = append(loaded, source)
}
return loaded, nil
}
// ValidateScript validates a script source by loading it and checking syntax.
// It loads the script (from file/URL/inline), parses it, and verifies
// that a 'transform' function is defined.
// It can return the following errors:
// - types.ErrScriptEmpty
// - types.ErrScriptTransformMissing
// - types.ScriptLoadError
// - types.ScriptExecutionError
// - types.ScriptUnknownEngineError
func ValidateScript(ctx context.Context, source string, engineType EngineType) error {
// Load the script source
src, err := LoadSource(ctx, source, engineType)
if err != nil {
return err
}
// Try to create an engine - this validates syntax and transform function
var engine Engine
switch engineType {
case EngineTypeLua:
engine, err = NewLuaEngine(src.Content)
case EngineTypeJavaScript:
engine, err = NewJsEngine(src.Content)
default:
return types.NewScriptUnknownEngineError(string(engineType))
}
if err != nil {
return err
}
// Clean up the engine - we only needed it for validation
engine.Close()
return nil
}
// fetchURL downloads content from an HTTP/HTTPS URL.
// It can return the following errors:
// - types.HTTPFetchError
// - types.HTTPStatusError
func fetchURL(ctx context.Context, url string) (string, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", types.NewHTTPFetchError(url, err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", types.NewHTTPFetchError(url, err)
}
defer resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
return "", types.NewHTTPStatusError(url, resp.StatusCode, resp.Status)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", types.NewHTTPFetchError(url, err)
}
return string(data), nil
}
// readFile reads content from a local file.
// It can return the following errors:
// - types.FileReadError
func readFile(path string) (string, error) {
if !filepath.IsAbs(path) {
pwd, err := os.Getwd()
if err != nil {
return "", types.NewFileReadError(path, err)
}
path = filepath.Join(pwd, path)
}
data, err := os.ReadFile(path) //nolint:gosec
if err != nil {
return "", types.NewFileReadError(path, err)
}
return string(data), nil
}

View File

@@ -6,16 +6,12 @@ import (
"strings"
)
var (
// General
ErrNoError = errors.New("no error (internal)")
// CLI
ErrCLINoArgs = errors.New("CLI expects arguments but received none")
)
// ======================================== General ========================================
var (
errNoError = errors.New("no error (internal)")
)
type FieldParseError struct {
Field string
Value string
@@ -24,7 +20,7 @@ type FieldParseError struct {
func NewFieldParseError(field string, value string, err error) FieldParseError {
if err == nil {
err = ErrNoError
err = errNoError
}
return FieldParseError{field, value, err}
}
@@ -72,7 +68,7 @@ type FieldValidationError struct {
func NewFieldValidationError(field string, value string, err error) FieldValidationError {
if err == nil {
err = ErrNoError
err = errNoError
}
return FieldValidationError{field, value, err}
}
@@ -118,7 +114,7 @@ type UnmarshalError struct {
func NewUnmarshalError(err error) UnmarshalError {
if err == nil {
err = ErrNoError
err = errNoError
}
return UnmarshalError{err}
}
@@ -131,8 +127,133 @@ func (e UnmarshalError) Unwrap() error {
return e.error
}
// ======================================== General I/O ========================================
type FileReadError struct {
Path string
Err error
}
func NewFileReadError(path string, err error) FileReadError {
if err == nil {
err = errNoError
}
return FileReadError{path, err}
}
func (e FileReadError) Error() string {
return fmt.Sprintf("failed to read file %s: %v", e.Path, e.Err)
}
func (e FileReadError) Unwrap() error {
return e.Err
}
type HTTPFetchError struct {
URL string
Err error
}
func NewHTTPFetchError(url string, err error) HTTPFetchError {
if err == nil {
err = errNoError
}
return HTTPFetchError{url, err}
}
func (e HTTPFetchError) Error() string {
return fmt.Sprintf("failed to fetch %s: %v", e.URL, e.Err)
}
func (e HTTPFetchError) Unwrap() error {
return e.Err
}
type HTTPStatusError struct {
URL string
StatusCode int
Status string
}
func NewHTTPStatusError(url string, statusCode int, status string) HTTPStatusError {
return HTTPStatusError{url, statusCode, status}
}
func (e HTTPStatusError) Error() string {
return fmt.Sprintf("HTTP %d %s (url: %s)", e.StatusCode, e.Status, e.URL)
}
type URLParseError struct {
URL string
Err error
}
func NewURLParseError(url string, err error) URLParseError {
if err == nil {
err = errNoError
}
return URLParseError{url, err}
}
func (e URLParseError) Error() string {
return fmt.Sprintf("invalid URL %q: %v", e.URL, e.Err)
}
func (e URLParseError) Unwrap() error {
return e.Err
}
// ======================================== Template ========================================
var (
ErrFileCacheNotInitialized = errors.New("file cache is not initialized")
ErrFormDataOddArgs = errors.New("body_FormData requires an even number of arguments (key-value pairs)")
)
type TemplateParseError struct {
Err error
}
func NewTemplateParseError(err error) TemplateParseError {
if err == nil {
err = errNoError
}
return TemplateParseError{err}
}
func (e TemplateParseError) Error() string {
return "template parse error: " + e.Err.Error()
}
func (e TemplateParseError) Unwrap() error {
return e.Err
}
type TemplateRenderError struct {
Err error
}
func NewTemplateRenderError(err error) TemplateRenderError {
if err == nil {
err = errNoError
}
return TemplateRenderError{err}
}
func (e TemplateRenderError) Error() string {
return "template rendering: " + e.Err.Error()
}
func (e TemplateRenderError) Unwrap() error {
return e.Err
}
// ======================================== CLI ========================================
var (
ErrCLINoArgs = errors.New("CLI expects arguments but received none")
)
type CLIUnexpectedArgsError struct {
Args []string
}
@@ -153,7 +274,7 @@ type ConfigFileReadError struct {
func NewConfigFileReadError(err error) ConfigFileReadError {
if err == nil {
err = ErrNoError
err = errNoError
}
return ConfigFileReadError{err}
}
@@ -168,6 +289,61 @@ func (e ConfigFileReadError) Unwrap() error {
// ======================================== Proxy ========================================
type ProxyUnsupportedSchemeError struct {
Scheme string
}
func NewProxyUnsupportedSchemeError(scheme string) ProxyUnsupportedSchemeError {
return ProxyUnsupportedSchemeError{scheme}
}
func (e ProxyUnsupportedSchemeError) Error() string {
return "unsupported proxy scheme: " + e.Scheme
}
type ProxyParseError struct {
Err error
}
func NewProxyParseError(err error) ProxyParseError {
if err == nil {
err = errNoError
}
return ProxyParseError{err}
}
func (e ProxyParseError) Error() string {
return "failed to parse proxy URL: " + e.Err.Error()
}
func (e ProxyParseError) Unwrap() error {
return e.Err
}
type ProxyConnectError struct {
Status string
}
func NewProxyConnectError(status string) ProxyConnectError {
return ProxyConnectError{status}
}
func (e ProxyConnectError) Error() string {
return "proxy CONNECT failed: " + e.Status
}
type ProxyResolveError struct {
Host string
}
func NewProxyResolveError(host string) ProxyResolveError {
return ProxyResolveError{host}
}
func (e ProxyResolveError) Error() string {
return "no IP addresses found for host: " + e.Host
}
type ProxyDialError struct {
Proxy string
Err error
@@ -175,7 +351,7 @@ type ProxyDialError struct {
func NewProxyDialError(proxy string, err error) ProxyDialError {
if err == nil {
err = ErrNoError
err = errNoError
}
return ProxyDialError{proxy, err}
}
@@ -187,3 +363,86 @@ func (e ProxyDialError) Error() string {
func (e ProxyDialError) Unwrap() error {
return e.Err
}
// ======================================== Script ========================================
var (
ErrScriptEmpty = errors.New("script cannot be empty")
ErrScriptSourceEmpty = errors.New("script source cannot be empty after @")
ErrScriptTransformMissing = errors.New("script must define a global 'transform' function")
ErrScriptTransformReturnObject = errors.New("transform function must return an object")
ErrScriptURLNoHost = errors.New("script URL must have a host")
)
type ScriptLoadError struct {
Source string
Err error
}
func NewScriptLoadError(source string, err error) ScriptLoadError {
if err == nil {
err = errNoError
}
return ScriptLoadError{source, err}
}
func (e ScriptLoadError) Error() string {
return fmt.Sprintf("failed to load script from %q: %v", e.Source, e.Err)
}
func (e ScriptLoadError) Unwrap() error {
return e.Err
}
type ScriptExecutionError struct {
EngineType string
Err error
}
func NewScriptExecutionError(engineType string, err error) ScriptExecutionError {
if err == nil {
err = errNoError
}
return ScriptExecutionError{engineType, err}
}
func (e ScriptExecutionError) Error() string {
return fmt.Sprintf("%s script error: %v", e.EngineType, e.Err)
}
func (e ScriptExecutionError) Unwrap() error {
return e.Err
}
type ScriptChainError struct {
EngineType string
Index int
Err error
}
func NewScriptChainError(engineType string, index int, err error) ScriptChainError {
if err == nil {
err = errNoError
}
return ScriptChainError{engineType, index, err}
}
func (e ScriptChainError) Error() string {
return fmt.Sprintf("%s script[%d]: %v", e.EngineType, e.Index, e.Err)
}
func (e ScriptChainError) Unwrap() error {
return e.Err
}
type ScriptUnknownEngineError struct {
EngineType string
}
func NewScriptUnknownEngineError(engineType string) ScriptUnknownEngineError {
return ScriptUnknownEngineError{engineType}
}
func (e ScriptUnknownEngineError) Error() string {
return "unknown engine type: " + e.EngineType
}

View File

@@ -1,7 +1,6 @@
package types
import (
"fmt"
"net/url"
)
@@ -17,6 +16,9 @@ func (proxies *Proxies) Append(proxy ...Proxy) {
*proxies = append(*proxies, proxy...)
}
// Parse parses a raw proxy string and appends it to the list.
// It can return the following errors:
// - ProxyParseError
func (proxies *Proxies) Parse(rawValue string) error {
parsedProxy, err := ParseProxy(rawValue)
if err != nil {
@@ -27,10 +29,13 @@ func (proxies *Proxies) Parse(rawValue string) error {
return nil
}
// ParseProxy parses a raw proxy URL string into a Proxy.
// It can return the following errors:
// - ProxyParseError
func ParseProxy(rawValue string) (*Proxy, error) {
urlParsed, err := url.Parse(rawValue)
if err != nil {
return nil, fmt.Errorf("failed to parse proxy URL: %w", err)
return nil, NewProxyParseError(err)
}
proxyParsed := Proxy(*urlParsed)