Source code for firefly_fitter

import numpy as np
from scipy.stats import sigmaclip
import copy

[docs]def fitter(wavelength_in,data_in,error_in,models_in,SPM): """ The essential ingredient of FIREFLY! Taking each base model as an initial guess, the fitter iteratively creates combinations of these base models when they improve the modified chi-squared value: value = chi-squared + BIC term - exploration term Input: data, base models, wavelength (for plotting) [, options] (data must be an array of length N) (base models must be a matrix of num_models x N) Options: plot_eps: True/False (plot to fit.eps if True, X if False) Output: a set of model weight combinations and their associated chi-squared values, via an array of 'fit' objects. Weights are arrays with length num_models. Fit arrays may be any size up to 10000. """ global models models = models_in global data data = data_in global error error = error_in global wavelength wavelength = wavelength_in global index_count index_count = 0 global iterate_count iterate_count = 0 # Set options manually here for testing upper_limit_fit,fit_cap = SPM.max_iterations, SPM.fit_per_iteration_cap # plot_eps = False # upper_limit_fit = 10 # maximum number of iterations before it gives up! # fit_cap = 1000 # maximum number of fit objects to be created per iteration num_models = len(models) num_wave = len(wavelength) global num_models global bic_n bic_n = np.log(num_wave) chi_models = np.zeros(np.shape(models)) for m in range(num_models): chi_models[m] = (models[m]-data)/error global chi_models class fit(object): """ A fit object contains: - index number in array (useful for clipping later) - branch number (i.e. first iteration is 0, second is 1, etc.) - index number of previous element in list (the 'parent', useful to check for improvement in chi-squared) - base model weights (array of weights matching the input model array) - raw chi-squared value and the following in-built functions: - spawn children iteratively - plot the model combinations compared to the data When initialises it: - makes the weights, branch number, index number and previous index based on inputs """ def __init__(self, weights, branch_num): if branch_num > 1: global clipped_arr global index_count super(fit, self).__init__() self.weights = weights self.branch_num = branch_num self.index = index_count #self.parent_index = parent_index # Auto-calculate chi-squared index_weights = np.nonzero(self.weights) # saves time! #chi_arr = ((np.dot(self.weights,models)) - data) / error chi_arr = np.dot(self.weights[index_weights],chi_models[index_weights]) if branch_num == 0: chi_clipped_arr = sigmaclip(chi_arr, low=3.0, high=3.0) chi_clip_sq = np.square(chi_clipped_arr[0]) clipped_arr = (chi_arr > chi_clipped_arr[1]) & (chi_arr < chi_clipped_arr[2]) self.clipped_arr = clipped_arr else: chi_clip_sq = np.square(chi_arr[clipped_arr]) chi_squared = np.sum(chi_clip_sq) #print chi_squared self.chi_squared = chi_squared index_count += 1 def spawn_children(self,branch_num): # Auto-produce an array of children with iteratively increased weights fit_list = [] new_weights = self.weights*branch_num sum_weights = np.sum(new_weights)+1 for im in range(num_models): new_weights[im]+= 1 fit_add = fit(new_weights/sum_weights,branch_num) fit_list.append(fit_add) new_weights[im]-= 1 return fit_list def retrieve_properties(fit_list): """ # Return an array of all weights and chi-squared of the fits (mainly used for testing) """ lf = len(fit_list) returned_weights = np.zeros((lf,num_models)) returned_chis = np.zeros(lf) returned_branch = np.zeros(lf) for f in range(len(fit_list)): returned_weights[f] = fit_list[f].weights returned_branch[f] = fit_list[f].branch_num returned_chis[f] = fit_list[f].chi_squared return returned_weights,returned_chis,returned_branch def bic_term(): # For convergence global bic_n return bic_n# * self.branch_num def previous_chi(branch_num,fit_list): # To ensure exploration returned_chis = [o.chi_squared for o in fit_list] diff = np.min(returned_chis)#diff = np.percentile(returned_chis[np.where(returned_branch == branch_num-1)],percentile_use) return diff def iterate(fit_list): global iterate_count iterate_count += 1 print "Iteration step: "+str(iterate_count) count_new = 0 len_list = len(copy.copy(fit_list)) save_bic = bic_term() previous_chis = previous_chi(iterate_count,fit_list) for f in range(len_list): new_list = fit_list[f].spawn_children(iterate_count) len_new = len(new_list) for n in range(len_new): # Check if any of the new spawned children represent better solutions new_chi = new_list[n].chi_squared extra_term = save_bic check_better = new_chi < previous_chis-extra_term if check_better: # If they do, add them to the fit list! count_new += 1 if count_new > fit_cap: break fit_list.append(new_list[n]) if count_new > fit_cap: print "Capped solutions at "+str(fit_cap) break if count_new == 0: # If they don't, we have finished the iteration process and may return. print "Converged!" print "Fit list with this many elements:" print len(fit_list) return fit_list else: if iterate_count == 10: print "Fit has not converged within user-defined number of iterations." print "Make sure this is a reasonable number." print "Returning all fits up to this stage." return fit_list else: print "Found "+str(count_new)+" new solutions. Iterate further..." fit_list_new = iterate(fit_list) return fit_list_new def mix(fit_list,full_fit_list,min_chi): """ Mix the best solutions together to improve error estimations. Never go more than 100 best solutions! """ # Importance check: important_chi = min_chi + 10.0 extra_fit_list = []#copy.copy(fit_list) print "Mixing best solutions to improve estimate." #print str(len(fit_list))+' fits to cross-check!' for f1 in range(len(fit_list)): for f2 in range(len(full_fit_list)): for q in [0.0000001,0.000001,0.00001,0.0001,0.001,0.01,0.1,1.0]: new_fit = fit( (fit_list[f1].weights+q*full_fit_list[f2].weights) / (1.0+q),\ fit_list[f1].branch_num+full_fit_list[f2].branch_num) #if new_fit.chi_squared < important_chi: extra_fit_list.append(new_fit) print "Added "+str(len(extra_fit_list))+" solutions!" return extra_fit_list # Initialise fit objects over initial set of models fit_list = [] int_chi = [] zero_weights = np.zeros(len(models)) print "Initiating fits..." for im in range(len(models)): zero_weights[im]+= 1 fit_first = fit(copy.copy(zero_weights),0) fit_list.append(fit_first) int_chi.append(fit_first.chi_squared) zero_weights[im]-= 1 # Find clipped array to remove artefacts: global clipped_arr clipped_arr = fit_list[np.argmin(int_chi)].clipped_arr # Fit_list is our initial guesses from which we will iterate print "Calculated initial chi-squared values." print "Begin iterative process." final_fit_list = iterate(fit_list) junk,chis,more_junk = retrieve_properties(final_fit_list) best_fits = np.argsort(chis) # print "Best chi (raw, reduced) is:" # print min(chis) # print min(chis)/len(wavelength) bf = len(best_fits) if bf>10: bf=10 extra_fit_list = mix(np.asarray(final_fit_list)[best_fits[:bf]].tolist(),final_fit_list,np.min(chis)) total_fit_list = final_fit_list+extra_fit_list #junk,chis,more_junk = retrieve_properties(total_fit_list) return retrieve_properties(total_fit_list)