Skip to content

Commit 5346339

Browse files
author
Paul Ramirez
committed
Fix One Euro filter's units of computation
Changed the units that the One Euro filter uses to compute the filtered coordinates. This was causing a crash because if two timestamps were sufficiently close to each other, by the time of implicitly converting from nanoseconds to seconds, they were considered equal. This led to a zero division when calculating the sampling frequency. Now, everything is handled in the scale of nanoseconds, and conversion are done if and only if they're necessary. Bug: 297226446 Flag: EXEMPT bugfix Test: TEST=libinput_tests; m $TEST && $ANDROID_HOST_OUT/nativetest64/$TEST/$TEST Change-Id: I7fced6db447074cccb3d938eb9dc7a9707433f53
1 parent 08ee199 commit 5346339

8 files changed

Lines changed: 258 additions & 23 deletions

File tree

include/input/CoordinateFilter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class CoordinateFilter {
4444
* the previous call.
4545
* @param coords Coordinates to be overwritten by the corresponding filtered coordinates.
4646
*/
47-
void filter(std::chrono::duration<float> timestamp, PointerCoords& coords);
47+
void filter(std::chrono::nanoseconds timestamp, PointerCoords& coords);
4848

4949
private:
5050
OneEuroFilter mXFilter;

include/input/OneEuroFilter.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class OneEuroFilter {
5656
* provided in the previous call.
5757
* @param rawPosition Position to be filtered.
5858
*/
59-
float filter(std::chrono::duration<float> timestamp, float rawPosition);
59+
float filter(std::chrono::nanoseconds timestamp, float rawPosition);
6060

6161
private:
6262
/**
@@ -67,7 +67,7 @@ class OneEuroFilter {
6767

6868
/**
6969
* Slope of the cutoff frequency criterion. This is the term scaling the absolute value of the
70-
* filtered signal's speed. The data member is dimensionless, that is, it does not have units.
70+
* filtered signal's speed. Units are 1 / position.
7171
*/
7272
const float mBeta;
7373

@@ -78,17 +78,17 @@ class OneEuroFilter {
7878
const float mSpeedCutoffFreq;
7979

8080
/**
81-
* The timestamp from the previous call. Units are seconds.
81+
* The timestamp from the previous call.
8282
*/
83-
std::optional<std::chrono::duration<float>> mPrevTimestamp;
83+
std::optional<std::chrono::nanoseconds> mPrevTimestamp;
8484

8585
/**
8686
* The raw position from the previous call.
8787
*/
8888
std::optional<float> mPrevRawPosition;
8989

9090
/**
91-
* The filtered velocity from the previous call. Units are position per second.
91+
* The filtered velocity from the previous call. Units are position per nanosecond.
9292
*/
9393
std::optional<float> mPrevFilteredVelocity;
9494

libs/input/CoordinateFilter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace android {
2323
CoordinateFilter::CoordinateFilter(float minCutoffFreq, float beta)
2424
: mXFilter{minCutoffFreq, beta}, mYFilter{minCutoffFreq, beta} {}
2525

26-
void CoordinateFilter::filter(std::chrono::duration<float> timestamp, PointerCoords& coords) {
26+
void CoordinateFilter::filter(std::chrono::nanoseconds timestamp, PointerCoords& coords) {
2727
coords.setAxisValue(AMOTION_EVENT_AXIS_X, mXFilter.filter(timestamp, coords.getX()));
2828
coords.setAxisValue(AMOTION_EVENT_AXIS_Y, mYFilter.filter(timestamp, coords.getY()));
2929
}

libs/input/OneEuroFilter.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,42 @@
2525
namespace android {
2626
namespace {
2727

28+
using namespace std::literals::chrono_literals;
29+
30+
const float kHertzPerGigahertz = 1E9f;
31+
const float kGigahertzPerHertz = 1E-9f;
32+
33+
// filteredSpeed's units are position per nanosecond. beta's units are 1 / position.
2834
inline float cutoffFreq(float minCutoffFreq, float beta, float filteredSpeed) {
29-
return minCutoffFreq + beta * std::abs(filteredSpeed);
35+
return kHertzPerGigahertz *
36+
((minCutoffFreq * kGigahertzPerHertz) + beta * std::abs(filteredSpeed));
3037
}
3138

32-
inline float smoothingFactor(std::chrono::duration<float> samplingPeriod, float cutoffFreq) {
33-
return samplingPeriod.count() / (samplingPeriod.count() + (1.0 / (2.0 * M_PI * cutoffFreq)));
39+
inline float smoothingFactor(std::chrono::nanoseconds samplingPeriod, float cutoffFreq) {
40+
const float constant = 2.0f * M_PI * samplingPeriod.count() * (cutoffFreq * kGigahertzPerHertz);
41+
return constant / (constant + 1);
3442
}
3543

36-
inline float lowPassFilter(float rawPosition, float prevFilteredPosition, float smoothingFactor) {
37-
return smoothingFactor * rawPosition + (1 - smoothingFactor) * prevFilteredPosition;
44+
inline float lowPassFilter(float rawValue, float prevFilteredValue, float smoothingFactor) {
45+
return smoothingFactor * rawValue + (1 - smoothingFactor) * prevFilteredValue;
3846
}
3947

4048
} // namespace
4149

4250
OneEuroFilter::OneEuroFilter(float minCutoffFreq, float beta, float speedCutoffFreq)
4351
: mMinCutoffFreq{minCutoffFreq}, mBeta{beta}, mSpeedCutoffFreq{speedCutoffFreq} {}
4452

45-
float OneEuroFilter::filter(std::chrono::duration<float> timestamp, float rawPosition) {
46-
LOG_IF(FATAL, mPrevFilteredPosition.has_value() && (timestamp <= *mPrevTimestamp))
47-
<< "Timestamp must be greater than mPrevTimestamp";
53+
float OneEuroFilter::filter(std::chrono::nanoseconds timestamp, float rawPosition) {
54+
LOG_IF(FATAL, mPrevTimestamp.has_value() && (*mPrevTimestamp >= timestamp))
55+
<< "Timestamp must be greater than mPrevTimestamp. Timestamp: " << timestamp.count()
56+
<< "ns. mPrevTimestamp: " << mPrevTimestamp->count() << "ns";
4857

49-
const std::chrono::duration<float> samplingPeriod = (mPrevTimestamp.has_value())
50-
? (timestamp - *mPrevTimestamp)
51-
: std::chrono::duration<float>{1.0};
58+
const std::chrono::nanoseconds samplingPeriod =
59+
(mPrevTimestamp.has_value()) ? (timestamp - *mPrevTimestamp) : 1s;
5260

5361
const float rawVelocity = (mPrevFilteredPosition.has_value())
54-
? ((rawPosition - *mPrevFilteredPosition) / samplingPeriod.count())
55-
: 0.0;
62+
? ((rawPosition - *mPrevFilteredPosition) / (samplingPeriod.count()))
63+
: 0.0f;
5664

5765
const float speedSmoothingFactor = smoothingFactor(samplingPeriod, mSpeedCutoffFreq);
5866

libs/input/tests/Android.bp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_test {
1717
"IdGenerator_test.cpp",
1818
"InputChannel_test.cpp",
1919
"InputConsumer_test.cpp",
20+
"InputConsumerFilteredResampling_test.cpp",
2021
"InputConsumerResampling_test.cpp",
2122
"InputDevice_test.cpp",
2223
"InputEvent_test.cpp",
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/**
2+
* Copyright 2024 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <input/InputConsumerNoResampling.h>
18+
19+
#include <chrono>
20+
#include <iostream>
21+
#include <memory>
22+
#include <queue>
23+
24+
#include <TestEventMatchers.h>
25+
#include <TestInputChannel.h>
26+
#include <android-base/logging.h>
27+
#include <gmock/gmock.h>
28+
#include <gtest/gtest.h>
29+
#include <input/Input.h>
30+
#include <input/InputEventBuilders.h>
31+
#include <input/Resampler.h>
32+
#include <utils/Looper.h>
33+
#include <utils/StrongPointer.h>
34+
35+
namespace android {
36+
namespace {
37+
38+
using std::chrono::nanoseconds;
39+
40+
using ::testing::AllOf;
41+
using ::testing::Matcher;
42+
43+
const int32_t ACTION_DOWN = AMOTION_EVENT_ACTION_DOWN;
44+
const int32_t ACTION_MOVE = AMOTION_EVENT_ACTION_MOVE;
45+
46+
struct Pointer {
47+
int32_t id{0};
48+
ToolType toolType{ToolType::FINGER};
49+
float x{0.0f};
50+
float y{0.0f};
51+
bool isResampled{false};
52+
53+
PointerBuilder asPointerBuilder() const {
54+
return PointerBuilder{id, toolType}.x(x).y(y).isResampled(isResampled);
55+
}
56+
};
57+
58+
} // namespace
59+
60+
class InputConsumerFilteredResamplingTest : public ::testing::Test, public InputConsumerCallbacks {
61+
protected:
62+
InputConsumerFilteredResamplingTest()
63+
: mClientTestChannel{std::make_shared<TestInputChannel>("TestChannel")},
64+
mLooper{sp<Looper>::make(/*allowNonCallbacks=*/false)} {
65+
Looper::setForThread(mLooper);
66+
mConsumer = std::make_unique<
67+
InputConsumerNoResampling>(mClientTestChannel, mLooper, *this, []() {
68+
return std::make_unique<FilteredLegacyResampler>(/*minCutoffFreq=*/4.7, /*beta=*/0.01);
69+
});
70+
}
71+
72+
void invokeLooperCallback() const {
73+
sp<LooperCallback> callback;
74+
ASSERT_TRUE(mLooper->getFdStateDebug(mClientTestChannel->getFd(), /*ident=*/nullptr,
75+
/*events=*/nullptr, &callback, /*data=*/nullptr));
76+
ASSERT_NE(callback, nullptr);
77+
callback->handleEvent(mClientTestChannel->getFd(), ALOOPER_EVENT_INPUT, /*data=*/nullptr);
78+
}
79+
80+
void assertOnBatchedInputEventPendingWasCalled() {
81+
ASSERT_GT(mOnBatchedInputEventPendingInvocationCount, 0UL)
82+
<< "onBatchedInputEventPending was not called";
83+
--mOnBatchedInputEventPendingInvocationCount;
84+
}
85+
86+
void assertReceivedMotionEvent(const Matcher<MotionEvent>& matcher) {
87+
ASSERT_TRUE(!mMotionEvents.empty()) << "No motion events were received";
88+
std::unique_ptr<MotionEvent> motionEvent = std::move(mMotionEvents.front());
89+
mMotionEvents.pop();
90+
ASSERT_NE(motionEvent, nullptr) << "The consumed motion event must not be nullptr";
91+
EXPECT_THAT(*motionEvent, matcher);
92+
}
93+
94+
InputMessage nextPointerMessage(nanoseconds eventTime, int32_t action, const Pointer& pointer);
95+
96+
std::shared_ptr<TestInputChannel> mClientTestChannel;
97+
sp<Looper> mLooper;
98+
std::unique_ptr<InputConsumerNoResampling> mConsumer;
99+
100+
// Batched input events
101+
std::queue<std::unique_ptr<KeyEvent>> mKeyEvents;
102+
std::queue<std::unique_ptr<MotionEvent>> mMotionEvents;
103+
std::queue<std::unique_ptr<FocusEvent>> mFocusEvents;
104+
std::queue<std::unique_ptr<CaptureEvent>> mCaptureEvents;
105+
std::queue<std::unique_ptr<DragEvent>> mDragEvents;
106+
std::queue<std::unique_ptr<TouchModeEvent>> mTouchModeEvents;
107+
108+
private:
109+
// InputConsumer callbacks
110+
void onKeyEvent(std::unique_ptr<KeyEvent> event, uint32_t seq) override {
111+
mKeyEvents.push(std::move(event));
112+
mConsumer->finishInputEvent(seq, /*handled=*/true);
113+
}
114+
115+
void onMotionEvent(std::unique_ptr<MotionEvent> event, uint32_t seq) override {
116+
mMotionEvents.push(std::move(event));
117+
mConsumer->finishInputEvent(seq, /*handled=*/true);
118+
}
119+
120+
void onBatchedInputEventPending(int32_t pendingBatchSource) override {
121+
if (!mConsumer->probablyHasInput()) {
122+
ADD_FAILURE() << "Should deterministically have input because there is a batch";
123+
}
124+
++mOnBatchedInputEventPendingInvocationCount;
125+
}
126+
127+
void onFocusEvent(std::unique_ptr<FocusEvent> event, uint32_t seq) override {
128+
mFocusEvents.push(std::move(event));
129+
mConsumer->finishInputEvent(seq, /*handled=*/true);
130+
}
131+
132+
void onCaptureEvent(std::unique_ptr<CaptureEvent> event, uint32_t seq) override {
133+
mCaptureEvents.push(std::move(event));
134+
mConsumer->finishInputEvent(seq, /*handled=*/true);
135+
}
136+
137+
void onDragEvent(std::unique_ptr<DragEvent> event, uint32_t seq) override {
138+
mDragEvents.push(std::move(event));
139+
mConsumer->finishInputEvent(seq, /*handled=*/true);
140+
}
141+
142+
void onTouchModeEvent(std::unique_ptr<TouchModeEvent> event, uint32_t seq) override {
143+
mTouchModeEvents.push(std::move(event));
144+
mConsumer->finishInputEvent(seq, /*handled=*/true);
145+
}
146+
147+
uint32_t mLastSeq{0};
148+
size_t mOnBatchedInputEventPendingInvocationCount{0};
149+
};
150+
151+
InputMessage InputConsumerFilteredResamplingTest::nextPointerMessage(nanoseconds eventTime,
152+
int32_t action,
153+
const Pointer& pointer) {
154+
++mLastSeq;
155+
return InputMessageBuilder{InputMessage::Type::MOTION, mLastSeq}
156+
.eventTime(eventTime.count())
157+
.source(AINPUT_SOURCE_TOUCHSCREEN)
158+
.action(action)
159+
.pointer(pointer.asPointerBuilder())
160+
.build();
161+
}
162+
163+
TEST_F(InputConsumerFilteredResamplingTest, NeighboringTimestampsDoNotResultInZeroDivision) {
164+
mClientTestChannel->enqueueMessage(
165+
nextPointerMessage(0ms, ACTION_DOWN, Pointer{.x = 0.0f, .y = 0.0f}));
166+
167+
invokeLooperCallback();
168+
169+
assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_DOWN), WithSampleCount(1)));
170+
171+
const std::chrono::nanoseconds initialTime{56'821'700'000'000};
172+
173+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 4'929'000ns, ACTION_MOVE,
174+
Pointer{.x = 1.0f, .y = 1.0f}));
175+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 9'352'000ns, ACTION_MOVE,
176+
Pointer{.x = 2.0f, .y = 2.0f}));
177+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 14'531'000ns, ACTION_MOVE,
178+
Pointer{.x = 3.0f, .y = 3.0f}));
179+
180+
invokeLooperCallback();
181+
mConsumer->consumeBatchedInputEvents(initialTime.count() + 18'849'395 /*ns*/);
182+
183+
assertOnBatchedInputEventPendingWasCalled();
184+
// Three samples are expected. The first two of the batch, and the resampled one. The
185+
// coordinates of the resampled sample are hardcoded because the matcher requires them. However,
186+
// the primary intention here is to check that the last sample is resampled.
187+
assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_MOVE), WithSampleCount(3),
188+
WithSample(/*sampleIndex=*/2,
189+
Sample{initialTime + 13'849'395ns,
190+
{PointerArgs{.x = 1.3286f,
191+
.y = 1.3286f,
192+
.isResampled = true}}})));
193+
194+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 20'363'000ns, ACTION_MOVE,
195+
Pointer{.x = 4.0f, .y = 4.0f}));
196+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 25'745'000ns, ACTION_MOVE,
197+
Pointer{.x = 5.0f, .y = 5.0f}));
198+
// This sample is part of the stream of messages, but should not be consumed because its
199+
// timestamp is greater than the ajusted frame time.
200+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 31'337'000ns, ACTION_MOVE,
201+
Pointer{.x = 6.0f, .y = 6.0f}));
202+
203+
invokeLooperCallback();
204+
mConsumer->consumeBatchedInputEvents(initialTime.count() + 35'516'062 /*ns*/);
205+
206+
assertOnBatchedInputEventPendingWasCalled();
207+
// Four samples are expected because the last sample of the previous batch was not consumed.
208+
assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_MOVE), WithSampleCount(4)));
209+
210+
mClientTestChannel->assertFinishMessage(/*seq=*/1, /*handled=*/true);
211+
mClientTestChannel->assertFinishMessage(/*seq=*/2, /*handled=*/true);
212+
mClientTestChannel->assertFinishMessage(/*seq=*/3, /*handled=*/true);
213+
mClientTestChannel->assertFinishMessage(/*seq=*/4, /*handled=*/true);
214+
mClientTestChannel->assertFinishMessage(/*seq=*/5, /*handled=*/true);
215+
mClientTestChannel->assertFinishMessage(/*seq=*/6, /*handled=*/true);
216+
}
217+
218+
} // namespace android

