🐛 Fix race condition in streamProgress in countSlice

This commit is contained in:
Aykhan Shahsuvarov 2024-09-10 03:30:32 +04:00
parent 04bd3ab6d1
commit 7296646428

View File

@ -160,10 +160,10 @@ func releaseDodos(
wg.Add(dodosCount) wg.Add(dodosCount)
streamWG.Add(1) streamWG.Add(1)
responses := make([][]Response, dodosCount) responses := make([][]Response, dodosCount)
countSlice := make([]int, dodosCount) increase := make(chan int64, requestCount)
streamCtx, streamCtxCancel := context.WithCancel(context.Background()) streamCtx, streamCtxCancel := context.WithCancel(context.Background())
go streamProgress(streamCtx, &streamWG, requestCount, "Dodos Working🔥", &countSlice) go streamProgress(streamCtx, &streamWG, int64(requestCount), "Dodos Working🔥", increase)
for i := 0; i < dodosCount; i++ { for i := 0; i < dodosCount; i++ {
if i+1 == dodosCount { if i+1 == dodosCount {
@ -180,7 +180,7 @@ func releaseDodos(
ctx, ctx,
dodoSpecificRequest, dodoSpecificRequest,
&responses[i], &responses[i],
&countSlice[i], increase,
requestCountPerDodo, requestCountPerDodo,
clientDoFunc, clientDoFunc,
&wg, &wg,
@ -192,20 +192,11 @@ func releaseDodos(
return utils.Flatten(responses) return utils.Flatten(responses)
} }
// sendRequest sends multiple HTTP requests concurrently using the provided clientDo function.
// It takes a context, a request, a slice to store the response data, a counter to keep track of the number of requests,
// the total number of requests to be sent, a clientDo function to execute the requests,
// and a wait group to synchronize the goroutines.
// It releases the request and decrements the wait group counter when done.
// For each request, it checks if the context has been canceled and returns if so.
// It measures the time it takes to complete each request and appends the response data to the responseData slice.
// If an error occurs during the request, it appends a response with a status code of 0 and the error to the responseData slice.
// Otherwise, it appends a response with the actual status code and nil error to the responseData slice.
func sendRequest( func sendRequest(
ctx context.Context, ctx context.Context,
request *fasthttp.Request, request *fasthttp.Request,
responseData *[]Response, responseData *[]Response,
counter *int, increase chan<- int64,
requestCount int, requestCount int,
clientDo ClientDoFunc, clientDo ClientDoFunc,
wg *sync.WaitGroup, wg *sync.WaitGroup,
@ -219,7 +210,7 @@ func sendRequest(
} }
func() { func() {
defer func() { *counter++ }() defer func() { increase <- 1 }()
startTime := time.Now() startTime := time.Now()
response, err := clientDo(ctx, request) response, err := clientDo(ctx, request)
@ -351,10 +342,10 @@ func getActiveProxyClients(
wg.Add(dodosCount) wg.Add(dodosCount)
streamWG.Add(1) streamWG.Add(1)
var proxiesSlice []config.Proxy var proxiesSlice []config.Proxy
increase := make(chan int64, proxiesCount)
countSlice := make([]int, dodosCount)
streamCtx, streamCtxCancel := context.WithCancel(context.Background()) streamCtx, streamCtxCancel := context.WithCancel(context.Background())
go streamProgress(streamCtx, &streamWG, proxiesCount, "Searching for active proxies🌐", &countSlice) go streamProgress(streamCtx, &streamWG, int64(proxiesCount), "Searching for active proxies🌐", increase)
for i := 0; i < dodosCount; i++ { for i := 0; i < dodosCount; i++ {
if i+1 == dodosCount { if i+1 == dodosCount {
@ -367,7 +358,7 @@ func getActiveProxyClients(
proxiesSlice, proxiesSlice,
timeout, timeout,
&activeProxyClientsArray[i], &activeProxyClientsArray[i],
&countSlice[i], increase,
URL, URL,
&wg, &wg,
) )
@ -378,20 +369,27 @@ func getActiveProxyClients(
return utils.Flatten(activeProxyClientsArray) return utils.Flatten(activeProxyClientsArray)
} }
// findActiveProxyClients is a function that finds active proxy clients by sending HTTP GET requests to a list of proxies. // findActiveProxyClients checks a list of proxies to determine which ones are active
// It takes a context.Context, a slice of config.Proxy, a time.Duration for the timeout, a pointer to a slice of fasthttp.HostClient to store the active proxy clients, // and appends the active ones to the provided activeProxyClients slice.
// a pointer to an int to keep track of the count, a pointer to a url.URL for the URL to send the requests to, and a pointer to a sync.WaitGroup to synchronize the goroutines. //
// It sends GET requests to each proxy in the given list and checks if the response status code is 200. // Parameters:
// If the context is canceled, the function returns immediately. // - ctx: The context to control cancellation and timeout.
// The active proxy clients that pass the check are added to the provided slice of fasthttp.HostClient. // - proxies: A slice of Proxy configurations to be checked.
// The function is designed to be run concurrently using goroutines and the sync.WaitGroup is used to wait for all goroutines to finish. // - timeout: The duration to wait for each proxy check before timing out.
// The function is responsible for releasing acquired resources and closing idle connections. // - activeProxyClients: A pointer to a slice where active proxy clients will be appended.
// - increase: A channel to signal the increase of checked proxies count.
// - URL: The URL to be used for checking the proxies.
// - wg: A WaitGroup to signal when the function is done.
//
// The function sends a GET request to each proxy using the provided URL. If the proxy
// responds with a status code of 200, it is considered active and added to the activeProxyClients slice.
// The function respects the context's cancellation and timeout settings.
func findActiveProxyClients( func findActiveProxyClients(
ctx context.Context, ctx context.Context,
proxies []config.Proxy, proxies []config.Proxy,
timeout time.Duration, timeout time.Duration,
activeProxyClients *[]fasthttp.HostClient, activeProxyClients *[]fasthttp.HostClient,
count *int, increase chan<- int64,
URL *url.URL, URL *url.URL,
wg *sync.WaitGroup, wg *sync.WaitGroup,
) { ) {
@ -408,7 +406,7 @@ func findActiveProxyClients(
} }
func() { func() {
defer func() { *count++ }() defer func() { increase <- 1 }()
response := fasthttp.AcquireResponse() response := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(response) defer fasthttp.ReleaseResponse(response)
@ -651,17 +649,17 @@ func setRequestBody(req *fasthttp.Request, body string) {
req.SetBody([]byte(body)) req.SetBody([]byte(body))
} }
// streamProgress displays the progress of a stream operation. // streamProgress streams the progress of a task to the console using a progress bar.
// It takes a wait group, the total number of items to process, a message to display, // It listens for increments on the provided channel and updates the progress bar accordingly.
// and a pointer to a slice of counts for each item processed. //
// The function runs in a separate goroutine and updates the progress bar until all items are processed. // The function will stop and mark the progress as errored if the context is cancelled.
// Once all items are processed, it marks the progress bar as done and stops rendering. // It will also stop and mark the progress as done when the total number of increments is reached.
func streamProgress( func streamProgress(
ctx context.Context, ctx context.Context,
wg *sync.WaitGroup, wg *sync.WaitGroup,
total int, total int64,
message string, message string,
countSlice *[]int, increase <-chan int64,
) { ) {
defer wg.Done() defer wg.Done()
pw := progress.NewWriter() pw := progress.NewWriter()
@ -672,33 +670,22 @@ func streamProgress(
go pw.Render() go pw.Render()
dodosTracker := progress.Tracker{ dodosTracker := progress.Tracker{
Message: message, Message: message,
Total: int64(total), Total: total,
} }
pw.AppendTracker(&dodosTracker) pw.AppendTracker(&dodosTracker)
for { for {
totalCount := 0 select {
for _, count := range *countSlice { case <-ctx.Done():
totalCount += count
}
dodosTracker.SetValue(int64(totalCount))
if ctx.Err() != nil {
fmt.Printf("\r") fmt.Printf("\r")
dodosTracker.MarkAsErrored() dodosTracker.MarkAsErrored()
time.Sleep(time.Millisecond * 300) time.Sleep(time.Millisecond * 300)
pw.Stop() pw.Stop()
return return
}
if totalCount == total { case value := <-increase:
break dodosTracker.Increment(value)
} }
time.Sleep(time.Millisecond * 200)
} }
fmt.Printf("\r")
dodosTracker.MarkAsDone()
time.Sleep(time.Millisecond * 300)
pw.Stop()
} }
// checkConnection checks the internet connection by making requests to different websites. // checkConnection checks the internet connection by making requests to different websites.