calculatePhylogeneticDiversityΒΆ

#! /usr/bin/env python2.7

desc = """
Calculate the Phylogenetic Diversity (PD: Faith 1992) of a group of taxa on a
tree. PD is the minimum total length of all the phylogenetic branches required
to span a given set of taxa on the phylogenetic tree (and does not include the
stem branch of a clade).

Faith DP. 1992. Conservation evaluation and phylogenetic diversity. Biol.
Conserv. 61:1-10.

Cymon J. Cox version 1.0 Tues 25 Oct 2011
"""

import os
import argparse
import textwrap
import unittest
import copy

#Squash any local configuration that might cause circular imports:
os.environ['P4_STARTUP'] = ""
from p4 import *

test_file = """
#NEXUS

begin taxa;
  dimensions ntax=7;
  taxlabels t1 t2 t3 t4 t5 t6 t7;
end;

begin characters;
    dimensions nchar=2;
    Format datatype=dna gap=-;
    Matrix
t1 AA
t2 GG
t3 CC
t4 TG
t5 TA
t6 TT
t7 GT
;
end;

begin trees;
  tree t0 = ((t2:0.00189, (t1:0.00418, ((t4:0.00489, t5:0.00295):0.00014, (t6:0.00094, t7:0.00018):0.00125):0.0077):0.00083):0.00145)t3;
end;
[
begin sets;
    taxset one = t6 t7;
    taxset two = t4 t5;
    taxset three = t4 t5 t7;
    taxset four = t2 t4 t7;
end;
]
"""

class RunUnittests(unittest.TestCase):

    def setUp(self):
        var.verboseRead = False
        var.warnReadNoFile = False
        read(test_file)
        self.t = var.trees[0]
        self.t.reRoot(3)
        self.total_tree_length = self.t.len()
        self.t.taxNames = [n.name for n in self.t.iterLeavesPreOrder()]
        #self.t.draw()

    def test_clade1(self):
        the_taxa = "t6 t7"
        read("#nexus begin sets; taxset ts1 = %s; end;" % the_taxa)
        self.t.setNexusSets()
        self.t.makeSplitKeys(makeNodeForSplitKeyDict=True)
        node = self.t.taxSetIsASplit("ts1")
        (total, pd) = calculateMinSpanTree(node, self.total_tree_length,
                verbose=False)
        self.assertEqual("%.5f" % total, "0.00112")
        self.assertEqual("%.5f" % pd, "4.24242")

    def test_clade2(self):
        the_taxa = "t4 t5"
        read("#nexus begin sets; taxset ts1 = %s; end;" % the_taxa)
        self.t.setNexusSets()
        self.t.makeSplitKeys(makeNodeForSplitKeyDict=True)
        node = self.t.taxSetIsASplit("ts1")
        (total, pd) = calculateMinSpanTree(node, self.total_tree_length,
                verbose=False)
        self.assertEqual("%.5f" % total, "0.00784")
        self.assertEqual("%.5f" % pd, "29.69697")

    def test_clade3(self):
        the_taxa = "t4 t5 t7"
        read("#nexus begin sets; taxset ts1 = %s; end;" % the_taxa)
        self.t.setNexusSets()
        self.t.makeSplitKeys(makeNodeForSplitKeyDict=True)
        tn = copy.deepcopy(self.t.taxNames)
        for taxon in tn:
            if taxon not in the_taxa.split():
                #print "removing %s" % taxon
                self.t.removeNode(taxon)
                #self.t.draw()
        (total, pd) = calculateMinSpanTree(self.t.root, self.total_tree_length,
                verbose=False)
        self.assertEqual("%.5f" % total, "0.00941")
        self.assertEqual("%.5f" % pd, "35.64394")

    def test_clade4(self):
        the_taxa = "t2 t4 t7"
        read("#nexus begin sets; taxset ts1 = %s; end;" % the_taxa)
        self.t.setNexusSets()
        self.t.makeSplitKeys(makeNodeForSplitKeyDict=True)
        tn = copy.deepcopy(self.t.taxNames)
        for taxon in tn:
            if taxon not in the_taxa.split():
                #print "removing %s" % taxon
                self.t.removeNode(taxon)
                #self.t.draw()
        (total, pd) = calculateMinSpanTree(self.t.root, self.total_tree_length,
                verbose=False)
        self.assertEqual("%.5f" % total, "0.01688")
        self.assertEqual("%.5f" % pd, "63.93939")

    def test_clade5(self):
        nodenum = 5
        (total, pd) = calculateMinSpanTree(self.t.node(nodenum), self.total_tree_length,
                verbose=False)
        self.assertEqual("%.5f" % total, "0.01035")
        self.assertEqual("%.5f" % pd, "39.20455")

    def tearDown(self):
        var.trees = []
        var.nexusSets = None
        var.alignments = []

