#!/usr/bin/env python
# encoding: utf-8
# The MIT License (MIT)
# Copyright (c) 2012-2019 CNRS
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# AUTHORS
# Hervé BREDIN - http://herve.niderb.fr
from .base import BaseMetric
from .base import Precision, PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED
from .base import Recall, RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED
from .matcher import LabelMatcher, \
MATCH_TOTAL, MATCH_CORRECT, MATCH_CONFUSION, \
MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM
from .utils import UEMSupportMixin
IER_TOTAL = MATCH_TOTAL
IER_CORRECT = MATCH_CORRECT
IER_CONFUSION = MATCH_CONFUSION
IER_FALSE_ALARM = MATCH_FALSE_ALARM
IER_MISS = MATCH_MISSED_DETECTION
IER_NAME = 'identification error rate'
[docs]class IdentificationErrorRate(UEMSupportMixin, BaseMetric):
"""Identification error rate
``ier = (wc x confusion + wf x false_alarm + wm x miss) / total``
where
- `confusion` is the total confusion duration in seconds
- `false_alarm` is the total hypothesis duration where there are
- `miss` is
- `total` is the total duration of all tracks
- wc, wf and wm are optional weights (default to 1)
Parameters
----------
collar : float, optional
Duration (in seconds) of collars removed from evaluation around
boundaries of reference segments.
skip_overlap : bool, optional
Set to True to not evaluate overlap regions.
Defaults to False (i.e. keep overlap regions).
confusion, miss, false_alarm: float, optional
Optional weights for confusion, miss and false alarm respectively.
Default to 1. (no weight)
"""
@classmethod
def metric_name(cls):
return IER_NAME
@classmethod
def metric_components(cls):
return [
IER_TOTAL,
IER_CORRECT,
IER_FALSE_ALARM, IER_MISS,
IER_CONFUSION]
def __init__(self, confusion=1., miss=1., false_alarm=1.,
collar=0., skip_overlap=False, **kwargs):
super(IdentificationErrorRate, self).__init__(**kwargs)
self.matcher_ = LabelMatcher()
self.confusion = confusion
self.miss = miss
self.false_alarm = false_alarm
self.collar = collar
self.skip_overlap = skip_overlap
[docs] def compute_components(self, reference, hypothesis, uem=None,
collar=None, skip_overlap=None, **kwargs):
"""
Parameters
----------
collar : float, optional
Override self.collar
skip_overlap : bool, optional
Override self.skip_overlap
See also
--------
:class:`pyannote.metric.diarization.DiarizationErrorRate` uses these
two options in its `compute_components` method.
"""
detail = self.init_components()
if collar is None:
collar = self.collar
if skip_overlap is None:
skip_overlap = self.skip_overlap
R, H, common_timeline = self.uemify(
reference, hypothesis, uem=uem,
collar=collar, skip_overlap=skip_overlap,
returns_timeline=True)
# loop on all segments
for segment in common_timeline:
# segment duration
duration = segment.duration
# list of IDs in reference segment
r = R.get_labels(segment, unique=False)
# list of IDs in hypothesis segment
h = H.get_labels(segment, unique=False)
counts, _ = self.matcher_(r, h)
detail[IER_TOTAL] += duration * counts[IER_TOTAL]
detail[IER_CORRECT] += duration * counts[IER_CORRECT]
detail[IER_CONFUSION] += duration * counts[IER_CONFUSION]
detail[IER_MISS] += duration * counts[IER_MISS]
detail[IER_FALSE_ALARM] += duration * counts[IER_FALSE_ALARM]
return detail
[docs] def compute_metric(self, detail):
numerator = 1. * (
self.confusion * detail[IER_CONFUSION] +
self.false_alarm * detail[IER_FALSE_ALARM] +
self.miss * detail[IER_MISS]
)
denominator = 1. * detail[IER_TOTAL]
if denominator == 0.:
if numerator == 0:
return 0.
else:
return 1.
else:
return numerator / denominator
[docs]class IdentificationPrecision(UEMSupportMixin, Precision):
"""Identification Precision
Parameters
----------
collar : float, optional
Duration (in seconds) of collars removed from evaluation around
boundaries of reference segments.
skip_overlap : bool, optional
Set to True to not evaluate overlap regions.
Defaults to False (i.e. keep overlap regions).
"""
def __init__(self, collar=0., skip_overlap=False, **kwargs):
super(IdentificationPrecision, self).__init__(**kwargs)
self.collar = collar
self.skip_overlap = skip_overlap
self.matcher_ = LabelMatcher()
[docs] def compute_components(self, reference, hypothesis, uem=None, **kwargs):
detail = self.init_components()
R, H, common_timeline = self.uemify(
reference, hypothesis, uem=uem,
collar=self.collar, skip_overlap=self.skip_overlap,
returns_timeline=True)
# loop on all segments
for segment in common_timeline:
# segment duration
duration = segment.duration
# list of IDs in reference segment
r = R.get_labels(segment, unique=False)
# list of IDs in hypothesis segment
h = H.get_labels(segment, unique=False)
counts, _ = self.matcher_(r, h)
detail[PRECISION_RETRIEVED] += duration * len(h)
detail[PRECISION_RELEVANT_RETRIEVED] += \
duration * counts[IER_CORRECT]
return detail
[docs]class IdentificationRecall(UEMSupportMixin, Recall):
"""Identification Recall
Parameters
----------
collar : float, optional
Duration (in seconds) of collars removed from evaluation around
boundaries of reference segments.
skip_overlap : bool, optional
Set to True to not evaluate overlap regions.
Defaults to False (i.e. keep overlap regions).
"""
def __init__(self, collar=0., skip_overlap=False, **kwargs):
super(IdentificationRecall, self).__init__(**kwargs)
self.collar = collar
self.skip_overlap = skip_overlap
self.matcher_ = LabelMatcher()
[docs] def compute_components(self, reference, hypothesis, uem=None, **kwargs):
detail = self.init_components()
R, H, common_timeline = self.uemify(
reference, hypothesis, uem=uem,
collar=self.collar, skip_overlap=self.skip_overlap,
returns_timeline=True)
# loop on all segments
for segment in common_timeline:
# segment duration
duration = segment.duration
# list of IDs in reference segment
r = R.get_labels(segment, unique=False)
# list of IDs in hypothesis segment
h = H.get_labels(segment, unique=False)
counts, _ = self.matcher_(r, h)
detail[RECALL_RELEVANT] += duration * counts[IER_TOTAL]
detail[RECALL_RELEVANT_RETRIEVED] += duration * counts[IER_CORRECT]
return detail