Skip to content

Commit 4f9e86f

Browse files
Treehugger RobotAndroid (Google) Code Review
authored andcommitted
Merge "Fix One Euro filter's units of computation" into main
2 parents 38d3681 + 5346339 commit 4f9e86f

8 files changed

Lines changed: 258 additions & 23 deletions

File tree

include/input/CoordinateFilter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class CoordinateFilter {
4444
* the previous call.
4545
* @param coords Coordinates to be overwritten by the corresponding filtered coordinates.
4646
*/
47-
void filter(std::chrono::duration<float> timestamp, PointerCoords& coords);
47+
void filter(std::chrono::nanoseconds timestamp, PointerCoords& coords);
4848

4949
private:
5050
OneEuroFilter mXFilter;

include/input/OneEuroFilter.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class OneEuroFilter {
5656
* provided in the previous call.
5757
* @param rawPosition Position to be filtered.
5858
*/
59-
float filter(std::chrono::duration<float> timestamp, float rawPosition);
59+
float filter(std::chrono::nanoseconds timestamp, float rawPosition);
6060

6161
private:
6262
/**
@@ -67,7 +67,7 @@ class OneEuroFilter {
6767

6868
/**
6969
* Slope of the cutoff frequency criterion. This is the term scaling the absolute value of the
70-
* filtered signal's speed. The data member is dimensionless, that is, it does not have units.
70+
* filtered signal's speed. Units are 1 / position.
7171
*/
7272
const float mBeta;
7373

@@ -78,17 +78,17 @@ class OneEuroFilter {
7878
const float mSpeedCutoffFreq;
7979

8080
/**
81-
* The timestamp from the previous call. Units are seconds.
81+
* The timestamp from the previous call.
8282
*/
83-
std::optional<std::chrono::duration<float>> mPrevTimestamp;
83+
std::optional<std::chrono::nanoseconds> mPrevTimestamp;
8484

8585
/**
8686
* The raw position from the previous call.
8787
*/
8888
std::optional<float> mPrevRawPosition;
8989

9090
/**
91-
* The filtered velocity from the previous call. Units are position per second.
91+
* The filtered velocity from the previous call. Units are position per nanosecond.
9292
*/
9393
std::optional<float> mPrevFilteredVelocity;
9494

libs/input/CoordinateFilter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace android {
2323
CoordinateFilter::CoordinateFilter(float minCutoffFreq, float beta)
2424
: mXFilter{minCutoffFreq, beta}, mYFilter{minCutoffFreq, beta} {}
2525

26-
void CoordinateFilter::filter(std::chrono::duration<float> timestamp, PointerCoords& coords) {
26+
void CoordinateFilter::filter(std::chrono::nanoseconds timestamp, PointerCoords& coords) {
2727
coords.setAxisValue(AMOTION_EVENT_AXIS_X, mXFilter.filter(timestamp, coords.getX()));
2828
coords.setAxisValue(AMOTION_EVENT_AXIS_Y, mYFilter.filter(timestamp, coords.getY()));
2929
}

libs/input/OneEuroFilter.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,42 @@
2525
namespace android {
2626
namespace {
2727

28+
using namespace std::literals::chrono_literals;
29+
30+
const float kHertzPerGigahertz = 1E9f;
31+
const float kGigahertzPerHertz = 1E-9f;
32+
33+
// filteredSpeed's units are position per nanosecond. beta's units are 1 / position.
2834
inline float cutoffFreq(float minCutoffFreq, float beta, float filteredSpeed) {
29-
return minCutoffFreq + beta * std::abs(filteredSpeed);
35+
return kHertzPerGigahertz *
36+
((minCutoffFreq * kGigahertzPerHertz) + beta * std::abs(filteredSpeed));
3037
}
3138

32-
inline float smoothingFactor(std::chrono::duration<float> samplingPeriod, float cutoffFreq) {
33-
return samplingPeriod.count() / (samplingPeriod.count() + (1.0 / (2.0 * M_PI * cutoffFreq)));
39+
inline float smoothingFactor(std::chrono::nanoseconds samplingPeriod, float cutoffFreq) {
40+
const float constant = 2.0f * M_PI * samplingPeriod.count() * (cutoffFreq * kGigahertzPerHertz);
41+
return constant / (constant + 1);
3442
}
3543

36-
inline float lowPassFilter(float rawPosition, float prevFilteredPosition, float smoothingFactor) {
37-
return smoothingFactor * rawPosition + (1 - smoothingFactor) * prevFilteredPosition;
44+
inline float lowPassFilter(float rawValue, float prevFilteredValue, float smoothingFactor) {
45+
return smoothingFactor * rawValue + (1 - smoothingFactor) * prevFilteredValue;
3846
}
3947

4048
} // namespace
4149

4250
OneEuroFilter::OneEuroFilter(float minCutoffFreq, float beta, float speedCutoffFreq)
4351
: mMinCutoffFreq{minCutoffFreq}, mBeta{beta}, mSpeedCutoffFreq{speedCutoffFreq} {}
4452

45-
float OneEuroFilter::filter(std::chrono::duration<float> timestamp, float rawPosition) {
46-
LOG_IF(FATAL, mPrevFilteredPosition.has_value() && (timestamp <= *mPrevTimestamp))
47-
<< "Timestamp must be greater than mPrevTimestamp";
53+
float OneEuroFilter::filter(std::chrono::nanoseconds timestamp, float rawPosition) {
54+
LOG_IF(FATAL, mPrevTimestamp.has_value() && (*mPrevTimestamp >= timestamp))
55+
<< "Timestamp must be greater than mPrevTimestamp. Timestamp: " << timestamp.count()
56+
<< "ns. mPrevTimestamp: " << mPrevTimestamp->count() << "ns";
4857

49-
const std::chrono::duration<float> samplingPeriod = (mPrevTimestamp.has_value())
50-
? (timestamp - *mPrevTimestamp)
51-
: std::chrono::duration<float>{1.0};
58+
const std::chrono::nanoseconds samplingPeriod =
59+
(mPrevTimestamp.has_value()) ? (timestamp - *mPrevTimestamp) : 1s;
5260

5361
const float rawVelocity = (mPrevFilteredPosition.has_value())
54-
? ((rawPosition - *mPrevFilteredPosition) / samplingPeriod.count())
55-
: 0.0;
62+
? ((rawPosition - *mPrevFilteredPosition) / (samplingPeriod.count()))
63+
: 0.0f;
5664

5765
const float speedSmoothingFactor = smoothingFactor(samplingPeriod, mSpeedCutoffFreq);
5866

libs/input/tests/Android.bp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_test {
1717
"IdGenerator_test.cpp",
1818
"InputChannel_test.cpp",
1919
"InputConsumer_test.cpp",
20+
"InputConsumerFilteredResampling_test.cpp",
2021
"InputConsumerResampling_test.cpp",
2122
"InputDevice_test.cpp",
2223
"InputEvent_test.cpp",
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/**
2+
* Copyright 2024 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <input/InputConsumerNoResampling.h>
18+
19+
#include <chrono>
20+
#include <iostream>
21+
#include <memory>
22+
#include <queue>
23+
24+
#include <TestEventMatchers.h>
25+
#include <TestInputChannel.h>
26+
#include <android-base/logging.h>
27+
#include <gmock/gmock.h>
28+
#include <gtest/gtest.h>
29+
#include <input/Input.h>
30+
#include <input/InputEventBuilders.h>
31+
#include <input/Resampler.h>
32+
#include <utils/Looper.h>
33+
#include <utils/StrongPointer.h>
34+
35+
namespace android {
36+
namespace {
37+
38+
using std::chrono::nanoseconds;
39+
40+
using ::testing::AllOf;
41+
using ::testing::Matcher;
42+
43+
const int32_t ACTION_DOWN = AMOTION_EVENT_ACTION_DOWN;
44+
const int32_t ACTION_MOVE = AMOTION_EVENT_ACTION_MOVE;
45+
46+
struct Pointer {
47+
int32_t id{0};
48+
ToolType toolType{ToolType::FINGER};
49+
float x{0.0f};
50+
float y{0.0f};
51+
bool isResampled{false};
52+
53+
PointerBuilder asPointerBuilder() const {
54+
return PointerBuilder{id, toolType}.x(x).y(y).isResampled(isResampled);
55+
}
56+
};
57+
58+
} // namespace
59+
60+
class InputConsumerFilteredResamplingTest : public ::testing::Test, public InputConsumerCallbacks {
61+
protected:
62+
InputConsumerFilteredResamplingTest()
63+
: mClientTestChannel{std::make_shared<TestInputChannel>("TestChannel")},
64+
mLooper{sp<Looper>::make(/*allowNonCallbacks=*/false)} {
65+
Looper::setForThread(mLooper);
66+
mConsumer = std::make_unique<
67+
InputConsumerNoResampling>(mClientTestChannel, mLooper, *this, []() {
68+
return std::make_unique<FilteredLegacyResampler>(/*minCutoffFreq=*/4.7, /*beta=*/0.01);
69+
});
70+
}
71+
72+
void invokeLooperCallback() const {
73+
sp<LooperCallback> callback;
74+
ASSERT_TRUE(mLooper->getFdStateDebug(mClientTestChannel->getFd(), /*ident=*/nullptr,
75+
/*events=*/nullptr, &callback, /*data=*/nullptr));
76+
ASSERT_NE(callback, nullptr);
77+
callback->handleEvent(mClientTestChannel->getFd(), ALOOPER_EVENT_INPUT, /*data=*/nullptr);
78+
}
79+
80+
void assertOnBatchedInputEventPendingWasCalled() {
81+
ASSERT_GT(mOnBatchedInputEventPendingInvocationCount, 0UL)
82+
<< "onBatchedInputEventPending was not called";
83+
--mOnBatchedInputEventPendingInvocationCount;
84+
}
85+
86+
void assertReceivedMotionEvent(const Matcher<MotionEvent>& matcher) {
87+
ASSERT_TRUE(!mMotionEvents.empty()) << "No motion events were received";
88+
std::unique_ptr<MotionEvent> motionEvent = std::move(mMotionEvents.front());
89+
mMotionEvents.pop();
90+
ASSERT_NE(motionEvent, nullptr) << "The consumed motion event must not be nullptr";
91+
EXPECT_THAT(*motionEvent, matcher);
92+
}
93+
94+
InputMessage nextPointerMessage(nanoseconds eventTime, int32_t action, const Pointer& pointer);
95+
96+
std::shared_ptr<TestInputChannel> mClientTestChannel;
97+
sp<Looper> mLooper;
98+
std::unique_ptr<InputConsumerNoResampling> mConsumer;
99+
100+
// Batched input events
101+
std::queue<std::unique_ptr<KeyEvent>> mKeyEvents;
102+
std::queue<std::unique_ptr<MotionEvent>> mMotionEvents;
103+
std::queue<std::unique_ptr<FocusEvent>> mFocusEvents;
104+
std::queue<std::unique_ptr<CaptureEvent>> mCaptureEvents;
105+
std::queue<std::unique_ptr<DragEvent>> mDragEvents;
106+
std::queue<std::unique_ptr<TouchModeEvent>> mTouchModeEvents;
107+
108+
private:
109+
// InputConsumer callbacks
110+
void onKeyEvent(std::unique_ptr<KeyEvent> event, uint32_t seq) override {
111+
mKeyEvents.push(std::move(event));
112+
mConsumer->finishInputEvent(seq, /*handled=*/true);
113+
}
114+
115+
void onMotionEvent(std::unique_ptr<MotionEvent> event, uint32_t seq) override {
116+
mMotionEvents.push(std::move(event));
117+
mConsumer->finishInputEvent(seq, /*handled=*/true);
118+
}
119+
120+
void onBatchedInputEventPending(int32_t pendingBatchSource) override {
121+
if (!mConsumer->probablyHasInput()) {
122+
ADD_FAILURE() << "Should deterministically have input because there is a batch";
123+
}
124+
++mOnBatchedInputEventPendingInvocationCount;
125+
}
126+
127+
void onFocusEvent(std::unique_ptr<FocusEvent> event, uint32_t seq) override {
128+
mFocusEvents.push(std::move(event));
129+
mConsumer->finishInputEvent(seq, /*handled=*/true);
130+
}
131+
132+
void onCaptureEvent(std::unique_ptr<CaptureEvent> event, uint32_t seq) override {
133+
mCaptureEvents.push(std::move(event));
134+
mConsumer->finishInputEvent(seq, /*handled=*/true);
135+
}
136+
137+
void onDragEvent(std::unique_ptr<DragEvent> event, uint32_t seq) override {
138+
mDragEvents.push(std::move(event));
139+
mConsumer->finishInputEvent(seq, /*handled=*/true);
140+
}
141+
142+
void onTouchModeEvent(std::unique_ptr<TouchModeEvent> event, uint32_t seq) override {
143+
mTouchModeEvents.push(std::move(event));
144+
mConsumer->finishInputEvent(seq, /*handled=*/true);
145+
}
146+
147+
uint32_t mLastSeq{0};
148+
size_t mOnBatchedInputEventPendingInvocationCount{0};
149+
};
150+
151+
InputMessage InputConsumerFilteredResamplingTest::nextPointerMessage(nanoseconds eventTime,
152+
int32_t action,
153+
const Pointer& pointer) {
154+
++mLastSeq;
155+
return InputMessageBuilder{InputMessage::Type::MOTION, mLastSeq}
156+
.eventTime(eventTime.count())
157+
.source(AINPUT_SOURCE_TOUCHSCREEN)
158+
.action(action)
159+
.pointer(pointer.asPointerBuilder())
160+
.build();
161+
}
162+
163+
TEST_F(InputConsumerFilteredResamplingTest, NeighboringTimestampsDoNotResultInZeroDivision) {
164+
mClientTestChannel->enqueueMessage(
165+
nextPointerMessage(0ms, ACTION_DOWN, Pointer{.x = 0.0f, .y = 0.0f}));
166+
167+
invokeLooperCallback();
168+
169+
assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_DOWN), WithSampleCount(1)));
170+
171+
const std::chrono::nanoseconds initialTime{56'821'700'000'000};
172+
173+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 4'929'000ns, ACTION_MOVE,
174+
Pointer{.x = 1.0f, .y = 1.0f}));
175+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 9'352'000ns, ACTION_MOVE,
176+
Pointer{.x = 2.0f, .y = 2.0f}));
177+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 14'531'000ns, ACTION_MOVE,
178+
Pointer{.x = 3.0f, .y = 3.0f}));
179+
180+
invokeLooperCallback();
181+
mConsumer->consumeBatchedInputEvents(initialTime.count() + 18'849'395 /*ns*/);
182+
183+
assertOnBatchedInputEventPendingWasCalled();
184+
// Three samples are expected. The first two of the batch, and the resampled one. The
185+
// coordinates of the resampled sample are hardcoded because the matcher requires them. However,
186+
// the primary intention here is to check that the last sample is resampled.
187+
assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_MOVE), WithSampleCount(3),
188+
WithSample(/*sampleIndex=*/2,
189+
Sample{initialTime + 13'849'395ns,
190+
{PointerArgs{.x = 1.3286f,
191+
.y = 1.3286f,
192+
.isResampled = true}}})));
193+
194+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 20'363'000ns, ACTION_MOVE,
195+
Pointer{.x = 4.0f, .y = 4.0f}));
196+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 25'745'000ns, ACTION_MOVE,
197+
Pointer{.x = 5.0f, .y = 5.0f}));
198+
// This sample is part of the stream of messages, but should not be consumed because its
199+
// timestamp is greater than the ajusted frame time.
200+
mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 31'337'000ns, ACTION_MOVE,
201+
Pointer{.x = 6.0f, .y = 6.0f}));
202+
203+
invokeLooperCallback();
204+
mConsumer->consumeBatchedInputEvents(initialTime.count() + 35'516'062 /*ns*/);
205+
206+
assertOnBatchedInputEventPendingWasCalled();
207+
// Four samples are expected because the last sample of the previous batch was not consumed.
208+
assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_MOVE), WithSampleCount(4)));
209+
210+
mClientTestChannel->assertFinishMessage(/*seq=*/1, /*handled=*/true);
211+
mClientTestChannel->assertFinishMessage(/*seq=*/2, /*handled=*/true);
212+
mClientTestChannel->assertFinishMessage(/*seq=*/3, /*handled=*/true);
213+
mClientTestChannel->assertFinishMessage(/*seq=*/4, /*handled=*/true);
214+
mClientTestChannel->assertFinishMessage(/*seq=*/5, /*handled=*/true);
215+
mClientTestChannel->assertFinishMessage(/*seq=*/6, /*handled=*/true);
216+
}
217+
218+
} // namespace android

