From 76cff6f6f0bc39115b5ea692643652454019865c Mon Sep 17 00:00:00 2001 From: Marcus Herrmann <marcus.herrmann@sed.ethz.ch> Date: Thu, 10 Nov 2016 18:36:19 +0100 Subject: [PATCH] Improve amp-mag regression - using statsmodels for regression --> better statistical analysis --> enable Weighted LS (if desired) (plotted by scaling scatter plot size) - plotting confidence + prediction intervals - plot a line with slope = 1 - improve residual plot --- TM/eventops.py | 41 +++++++----- TM/plotting.py | 178 ++++++++++++++++++++++++++++++++++--------------- version | 2 +- 3 files changed, 151 insertions(+), 70 deletions(-) diff --git a/TM/eventops.py b/TM/eventops.py index 4ad165a..61e5d1b 100644 --- a/TM/eventops.py +++ b/TM/eventops.py @@ -24,6 +24,7 @@ from itertools import repeat import math import scipy import numpy as np +import statsmodels.api as sm # ObsPy from obspy import UTCDateTime @@ -316,7 +317,7 @@ def correct_amplitude_attenuation(events, eventType): # return logLike -def regressMagnitude(foundEvents, catalog, magRange): +def regressMagnitude(foundEvents, catalog, magRange, w_base=1): """ Determines the magnitde (ML) for the detections (foundEvents) based on a regression b/w the catalogs' amplitudes and magnitudes. This relation is @@ -344,7 +345,6 @@ def regressMagnitude(foundEvents, catalog, magRange): if len(catalog[0]) > 4: raise Warning("Catalog has to much values. It should only have ID, time, Mag, amp.") - # Deepcopy catalog + events to not touch original catalog = copy.deepcopy(catalog) foundEvents = copy.deepcopy(foundEvents) @@ -362,21 +362,28 @@ def regressMagnitude(foundEvents, catalog, magRange): catalog_ML = np.asarray([x[2] for x in catalog_minMag]) # Do regression using Linear Least Squares of log10-amp - b, a, r_value, p_value, std_err = scipy.stats.linregress(np.log10(catalog_A), catalog_ML) - print("\nLS-regression: ML = %.2f * log(A) + %.2f" % (b, a)) - -# # Do regression using maximum likelihood -# nll = lambda *args: -loglikelihoodRegress(*args) -# initParams = [b, a, std_err] -# initParams = [1, 1, 1] -## catalog_XY = np.sort(np.column_stack((catalog_logA, catalog_ML)), axis=0) -# result = scipy.optimize.minimize(nll, initParams, args=(catalog_XY[:,0], catalog_XY[:,1]), -# method='nelder-mead') -# print(result) -# b_maxL, a_maxL, sd_maxL = result["x"] - - plotting.plot_amp_regression(catalog_ML, catalog_A, - regLineInfos=[b, a, r_value, p_value, std_err, '']) +# b, a, r2, _, std_err = scipy.stats.linregress(np.log10(catalog_A), catalog_ML) + + # Regression using statsmodels - weighting possible + # 1. Define the weights + # w_pot = catalog_A + w_pot = catalog_ML + weights = np.ones(len(catalog_A)) * w_base**w_pot + weights /= max(weights) + # 2. Fit a linear line + model = sm.WLS(catalog_ML, sm.add_constant(np.log10(catalog_A)), weights=weights) + results = model.fit() + # Get results & statistics + a, b = results.params + std_errs = results.bse + r2 = results.rsquared + + LStype = 'OLS' if w_base == 1 else 'WLS' + print("\n%s-regression: ML = %.2f * log(A) + %.2f" % (LStype, b, a)) + + # Plot the regression line + regLineInfos = [b, a, r2, std_errs, w_base, ''] + plotting.plot_amp_regression(catalog_ML, catalog_A, regLineInfos, weights) # Update foundEvents with REGRESSION ML [x.insert(2, float("{0:.2f}".format(b*math.log10(x[1]) + a))) for x in foundEvents] diff --git a/TM/plotting.py b/TM/plotting.py index 6fc7330..dae965a 100644 --- a/TM/plotting.py +++ b/TM/plotting.py @@ -723,7 +723,8 @@ def plotCHlocationOverview(directory): subprocess.call([command] + strArgs, env=environ, stderr=subprocess.STDOUT) -def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left'): +def plot_amp_regression(magSet, ampSet, regLineInfos=None, + weights=None, legendPos='upper left'): """ Plots the regression for given magnitude <-> amplitude pairs. @@ -734,7 +735,7 @@ def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left if isinstance(magSet, list) or isinstance(ampSet, list): numSets = len(magSet) if len(ampSet) != numSets: - raise Warning("Number of magnitude and amplitde sets doesn't agree.") + raise Warning("Number of magnitude and amplitude sets doesn't agree.") else: numSets = 1 # Pack the arrays into a list (with one entry) @@ -764,15 +765,28 @@ def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left plt.ylabel("ML (SED)") # Variant 1: scatter with density coloring (only useful for ONE set) - if numSets == 1 and len(mags) > 100: - # Calculate the point density - xy = np.vstack([np.log10(amps), mags]) - density = scipy.stats.gaussian_kde(xy)(xy) - # Sort points by density; densest points are plotted last - idx = density.argsort() + if numSets == 1: + + if len(mags) > 100: + # Calculate the point density + xy = np.vstack([np.log10(amps), mags]) + density = scipy.stats.gaussian_kde(xy)(xy) + # Sort points by density; densest points are plotted last + idx = density.argsort() + else: + density = np.ones(len(amps)) + idx = np.arange(0, len(amps)) + + if weights is None: + weights = np.ones(len(amps)) # (no weigths) + + weights /= np.max(weights) # Make sure that it ranges b/w [0...1] + weights *= 25/np.mean(weights) # Scale for plot + # Plot - plt.scatter(amps[idx], mags[idx], s=25, alpha=0.5, - c=density[idx], cmap=plt.get_cmap('gist_heat'), edgecolor='') + plt.scatter(amps[idx], mags[idx], s=weights[idx], + alpha=0.5, c=density[idx], cmap=plt.get_cmap('gist_heat'), + edgecolor='') plt.gca().set_xscale('log') # otherwise Variant 2: simple one-colored scatter (per set) @@ -780,12 +794,14 @@ def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left plt.semilogx(amps, mags, '.', c=colors[iSet]) # min/max among all sets - minAmp_log10 = math.log10(min([min(x) for x in ampSet])) - maxAmp_log10 = math.log10(max([max(x) for x in ampSet])) -# minAmp_log10 = math.log10(1e4) -# maxAmp_log10 = math.log10(2e7) + minAmp_log10 = np.log10(min([min(x) for x in ampSet])) + maxAmp_log10 = np.log10(max([max(x) for x in ampSet])) +# minAmp_log10 = np.log10(5e3) +# maxAmp_log10 = np.log10(5e6) minMag = min([min(x) for x in magSet]) maxMag = max([max(x) for x in magSet]) +# minMag = 0.7 +# maxMag = 3.3 # Plot (possible) regression line(s) + infos if regLineInfos: @@ -793,66 +809,124 @@ def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left if not isinstance(regLineInfos[0], list): regLineInfos = [regLineInfos] # Pack into a list + # b=1 line + x_b1 = np.linspace(minAmp_log10*0.9, maxAmp_log10*1.5, 3) + a_n_b = np.mean(np.array([x[:2] for x in regLineInfos]), 0) + x_mean = np.mean([np.mean(np.log10(x)) for x in ampSet]) + plt.plot(10**x_b1, a_n_b[1] + (a_n_b[0]-1)*x_mean + 1*x_b1, ':', c='k', lw=1, + label='slope = 1') + for idx, regLine in enumerate(regLineInfos): # Unpack values - b, a, r_value, p_value, std_err, label = regLine + b, a, r2, std_errs, w_base, label = regLine + x = np.log10(ampSet[idx]) + # Create x vector - if len(regLineInfos) == numSets: - x_plot = np.array([math.log10(min(ampSet[idx]))-0.1, - math.log10(max(ampSet[idx]))+0.1]) - else: - x_plot = np.array([minAmp_log10-0.1, maxAmp_log10+0.1]) - # Plot - plt.plot(10**x_plot, a + b*x_plot, '-', c=colors[idx], lw=1, + x_plot = np.linspace(min(x), max(x), 100) + + y_plot = a + b*x_plot + + # Plot the regression line + w_label = '\nevent_weight: %d$^{ML}$' % w_base if w_base != 1 else '' + plt.plot(10**x_plot, y_plot, '-', c=colors[idx], lw=1, label=('%s\n' % label + 'ML = %.2f * log(A) + %.2f' % (b, a) + '\n$R^2$: %.4f $\sigma_{slope}$: %.4f' % - (r_value**2, std_err))) - label += '_' - # plt.plot(10**x_plot, a_maxL + b_maxL*x_plot, 'r-', lw=1) - - # plt.axis('scaled') # Doesn't work with semilog axis - plt.axis([10**(minAmp_log10-0.2), 10**(maxAmp_log10+0.2), - minMag-0.2, maxMag+0.2]) -# plt.axis([10**(4.9893-0.2), 10**(7.0858+0.2), 0.7-0.2, 3.4+0.2]) # SED events for OTER2 -# plt.axis([10**(3.8516-0.2), 10**(6.7256+0.2), 0.7-0.2, 3.4+0.2]) # SED events for MATTE - - # Legend - l = plt.legend(loc=legendPos, fontsize=legFontSize, fancybox=True) - # Color even the text of each entry with the appropriate color - for idx, legTexts in enumerate(l.get_texts()): - legTexts.set_color(colors[idx]) - - plt.savefig(config.directory + "/regression_mag-amp_%s%devents.png" % - (label, sum(len(x) for x in magSet)), dpi=150, bbox_inches='tight') -# plt.savefig(config.directory + "/regression_mag-amp_%devents.pdf" % len(catalog_ML), -# dpi=150, bbox_inches='tight') - # Determine current time as ID... to not overwrite possible previous file -# crtTime = time.strftime("%Y-%m-%dT%H%M%S", time.localtime()) -# plt.savefig("plots/regression_" + crtTime + ".png", dpi=150) - plt.close() + (r2, std_errs[1]) + w_label)) + label += '_' if label != '' else '' + + # Plot confidence + prediction interval + resids = magSet[idx] - (a + b*x) + ssr = np.sum(resids**2) + x_mean = np.log10(ampSet[idx]).mean() + n = len(ampSet[idx]) + dof = n - 2 + s = np.sqrt(ssr/dof) + t = scipy.stats.t.ppf(1-0.05/2, df=dof) # Using t-student's distribution + + conf = t * s * np.sqrt(1.0/n + + (x_plot-x_mean)**2 / (np.sum((x_plot-x_mean)**2))) + upper_confI = y_plot + abs(conf) + lower_confI = y_plot - abs(conf) + + pred = t * s * np.sqrt(1 + 1.0/n + + (x_plot-x_mean)**2 / (sum(x**2) - 1/n * sum(x**2))) + + upper_predI = y_plot + abs(pred) + lower_predI = y_plot - abs(pred) + + if idx == len(regLineInfos)-1: + label_CI = '95% confidence interval' + label_PI = '95% prediction interval' + else: + label_CI = '' + label_PI = '' + plt.fill_between(10**x_plot, lower_confI, upper_confI, + color=colors[idx], alpha=0.3, label=label_CI) + plt.fill_between(10**x_plot, lower_predI, upper_predI, + color=colors[idx], alpha=0.1, label=label_PI) + + # plt.axis('scaled') # Doesn't work with semilog axis + plt.axis([10**minAmp_log10*0.8, 10**maxAmp_log10*1.5, + minMag-0.2, maxMag+0.2]) + + # Legend + # Put the b=1 line at the end + handles, labels = plt.gca().get_legend_handles_labels() + handles.append(handles.pop(0)) + labels.append(labels.pop(0)) + + l = plt.legend(handles, labels, loc=legendPos, + fontsize=legFontSize, fancybox=True) + # Color even the text of each entry with the appropriate color + for idx, legTexts in enumerate(l.get_texts()): + if idx < len(regLineInfos): # To allow plotting label of CI & PI + legTexts.set_color(colors[idx]) + + if numSets > 1: + label = 'ALLclusters_' + + plt.savefig(config.directory + "/regression_mag-amp_%s%devents_w%d.png" % + (label, sum([len(x) for x in magSet]), w_base), + dpi=150, bbox_inches='tight') +# plt.savefig(config.directory + "/regression_mag-amp_%devents.pdf" % len(catalog_ML), +# dpi=150, bbox_inches='tight') + # Determine current time as ID... to not overwrite possible previous file +# crtTime = time.strftime("%Y-%m-%dT%H%M%S", time.localtime()) +# plt.savefig("plots/regression_" + crtTime + ".png", dpi=150) + plt.close() # Plot residuals if numSets == 1: amps = ampSet[0] mags = magSet[0] - resis = mags - a - b*np.log10(amps) - plt.plot(amps, resis, 'k.') + resids = mags - (a + b*np.log10(amps)) + resid_ssr = np.sum(resids**2) + resids_MSE = np.sum(resids**2)/len(amps) + + plt.scatter(amps, resids, s=weights, alpha=0.5, color='gray') plt.axhline(color='k', linestyle='--') + plt.fill_between(10**x_plot, abs(pred), -abs(pred), + color=colors[0], alpha=0.1, label=label_PI) + plt.gca().set_xscale('log') - plt.annotate('sum(residuals²): %.2f' % sum(resis**2), + plt.annotate('sum(residuals²): %.2f' % resid_ssr, xy=(0, 1), xycoords='axes fraction', fontsize=12, xytext=(4, 4), textcoords='offset points', ha='left', va='bottom', bbox=dict(boxstyle="square", fc="w")) + plt.annotate('MSE: %.3f' % resids_MSE, + xy=(1, 1), xycoords='axes fraction', fontsize=12, + xytext=(-4, 4), textcoords='offset points', + ha='right', va='bottom', bbox=dict(boxstyle="square", fc="w")) plt.xlabel("max. amplitude @%s" % (config.station)) plt.ylabel("residuals") - plt.axis([10**(minAmp_log10-0.2), 10**(maxAmp_log10+0.2), -1, 1]) + plt.axis([10**minAmp_log10*0.8, 10**maxAmp_log10*1.5, -1, 1]) - plt.savefig(config.directory + "/regression_mag-amp_resid_%s%devents.png" % - (label, len(mags)), dpi=150, bbox_inches='tight') + plt.savefig(config.directory + "/regression_mag-amp_resid_%s%devents_w%d.png" % + (label, len(mags), w_base), dpi=150, bbox_inches='tight') plt.close() diff --git a/version b/version index 17e51c3..d917d3e 100644 --- a/version +++ b/version @@ -1 +1 @@ -0.1.1 +0.1.2 -- GitLab