Skip to content

Commit 4dffc9b

Browse files
Eric Biggersherbertx
authored andcommitted
crypto: scatterwalk - Fix memcpy_sglist() to always succeed
The original implementation of memcpy_sglist() was broken because it didn't handle scatterlists that describe exactly the same memory, which is a case that many callers rely on. The current implementation is broken too because it calls the skcipher_walk functions which can fail. It ignores any errors from those functions. Fix it by replacing it with a new implementation written from scratch. It always succeeds. It's also a bit faster, since it avoids the overhead of skcipher_walk. skcipher_walk includes a lot of functionality (such as alignmask handling) that's irrelevant here. Reported-by: Colin Ian King <coking@nvidia.com> Closes: https://lore.kernel.org/r/20251114122620.111623-1-coking@nvidia.com Fixes: 131bdce ("crypto: scatterwalk - Add memcpy_sglist") Fixes: 0f8d42b ("crypto: scatterwalk - Move skcipher walk and use it for memcpy_sglist") Cc: stable@vger.kernel.org Signed-off-by: Eric Biggers <ebiggers@kernel.org> Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
1 parent 5727a84 commit 4dffc9b

2 files changed

Lines changed: 115 additions & 34 deletions

File tree

crypto/scatterwalk.c

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,26 +101,97 @@ void memcpy_to_sglist(struct scatterlist *sg, unsigned int start,
101101
}
102102
EXPORT_SYMBOL_GPL(memcpy_to_sglist);
103103

104+
/**
105+
* memcpy_sglist() - Copy data from one scatterlist to another
106+
* @dst: The destination scatterlist. Can be NULL if @nbytes == 0.
107+
* @src: The source scatterlist. Can be NULL if @nbytes == 0.
108+
* @nbytes: Number of bytes to copy
109+
*
110+
* The scatterlists can describe exactly the same memory, in which case this
111+
* function is a no-op. No other overlaps are supported.
112+
*
113+
* Context: Any context
114+
*/
104115
void memcpy_sglist(struct scatterlist *dst, struct scatterlist *src,
105116
unsigned int nbytes)
106117
{
107-
struct skcipher_walk walk = {};
118+
unsigned int src_offset, dst_offset;
108119

109-
if (unlikely(nbytes == 0)) /* in case sg == NULL */
120+
if (unlikely(nbytes == 0)) /* in case src and/or dst is NULL */
110121
return;
111122

112-
walk.total = nbytes;
113-
114-
scatterwalk_start(&walk.in, src);
115-
scatterwalk_start(&walk.out, dst);
123+
src_offset = src->offset;
124+
dst_offset = dst->offset;
125+
for (;;) {
126+
/* Compute the length to copy this step. */
127+
unsigned int len = min3(src->offset + src->length - src_offset,
128+
dst->offset + dst->length - dst_offset,
129+
nbytes);
130+
struct page *src_page = sg_page(src);
131+
struct page *dst_page = sg_page(dst);
132+
const void *src_virt;
133+
void *dst_virt;
134+
135+
if (IS_ENABLED(CONFIG_HIGHMEM)) {
136+
/* HIGHMEM: we may have to actually map the pages. */
137+
const unsigned int src_oip = offset_in_page(src_offset);
138+
const unsigned int dst_oip = offset_in_page(dst_offset);
139+
const unsigned int limit = PAGE_SIZE;
140+
141+
/* Further limit len to not cross a page boundary. */
142+
len = min3(len, limit - src_oip, limit - dst_oip);
143+
144+
/* Compute the source and destination pages. */
145+
src_page += src_offset / PAGE_SIZE;
146+
dst_page += dst_offset / PAGE_SIZE;
147+
148+
if (src_page != dst_page) {
149+
/* Copy between different pages. */
150+
memcpy_page(dst_page, dst_oip,
151+
src_page, src_oip, len);
152+
flush_dcache_page(dst_page);
153+
} else if (src_oip != dst_oip) {
154+
/* Copy between different parts of same page. */
155+
dst_virt = kmap_local_page(dst_page);
156+
memcpy(dst_virt + dst_oip, dst_virt + src_oip,
157+
len);
158+
kunmap_local(dst_virt);
159+
flush_dcache_page(dst_page);
160+
} /* Else, it's the same memory. No action needed. */
161+
} else {
162+
/*
163+
* !HIGHMEM: no mapping needed. Just work in the linear
164+
* buffer of each sg entry. Note that we can cross page
165+
* boundaries, as they are not significant in this case.
166+
*/
167+
src_virt = page_address(src_page) + src_offset;
168+
dst_virt = page_address(dst_page) + dst_offset;
169+
if (src_virt != dst_virt) {
170+
memcpy(dst_virt, src_virt, len);
171+
if (ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE)
172+
__scatterwalk_flush_dcache_pages(
173+
dst_page, dst_offset, len);
174+
} /* Else, it's the same memory. No action needed. */
175+
}
176+
nbytes -= len;
177+
if (nbytes == 0) /* No more to copy? */
178+
break;
116179

117-
skcipher_walk_first(&walk, true);
118-
do {
119-
if (walk.src.virt.addr != walk.dst.virt.addr)
120-
memcpy(walk.dst.virt.addr, walk.src.virt.addr,
121-
walk.nbytes);
122-
skcipher_walk_done(&walk, 0);
123-
} while (walk.nbytes);
180+
/*
181+
* There's more to copy. Advance the offsets by the length
182+
* copied this step, and advance the sg entries as needed.
183+
*/
184+
src_offset += len;
185+
if (src_offset >= src->offset + src->length) {
186+
src = sg_next(src);
187+
src_offset = src->offset;
188+
}
189+
dst_offset += len;
190+
if (dst_offset >= dst->offset + dst->length) {
191+
dst = sg_next(dst);
192+
dst_offset = dst->offset;
193+
}
194+
}
124195
}
125196
EXPORT_SYMBOL_GPL(memcpy_sglist);
126197