libs/input/tests/OneEuroFilter_test.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ class OneEuroFilterTest : public ::testing::Test {
9898
std::vector<Sample> filteredSignal;
9999
for (const Sample& sample : signal) {
100100
filteredSignal.push_back(
101-
Sample{sample.timestamp, mFilter.filter(sample.timestamp, sample.value)});
101+
Sample{sample.timestamp,
102+
mFilter.filter(std::chrono::duration_cast<std::chrono::nanoseconds>(
103+
sample.timestamp),
104+
sample.value)});
102105
}
103106
return filteredSignal;
104107
}

libs/input/tests/TestEventMatchers.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include <chrono>
20+
#include <cmath>
2021
#include <ostream>
2122
#include <vector>
2223

@@ -156,14 +157,18 @@ class WithSampleMatcher {
156157
++pointerIndex) {
157158
const PointerCoords& pointerCoords =
158159
*(motionEvent.getHistoricalRawPointerCoords(pointerIndex, mSampleIndex));
159-
if ((pointerCoords.getX() != mSample.pointers[pointerIndex].x) ||
160-
(pointerCoords.getY() != mSample.pointers[pointerIndex].y)) {
160+
161+
if ((std::abs(pointerCoords.getX() - mSample.pointers[pointerIndex].x) >
162+
MotionEvent::ROUNDING_PRECISION) ||
163+
(std::abs(pointerCoords.getY() - mSample.pointers[pointerIndex].y) >
164+
MotionEvent::ROUNDING_PRECISION)) {
161165
*os << "sample coordinates mismatch at pointer index " << pointerIndex
162166
<< ". sample: (" << pointerCoords.getX() << ", " << pointerCoords.getY()
163167
<< ") expected: (" << mSample.pointers[pointerIndex].x << ", "
164168
<< mSample.pointers[pointerIndex].y << ")";
165169
return false;
166170
}
171+
167172
if (motionEvent.isResampled(pointerIndex, mSampleIndex) !=
168173
mSample.pointers[pointerIndex].isResampled) {
169174
*os << "resampling flag mismatch. sample: "

0 commit comments

Comments
 (0)