Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 165 additions & 2 deletions PWGDQ/Core/MuonMatchingMlResponse.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,37 @@ namespace o2::analysis
// possible input features for ML
enum class InputFeaturesMFTMuonMatch : uint8_t {
zMatching,
// MFT track parameters
xMFT,
yMFT,
qOverptMFT,
tglMFT,
phiMFT,
ptMFT,
etaMFT,
timeMFT,
timeResMFT,
clusterSizesAndTrackFlagsMFT,
trackTypeMFT,
dcaXY,
dcaZ,
chi2MFT,
nClustersMFT,
// MCH track parameters
xMCH,
yMCH,
qOverptMCH,
tglMCH,
phiMCH,
ptMCH,
etaMCH,
timeMCH,
timeResMCH,
nClustersMCH,
chi2MCH,
pdca,
Rabs,
// MFT covariances
cXXMFT,
cXYMFT,
cYYMFT,
Expand All @@ -80,6 +93,7 @@ enum class InputFeaturesMFTMuonMatch : uint8_t {
c1PtPhiMFT,
c1PtTglMFT,
c1Pt21Pt2MFT,
// MCH covariances
cXXMCH,
cXYMCH,
cYYMCH,
Expand All @@ -95,17 +109,32 @@ enum class InputFeaturesMFTMuonMatch : uint8_t {
c1PtPhiMCH,
c1PtTglMCH,
c1Pt21Pt2MCH,
// track residuals
deltaX,
deltaY,
deltaPhi,
deltaTgl,
deltaEta,
deltaPt,
deltaR,
deltaDirection,
sameSign,
pullX,
pullY,
pullPhi,
pullTgl,
pullEta,
pullPt,
pullR,
deltaPtRel,
// primary vertex parameters
posX,
posY,
posZ,
numContrib,
trackOccupancyInTimeRange,
ft0cOccupancyInTimeRange,
multMFT,
multFT0A,
multFT0C,
multNTracksPV,
Expand All @@ -117,10 +146,47 @@ enum class InputFeaturesMFTMuonMatch : uint8_t {
centFT0M,
centFT0A,
centFT0C,
// global forward track parameters
chi2MCHMFT,
chi2GlobMUON
chi2GlobMUON,
dcaX,
dcaY,
isAmbig
};

template <typename T1, typename T2>
float getDeltaR(T1 const& mftprop, T2 const& mchprop)
{
return std::sqrt((mchprop.getX() - mftprop.getX()) * (mchprop.getX() - mftprop.getX()) + (mchprop.getY() - mftprop.getY()) * (mchprop.getY() - mftprop.getY()));
}

template <typename T1, typename T2>
float getPullR(T1 const& mftprop, T2 const& mchprop)
{
double deltaR = getDeltaR(mftprop, mchprop);
double err2X = mftprop.getCovariances()(0, 0) + mchprop.getCovariances()(0, 0);
double err2Y = mftprop.getCovariances()(1, 1) + mchprop.getCovariances()(1, 1);
double errR = std::sqrt(err2X + err2Y);
return deltaR / errR;
}

template <typename T1, typename T2>
float getDeltaDirection(T1 const& mftprop, T2 const& mchprop)
{
double cos2 = std::cos(mchprop.getPhi()) * std::cos(mftprop.getPhi());
double sin2 = std::sin(mchprop.getPhi()) * std::sin(mftprop.getPhi());
double tgl2 = mchprop.getTgl() * mftprop.getTgl();
double cosDelta = (cos2 + sin2 + tgl2) / (std::sqrt(mchprop.getTgl() * mchprop.getTgl() + 1) * std::sqrt(mftprop.getTgl() * mftprop.getTgl() + 1));
return static_cast<float>(std::acos(std::clamp(cosDelta, -1.0, 1.0)));
}

float getPull(float mftVal, float mftErr2, float mchVal, float mchErr2)
{
float delta = mchVal - mftVal;
float err = std::sqrt(mchErr2 + mftErr2);
return (err > 0 ? delta / err : delta);
}

template <typename TypeOutputScore = float>
class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
{
Expand All @@ -135,7 +201,9 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
{
float inputFeature = 0.;
switch (idx) {
// matching parameters
CHECK_AND_FILL_FEATURE(zMatching, mftprop.getZ());
// MFT track parameters
CHECK_AND_FILL_FEATURE(xMFT, mftprop.getX());
CHECK_AND_FILL_FEATURE(yMFT, mftprop.getY());
CHECK_AND_FILL_FEATURE(qOverptMFT, mftprop.getInvQPt());
Expand All @@ -145,14 +213,27 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
CHECK_AND_FILL_FEATURE(nClustersMFT, mft.nClusters());
/*dummy value*/ CHECK_AND_FILL_FEATURE(dcaXY, 0);
/*dummy value*/ CHECK_AND_FILL_FEATURE(dcaZ, 0);
CHECK_AND_FILL_FEATURE(ptMFT, mftprop.getPt());
CHECK_AND_FILL_FEATURE(etaMFT, mftprop.getEta());
CHECK_AND_FILL_FEATURE(timeMFT, mft.trackTime());
CHECK_AND_FILL_FEATURE(timeResMFT, mft.trackTimeRes());
CHECK_AND_FILL_FEATURE(clusterSizesAndTrackFlagsMFT, mft.mftClusterSizesAndTrackFlags());
CHECK_AND_FILL_FEATURE(trackTypeMFT, (mft.isCA() ? 1 : 0));
// MCH track parameters
CHECK_AND_FILL_FEATURE(xMCH, mchprop.getX());
CHECK_AND_FILL_FEATURE(yMCH, mchprop.getY());
CHECK_AND_FILL_FEATURE(qOverptMCH, mchprop.getInvQPt());
CHECK_AND_FILL_FEATURE(tglMCH, mchprop.getTanl());
CHECK_AND_FILL_FEATURE(phiMCH, mchprop.getPhi());
CHECK_AND_FILL_FEATURE(ptMCH, mchprop.getPt());
CHECK_AND_FILL_FEATURE(etaMCH, mchprop.getEta());
CHECK_AND_FILL_FEATURE(timeMCH, mch.trackTime());
CHECK_AND_FILL_FEATURE(timeResMCH, mch.trackTimeRes());
CHECK_AND_FILL_FEATURE(nClustersMCH, mch.nClusters());
CHECK_AND_FILL_FEATURE(chi2MCH, mch.chi2());
CHECK_AND_FILL_FEATURE(pdca, muon.pDca());
CHECK_AND_FILL_FEATURE(Rabs, muon.rAtAbsorberEnd());
// MFT covariances
CHECK_AND_FILL_FEATURE(cXXMFT, mftprop.getCovariances()(0, 0));
CHECK_AND_FILL_FEATURE(cXYMFT, mftprop.getCovariances()(0, 1));
CHECK_AND_FILL_FEATURE(cYYMFT, mftprop.getCovariances()(1, 1));
Expand All @@ -168,6 +249,7 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
CHECK_AND_FILL_FEATURE(c1PtPhiMFT, mftprop.getCovariances()(2, 4));
CHECK_AND_FILL_FEATURE(c1PtTglMFT, mftprop.getCovariances()(3, 4));
CHECK_AND_FILL_FEATURE(c1Pt21Pt2MFT, mftprop.getCovariances()(4, 4));
// MCH covariances
CHECK_AND_FILL_FEATURE(cXXMCH, mchprop.getCovariances()(0, 0));
CHECK_AND_FILL_FEATURE(cXYMCH, mchprop.getCovariances()(0, 1));
CHECK_AND_FILL_FEATURE(cYYMCH, mchprop.getCovariances()(1, 1));
Expand All @@ -183,12 +265,32 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
CHECK_AND_FILL_FEATURE(c1PtPhiMCH, mchprop.getCovariances()(2, 4));
CHECK_AND_FILL_FEATURE(c1PtTglMCH, mchprop.getCovariances()(3, 4));
CHECK_AND_FILL_FEATURE(c1Pt21Pt2MCH, mchprop.getCovariances()(4, 4));
// Track residuals
CHECK_AND_FILL_FEATURE(deltaX, mchprop.getX() - mftprop.getX());
CHECK_AND_FILL_FEATURE(deltaY, mchprop.getY() - mftprop.getY());
CHECK_AND_FILL_FEATURE(deltaPhi, mchprop.getPhi() - mftprop.getPhi());
CHECK_AND_FILL_FEATURE(deltaTgl, mchprop.getTgl() - mftprop.getTgl());
CHECK_AND_FILL_FEATURE(deltaEta, mchprop.getEta() - mftprop.getEta());
CHECK_AND_FILL_FEATURE(deltaPt, mchprop.getPt() - mftprop.getPt());
CHECK_AND_FILL_FEATURE(deltaR, getDeltaR(mftprop, mchprop));
CHECK_AND_FILL_FEATURE(deltaDirection, getDeltaDirection(mftprop, mchprop));
CHECK_AND_FILL_FEATURE(deltaPtRel, (mchprop.getPt() - mftprop.getPt()) / (mchprop.getPt() + mftprop.getPt()));
CHECK_AND_FILL_FEATURE(sameSign, (mch.sign() == mft.sign()) ? 1 : 0);
CHECK_AND_FILL_FEATURE(pullX, getPull(mftprop.getX(), mftprop.getCovariances()(0, 0), mchprop.getX(), mchprop.getCovariances()(0, 0)));
CHECK_AND_FILL_FEATURE(pullY, getPull(mftprop.getY(), mftprop.getCovariances()(1, 1), mchprop.getY(), mchprop.getCovariances()(1, 1)));
CHECK_AND_FILL_FEATURE(pullPhi, getPull(mftprop.getPhi(), mftprop.getCovariances()(2, 2), mchprop.getPhi(), mchprop.getCovariances()(2, 2)));
CHECK_AND_FILL_FEATURE(pullTgl, getPull(mftprop.getTgl(), mftprop.getCovariances()(3, 3), mchprop.getTgl(), mchprop.getCovariances()(3, 3)));
/*dummy value*/ CHECK_AND_FILL_FEATURE(pullEta, 0);
/*dummy value*/ CHECK_AND_FILL_FEATURE(pullPt, 0);
CHECK_AND_FILL_FEATURE(pullR, getPullR(mftprop, mchprop));
// primary vertex parameters
CHECK_AND_FILL_FEATURE(posX, collision.posX());
CHECK_AND_FILL_FEATURE(posY, collision.posY());
CHECK_AND_FILL_FEATURE(posZ, collision.posZ());
CHECK_AND_FILL_FEATURE(numContrib, collision.numContrib());
CHECK_AND_FILL_FEATURE(trackOccupancyInTimeRange, collision.trackOccupancyInTimeRange());
CHECK_AND_FILL_FEATURE(ft0cOccupancyInTimeRange, collision.ft0cOccupancyInTimeRange());
CHECK_AND_FILL_FEATURE(multMFT, collision.mftNtracks());
CHECK_AND_FILL_FEATURE(multFT0A, collision.multFT0A());
CHECK_AND_FILL_FEATURE(multFT0C, collision.multFT0C());
CHECK_AND_FILL_FEATURE(multNTracksPV, collision.multNTracksPV());
Expand All @@ -200,7 +302,12 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
CHECK_AND_FILL_FEATURE(centFT0M, collision.centFT0M());
CHECK_AND_FILL_FEATURE(centFT0A, collision.centFT0A());
CHECK_AND_FILL_FEATURE(centFT0C, collision.centFT0C());
// global forward track parameters
CHECK_AND_FILL_FEATURE(chi2MCHMFT, muon.chi2MatchMCHMFT());
CHECK_AND_FILL_FEATURE(chi2GlobMUON, muon.chi2());
CHECK_AND_FILL_FEATURE(dcaX, muon.fwdDcaX());
CHECK_AND_FILL_FEATURE(dcaY, muon.fwdDcaX());
CHECK_AND_FILL_FEATURE(isAmbig, (muon.compatibleCollIds().size() == 1) ? 0 : 1);
}
return inputFeature;
}
Expand Down Expand Up @@ -243,25 +350,39 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
void setAvailableInputFeatures()
{
MlResponse<TypeOutputScore>::mAvailableInputFeatures = {
// matching parameters
FILL_MAP_MFTMUON_MATCH(zMatching),
// MFT track parameters
FILL_MAP_MFTMUON_MATCH(xMFT),
FILL_MAP_MFTMUON_MATCH(yMFT),
FILL_MAP_MFTMUON_MATCH(qOverptMFT),
FILL_MAP_MFTMUON_MATCH(tglMFT),
FILL_MAP_MFTMUON_MATCH(phiMFT),
FILL_MAP_MFTMUON_MATCH(ptMFT),
FILL_MAP_MFTMUON_MATCH(etaMFT),
FILL_MAP_MFTMUON_MATCH(timeMFT),
FILL_MAP_MFTMUON_MATCH(timeResMFT),
FILL_MAP_MFTMUON_MATCH(dcaXY),
FILL_MAP_MFTMUON_MATCH(dcaZ),
FILL_MAP_MFTMUON_MATCH(chi2MFT),
FILL_MAP_MFTMUON_MATCH(clusterSizesAndTrackFlagsMFT),
FILL_MAP_MFTMUON_MATCH(trackTypeMFT),
FILL_MAP_MFTMUON_MATCH(nClustersMFT),
// MCH track parameters
FILL_MAP_MFTMUON_MATCH(xMCH),
FILL_MAP_MFTMUON_MATCH(yMCH),
FILL_MAP_MFTMUON_MATCH(qOverptMCH),
FILL_MAP_MFTMUON_MATCH(tglMCH),
FILL_MAP_MFTMUON_MATCH(phiMCH),
FILL_MAP_MFTMUON_MATCH(nClustersMCH),
FILL_MAP_MFTMUON_MATCH(ptMCH),
FILL_MAP_MFTMUON_MATCH(etaMCH),
FILL_MAP_MFTMUON_MATCH(timeMCH),
FILL_MAP_MFTMUON_MATCH(timeResMCH),
FILL_MAP_MFTMUON_MATCH(chi2MCH),
FILL_MAP_MFTMUON_MATCH(pdca),
FILL_MAP_MFTMUON_MATCH(Rabs),
// MFT covariances
FILL_MAP_MFTMUON_MATCH(cXXMFT),
FILL_MAP_MFTMUON_MATCH(cXYMFT),
FILL_MAP_MFTMUON_MATCH(cYYMFT),
Expand All @@ -277,6 +398,7 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
FILL_MAP_MFTMUON_MATCH(c1PtPhiMFT),
FILL_MAP_MFTMUON_MATCH(c1PtTglMFT),
FILL_MAP_MFTMUON_MATCH(c1Pt21Pt2MFT),
// MCH covariances
FILL_MAP_MFTMUON_MATCH(cXXMCH),
FILL_MAP_MFTMUON_MATCH(cXYMCH),
FILL_MAP_MFTMUON_MATCH(cYYMCH),
Expand All @@ -292,8 +414,49 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
FILL_MAP_MFTMUON_MATCH(c1PtPhiMCH),
FILL_MAP_MFTMUON_MATCH(c1PtTglMCH),
FILL_MAP_MFTMUON_MATCH(c1Pt21Pt2MCH),
// track residuals
FILL_MAP_MFTMUON_MATCH(deltaX),
FILL_MAP_MFTMUON_MATCH(deltaY),
FILL_MAP_MFTMUON_MATCH(deltaPhi),
FILL_MAP_MFTMUON_MATCH(deltaTgl),
FILL_MAP_MFTMUON_MATCH(deltaEta),
FILL_MAP_MFTMUON_MATCH(deltaPt),
FILL_MAP_MFTMUON_MATCH(deltaR),
FILL_MAP_MFTMUON_MATCH(deltaDirection),
FILL_MAP_MFTMUON_MATCH(sameSign),
FILL_MAP_MFTMUON_MATCH(pullX),
FILL_MAP_MFTMUON_MATCH(pullY),
FILL_MAP_MFTMUON_MATCH(pullPhi),
FILL_MAP_MFTMUON_MATCH(pullTgl),
FILL_MAP_MFTMUON_MATCH(pullEta),
FILL_MAP_MFTMUON_MATCH(pullPt),
FILL_MAP_MFTMUON_MATCH(pullR),
FILL_MAP_MFTMUON_MATCH(deltaPtRel),
// primary vertex parameters
FILL_MAP_MFTMUON_MATCH(posX),
FILL_MAP_MFTMUON_MATCH(posY),
FILL_MAP_MFTMUON_MATCH(posZ),
FILL_MAP_MFTMUON_MATCH(numContrib),
FILL_MAP_MFTMUON_MATCH(trackOccupancyInTimeRange),
FILL_MAP_MFTMUON_MATCH(ft0cOccupancyInTimeRange),
FILL_MAP_MFTMUON_MATCH(multMFT),
FILL_MAP_MFTMUON_MATCH(multFT0A),
FILL_MAP_MFTMUON_MATCH(multFT0C),
FILL_MAP_MFTMUON_MATCH(multNTracksPV),
FILL_MAP_MFTMUON_MATCH(multNTracksPVeta1),
FILL_MAP_MFTMUON_MATCH(multNTracksPVetaHalf),
FILL_MAP_MFTMUON_MATCH(isInelGt0),
FILL_MAP_MFTMUON_MATCH(isInelGt1),
FILL_MAP_MFTMUON_MATCH(multFT0M),
FILL_MAP_MFTMUON_MATCH(centFT0M),
FILL_MAP_MFTMUON_MATCH(centFT0A),
FILL_MAP_MFTMUON_MATCH(centFT0C),
// global forward track parameters
FILL_MAP_MFTMUON_MATCH(chi2MCHMFT),
FILL_MAP_MFTMUON_MATCH(chi2GlobMUON)};
FILL_MAP_MFTMUON_MATCH(chi2GlobMUON),
FILL_MAP_MFTMUON_MATCH(dcaX),
FILL_MAP_MFTMUON_MATCH(dcaY),
FILL_MAP_MFTMUON_MATCH(isAmbig)};
}

uint8_t mCachedIndexBinning; // index correspondance between configurable and available input features
Expand Down
8 changes: 6 additions & 2 deletions PWGDQ/Tasks/qaMatching.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
#include "PWGDQ/Core/VarManager.h"

#include "Common/CCDB/RCTSelectionFlags.h"
#include "Common/DataModel/Centrality.h"
#include "Common/DataModel/CollisionAssociationTables.h"
#include "Common/DataModel/EventSelection.h"
#include "Common/DataModel/Multiplicity.h"
#include "Common/DataModel/TrackSelectionTables.h"
#include "Tools/ML/MlResponse.h"

#include <CCDB/BasicCCDBManager.h>
Expand Down Expand Up @@ -155,9 +159,9 @@ DECLARE_SOA_TABLE(QaMatchingCandidates, "AOD", "QAMCAND",
qamatching::PzAtVtx);
} // namespace o2::aod

using MyEvents = soa::Join<aod::Collisions, aod::EvSels>;
using MyEvents = soa::Join<aod::Collisions, aod::EvSels, aod::FT0Mults, aod::MFTMults, aod::PVMults, aod::CentFT0Ms, aod::CentFT0As, aod::CentFT0Cs>;
using MyMuons = soa::Join<aod::FwdTracks, aod::FwdTracksCov>;
using MyMuonsMC = soa::Join<aod::FwdTracks, aod::FwdTracksCov, aod::McFwdTrackLabels>;
using MyMuonsMC = soa::Join<aod::FwdTracks, aod::FwdTracksCov, aod::McFwdTrackLabels, aod::FwdTracksDCA, aod::FwdTrkCompColls>;
using MyMFTs = aod::MFTTracks;
using MyMFTCovariances = aod::MFTTracksCov;
using MyMFTsMC = soa::Join<aod::MFTTracks, aod::McMFTTrackLabels>;
Expand Down
Loading