include/crypto/scatterwalk.h

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,34 @@ static inline void scatterwalk_done_src(struct scatter_walk *walk,
227227
scatterwalk_advance(walk, nbytes);
228228
}
229229

230+
/*
231+
* Flush the dcache of any pages that overlap the region
232+
* [offset, offset + nbytes) relative to base_page.
233+
*
234+
* This should be called only when ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE, to ensure
235+
* that all relevant code (including the call to sg_page() in the caller, if
236+
* applicable) gets fully optimized out when !ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE.
237+
*/
238+
static inline void __scatterwalk_flush_dcache_pages(struct page *base_page,
239+
unsigned int offset,
240+
unsigned int nbytes)
241+
{
242+
unsigned int num_pages;
243+
244+
base_page += offset / PAGE_SIZE;
245+
offset %= PAGE_SIZE;
246+
247+
/*
248+
* This is an overflow-safe version of
249+
* num_pages = DIV_ROUND_UP(offset + nbytes, PAGE_SIZE).
250+
*/
251+
num_pages = nbytes / PAGE_SIZE;
252+
num_pages += DIV_ROUND_UP(offset + (nbytes % PAGE_SIZE), PAGE_SIZE);
253+
254+
for (unsigned int i = 0; i < num_pages; i++)
255+
flush_dcache_page(base_page + i);
256+
}
257+
230258
/**
231259
* scatterwalk_done_dst() - Finish one step of a walk of destination scatterlist
232260
* @walk: the scatter_walk
@@ -240,27 +268,9 @@ static inline void scatterwalk_done_dst(struct scatter_walk *walk,
240268
unsigned int nbytes)
241269
{
242270
scatterwalk_unmap(walk);
243-
/*
244-
* Explicitly check ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE instead of just
245-
* relying on flush_dcache_page() being a no-op when not implemented,
246-
* since otherwise the BUG_ON in sg_page() does not get optimized out.
247-
* This also avoids having to consider whether the loop would get
248-
* reliably optimized out or not.
249-
*/
250-
if (ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE) {
251-
struct page *base_page;
252-
unsigned int offset;
253-
int start, end, i;
254-
255-
base_page = sg_page(walk->sg);
256-
offset = walk->offset;
257-
start = offset >> PAGE_SHIFT;
258-
end = start + (nbytes >> PAGE_SHIFT);
259-
end += (offset_in_page(offset) + offset_in_page(nbytes) +
260-
PAGE_SIZE - 1) >> PAGE_SHIFT;
261-
for (i = start; i < end; i++)
262-
flush_dcache_page(base_page + i);
263-
}
271+
if (ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE)
272+
__scatterwalk_flush_dcache_pages(sg_page(walk->sg),
273+
walk->offset, nbytes);
264274
scatterwalk_advance(walk, nbytes);
265275
}
266276

0 commit comments

Comments
 (0)