Add context and signal handling for graceful shutdown

This commit is contained in:
Aykhan Shahsuvarov 2024-08-24 19:41:12 +04:00
parent d25c2b2964
commit 2fe8df28d5
2 changed files with 108 additions and 19 deletions

23
main.go
View File

@ -1,9 +1,12 @@
package main package main
import ( import (
"context"
"net/url" "net/url"
"os" "os"
"os/signal"
"strings" "strings"
"syscall"
"time" "time"
"github.com/aykhans/dodo/config" "github.com/aykhans/dodo/config"
@ -79,6 +82,24 @@ func main() {
} }
dodoConf.Print() dodoConf.Print()
responses := requests.Run(dodoConf) ctx, cancel := context.WithCancel(context.Background())
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
cancel()
}()
responses, err := requests.Run(ctx, dodoConf)
if err != nil {
if customerrors.Is(err, customerrors.ErrInterrupt) {
utils.PrintlnC(utils.Colors.Yellow, err.Error())
return
} else if customerrors.Is(err, customerrors.ErrNoInternet) {
utils.PrintAndExit("No internet connection")
}
panic(err)
}
responses.Print() responses.Print()
} }

View File

@ -1,6 +1,7 @@
package requests package requests
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"net/url" "net/url"
@ -9,6 +10,7 @@ import (
"time" "time"
"github.com/aykhans/dodo/config" "github.com/aykhans/dodo/config"
customerrors "github.com/aykhans/dodo/custom_errors"
"github.com/aykhans/dodo/readers" "github.com/aykhans/dodo/readers"
"github.com/aykhans/dodo/utils" "github.com/aykhans/dodo/utils"
"github.com/jedib0t/go-pretty/v6/progress" "github.com/jedib0t/go-pretty/v6/progress"
@ -91,17 +93,27 @@ func (respones *Responses) Print() {
} }
// Run executes the HTTP requests based on the provided request configuration. // Run executes the HTTP requests based on the provided request configuration.
// It returns the Responses type, which contains the responses received from all the requests. // It checks for internet connection and returns an error if there is no connection.
func Run(requestConfig *config.RequestConfig) Responses { // If the context is canceled while checking proxies, it returns the ErrInterrupt.
if !checkConnection() { // If the context is canceled while sending requests, it returns the response objects obtained so far.
utils.PrintAndExit("No internet connection") func Run(ctx context.Context, requestConfig *config.RequestConfig) (Responses, error) {
checkConnectionCtx, checkConnectionCtxCancel := context.WithTimeout(ctx, 8*time.Second)
if !checkConnection(checkConnectionCtx) {
checkConnectionCtxCancel()
return nil, customerrors.ErrNoInternet
} }
checkConnectionCtxCancel()
clientFunc := getClientFunc( clientFunc := getClientFunc(
ctx,
requestConfig.Timeout, requestConfig.Timeout,
requestConfig.Proxies, requestConfig.Proxies,
requestConfig.GetValidDodosCountForRequests(), requestConfig.GetValidDodosCountForRequests(),
requestConfig.URL, requestConfig.URL,
) )
if clientFunc == nil {
return nil, customerrors.ErrInterrupt
}
request := newRequest( request := newRequest(
requestConfig.URL, requestConfig.URL,
@ -113,14 +125,18 @@ func Run(requestConfig *config.RequestConfig) Responses {
) )
defer fasthttp.ReleaseRequest(request) defer fasthttp.ReleaseRequest(request)
responses := releaseDodos( responses := releaseDodos(
ctx,
request, request,
requestConfig.Timeout, requestConfig.Timeout,
clientFunc, clientFunc,
requestConfig.GetValidDodosCountForRequests(), requestConfig.GetValidDodosCountForRequests(),
requestConfig.RequestCount, requestConfig.RequestCount,
) )
if ctx.Err() != nil && len(responses) == 0 {
return nil, customerrors.ErrInterrupt
}
return responses return responses, nil
} }
// releaseDodos sends multiple HTTP requests concurrently using multiple "dodos" (goroutines). // releaseDodos sends multiple HTTP requests concurrently using multiple "dodos" (goroutines).
@ -128,6 +144,7 @@ func Run(requestConfig *config.RequestConfig) Responses {
// dodosCount as the number of goroutines to be used, and requestCount as the total number of requests to be sent. // dodosCount as the number of goroutines to be used, and requestCount as the total number of requests to be sent.
// It returns the responses received from all the requests. // It returns the responses received from all the requests.
func releaseDodos( func releaseDodos(
ctx context.Context,
mainRequest *fasthttp.Request, mainRequest *fasthttp.Request,
timeout time.Duration, timeout time.Duration,
clientFunc ClientFunc, clientFunc ClientFunc,
@ -136,14 +153,17 @@ func releaseDodos(
) Responses { ) Responses {
var ( var (
wg sync.WaitGroup wg sync.WaitGroup
streamWG sync.WaitGroup
requestCountPerDodo int requestCountPerDodo int
) )
wg.Add(dodosCount + 1) // +1 for progress tracker wg.Add(dodosCount)
streamWG.Add(1)
responses := make([][]Response, dodosCount) responses := make([][]Response, dodosCount)
countSlice := make([]int, dodosCount) countSlice := make([]int, dodosCount)
go streamProgress(&wg, requestCount, "Dodos Working🔥", &countSlice) streamCtx, streamCtxCancel := context.WithCancel(context.Background())
go streamProgress(streamCtx, &streamWG, requestCount, "Dodos Working🔥", &countSlice)
for i := 0; i < dodosCount; i++ { for i := 0; i < dodosCount; i++ {
if i+1 == dodosCount { if i+1 == dodosCount {
@ -157,6 +177,7 @@ func releaseDodos(
mainRequest.CopyTo(dodoSpecificRequest) mainRequest.CopyTo(dodoSpecificRequest)
go sendRequest( go sendRequest(
ctx,
dodoSpecificRequest, dodoSpecificRequest,
timeout, timeout,
&responses[i], &responses[i],
@ -167,6 +188,8 @@ func releaseDodos(
) )
} }
wg.Wait() wg.Wait()
streamCtxCancel()
streamWG.Wait()
return utils.Flatten(responses) return utils.Flatten(responses)
} }
@ -175,6 +198,7 @@ func releaseDodos(
// For each request, it acquires a response object, gets a client, and measures the time taken to complete the request. // For each request, it acquires a response object, gets a client, and measures the time taken to complete the request.
// If an error occurs during the request, the error is recorded in the responseData slice. // If an error occurs during the request, the error is recorded in the responseData slice.
func sendRequest( func sendRequest(
ctx context.Context,
request *fasthttp.Request, request *fasthttp.Request,
timeout time.Duration, timeout time.Duration,
responseData *[]Response, responseData *[]Response,
@ -187,6 +211,10 @@ func sendRequest(
defer wg.Done() defer wg.Done()
for range requestCount { for range requestCount {
if ctx.Err() != nil {
return
}
func() { func() {
defer func() { *counter++ }() defer func() { *counter++ }()
@ -221,6 +249,7 @@ func sendRequest(
// If the user chooses to continue, it returns a ClientFunc with a shared client or a randomized client. // If the user chooses to continue, it returns a ClientFunc with a shared client or a randomized client.
// If there are no proxies available, it returns a ClientFunc with a shared client. // If there are no proxies available, it returns a ClientFunc with a shared client.
func getClientFunc( func getClientFunc(
ctx context.Context,
timeout time.Duration, timeout time.Duration,
proxies []config.Proxy, proxies []config.Proxy,
dodosCount int, dodosCount int,
@ -229,8 +258,11 @@ func getClientFunc(
isTLS := URL.Scheme == "https" isTLS := URL.Scheme == "https"
if len(proxies) > 0 { if len(proxies) > 0 {
activeProxyClients := getActiveProxyClients( activeProxyClients := getActiveProxyClients(
proxies, timeout, dodosCount, URL, ctx, proxies, timeout, dodosCount, URL,
) )
if ctx.Err() != nil {
return nil
}
activeProxyClientsCount := len(activeProxyClients) activeProxyClientsCount := len(activeProxyClients)
var yesOrNoMessage string var yesOrNoMessage string
if activeProxyClientsCount == 0 { if activeProxyClientsCount == 0 {
@ -287,6 +319,7 @@ func getClientFunc(
// Once all goroutines have completed, the function waits for them to finish and // Once all goroutines have completed, the function waits for them to finish and
// returns a flattened slice of active proxy clients. // returns a flattened slice of active proxy clients.
func getActiveProxyClients( func getActiveProxyClients(
ctx context.Context,
proxies []config.Proxy, proxies []config.Proxy,
timeout time.Duration, timeout time.Duration,
dodosCount int, dodosCount int,
@ -295,12 +328,17 @@ func getActiveProxyClients(
activeProxyClientsArray := make([][]fasthttp.HostClient, dodosCount) activeProxyClientsArray := make([][]fasthttp.HostClient, dodosCount)
proxiesCount := len(proxies) proxiesCount := len(proxies)
var wg sync.WaitGroup var (
wg.Add(dodosCount + 1) // +1 for progress tracker wg sync.WaitGroup
streamWG sync.WaitGroup
)
wg.Add(dodosCount)
streamWG.Add(1)
var proxiesSlice []config.Proxy var proxiesSlice []config.Proxy
countSlice := make([]int, dodosCount) countSlice := make([]int, dodosCount)
go streamProgress(&wg, proxiesCount, "Searching for active proxies🌐", &countSlice) streamCtx, streamCtxCancel := context.WithCancel(context.Background())
go streamProgress(streamCtx, &streamWG, proxiesCount, "Searching for active proxies🌐", &countSlice)
for i := 0; i < dodosCount; i++ { for i := 0; i < dodosCount; i++ {
if i+1 == dodosCount { if i+1 == dodosCount {
@ -309,6 +347,7 @@ func getActiveProxyClients(
proxiesSlice = proxies[i*proxiesCount/dodosCount : (i+1)*proxiesCount/dodosCount] proxiesSlice = proxies[i*proxiesCount/dodosCount : (i+1)*proxiesCount/dodosCount]
} }
go findActiveProxyClients( go findActiveProxyClients(
ctx,
proxiesSlice, proxiesSlice,
timeout, timeout,
&activeProxyClientsArray[i], &activeProxyClientsArray[i],
@ -318,6 +357,8 @@ func getActiveProxyClients(
) )
} }
wg.Wait() wg.Wait()
streamCtxCancel()
streamWG.Wait()
return utils.Flatten(activeProxyClientsArray) return utils.Flatten(activeProxyClientsArray)
} }
@ -326,6 +367,7 @@ func getActiveProxyClients(
// It also increments the count for each successful request. // It also increments the count for each successful request.
// The function is designed to be used as a concurrent operation, and it uses the WaitGroup to wait for all goroutines to finish. // The function is designed to be used as a concurrent operation, and it uses the WaitGroup to wait for all goroutines to finish.
func findActiveProxyClients( func findActiveProxyClients(
ctx context.Context,
proxies []config.Proxy, proxies []config.Proxy,
timeout time.Duration, timeout time.Duration,
activeProxyClients *[]fasthttp.HostClient, activeProxyClients *[]fasthttp.HostClient,
@ -341,6 +383,10 @@ func findActiveProxyClients(
request.Header.SetMethod("GET") request.Header.SetMethod("GET")
for _, proxy := range proxies { for _, proxy := range proxies {
if ctx.Err() != nil {
return
}
func() { func() {
defer func() { *count++ }() defer func() { *count++ }()
@ -520,6 +566,7 @@ func setRequestBody(req *fasthttp.Request, body string) {
// The function runs in a separate goroutine and updates the progress bar until all items are processed. // The function runs in a separate goroutine and updates the progress bar until all items are processed.
// Once all items are processed, it marks the progress bar as done and stops rendering. // Once all items are processed, it marks the progress bar as done and stops rendering.
func streamProgress( func streamProgress(
ctx context.Context,
wg *sync.WaitGroup, wg *sync.WaitGroup,
total int, total int,
message string, message string,
@ -543,11 +590,21 @@ func streamProgress(
totalCount += count totalCount += count
} }
dodosTracker.SetValue(int64(totalCount)) dodosTracker.SetValue(int64(totalCount))
if ctx.Err() != nil {
fmt.Printf("\r")
dodosTracker.MarkAsErrored()
time.Sleep(time.Millisecond * 300)
pw.Stop()
return
}
if totalCount == total { if totalCount == total {
break break
} }
time.Sleep(time.Millisecond * 200) time.Sleep(time.Millisecond * 200)
} }
fmt.Printf("\r")
dodosTracker.MarkAsDone() dodosTracker.MarkAsDone()
time.Sleep(time.Millisecond * 300) time.Sleep(time.Millisecond * 300)
pw.Stop() pw.Stop()
@ -555,14 +612,25 @@ func streamProgress(
// checkConnection checks the internet connection by making requests to different websites. // checkConnection checks the internet connection by making requests to different websites.
// It returns true if the connection is successful, otherwise false. // It returns true if the connection is successful, otherwise false.
func checkConnection() bool { func checkConnection(ctx context.Context) bool {
ch := make(chan bool)
go func() {
_, _, err := fasthttp.Get(nil, "https://www.google.com") _, _, err := fasthttp.Get(nil, "https://www.google.com")
if err != nil { if err != nil {
_, _, err = fasthttp.Get(nil, "https://www.bing.com") _, _, err = fasthttp.Get(nil, "https://www.bing.com")
if err != nil { if err != nil {
_, _, err = fasthttp.Get(nil, "https://www.yahoo.com") _, _, err = fasthttp.Get(nil, "https://www.yahoo.com")
return err == nil ch <- err == nil
}
ch <- true
}
ch <- true
}()
select {
case <-ctx.Done():
return false
case res := <-ch:
return res
} }
} }
return true
}