diff --git a/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/CustomPartitioner.java b/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/CustomPartitioner.java new file mode 100644 index 0000000000..f02a91f6fc --- /dev/null +++ b/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/CustomPartitioner.java @@ -0,0 +1,32 @@ +package com.baeldung.partitioningstrategy; + +import java.util.Map; + +import org.apache.kafka.clients.producer.Partitioner; +import org.apache.kafka.common.Cluster; + +public class CustomPartitioner implements Partitioner { + private static final int PREMIUM_PARTITION = 0; + private static final int NORMAL_PARTITION = 1; + + @Override + public int partition(String topic, Object key, byte[] keyBytes, Object value, byte[] valueBytes, Cluster cluster) { + String customerType = extractCustomerType(key.toString()); + return "premium".equalsIgnoreCase(customerType) ? PREMIUM_PARTITION : NORMAL_PARTITION; + } + + private String extractCustomerType(String key) { + String[] parts = key.split("_"); + return parts.length > 1 ? parts[1] : "normal"; + } + + @Override + public void configure(Map configs) { + + } + + @Override + public void close() { + + } +} diff --git a/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/KafkaApplication.java b/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/KafkaApplication.java new file mode 100644 index 0000000000..57ecc5e187 --- /dev/null +++ b/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/KafkaApplication.java @@ -0,0 +1,48 @@ +package com.baeldung.partitioningstrategy; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; +import org.springframework.kafka.core.DefaultKafkaProducerFactory; +import org.springframework.kafka.core.KafkaTemplate; +import org.springframework.kafka.core.ProducerFactory; + +@SpringBootApplication +public class KafkaApplication { + + @Bean + public KafkaTemplate kafkaTemplate() { + return new KafkaTemplate<>(producerFactory()); + } + + @Bean + public ProducerFactory producerFactory() { + Map configProps = new HashMap<>(); + configProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + configProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + configProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + return new DefaultKafkaProducerFactory<>(configProps); + } + + @Bean + public KafkaConsumer kafkaConsumer() { + Map configProps = new HashMap<>(); + configProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + configProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + configProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + configProps.put(ConsumerConfig.GROUP_ID_CONFIG, "test-group"); // Set a unique group ID + return new KafkaConsumer<>(configProps); + } + + public static void main(String[] args) { + SpringApplication.run(KafkaApplication.class, args); + } +} diff --git a/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/KafkaMessageConsumer.java b/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/KafkaMessageConsumer.java new file mode 100644 index 0000000000..a500e04737 --- /dev/null +++ b/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/KafkaMessageConsumer.java @@ -0,0 +1,32 @@ +package com.baeldung.partitioningstrategy; + +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.List; + +import org.springframework.kafka.annotation.KafkaListener; +import org.springframework.messaging.handler.annotation.Header; +import org.springframework.messaging.handler.annotation.Payload; +import org.springframework.stereotype.Service; +import org.springframework.kafka.support.KafkaHeaders; + +import jakarta.annotation.Nullable; + +@Service +public class KafkaMessageConsumer { + private final List receivedMessages = new CopyOnWriteArrayList<>(); + + @KafkaListener(topics = { "order-topic", "default-topic" }, groupId = "test-group") + public void listen(@Payload String message, @Header(KafkaHeaders.RECEIVED_PARTITION) int partition, @Header(KafkaHeaders.RECEIVED_KEY) @Nullable String key) { + ReceivedMessage receivedMessage = new ReceivedMessage(key, message, partition); + System.out.println("Received message: " + receivedMessage); + receivedMessages.add(receivedMessage); + } + + public List getReceivedMessages() { + return receivedMessages; + } + + public void clearReceivedMessages() { + receivedMessages.clear(); + } +} diff --git a/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/ReceivedMessage.java b/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/ReceivedMessage.java new file mode 100644 index 0000000000..a262f62e39 --- /dev/null +++ b/spring-kafka/src/main/java/com/baeldung/partitioningstrategy/ReceivedMessage.java @@ -0,0 +1,30 @@ +package com.baeldung.partitioningstrategy; + +public class ReceivedMessage { + private final String key; + private final String message; + private final int partition; + + public ReceivedMessage(String key, String message, int partition) { + this.key = key; + this.message = message; + this.partition = partition; + } + + @Override + public String toString() { + return "Key: " + key + " - Message: " + message + " - Partition: " + partition; + } + + public String getKey() { + return key; + } + + public String getMessage() { + return message; + } + + public int getPartition() { + return partition; + } +} diff --git a/spring-kafka/src/test/java/com/baeldung/partitioningstrategy/KafkaApplicationIntegrationTest.java b/spring-kafka/src/test/java/com/baeldung/partitioningstrategy/KafkaApplicationIntegrationTest.java new file mode 100644 index 0000000000..2f2cccbb12 --- /dev/null +++ b/spring-kafka/src/test/java/com/baeldung/partitioningstrategy/KafkaApplicationIntegrationTest.java @@ -0,0 +1,193 @@ +package com.baeldung.partitioningstrategy; + +import static org.junit.Assert.assertEquals; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.StringSerializer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.kafka.core.DefaultKafkaProducerFactory; +import org.springframework.kafka.core.KafkaTemplate; +import org.springframework.kafka.test.EmbeddedKafkaBroker; +import org.springframework.kafka.test.context.EmbeddedKafka; +import org.springframework.kafka.test.utils.KafkaTestUtils; + +import static org.awaitility.Awaitility.await; +import static java.util.concurrent.TimeUnit.SECONDS; + +@SpringBootTest +@EmbeddedKafka(partitions = 3, brokerProperties = { "listeners=PLAINTEXT://localhost:9092" }) +public class KafkaApplicationIntegrationTest { + + @Autowired + private KafkaTemplate kafkaTemplate; + + @Autowired + private KafkaMessageConsumer kafkaMessageConsumer; + + @Autowired + private EmbeddedKafkaBroker embeddedKafkaBroker; + + @Autowired + private Consumer consumer; + + @BeforeEach + public void clearMessages() { + kafkaMessageConsumer.clearReceivedMessages(); + } + + @Test + public void givenDefaultPartitioner_whenSendingMessagesWithoutKey_shouldUseStickyDistribution() throws InterruptedException { + kafkaTemplate.send("default-topic", "message1"); + kafkaTemplate.send("default-topic", "message2"); + kafkaTemplate.send("default-topic", "message3"); + + await().atMost(2, SECONDS) + .until(() -> kafkaMessageConsumer.getReceivedMessages() + .size() >= 3); + + List records = kafkaMessageConsumer.getReceivedMessages(); + + Set uniquePartitions = records.stream() + .map(ReceivedMessage::getPartition) + .collect(Collectors.toSet()); + + assertEquals(1, uniquePartitions.size()); + } + + @Test + void givenProducerWithSameKeyMessages_whenSendingMessages_shouldRouteToSamePartition() throws InterruptedException { + kafkaTemplate.send("order-topic", "partitionA", "critical data"); + kafkaTemplate.send("order-topic", "partitionA", "more critical data"); + kafkaTemplate.send("order-topic", "partitionB", "another critical message"); + kafkaTemplate.send("order-topic", "partitionA", "another more critical data"); + + await().atMost(2, SECONDS) + .until(() -> kafkaMessageConsumer.getReceivedMessages() + .size() >= 4); + + List records = kafkaMessageConsumer.getReceivedMessages(); + Map> messagesByKey = groupMessagesByKey(records); + + messagesByKey.forEach((key, messages) -> { + int expectedPartition = messages.get(0) + .getPartition(); + for (ReceivedMessage message : messages) { + assertEquals("Messages with key '" + key + "' should be in the same partition", message.getPartition(), expectedPartition); + } + }); + } + + @Test + public void givenProducerWithSameKeyMessages_whenSendingMessages_shouldReceiveInProducedOrder() throws InterruptedException { + kafkaTemplate.send("order-topic", "partitionA", "message1"); + kafkaTemplate.send("order-topic", "partitionA", "message3"); + kafkaTemplate.send("order-topic", "partitionA", "message4"); + + await().atMost(2, SECONDS) + .until(() -> kafkaMessageConsumer.getReceivedMessages() + .size() >= 3); + + List records = kafkaMessageConsumer.getReceivedMessages(); + + StringBuilder resultMessage = new StringBuilder(); + records.forEach(record -> resultMessage.append(record.getMessage())); + String expectedMessage = "message1message3message4"; + + assertEquals("Messages with the same key should be received in the order they were produced within a partition", expectedMessage, + resultMessage.toString()); + } + + @Test + public void givenCustomPartitioner_whenSendingMessages_shouldRouteToCorrectPartition() throws InterruptedException { + // Configure the producer with the custom partitioner + KafkaTemplate kafkaTemplate = setProducerToUseCustomPartitioner(); + + kafkaTemplate.send("order-topic", "123_premium", "Order 123, Premium order message"); + kafkaTemplate.send("order-topic", "456_normal", "Normal order message"); + + await().atMost(2, SECONDS) + .until(() -> kafkaMessageConsumer.getReceivedMessages() + .size() >= 2); + + List records = kafkaMessageConsumer.getReceivedMessages(); + + // Validate that messages are routed to the correct partition based on customer type + for (ReceivedMessage record : records) { + if ("123_premium".equals(record.getKey())) { + assertEquals("Premium order message should be in partition 0", 0, record.getPartition()); + } else if ("456_normal".equals(record.getKey())) { + assertEquals("Normal order message should be in partition 1", 1, record.getPartition()); + } + } + } + + @Test + public void givenDirectPartitionAssignment_whenSendingMessages_shouldRouteToSpecifiedPartitions() throws Exception { + kafkaTemplate.send("order-topic", 0, "123_premium", "Premium order message"); + kafkaTemplate.send("order-topic", 1, "456_normal", "Normal order message"); + + await().atMost(2, SECONDS) + .until(() -> kafkaMessageConsumer.getReceivedMessages() + .size() >= 2); + + List records = kafkaMessageConsumer.getReceivedMessages(); + + for (ReceivedMessage record : records) { + if ("123_premium".equals(record.getKey())) { + assertEquals("Premium order message should be in partition 0", 0, record.getPartition()); + } else if ("456_normal".equals(record.getKey())) { + assertEquals("Normal order message should be in partition 1", 1, record.getPartition()); + } + } + } + + @Test + public void givenCustomPartitioner_whenSendingMessages_shouldConsumeOnlyFromSpecificPartition() throws InterruptedException { + KafkaTemplate kafkaTemplate = setProducerToUseCustomPartitioner(); + + kafkaTemplate.send("order-topic", "123_premium", "Order 123, Premium order message"); + kafkaTemplate.send("order-topic", "456_normal", "Normal order message"); + + await().atMost(2, SECONDS) + .until(() -> kafkaMessageConsumer.getReceivedMessages() + .size() >= 2); + + consumer.assign(Collections.singletonList(new TopicPartition("order-topic", 0))); + ConsumerRecords records = consumer.poll(Duration.ofMillis(100)); + + for (ConsumerRecord record : records) { + assertEquals("Premium order message should be in partition 0", 0, record.partition()); + assertEquals("123_premium", record.key()); + } + } + + private KafkaTemplate setProducerToUseCustomPartitioner() { + Map producerProps = KafkaTestUtils.producerProps(embeddedKafkaBroker.getBrokersAsString()); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + producerProps.put(ProducerConfig.PARTITIONER_CLASS_CONFIG, CustomPartitioner.class.getName()); + DefaultKafkaProducerFactory producerFactory = new DefaultKafkaProducerFactory<>(producerProps); + + return new KafkaTemplate<>(producerFactory); + } + + private Map> groupMessagesByKey(List messages) { + return messages.stream() + .filter(message -> message.getKey() != null) + .collect(Collectors.groupingBy(ReceivedMessage::getKey)); + } +}