import sys,math
sys.dont_write_bytecode=True # don't make .pyc files
from lib import *
from about import *
from a12 import *

def scottknott(data,small=3,b=250, conf=0.05):
  """Recursively split data, maximizing delta of
  the expected value of the mean before and 
  after the splits. 
  Reject splits with under 3 items"""
  def theSame(one, two):
    if a12small(two, one): return True
    return  not bootstrap(one, two, b=b, conf=conf)
  all  = reduce(lambda x,y:x+y,data)
  same = lambda l,r: theSame(l.saw(), r.saw())
  big  = lambda    n: n > small    
  return rdiv(data,all,minMu,big,same)

def rdiv(data,  # a list of class Nums
         all,   # all the data combined into one num
         div,   # function: find the best split
         big,   # function: rejects small splits
         same): # function: rejects similar splits
  """Looks for ways to split sorted data, 
  Recurses into each split. Assigns a 'rank' number
  to all the leaf splits found in this way. 
  """
  def recurse(parts,all,rank=0):
    "Split, then recurse on each part."
    cut,left,right = div(parts,all,big,same)
    if cut: 
      # if cut, rank "right" higher than "left"
      rank = recurse(parts[:cut],left,rank) + 1
      rank = recurse(parts[cut:],right,rank)
    else: 
      # if no cut, then all get same rank
      for part in parts: 
        part.rank = rank
    return rank
  recurse(sorted(data),all)
  return data

def minMu(parts,all,big,same):
  """Find a cut in the parts that maximizes
  the expected value of the difference in
  the mean before and after the cut.
  Reject splits that are insignificantly
  different or that generate very small subsets.
  """
  cut,left,right = None,None,None
  before, mu     =  0, all.mu
  for i,l,r in leftRight(parts):
    if big(l.n) and big(r.n):
      if not same(l,r):
        n = all.n * 1.0
        x = l.n/n*(mu - l.mu)**2 + \
            r.n/n*(mu - r.mu)**2  
        if x > before:
          before,cut,left,right = x,i,l,r
  return cut,left,right

def leftRight(parts):
  """Iterator. For all items in 'parts',
  return everything to the left and everything
  from here to the end. For reasons of
  efficiency, take a first pass over the data
  to pre-compute and cache right-hand-sides
  """
  rights = {}
  n = j = len(parts) - 1
  while j > 0:
    rights[j] = parts[j]
    if j < n: rights[j] += rights[j+1]
    j -=1
  left = parts[0]
  for i,one in enumerate(parts):
    if i> 0: 
      yield i,left,rights[i]
      left += one

def bootstrap(y0,z0,conf=0.05,b=1000):
  """The bootstrap hypothesis test from
     p220 to 223 of Efron's book 'An
    introduction to the bootstrap."""
  class total():
    "quick and dirty data collector"
    def __init__(i,some=[]):
      i.sum = i.n = i.mu = 0 ; i.all=[]
      for one in some: i.put(one)
    def put(i,x):
      i.all.append(x);
      i.sum +=x; i.n += 1; i.mu = float(i.sum)/i.n
    def __add__(i1,i2): return total(i1.all + i2.all)
  def testStatistic(y,z): 
    """Checks if two means are different, tempered
     by the sample size of 'y' and 'z'"""
    tmp1 = tmp2 = 0
    for y1 in y.all: tmp1 += (y1 - y.mu)**2 
    for z1 in z.all: tmp2 += (z1 - z.mu)**2
    s1    = float(tmp1)/(y.n - 1)
    s2    = float(tmp2)/(z.n - 1)
    delta = z.mu - y.mu
    if s1+s2:
      delta =  delta/((s1/y.n + s2/z.n)**0.5)
    return delta
  def one(lst): return lst[ int(any(len(lst))) ]
  def any(n)  : return random.uniform(0,n)
  y, z   = total(y0), total(z0)
  x      = y + z
  tobs   = testStatistic(y,z)
  yhat   = [y1 - y.mu + x.mu for y1 in y.all]
  zhat   = [z1 - z.mu + x.mu for z1 in z.all]
  bigger = 0.0
  for i in range(b):
    if testStatistic(total([one(yhat) for _ in yhat]),
                     total([one(zhat) for _ in zhat])) > tobs:
      bigger += 1
  return bigger / b < conf

def bootstrapd(): 
  def worker(n=30,mu1=10,sigma1=1,mu2=10.2,sigma2=1):
    def g(mu,sigma) : return random.gauss(mu,sigma)
    x = [g(mu1,sigma1) for i in range(n)]
    y = [g(mu2,sigma2) for i in range(n)]
    return n,mu1,sigma1,mu2,sigma2,\
        'different' if bootstrap(x,y) else 'same'
  print worker(30, 10.1, 1, 10.2, 1)
  print worker(30, 10.1, 1, 10.8, 1)
  print worker(30, 10.1, 10, 10.8, 1)
  print msecs(lambda : 
              worker(1000, 10.1, 1, 10.2, 1))

if __name__ == '__main__':
  eval(cmd('bootstrapd()'))

