Skip to content

Refactor sendAndReceive #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions internal/stage_dtp2.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package internal
import (
"github.com/codecrafters-io/kafka-tester/internal/assertions"
"github.com/codecrafters-io/kafka-tester/internal/kafka_executable"
"github.com/codecrafters-io/kafka-tester/protocol"
kafkaapi "github.com/codecrafters-io/kafka-tester/protocol/api"
"github.com/codecrafters-io/kafka-tester/protocol/builder"
"github.com/codecrafters-io/kafka-tester/protocol/common"
"github.com/codecrafters-io/kafka-tester/protocol/kafka_client"
"github.com/codecrafters-io/kafka-tester/protocol/serializer"
"github.com/codecrafters-io/tester-utils/logger"
"github.com/codecrafters-io/tester-utils/test_case_harness"
Expand All @@ -25,13 +25,13 @@ func testDTPartitionWithUnknownTopic(stageHarness *test_case_harness.TestCaseHar
}

correlationId := getRandomCorrelationId()
broker := protocol.NewBroker("localhost:9092")
if err := broker.ConnectWithRetries(b, stageLogger); err != nil {
client := kafka_client.NewClient("localhost:9092")
if err := client.ConnectWithRetries(b, stageLogger); err != nil {
return err
}
defer func(broker *protocol.Broker) {
_ = broker.Close()
}(broker)
defer func(client *kafka_client.Client) {
_ = client.Close()
}(client)

request := kafkaapi.DescribeTopicPartitionsRequest{
Header: builder.NewRequestHeaderBuilder().BuildDescribeTopicPartitionsRequestHeader(correlationId),
Expand All @@ -45,15 +45,10 @@ func testDTPartitionWithUnknownTopic(stageHarness *test_case_harness.TestCaseHar
},
}

message := kafkaapi.EncodeDescribeTopicPartitionsRequest(&request)
stageLogger.Infof("Sending \"DescribeTopicPartitions\" (version: %v) request (Correlation id: %v)", request.Header.ApiVersion, request.Header.CorrelationId)
stageLogger.Debugf("Hexdump of sent \"DescribeTopicPartitions\" request: \n%v\n", GetFormattedHexdump(message))

response, err := broker.SendAndReceive(message)
response, err := client.SendAndReceive(request, stageLogger)
if err != nil {
return err
}
stageLogger.Debugf("Hexdump of received \"DescribeTopicPartitions\" response: \n%v\n", GetFormattedHexdump(response.RawBytes))

responseHeader, responseBody, err := kafkaapi.DecodeDescribeTopicPartitionsHeaderAndResponse(response.Payload, stageLogger)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions internal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ func getRandomCorrelationId() int32 {
return int32(random.RandomInt(0, math.MaxInt32-1))
}

// TODO: Remove in lieu of protocol.utils.GetFormattedHexdump

func GetFormattedHexdump(data []byte) string {
// This is used for logs
// Contains headers + vertical & horizontal separators + offset
Expand Down
9 changes: 1 addition & 8 deletions protocol/api/describe_topic_partitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,12 @@ package kafkaapi
import (
"github.com/codecrafters-io/kafka-tester/protocol"
"github.com/codecrafters-io/kafka-tester/protocol/decoder"
"github.com/codecrafters-io/kafka-tester/protocol/encoder"
"github.com/codecrafters-io/kafka-tester/protocol/errors"
"github.com/codecrafters-io/tester-utils/logger"
)

func EncodeDescribeTopicPartitionsRequest(request *DescribeTopicPartitionsRequest) []byte {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be removed.

encoder := encoder.Encoder{}
encoder.Init(make([]byte, 4096))

request.Encode(&encoder)
messageBytes := encoder.PackMessage()

return messageBytes
return request.Encode()
}

// DecodeDescribeTopicPartitionsHeaderAndResponse decodes the header and response
Expand Down
31 changes: 21 additions & 10 deletions protocol/api/describe_topic_partitions_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@ import (
"github.com/codecrafters-io/kafka-tester/protocol/encoder"
)

type DescribeTopicPartitionsRequest struct {
Header RequestHeader
Body DescribeTopicPartitionsRequestBody
}

func (r *DescribeTopicPartitionsRequest) Encode(pe *encoder.Encoder) {
r.Header.EncodeV2(pe)
r.Body.Encode(pe)
}

type DescribeTopicPartitionsRequestBody struct {
Topics []TopicName
ResponsePartitionLimit int32
Expand Down Expand Up @@ -61,3 +51,24 @@ func (c *Cursor) Encode(pe *encoder.Encoder) {
pe.PutInt32(c.PartitionIndex)
pe.PutEmptyTaggedFieldArray()
}

type DescribeTopicPartitionsRequest struct {
Header RequestHeader
Body DescribeTopicPartitionsRequestBody
}

func (r DescribeTopicPartitionsRequest) Encode() []byte {
encoder := encoder.Encoder{}
encoder.Init(make([]byte, 4096))

r.Header.Encode(&encoder)
r.Body.Encode(&encoder)
messageBytes := encoder.PackMessage()

return messageBytes

}

func (r DescribeTopicPartitionsRequest) GetHeader() RequestHeader {
return r.Header
}
8 changes: 7 additions & 1 deletion protocol/api/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ type RequestHeader struct {
ClientId string
}

func (h *RequestHeader) EncodeV2(enc *encoder.Encoder) {
func (h RequestHeader) Encode(enc *encoder.Encoder) {
h.EncodeV2(enc)
}

// TODO: Don't expose this, only Encode

func (h RequestHeader) EncodeV2(enc *encoder.Encoder) {
enc.PutInt16(h.ApiKey)
enc.PutInt16(h.ApiVersion)
enc.PutInt32(h.CorrelationId)
Expand Down
8 changes: 8 additions & 0 deletions protocol/builder/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package builder

import kafkaapi "github.com/codecrafters-io/kafka-tester/protocol/api"

type RequestI interface {
Encode() []byte
GetHeader() kafkaapi.RequestHeader
}
198 changes: 198 additions & 0 deletions protocol/kafka_client/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package kafka_client

import (
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
"time"

"github.com/codecrafters-io/kafka-tester/internal/kafka_executable"
"github.com/codecrafters-io/kafka-tester/protocol/builder"
"github.com/codecrafters-io/kafka-tester/protocol/utils"
"github.com/codecrafters-io/tester-utils/logger"
)

type Response struct {
RawBytes []byte
Payload []byte
}

func (r *Response) createFrom(lengthResponse []byte, bodyResponse []byte) Response {
return Response{
RawBytes: append(lengthResponse, bodyResponse...),
Payload: bodyResponse,
}
}

// Client represents a single connection to the Kafka broker.
// All operations on this object are entirely concurrency-safe.
type Client struct {
id int32
addr string
conn net.Conn
}

// NewClient creates and returns a Client targeting the given host:port address.
// This does not attempt to actually connect, you have to call Open() for that.
func NewClient(addr string) *Client {
return &Client{id: -1, addr: addr}
}

func (c *Client) ConnectWithRetries(executable *kafka_executable.KafkaExecutable, logger *logger.Logger) error {
const maxRetries = 10
logger.Debugf("Connecting to broker at: %s", c.addr)

retries := 0
var err error
var conn net.Conn
for {
conn, err = net.Dial("tcp", c.addr)
if err != nil && retries >= maxRetries {
logger.Infof("All retries failed. Exiting.")
return err
}

if err != nil {
if executable.HasExited() {
return fmt.Errorf("Looks like your program has terminated. A Kafka server is expected to be a long-running process.")
}

// Don't print errors in the first second
// ToDo: fixtures fail
// if retries > 2 {
// logger.Infof("Failed to connect to broker at %s, retrying in 1s", b.addr)
// }

retries++
time.Sleep(1000 * time.Millisecond)
} else {
break
}
}
logger.Debugf("Connection to broker at %s successful", c.addr)
c.conn = conn

return nil
}

func (c *Client) Close() error {
err := c.conn.Close()
if err != nil {
return fmt.Errorf("Failed to close connection to broker at %s: %s", c.addr, err)
}
return nil
}

func (c *Client) SendAndReceive(request builder.RequestI, stageLogger *logger.Logger) (Response, error) {
header := request.GetHeader()
apiType := utils.APIKeyToName(header.ApiKey)
apiVersion := header.ApiVersion
correlationId := header.CorrelationId
message := request.Encode()

stageLogger.Infof("Sending \"%s\" (version: %v) request (Correlation id: %v)", apiType, apiVersion, correlationId)
stageLogger.Debugf("Hexdump of sent \"%s\" request: \n%v\n", apiType, utils.GetFormattedHexdump(message))

response := Response{}

err := c.Send(message)
if err != nil {
return response, err
}

response, err = c.Receive()
if err != nil {
return response, err
}

stageLogger.Debugf("Hexdump of received \"%s\" response: \n%v\n", apiType, utils.GetFormattedHexdump(response.RawBytes))

return response, nil
}

func (c *Client) Send(message []byte) error {
// Set a deadline for the write operation
err := c.conn.SetWriteDeadline(time.Now().Add(100 * time.Millisecond))
if err != nil {
return fmt.Errorf("failed to set write deadline: %v", err)
}

_, err = c.conn.Write(message)

// Reset the write deadline
c.conn.SetWriteDeadline(time.Time{})

if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return fmt.Errorf("write operation timed out")
}
return fmt.Errorf("error writing to connection: %v", err)
}

return nil
}

func (c *Client) Receive() (Response, error) {
response := Response{}

lengthResponse := make([]byte, 4) // length
_, err := io.ReadFull(c.conn, lengthResponse)
if err != nil {
return response, err
}
length := int32(binary.BigEndian.Uint32(lengthResponse))

bodyResponse := make([]byte, length)

// Set a deadline for the read operation
err = c.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
if err != nil {
return response, fmt.Errorf("failed to set read deadline: %v", err)
}

numBytesRead, err := io.ReadFull(c.conn, bodyResponse)

// Reset the read deadline
c.conn.SetReadDeadline(time.Time{})

bodyResponse = bodyResponse[:numBytesRead]

if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// If the read timed out, return the partial response we have so far
// This way we can surface a better error message to help w debugging
return response.createFrom(lengthResponse, bodyResponse), nil
}
return response, fmt.Errorf("error reading from connection: %v", err)
}

return response.createFrom(lengthResponse, bodyResponse), nil
}

func (c *Client) ReceiveRaw() ([]byte, error) {
var buf bytes.Buffer

// Set a deadline for the read operation
err := c.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
if err != nil {
return nil, fmt.Errorf("failed to set read deadline: %v", err)
}

// Use a limited reader to prevent reading indefinitely
limitedReader := io.LimitReader(c.conn, 1024*1024) // Limit to 1MB, adjust as needed
_, err = io.Copy(&buf, limitedReader)

// Reset the read deadline
c.conn.SetReadDeadline(time.Time{})

if err != nil && err != io.EOF {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
} else {
return nil, fmt.Errorf("error reading from connection: %v", err)
}
}

return buf.Bytes(), nil
}
Loading