Add context and signal handling for graceful shutdown

This commit is contained in:
Aykhan Shahsuvarov 2024-08-24 19:41:12 +04:00
parent d25c2b2964
commit 2fe8df28d5
2 changed files with 108 additions and 19 deletions

23
main.go
View File

@ -1,9 +1,12 @@
package main
import (
"context"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/aykhans/dodo/config"
@ -79,6 +82,24 @@ func main() {
}
dodoConf.Print()
responses := requests.Run(dodoConf)
ctx, cancel := context.WithCancel(context.Background())
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
cancel()
}()
responses, err := requests.Run(ctx, dodoConf)
if err != nil {
if customerrors.Is(err, customerrors.ErrInterrupt) {
utils.PrintlnC(utils.Colors.Yellow, err.Error())
return
} else if customerrors.Is(err, customerrors.ErrNoInternet) {
utils.PrintAndExit("No internet connection")
}
panic(err)
}
responses.Print()
}

View File

@ -1,6 +1,7 @@
package requests
import (
"context"
"fmt"
"math/rand"
"net/url"
@ -9,6 +10,7 @@ import (
"time"
"github.com/aykhans/dodo/config"
customerrors "github.com/aykhans/dodo/custom_errors"
"github.com/aykhans/dodo/readers"
"github.com/aykhans/dodo/utils"
"github.com/jedib0t/go-pretty/v6/progress"
@ -91,17 +93,27 @@ func (respones *Responses) Print() {
}
// Run executes the HTTP requests based on the provided request configuration.
// It returns the Responses type, which contains the responses received from all the requests.
func Run(requestConfig *config.RequestConfig) Responses {
if !checkConnection() {
utils.PrintAndExit("No internet connection")
// It checks for internet connection and returns an error if there is no connection.
// If the context is canceled while checking proxies, it returns the ErrInterrupt.
// If the context is canceled while sending requests, it returns the response objects obtained so far.
func Run(ctx context.Context, requestConfig *config.RequestConfig) (Responses, error) {
checkConnectionCtx, checkConnectionCtxCancel := context.WithTimeout(ctx, 8*time.Second)
if !checkConnection(checkConnectionCtx) {
checkConnectionCtxCancel()
return nil, customerrors.ErrNoInternet
}
checkConnectionCtxCancel()
clientFunc := getClientFunc(
ctx,
requestConfig.Timeout,
requestConfig.Proxies,
requestConfig.GetValidDodosCountForRequests(),
requestConfig.URL,
)
if clientFunc == nil {
return nil, customerrors.ErrInterrupt
}
request := newRequest(
requestConfig.URL,
@ -113,14 +125,18 @@ func Run(requestConfig *config.RequestConfig) Responses {
)
defer fasthttp.ReleaseRequest(request)
responses := releaseDodos(
ctx,
request,
requestConfig.Timeout,
clientFunc,
requestConfig.GetValidDodosCountForRequests(),
requestConfig.RequestCount,
)
if ctx.Err() != nil && len(responses) == 0 {
return nil, customerrors.ErrInterrupt
}
return responses
return responses, nil
}
// releaseDodos sends multiple HTTP requests concurrently using multiple "dodos" (goroutines).
@ -128,6 +144,7 @@ func Run(requestConfig *config.RequestConfig) Responses {
// dodosCount as the number of goroutines to be used, and requestCount as the total number of requests to be sent.
// It returns the responses received from all the requests.
func releaseDodos(
ctx context.Context,
mainRequest *fasthttp.Request,
timeout time.Duration,
clientFunc ClientFunc,
@ -136,14 +153,17 @@ func releaseDodos(
) Responses {
var (
wg sync.WaitGroup
streamWG sync.WaitGroup
requestCountPerDodo int
)
wg.Add(dodosCount + 1) // +1 for progress tracker
wg.Add(dodosCount)
streamWG.Add(1)
responses := make([][]Response, dodosCount)
countSlice := make([]int, dodosCount)
go streamProgress(&wg, requestCount, "Dodos Working🔥", &countSlice)
streamCtx, streamCtxCancel := context.WithCancel(context.Background())
go streamProgress(streamCtx, &streamWG, requestCount, "Dodos Working🔥", &countSlice)
for i := 0; i < dodosCount; i++ {
if i+1 == dodosCount {
@ -157,6 +177,7 @@ func releaseDodos(
mainRequest.CopyTo(dodoSpecificRequest)
go sendRequest(
ctx,
dodoSpecificRequest,
timeout,
&responses[i],
@ -167,6 +188,8 @@ func releaseDodos(
)
}
wg.Wait()
streamCtxCancel()
streamWG.Wait()
return utils.Flatten(responses)
}
@ -175,6 +198,7 @@ func releaseDodos(
// For each request, it acquires a response object, gets a client, and measures the time taken to complete the request.
// If an error occurs during the request, the error is recorded in the responseData slice.
func sendRequest(
ctx context.Context,
request *fasthttp.Request,
timeout time.Duration,
responseData *[]Response,
@ -187,6 +211,10 @@ func sendRequest(
defer wg.Done()
for range requestCount {
if ctx.Err() != nil {
return
}
func() {
defer func() { *counter++ }()
@ -221,6 +249,7 @@ func sendRequest(
// If the user chooses to continue, it returns a ClientFunc with a shared client or a randomized client.
// If there are no proxies available, it returns a ClientFunc with a shared client.
func getClientFunc(
ctx context.Context,
timeout time.Duration,
proxies []config.Proxy,
dodosCount int,
@ -229,8 +258,11 @@ func getClientFunc(
isTLS := URL.Scheme == "https"
if len(proxies) > 0 {
activeProxyClients := getActiveProxyClients(
proxies, timeout, dodosCount, URL,
ctx, proxies, timeout, dodosCount, URL,
)
if ctx.Err() != nil {
return nil
}
activeProxyClientsCount := len(activeProxyClients)
var yesOrNoMessage string
if activeProxyClientsCount == 0 {
@ -287,6 +319,7 @@ func getClientFunc(
// Once all goroutines have completed, the function waits for them to finish and
// returns a flattened slice of active proxy clients.
func getActiveProxyClients(
ctx context.Context,
proxies []config.Proxy,
timeout time.Duration,
dodosCount int,
@ -295,12 +328,17 @@ func getActiveProxyClients(
activeProxyClientsArray := make([][]fasthttp.HostClient, dodosCount)
proxiesCount := len(proxies)
var wg sync.WaitGroup
wg.Add(dodosCount + 1) // +1 for progress tracker
var (
wg sync.WaitGroup
streamWG sync.WaitGroup
)
wg.Add(dodosCount)
streamWG.Add(1)
var proxiesSlice []config.Proxy
countSlice := make([]int, dodosCount)
go streamProgress(&wg, proxiesCount, "Searching for active proxies🌐", &countSlice)
streamCtx, streamCtxCancel := context.WithCancel(context.Background())
go streamProgress(streamCtx, &streamWG, proxiesCount, "Searching for active proxies🌐", &countSlice)
for i := 0; i < dodosCount; i++ {
if i+1 == dodosCount {
@ -309,6 +347,7 @@ func getActiveProxyClients(
proxiesSlice = proxies[i*proxiesCount/dodosCount : (i+1)*proxiesCount/dodosCount]
}
go findActiveProxyClients(
ctx,
proxiesSlice,
timeout,
&activeProxyClientsArray[i],
@ -318,6 +357,8 @@ func getActiveProxyClients(
)
}
wg.Wait()
streamCtxCancel()
streamWG.Wait()
return utils.Flatten(activeProxyClientsArray)
}
@ -326,6 +367,7 @@ func getActiveProxyClients(
// It also increments the count for each successful request.
// The function is designed to be used as a concurrent operation, and it uses the WaitGroup to wait for all goroutines to finish.
func findActiveProxyClients(
ctx context.Context,
proxies []config.Proxy,
timeout time.Duration,
activeProxyClients *[]fasthttp.HostClient,
@ -341,6 +383,10 @@ func findActiveProxyClients(
request.Header.SetMethod("GET")
for _, proxy := range proxies {
if ctx.Err() != nil {
return
}
func() {
defer func() { *count++ }()
@ -520,6 +566,7 @@ func setRequestBody(req *fasthttp.Request, body string) {
// The function runs in a separate goroutine and updates the progress bar until all items are processed.
// Once all items are processed, it marks the progress bar as done and stops rendering.
func streamProgress(
ctx context.Context,
wg *sync.WaitGroup,
total int,
message string,
@ -543,11 +590,21 @@ func streamProgress(
totalCount += count
}
dodosTracker.SetValue(int64(totalCount))
if ctx.Err() != nil {
fmt.Printf("\r")
dodosTracker.MarkAsErrored()
time.Sleep(time.Millisecond * 300)
pw.Stop()
return
}
if totalCount == total {
break
}
time.Sleep(time.Millisecond * 200)
}
fmt.Printf("\r")
dodosTracker.MarkAsDone()
time.Sleep(time.Millisecond * 300)
pw.Stop()
@ -555,14 +612,25 @@ func streamProgress(
// checkConnection checks the internet connection by making requests to different websites.
// It returns true if the connection is successful, otherwise false.
func checkConnection() bool {
_, _, err := fasthttp.Get(nil, "https://www.google.com")
if err != nil {
_, _, err = fasthttp.Get(nil, "https://www.bing.com")
func checkConnection(ctx context.Context) bool {
ch := make(chan bool)
go func() {
_, _, err := fasthttp.Get(nil, "https://www.google.com")
if err != nil {
_, _, err = fasthttp.Get(nil, "https://www.yahoo.com")
return err == nil
_, _, err = fasthttp.Get(nil, "https://www.bing.com")
if err != nil {
_, _, err = fasthttp.Get(nil, "https://www.yahoo.com")
ch <- err == nil
}
ch <- true
}
ch <- true
}()
select {
case <-ctx.Done():
return false
case res := <-ch:
return res
}
return true
}