libs/input/tests/OneEuroFilter_test.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ class OneEuroFilterTest : public ::testing::Test {
9898
std::vector<Sample> filteredSignal;
9999
for (const Sample& sample : signal) {
100100
filteredSignal.push_back(
101-
Sample{sample.timestamp, mFilter.filter(sample.timestamp, sample.value)});
101+
Sample{sample.timestamp,
102+
mFilter.filter(std::chrono::duration_cast<std::chrono::nanoseconds>(
103+
sample.timestamp),
104+
sample.value)});
102105
}
103106
return filteredSignal;
104107
}

libs/input/tests/TestEventMatchers.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include <chrono>
20+
#include <cmath>
2021
#include <ostream>
2122
#include <vector>
2223

@@ -150,14 +151,18 @@ class WithSampleMatcher {
150151
++pointerIndex) {
151152
const PointerCoords& pointerCoords =
152153
*(motionEvent.getHistoricalRawPointerCoords(pointerIndex, mSampleIndex));
153-
if ((pointerCoords.getX() != mSample.pointers[pointerIndex].x) ||
154-
(pointerCoords.getY() != mSample.pointers[pointerIndex].y)) {
154+
155+
if ((std::abs(pointerCoords.getX() - mSample.pointers[pointerIndex].x) >
156+
MotionEvent::ROUNDING_PRECISION) ||
157+
(std::abs(pointerCoords.getY() - mSample.pointers[pointerIndex].y) >
158+
MotionEvent::ROUNDING_PRECISION)) {
155159
*os << "sample coordinates mismatch at pointer index " << pointerIndex
156160
<< ". sample: (" << pointerCoords.getX() << ", " << pointerCoords.getY()
157161
<< ") expected: (" << mSample.pointers[pointerIndex].x << ", "
158162
<< mSample.pointers[pointerIndex].y << ")";
159163
return false;
160164
}
165+
161166
if (motionEvent.isResampled(pointerIndex, mSampleIndex) !=
162167
mSample.pointers[pointerIndex].isResampled) {
163168
*os << "resampling flag mismatch. sample: "

0 commit comments

Comments
 (0)