def calculateMinSpanTree(node, total_tree_length, verbose=True):
    """Calculate PD of an entire clade (without stem)
    """
    total = 0.0
    if verbose:
        leaves = [n.name for n in node.iterLeaves()]
        print "\t%18s = %s" % ("Taxa", " ".join(leaves))
    for n in node.iterPreOrder():
        if n != node:
            if verbose:
                nl = "Branch length %i" % n.nodeNum
                print "\t%18s = %.5f" % (nl, n.br.len)
            total += n.br.len
    pd = (total/total_tree_length)* 100
    if verbose:
        print "\t%18s = %.5f" % ("Min. span. tree", total)
        print "\n\t%18s = %.5f" % ("Total tree length", total_tree_length)
        print "\t%18s = %.5f\n" % ("PD (%)", pd)
    return (total, pd)
        
def main(treefile, re_root, nodenum, run_unittests, verbose):

    if run_unittests:
        print "\n\t !!!! Ignoring all other options and running unittests...\n"
        runner = unittest.TextTestRunner(verbosity = 2)
        ts = unittest.TestSuite()
        ts.addTest(RunUnittests("test_clade1"))
        ts.addTest(RunUnittests("test_clade2"))
        ts.addTest(RunUnittests("test_clade3"))
        ts.addTest(RunUnittests("test_clade4"))
        ts.addTest(RunUnittests("test_clade5"))
        runner.run(ts)
        sys.exit()

    read(treefile)
    try:
        t = var.trees[0]
    except IndexError:
        print "No tree found in treefile."
        print "Aborting."
        sys.exit(1)
    t.taxNames = [n.name for n in t.iterLeavesPreOrder()]
    t.setNexusSets()
    if not t.nexusSets:
        if nodenum == None:
            print "Must either be a Nexus taxon set defined in the tree file,"
            print "or a node number must be given with option -n"
            print "Aborting."
            sys.exit(1)
    else:
        if nodenum != None:
            print "Both a taxon set and node number given with -n (use one or"
            print "the other, not both)"
            print "Aborting."
            sys.exit(1)
    print "\n"
    if re_root or t.root.isLeaf:
        t.draw(showInternalNodeNames=1,showNodeNums=1)
        if t.root.isLeaf:
            print "\tTree must be re-rooted to an internal node..."
        while 1:
            i = raw_input("\tNode number to re-root with... ")
            try:
               nn = int(i)
            except ValueError:
                print "\tTry a number this time..."
                continue
            if t.node(nn).isLeaf:
                print "\tTree must be re-rooted to an internal node..."
                continue
            break
        t.reRoot(nn)
        t.ladderize()
        t.draw(showInternalNodeNames=1,showNodeNums=1)
    else:
        t.draw(showInternalNodeNames=1,showNodeNums=1)

    total_tree_length = t.len()

    if nodenum != None:
        #print "\n\t%18s = %.5f" % ("Total tree length", t.len())
        print "\t%18s = %i" % ("Node", nodenum)
        calculateMinSpanTree(t.node(nodenum), total_tree_length)
    else:
        #Iter over taxon sets
        for taxonset in t.nexusSets.taxSets:
            taxonset.setUseTaxNames()
            if verbose:
                print "\t%18s = %s" % ("Taxon set", ", ".join(taxonset.taxNames))
            t1 = t.dupe()
            t1.makeSplitKeys(makeNodeForSplitKeyDict=True)
            #Is it a clade:
            subTreeNode = t1.taxSetIsASplit(taxonset.name)
            if subTreeNode:
                #print "\n\t%18s = %.5f" % ("Total tree length", t.len())
                calculateMinSpanTree(subTreeNode, total_tree_length)
            else:
                #Prune the tree
                if verbose:
                    print "t.taxNames: %s" % t.taxNames
                for taxon in t.taxNames:#Not the t1.taxaNames which will be
                    #modified
                    if verbose:
                        print "doing: %s" % taxon
                    if taxon not in taxonset.taxNames:
                        if verbose:
                            print "removing taxon %s" % taxon
                        t1.removeNode(taxon)
                        if verbose:
                            t1.draw()
                print "\tPruned tree:"
                t1.draw()
                #print "\n\t%18s = %.5f" % ("Total tree length", t.len())
                calculateMinSpanTree(t1.root, total_tree_length)
            del t1

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
            formatter_class=argparse.RawDescriptionHelpFormatter,
            description=textwrap.dedent(desc),
            )
    parser.add_argument("treefile", help="Path to Nexus tree file.")
    parser.add_argument("-r", "--re_root",
                        dest="re_root",
                        help="Re-root tree. " +\
                        "Default: False",
                        default=False,
                        action='store_true')
    parser.add_argument("-n", "--nodenum",
                        dest="nodenum",
                        help="Node number of clade." +\
                        " Default: None",
                        default=None,
                        type=int)
    parser.add_argument("-v", "--verbose",
                        dest="verbose",
                        help="Verbose. Default: False",
                        default=False,
                        action='store_true')
    parser.add_argument("-u", "--unittest",
                        dest="unittest",
                        help="Run unittests. Default: False",
                        default=False,
                        action='store_true')
    args = parser.parse_args()
    main(args.treefile, args.re_root, args.nodenum, args.unittest, args.verbose)