a machine learning approach to segmentation of neurons
For my senior thesis at Princeton, I got involved with an ambitious open project in neuroscience - modeling the brain as a network of connections. It's a strongly held theory that brains, as well as the neural networks they inspire, perform learning by changing the strengths of the connections between individual neurons. By reinforcing connections, new patterns of network behavior become engrained in the system. Several major figures in the neuroscience academic community, including my adviser Sebastian Seung, have great hope for a method of study they call connectomics. The hope is that by analyzing the connection matrix between the neurons in the brain, we can begin to understand what is happening at a network level as memories are formed and skills are learned.
However, in order to study the patterns in which neurons connect to which others, one must first gather data about these connections, which is no mean feat. The ends of dendritic branches are only nanometers in diameter, smaller than the wavelengths of light that would be used to see them with a traditional light microscope. By imaging a cube of tissue with an electron microscope and successively slicing away the face, scientists can capture a three dimensional image of a cube of brain tissue. The remaining challenge then is to reconstruct the three dimensional structure of the neurons from the serial 2D images. This process is beautifully illustrated in the short video below.
When I decided to work on this project for my thesis, I realized that as a single individual working for less than a year I could make the greatest impact by contributing to a larger project rather than attempting to build from scratch. I also wanted to find a standardized way to compare my results to others and try to improve to the state of the art. Early on, I found the SNEMI-3D contest, an acronym for Segmentation of Neurons from Electron Microscope Images in 3D. The challenge, which can be found on the web here, provides datasets, a success metric, and a leaderboard. I was drawn to the efforts of the team from Janelia Farm narrowly in second place because their efforts were well documented in a paper published in PLoS One in 2013, and the entire project was available open source on Github.
The SNEMI-3D challenge datasets include training volume and a test volume. The training volume has three channels, which are depicted below - the original grayscale electron microscope pixels, the ground truth labeling for the training, and a version of the volume where each value represents the probability that that pixel represents part of a cell membrane. The membrane probabilities provide critical information for finding borders between the neurite cell structures. In the image below, the raw microscope data is shown first, the membrane probabilities second, and the labels last.
The Janelia project is known as Gala, which stands for Graph-based Active Learning of Agglomeration. The approach to the neuron segmentation problem is to begin by oversegmenting the target volume into superpixels. Each superpixel is then represented as a vertex in a graph, with graph edges between all superpixels that touch on at least one pixel. The algorithm then proceeds to consider each edge in order of the likelihood that the two superpixels it touches should be part of the same neurite in the final segmentation. The properties of the edge and of the two superpixels are added to the set of training examples, with the label of "good merge" or "bad merge" derived from a comparison with a ground truth segmentation of the volume. If the two superpixels are part of the same neurite in the ground truth, then they are merged in the graph. As the algorithm continues, it builds a list of labeled training examples until there are no more valid merges to make in the graph. At that point, Gala uses the labeled examples to train a binary classifier such as a random forest or support vector machine to discriminate between merges which should and should not be made. This constitutes one epoch, and in the next epoch, the ratings from the newly trained classifier are used to score the edges for potential merges.
This formulation turns the classic computer vision problem of image segmentation into one of machine learning and binary classification. The effectiveness of supervised machine learning applications is often determined by the degree to which the classification targets (in this case potential merges) are well described by the feature vectors used to represent them to the learning algorithm. In the original Gala paper, the authors mention that developing more advanced features is a great source for potential improvement.
Of the months that I spent trying to improve the results of Gala on the SNEMI3D challenge, I spent the vast bulk of my time devising, implementing, and testing new features that could be fed to the random forest classifier I used as the learning engine. I ended up making three solid contributions to Gala - I created one genuinely useful feature, I made a substantial decrease in the extreme memory and time requirements of the software, and I ultimately boosted the score on SNEMI3D to the top of the leaderboard by changing the parameters used to create the initial oversegmentation of superpixels in each 2D image.
Improving resource usage
Working with Gala was made challenging by its intense demands for working memory. To train and classify on the full SNEMI3D dataset took 14 hours and sustained usage of 30-90 GB of RAM. The cluster computers I worked on had 128 GB of RAM available, but because they were shared machines, full size runs were often killed by the system for hogging nearly all the memory. The bulk of the memory was used to store the graph data structure, which manages the superpixels and their connections. Though the most dramatic improvements could be realized only by vectorizing the graph into an matrices and ditching the mildly bloated NetworkX graph library we were using, I simply decreased the data stored on each node in the graph. Whereas each superpixel originally stored the ids of every single pixel it contained, I wrote an optimized Cython implementation of flood fill to identify these pixels again when necessary. This optimization decreased peak memory usage by more than 20% and gave me the buffer I needed to run Gala comfortably on the computers I had available.
I also profiled the code to identify bottlenecks in the run time. By simply replacing workhorse functions with compiled Cython implementations that ran hundreds of times faster than the original interpreted Python code, I was able to decrease the entire runtime by 30%.
These improvements are summarized in the graph above, which plots memory usage against time for runs before and after optimization. Note that this data is from a run training on only one quarter of the training data, so both the running time and the memory used were four times greater for the full training.
Developing new features
The most obvious way to improve the accuracy of Gala was to improve the feature descriptors of the superpixels to be merged and the area of their contact to improve classification of good and bad merges. Despite spending the bulk of my time experimenting with new features, I was able to produce only one feature, the 'contact' feature, that improved classification accuracy across the board.
I designed the contact feature after I discovered that the most important existing features were those that quantified the distribution across membrane probabilities of the pixels at the boundary between the two candidate neurites. The contact feature takes that a step further by augmenting the feature vector with these values:
- the fraction of the current neurite pixels that are touching the other neurite. Thin pinched channels have an exceptionally low value for this.
- for each of a set of thresholds, the fraction of the pixels in the contact region that have a value below the threshold, i.e. are less likely to be membrane
- for each threshold, the fraction above normalized by the fraction of the segment volume that the contact represents.
After training on SNEMI3D, the random forest weighted these values to have the second highest importance. They elucidate whether the two neurites are connected at their tips or along a broadside, and they proved very effective in cross validation.
I spent extensive time attempting to create features that leveraged information about the shape of the superpixels. The 'skeleton' feature preserved and periodically merged the centroids of each merged superpixel to maintain the 'spine' of each neurite, and added information about the angles between neurite skeletons at a merge point. This feature is illustrated in the image below, with two example merged neurites and their representative skeletons. Similarly, the 'direction' feature took a random sample of pixels local to a merge point and compared the angles between their principal components. The direction feature made accuracy worse, and the skeleton feature fared similarly, even with extensive tweaking of parameters.
I believe that there were two major factors holding back the effectiveness of shape features. First is that the SNEMI3D dataset is anisotropic, meaning that the distance between two pixels in an image is not the same as the distance between two images in the stack. This means that the original superpixels are pancakes, and that it is only after several merges that they begin to take on a discernible shape. This plays into the second problem, that the changing shape and character of the neurites as more superpixels are merged together means that certain values for the feature can indicate a good merge at one scale and a bad merge at another. Incorporating information into the feature vector about the number of merges already undergone by the neurite was not an effective way to combat this.
I also tried a host of other features, including adding a channel for the gradient magnitude of the membrane probability, and computing the original features only on pixels in a certain radius from the merge point, but was unable to get workable results from them. I considered using convolutional neural networks to auto-learn features, but decided not to pursue that route after neural networks were handily outperformed by hand designed features in this paper by Bogovic et al. and because getting such networks running would bring a whole new host of engineering challenges.
Though I was able to improve cross fold accuracy with the contact feature, it showed only a tiny, fractional improvement on the accuracy on the test set after training on the whole train set. What ultimately put my results over the edge was to change the parameters of the oversegmentation. The oversegmentation generates the starting superpixels that are agglomerated to form neurites. Though I was using the same oversegmentation parameters as the creators of Gala used when they created their SNEMI3D submission, I realized that some of the most costly mistakes where two different neurites were mistakenly merged into one had their root in the oversegmentation. If slices of two different neurites were grouped as the same superpixel, then as each is agglomerated, the two neurites end up conjoined. I changed the oversegmentation parameters to favor even smaller superpixels, then re-ran my experiments. I was able to get dramatically improved results on both crossfolding and the actual results. I was encouraged to see that the contact feature made a substantial difference in the crossvalidation performance with the new oversegmentation.
Increasing the number of superpixels in the oversegmentation had the side effect of dramatically increasing the running time of the algorithm, as far more merges were needed before all positive matches had been exhausted. Training and segmentation for the full size datasets ran for over 30 hours, using almost all available RAM the entire time, and I'm confident the training would never have finished if not for the optimizations.
Ultimately, by using the new oversegmentation and adding the contact feature, I was able to reduce our error on the test set to 0.1004, well below the second place error of 0.1239 and the 0.125 error score of the original Gala submission. The leaderboard can be seen here.
Despite my successes, Gala cannot be more than one piece of a neuron segmentation pipeline. It is limited both by its accuracy, which, while good, is not yet human level, and by its resource requirements, which forbid it from working on much larger datasets without being dramatically refactored. If you're still interested, you should take a look at the code on Github, or if you'd like to read more detail about what I did, take a look at my thesis.
A common approach for image segmentation by hierarchical superpixel agglomeration in general is to find the boundary probability of each pixel on the edge between two superpixels, eg by applying edge detection algorithms, and to agglomerate each pair of neighboring superpixels whose bordering pixels have the lowest mean boundary probability. This method is used for the first epoch of Gala before any classifiers have been trained or examples generated, using the cell membrane probability as the boundary probability.
The distribution of membrane probabilities of the pixels on the edge between two superpixels are the source of the workhorse features used by the Gala random forest for the SNEMI3D dataset. The computation for two adjacent superpixels within one 2D plane is illustrated in the diagram below. a) shows the current oversegmentation into superpixels, where each color indicates the label of a distinct superpixel. c) shows the original membrane probability map, where dark values indicate high probability of a cell membrane. It is easy to tell visually from this that purple should be merged with red, not with yellow. b) shows the two overlaid, and d) shows a zoomed in view of the edges where we seek to determine whether to merge the purple region with red region or the yellow region. The edge pixels to be considered are highlighted in purple-red and purple-yellow. e) shows the membrane probability values of the highlighted edges, and the purple-yellow border is clearly much darker than the purple-red border. Finally, f) and g) show the respective histograms of the highlighted edges over the membrane probability values. Gala uses several quantifications of these distributions as the primary features used to train the learner to differentiate between good and bad merges.
The primary challenge comes from the inherent problems of scale. A single neurons dendritic branches, which form the incoming signal pathways from other neurons, can extend for a spread as much as a millimeter in diameter. However, the branching ends themselves (called neurites) can be as small as a nanometer in diameter. This means that imaging a single neuron such that each neurite is at least one pixel wide would require a 3D image volume 1 million pixels per side, with a total size of an exavoxel, or one million teravoxels.