How to Use TensorFlow with Java
Machine Learning is gaining popularity and usage over the globe. It has already drastically changed the way certain applications are built and will likely continue to be a huge (and increasing) part of our daily lives. There's no sugarcoating it, Machine Learning isn't simple. It's pretty daunting and can seem very complex to many. Companies such as Google took it upon themselves to bring Machine Learning concepts closer to developers and allow them to gradually, with major help, make their first steps. Thus, frameworks such as TensorFlow were born.
What is TensorFlow?
TensorFlow is an open-source Machine Learning framework developed by Google in Python and C++. It helps developers easily acquire data, prepare and train models, predict future states, and perform large-scale machine learning. With it, we can train and run deep neural networks which are most often used for Optical Character Recognition, Image Recognition/Classification, Natural Language Processing, etc.
Tensors and Operations
TensorFlow is based on computational graphs, which you can imagine as a classic graph with nodes and edges. Each node is referred to as an operation, and they take zero or more tensors in and produce zero or more tensors out. An operation can be very simple, such as basic addition, but they can also be very complex. Tensors are depicted as edges of the graph, and are the core data unit. We perform different functions on these tensors as we feed them to operations. They can have a single or multiple dimensions, which are sometimes referred to as their ranks - (Scalar: rank 0, Vector: rank 1, Matrix: rank 2) This data flows through the computational graph through tensors, impacted by operations - hence the name TensorFlow. Tensors can store data in any number of dimensions, and there are three main types of tensors: placeholders, variables, and constants.
Using Maven, installing TensorFlow is as easy as including the dependency:
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.13.1</version> </dependency>
If your device supports GPU support, then use these dependencies:
<dependency> <groupId>org.tensorflow</groupId> <artifactId>libtensorflow</artifactId> <version>1.13.1</version> </dependency> <dependency> <groupId>org.tensorflow</groupId> <artifactId>libtensorflow_jni_gpu</artifactId> <version>1.13.1</version> </dependency>
You can check the version of TensorFlow currently installed by using the
TensorFlow Java API
The Java API TensorFlow offers is contained within the
org.tensorflow package. It's currently experimental so it's not guaranteed to be stable.
Please note that the only fully supported language for TensorFlow is Python and that the Java API isn't nearly as functional.
It introduces us to new classes, an interface, enum, and exception.
The new classes introduced through the API are:
Graph: A data flow graph representing a TensorFlow computation
Operation: A Graph node that performs computation on Tensors
OperationBuilder: A builder class for Operations
Output: A symbolic handle to a tensor produced by an Operation
SavedModelBundle: Represents a model loaded from storage.
SavedModelBundle.Loader: Provides options for loading a SavedModel
Server: An in-process TensorFlow server, for use in distributed training
Session: Driver for Graph execution
Session.Run: Output tensors and metadata obtained when executing a session
Session.Runner: Run Operations and evaluate Tensors
Shape: The possibly partially known shape of a tensor produced by an operation
Tensor: A statically typed multi-dimensional array whose elements are of a type described by T
TensorFlow: Static utility methods describing the TensorFlow runtime
Tensors: Type-safe factory methods for creating Tensor objects
DataType: Represents the type of elements in a Tensor as an enum
Operand: Interface implemented by operands of a TensorFlow operation
TensorFlowException: Unchecked exception thrown when executing TensorFlow Graphs
If we compare all of this to the tf module in Python, there's an obvious difference. The Java API doesn't have nearly the same amount of functionality, at least for now.
As mentioned before, TensorFlow is based on computational graphs - where
org.tensorflow.Graph is Java's implementation.
Note: Its instances are thread-safe, though we need to explicitly release resources used by the Graph after we're finished with it.
Let's start off with an empty graph:
Graph graph = new Graph();
This graph doesn't mean much, it's empty. To do anything with it, we first need to load it up with
To load it up with operations, we use the
opBuilder() method, which returns an
OperationBuilder object that'll add the operations to our graph once we call the
Let's add a constant to our graph:
Operation x = graph.opBuilder("Const", "x") .setAttr("dtype", DataType.FLOAT) .setAttr("value", Tensor.create(3.0f)) .build();
Placeholders are a "type" of variable that don't have a value at declaration. Their values will be assigned at a later date. This allows us to build graphs with operations without any actual data:
Operation y = graph.opBuilder("Placeholder", "y") .setAttr("dtype", DataType.FLOAT) .build();
And now finally, to round this up, we need to add certain functions. These could be as simple as multiplication, division, or addition, or as complex as matrix multiplications. The same as before, we define functions using the
Operation xy = graph.opBuilder("Mul", "xy") .addInput(x.output(0)) .addInput(y.output(0)) .build();
Note: We're using
output(0) as a tensor can have more than one output.
Sadly, the Java API doesn't yet include any tools that allow you to visualize graphs as you would in Python. When the Java API gets updated, so will this article.
As mentioned before, a
Session is the driver for a
Graph's execution. It encapsulates the environment in which
Graphs are executed to compute
What this means is that the tensors in our graph that we constructed don't actually hold any value, as we didn't run the graph within a session.
Let's first add the graph to a session:
Session session = new Session(graph);
Our computation simply multiples the
y value. In order to run our graph and compute it, we
xy operation and feed it the
Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0); System.out.println(tensor.floatValue());
Running this piece of code will yield:
Saving Models in Python and Loading in Java
This may sound a bit odd, but since Python is the only well-supported language, the Java API still doesn't have the functionality to save models.
This means that the Java API is meant only for the serving use-case, at least until it's fully supported by TensorFlow. At least, we can train and save models in Python and then load them in Java to serve them, using the
SavedModelBundle model = SavedModelBundle.load("./model", "serve"); Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0); System.out.println(tensor.floatValue());