#!/usr/bin/env python
'''
Power of Cochran-Armitage trend test for case/ctrl genetic association studies
Gao Wang (*) 2012 distributed under GNU GPL

References
----------
Hum Hered 2002;53:146-152 (DOI: 10.1159/000064976)
Genet Epi 2003 (DOI: 10.1002/gepi.10245) for matrix notation for sigma under H1

Bash Example
------------
maf="0.1 0.2 0.3"
OR="1.1 1.2 1.3 1.4"
N="200 500 800"
f0=0.01
moi='additive'
alpha=0.05

for i in $maf; do
  for j in $OR; do
    for k in $N; do
      power=`catt -b $f0 -r $j -p $i -R $k -S $k -m $moi -a $alpha`
      echo $power $i $j $k $f0 $alpha $moi
    done
  done
done

ChangeLog
---------
2012-December-23
    * Initial implementation
'''
import sys
try:
    import numpy as np
    from scipy.stats import norm
except ImportError as e:
    sys.exit(e)
try:
    from argparse import ArgumentParser
except ImportError:
    sys.exit('argparse module is required (available in Python 2.7.2+ or Python 3.2.1+)')

class CATT:
    def __init__(self, odr, f0, maf, n1, n2, model = 'additive'):
        base_odds = f0 / (1.0 - f0)
        if model == 'recessive':
            # odds assuming recessive model
            self.odds = np.array([base_odds, base_odds, base_odds * odr])
            # genotype coding
            self.X = np.array([0,0,1], dtype=float)
        elif model == 'dominant':
            # odds assuming dominant model
            self.odds = np.array([base_odds, base_odds * odr, base_odds * odr])
            # genotype coding
            self.X = np.array([0,1,1], dtype=float)            
        elif model == 'multiplicative':
            # odds assuming dominant model
            self.odds = np.array([base_odds, base_odds * odr, base_odds * odr * odr])
            # genotype coding
            self.X = np.array([0,1,2], dtype=float)
        else:
            # odds assuming additive model
            self.odds = np.array([base_odds, base_odds * odr, max(base_odds * (2 * odr - 1.0), 0.0)])
            # genotype coding
            self.X = np.array([0,1,2], dtype=float)            
        # sample size
        self.N1 = float(n1)
        self.N2 = float(n2)
        self.N = float(n1 + n2)
        # genotype frequency assuming HWE
        self.gf = np.array([(1-maf)**2, 2*maf*(1-maf), maf**2])
        # genotype frequency in cases
        self.p = self.conditionalGF('cases') 
        # genotype frequency in ctrls
        self.q = self.conditionalGF('ctrls')

    def conditionalGF(self, condition = 'cases'):
        '''genotype frequency conditional on disease status,
        assuming complete LD between marker and disease causal variant'''
        if condition == 'ctrls':
            f = 1.0 - self.odds / (1.0 + self.odds)
        else:
            f = self.odds / (1.0 + self.odds)
        P_status_and_genotype = f * self.gf
        self.prevalence = np.sum(P_status_and_genotype) 
        if condition == 'cases':
            sys.stderr.write('(disease prevalence = {0})\n'.format(np.around(self.prevalence,4)))
        return P_status_and_genotype / self.prevalence 

    def getSigma0(self, sigma1, mu1):
       '''see proof in Appendix A of Freidlin et al 2002'''
       return np.sqrt(sigma1 ** 2 + mu1 ** 2)
       
    def getMu1(self, X, N1, N2, N, p, q):
        '''Equation 4 of Freidlin et al 2002'''
        return np.sum(X * (p - q) * N1 * N2 / N ** 2)

    def getSigma1(self, X, N1, N2, N, p, q):
        '''Equation 4 of Freidlin et al 2002.
        Matrix notation, Pfeiffer and Gail 2003'''
        Sp = self.correlationMatrix(p)
        Sq = self.correlationMatrix(q)
        sigma1 = (N1 * N2 / N ** 3) * \
            (N1 * np.dot(np.dot(X.T, Sp), X) + N2 * np.dot(np.dot(X.T, Sq), X)) 
        # equivalent non-matrix notation, easier to compute but less intuitive
        # sigma1 = (N1 * N2 / N ** 3) * \
        #     ( (N1 * (np.sum(X ** 2 * p) - np.sum(X * p) ** 2)) + \
        #           (N2 * (np.sum(X ** 2 * q) - np.sum(X * q) ** 2)) ) 
        return np.sqrt(sigma1) 

    def power(self, alpha, alternative = 'two-sided'):
        '''Equation 5 of Freidlin et al 2002'''
        sigma1 = self.getSigma1(self.X, self.N1, self.N2, self.N, self.p, self.q)
        mu1 = self.getMu1(self.X, self.N1, self.N2, self.N, self.p, self.q) 
        sigma0 = self.getSigma0(sigma1, mu1)
        z = norm.ppf(1 - alpha / 2.0)
        power = 1 - norm.cdf((z * sigma0 - np.sqrt(self.N) * mu1) / sigma1) + \
            norm.cdf((-1.0 * z * sigma0 - np.sqrt(self.N) * mu1) / sigma1)
        return power

    ### ###
    
    def correlationMatrix(self, p):
        '''correlation matrix where m_ii = pi * (1-pi) and m_ik = -pi*pk'''
        m = np.empty((len(p),len(p),))
        for i in range(len(p)):
            for k in range(len(p)):
                if k == i:
                    m[i,k] = p[i] * (1 - p[i])
                else:
                    m[i,k] = p[i] * p[k] * -1.0
        return m

if __name__ == '__main__':
    parser = ArgumentParser(description='Cochran-Armitage trend test (CATT) power calculation for case/ctrl genetic association studies, implementing Freidlin et al, Hum Hered 2002;53:146-152')
    parser.add_argument('-b', '--baseline_penetrance', metavar = 'f0', type = float, default = 0.01, help = 'wildtype genotype penetrance, default to 0.01')
    parser.add_argument('-r', '--odds_ratio', metavar = 'gamma', type = float, default = 1.0, help = 'odds ratio, default to 1.0')
    parser.add_argument('-p', '--maf', type = float, default = 0.1, help = 'minor allele frequency, default to 0.1')
    parser.add_argument('-R', '--num_cases', metavar = 'N', type = int, default = 500, help = 'number of cases, default to 500')
    parser.add_argument('-S', '--num_ctrls', metavar = 'N', type = int, default = 500, help = 'number of ctrls, default to 500')
    parser.add_argument('-m', '--moi', default = 'additive', choices = ['additive', 'dominant', 'recessive', 'multiplicative'], help = 'mode of inheritance, default to "additive"')
    parser.add_argument('-a', '--alpha', type = float, default = 0.05, help = 'significance level (test size), default to 0.05')
    parser.add_argument('--alternative', default = 'two-sided', choices = ['two-sided'], help = 'alternative hypothesis, default to two-sided')
    args = parser.parse_args()
    #
    try:
        t = CATT(args.odds_ratio, args.baseline_penetrance,
                 args.maf, args.num_cases, args.num_ctrls, model = args.moi)
        power = t.power(args.alpha, args.alternative)
        print(np.around(power, 4))
    except Exception as e:
        sys.exit("ERROR: {0}".format(e))