This article explains the Java 8 Stream.flatMap and how to use it.
1. What is flatMap()?
1.1 Review the below structure. It consists of a 2 levels Stream or a 2d arrays.
# Stream<String[]>
# Stream<Stream<String>>
# String[][]
[
[1, 2],
[3, 4],
[5, 6]
]
In Java 8, we can use the flatMap to convert the above 2 levels Stream into one Stream level or a 2d array into a 1d array.
# Stream<String>
# String[]
[1, 2, 3, 4, 5, 6]
2. Why flat a Stream?
2.1 It’s challenging to process a Stream containing more than one level, like Stream<String[]> or Stream<List<LineItem>> or Stream<Stream<String>>. And we flat the 2 levels Stream into one level, like Stream<String> or Stream<LineItem>, so that we can easily loop the Stream and process it.
Review the below example, before and after applying flatMap on a Stream.
2.2 Below is a 2d array, and we can use Arrays.stream or Stream.of to convert it into a Stream, and it produces a Stream of String[] or Stream<String[]>.
String[][] array = new String[][]{{"a", "b"}, {"c", "d"}, {"e", "f"}};
// array to a stream
Stream<String[]> stream1 = Arrays.stream(array);
// same result
Stream<String[]> stream2 = Stream.of(array);
or like this.
[
[a, b],
[c, d],
[e, f]
]
2.3 Here’s the requirement, we want to filter out the a and print out all the characters.
First, we try the Stream#filter directly. However, the below program will print nothing, and it is because the x inside the Stream#filter is a String[], not a String; the condition will always remain false, and the Stream will collect nothing.
String[][] array = new String[][]{{"a", "b"}, {"c", "d"}, {"e", "f"}};
// convert array to a stream
Stream<String[]> stream1 = Arrays.stream(array);
List<String[]> result = stream1
.filter(x -> !x.equals("a")) // x is a String[], not String!
.collect(Collectors.toList());
System.out.println(result.size()); // 0
result.forEach(System.out::println); // print nothing?
OK, this time, we refactor the filter method to deal with the String[].
String[][] array = new String[][]{{"a", "b"}, {"c", "d"}, {"e", "f"}};
// array to a stream
Stream<String[]> stream1 = Arrays.stream(array);
// x is a String[]
List<String[]> result = stream1
.filter(x -> {
for(String s : x){ // really?
if(s.equals("a")){
return false;
}
}
return true;
}).collect(Collectors.toList());
// print array
result.forEach(x -> System.out.println(Arrays.toString(x)));
Output
[c, d]
[e, f]
In the above case, the Stream#filter will filter out the entire [a, b], but we want to filter out only the character a
3.4 Below is the final version, and we combine the array first and follow by a filter later. In Java, to convert a 2d array into a 1d array, we can loop the 2d array and put all the elements into a new array; Or we can use the Java 8 flatMap to flatten the 2d array into a 1d array, or from Stream<String[]> to Stream<String>.
String[][] array = new String[][]{{"a", "b"}, {"c", "d"}, {"e", "f"}};
// Java 8
String[] result = Stream.of(array) // Stream<String[]>
.flatMap(Stream::of) // Stream<String>
.toArray(String[]::new); // [a, b, c, d, e, f]
for (String s : result) {
System.out.println(s);
}
Output
a
b
c
d
e
f
Now, we can easily filter out the a; let see the final version.
String[][] array = new String[][]{{"a", "b"}, {"c", "d"}, {"e", "f"}};
List<String> collect = Stream.of(array) // Stream<String[]>
.flatMap(Stream::of) // Stream<String>
.filter(x -> !"a".equals(x)) // filter out the a
.collect(Collectors.toList()); // return a List
collect.forEach(System.out::println);
Output
b
c
d
e
f
I want to point out that dealing with more than one level of Stream is challenging, confusing, and error-prone, and we can use this Stream#flatMap to flatten the 2 levels Stream into one level Stream.
Stream<String[]> -> flatMap -> Stream<String>
Stream<Set<String>> -> flatMap -> Stream<String>
Stream<List<String>> -> flatMap -> Stream<String>
Stream<List<Object>> -> flatMap -> Stream<Object>
3. flatMap example – Find all books.
This example uses .stream() to convert a List into a stream of objects, and each object contains a set of books, and we can use flatMap to produces a stream containing all the book in all the objects.
In the end, we also filter out the book containing the word python and collect a Set to remove the duplicated book.
Developer.java
package com.favtuts.java8.stream.flatmap;
import java.util.HashSet;
import java.util.Set;
public class Developer {
public Developer(Integer id, String name, Set<String> book) {
this.id = id;
this.name = name;
this.book = book;
}
public Developer() {
}
private Integer id;
private String name;
private Set<String> book;
//getters, setters, toString
public Integer getId() {
return this.id;
}
public void setId(Integer id) {
this.id = id;
}
public String getName() {
return this.name;
}
public void setName(String name) {
this.name = name;
}
public Set<String> getBook() {
return this.book;
}
public void setBook(Set<String> book) {
this.book = book;
}
@Override
public String toString() {
return "{" +
" id='" + getId() + "'" +
", name='" + getName() + "'" +
", book='" + getBook() + "'" +
"}";
}
public void addBook(String book) {
if (this.book == null) {
this.book = new HashSet<>();
}
this.book.add(book);
}
}
FlatMapExample1.java
package com.favtuts.java8;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import com.favtuts.java8.stream.flatmap.Developer;
public class TestFlatMap {
public static void main(String[] args) {
filterBookOfDevelopers();
}
private static void filterBookOfDevelopers() {
Developer o1 = new Developer();
o1.setName("favtuts");
o1.addBook("Java 8 in Action");
o1.addBook("Spring Boot in Action");
o1.addBook("Effective Java (3nd Edition)");
Developer o2 = new Developer();
o2.setName("zilap");
o2.addBook("Learning Python, 5th Edition");
o2.addBook("Effective Java (3nd Edition)");
List<Developer> list = new ArrayList<>();
list.add(o1);
list.add(o2);
// hmm....Set of Set...how to process?
/*Set<Set<String>> collect = list.stream()
.map(x -> x.getBook())
.collect(Collectors.toSet());*/
Set<String> collect =
list.stream()
.map(x -> x.getBook()) // Stream<Set<String>>
.flatMap(x -> x.stream()) // Stream<String>
.filter(x -> !x.toLowerCase().contains("python")) // filter python book
.collect(Collectors.toSet()); // remove duplicated
collect.forEach(System.out::println);
}
}
Output
Spring Boot in Action
Effective Java (3nd Edition)
Java 8 in Action
The map is optional.
Set<String> collect2 = list.stream()
//.map(x -> x.getBook())
.flatMap(x -> x.getBook().stream()) // Stream<String>
.filter(x -> !x.toLowerCase().contains("python")) // filter python book
.collect(Collectors.toSet());
4. flatMap example – Order and LineItems
This example is similar to the official flatMap JavaDoc example.
The orders is a stream of purchase orders, and each purchase order contains a collection of line items, then we can use flatMap to produce a Stream or Stream<LineItem> containing all the line items in all the orders. Furthermore, we also add a reduce operation to sum the line items’ total amount.
FlatMapExample2.java
package com.favtuts.java8;
import java.math.BigDecimal;
import java.util.*;
import com.favtuts.java8.stream.flatmap.LineItem;
import com.favtuts.java8.stream.flatmap.Order;
public class TestFlatMap {
public static void main(String[] args) {
workOrderAndLineItems();
}
private static void workOrderAndLineItems() {
List<Order> orders = findAll();
/*
Stream<List<LineItem>> listStream = orders.stream()
.map(order -> order.getLineItems());
Stream<LineItem> lineItemStream = orders.stream()
.flatMap(order -> order.getLineItems().stream());
*/
// sum the line items' total amount
BigDecimal sumOfLineItems = orders.stream()
.flatMap(order -> order.getLineItems().stream()) // Stream<LineItem>
.map(line -> line.getTotal()) // Stream<BigDecimal>
.reduce(BigDecimal.ZERO, BigDecimal::add); // reduce to sum all
// sum the order's total amount
BigDecimal sumOfOrder = orders.stream()
.map(order -> order.getTotal()) // Stream<BigDecimal>
.reduce(BigDecimal.ZERO, BigDecimal::add); // reduce to sum all
System.out.println(sumOfLineItems); // 3194.20
System.out.println(sumOfOrder); // 3194.20
if (!sumOfOrder.equals(sumOfLineItems)) {
System.out.println("The sumOfOrder is not equals to sumOfLineItems!");
}
}
private static List<Order> findAll() {
LineItem item1 = new LineItem(1, "apple", 1, new BigDecimal("1.20"), new BigDecimal("1.20"));
LineItem item2 = new LineItem(2, "orange", 2, new BigDecimal(".50"), new BigDecimal("1.00"));
Order order1 = new Order(1, "A0000001", Arrays.asList(item1, item2), new BigDecimal("2.20"));
LineItem item3 = new LineItem(3, "monitor BenQ", 5, new BigDecimal("99.00"), new BigDecimal("495.00"));
LineItem item4 = new LineItem(4, "monitor LG", 10, new BigDecimal("120.00"), new BigDecimal("1200.00"));
Order order2 = new Order(2, "A0000002", Arrays.asList(item3, item4), new BigDecimal("1695.00"));
LineItem item5 = new LineItem(5, "One Plus 8T", 3, new BigDecimal("499.00"), new BigDecimal("1497.00"));
Order order3 = new Order(3, "A0000003", Arrays.asList(item5), new BigDecimal("1497.00"));
return Arrays.asList(order1, order2, order3);
}
}
Order.java
package com.favtuts.java8.stream.flatmap;
import java.math.BigDecimal;
import java.util.List;
public class Order {
public Order(Integer id, String invoice, List<LineItem> lineItems, BigDecimal total) {
this.id = id;
this.invoice = invoice;
this.lineItems = lineItems;
this.total = total;
}
public Order() {
}
private Integer id;
private String invoice;
private List<LineItem> lineItems;
private BigDecimal total;
// getter, setters, constructor
@Override
public String toString() {
return "Order{" +
" id='" + getId() + "'" +
", invoice='" + getInvoice() + "'" +
", lineItems='" + getLineItems() + "'" +
", total='" + getTotal() + "'" +
"}";
}
public Integer getId() {
return this.id;
}
public void setId(Integer id) {
this.id = id;
}
public String getInvoice() {
return this.invoice;
}
public void setInvoice(String invoice) {
this.invoice = invoice;
}
public List<LineItem> getLineItems() {
return this.lineItems;
}
public void setLineItems(List<LineItem> lineItems) {
this.lineItems = lineItems;
}
public BigDecimal getTotal() {
return this.total;
}
public void setTotal(BigDecimal total) {
this.total = total;
}
}
LineItem.java
package com.favtuts.java8.stream.flatmap;
import java.math.BigDecimal;
public class LineItem {
public LineItem(Integer id, String item, Integer qty, BigDecimal price, BigDecimal total) {
this.id = id;
this.item = item;
this.qty = qty;
this.price = price;
this.total = total;
}
public LineItem() {
}
private Integer id;
private String item;
private Integer qty;
private BigDecimal price;
private BigDecimal total;
// getter, setters, constructor
@Override
public String toString() {
return "LineItem{" +
" id='" + getId() + "'" +
", item='" + getItem() + "'" +
", qty='" + getQty() + "'" +
", price='" + getPrice() + "'" +
", total='" + getTotal() + "'" +
"}";
}
public Integer getId() {
return this.id;
}
public void setId(Integer id) {
this.id = id;
}
public String getItem() {
return this.item;
}
public void setItem(String item) {
this.item = item;
}
public Integer getQty() {
return this.qty;
}
public void setQty(Integer qty) {
this.qty = qty;
}
public BigDecimal getPrice() {
return this.price;
}
public void setPrice(BigDecimal price) {
this.price = price;
}
public BigDecimal getTotal() {
return this.total;
}
public void setTotal(BigDecimal total) {
this.total = total;
}
}
Output
3194.20
3194.20
5. flatMap example – Splits the line by spaces
This example read a text file, split the line by spaces, and displayed the total number of the words.
A text file.
c:\\test\\test.txt
hello world Java
hello world Python
hello world Node JS
hello world Rust
hello world Flutter
Read the comment for self-explanatory.
FlatMapExample3.java
package com.favtuts.java8;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.stream.Stream;
public class TestFlatMap {
public static void main(String[] args) {
countTotalWordsTextFile();
}
private static void countTotalWordsTextFile() {
try {
Path path = Paths.get("/home/tvt/workspace/favtuts/flatmap.txt");
// read file into a stream of lines
Stream<String> lines = Files.lines(path, StandardCharsets.UTF_8);
// stream of array...hard to process.
// Stream<String[]> words = lines.map(line -> line.split(" +"));
// stream of stream of string....hmm...better flat to one level.
// Stream<Stream<String>> words = lines.map(line -> Stream.of(line.split(" +")));
// result a stream of words, good!
Stream<String> words = lines.flatMap(line -> Stream.of(line.split(" +")));
// count the number of words.
long noOfWords = words.count();
System.out.println(noOfWords); // 16
} catch (IOException e) {
e.printStackTrace();
}
}
}
Output
16
6. flatMap and primitive type
For primitive types like int, long, double, etc. Java 8 Stream also provide related flatMapTo{primative type} to flat the Stream of primitive type; the concept is the same.
flatMapToInt -> IntStream
FlatMapExample4.java
package com.favtuts.java8;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.file.*;
import java.util.*;
public class TestFlatMap {
public static void main(String[] args) {
flatMapWithPrimitiveType();
}
private static void flatMapWithPrimitiveType() {
int[] array = {1, 2, 3, 4, 5, 6};
//Stream<int[]>
Stream<int[]> streamArray = Stream.of(array);
//Stream<int[]> -> flatMap -> IntStream
IntStream intStream = streamArray.flatMapToInt(x -> Arrays.stream(x));
intStream.forEach(x -> System.out.println(x));
}
}
Output
1
2
3
4
5
6
flatMapToLong -> LongStream
long[] array = {1, 2, 3, 4, 5, 6};
Stream<long[]> longArray = Stream.of(array);
LongStream longStream = longArray.flatMapToLong(x -> Arrays.stream(x));
System.out.println(longStream.count());
Download Source Code
$ git clone https://github.com/favtuts/java-core-tutorials-examples
$ cd java-basic/java8