@@ -212,7 +212,7 @@ unsigned int get_smbd_max_read_write_size(struct ksmbd_transport *kt)
212212static int smb_direct_post_send_data (struct smbdirect_socket * sc ,
213213 struct smbdirect_send_batch * send_ctx ,
214214 struct iov_iter * iter ,
215- size_t * remaining_data_length );
215+ u32 remaining_data_length );
216216
217217static void smb_direct_send_immediate_work (struct work_struct * work )
218218{
@@ -222,7 +222,7 @@ static void smb_direct_send_immediate_work(struct work_struct *work)
222222 if (sc -> status != SMBDIRECT_SOCKET_CONNECTED )
223223 return ;
224224
225- smb_direct_post_send_data (sc , NULL , NULL , NULL );
225+ smb_direct_post_send_data (sc , NULL , NULL , 0 );
226226}
227227
228228static struct smb_direct_transport * alloc_transport (struct rdma_cm_id * cm_id )
@@ -805,23 +805,27 @@ static int post_sendmsg(struct smbdirect_socket *sc,
805805static int smb_direct_post_send_data (struct smbdirect_socket * sc ,
806806 struct smbdirect_send_batch * send_ctx ,
807807 struct iov_iter * iter ,
808- size_t * _remaining_data_length )
808+ u32 remaining_data_length )
809809{
810810 const struct smbdirect_socket_parameters * sp = & sc -> parameters ;
811811 int ret ;
812812 struct smbdirect_send_io * msg ;
813813 struct smbdirect_data_transfer * packet ;
814814 size_t header_length ;
815- u32 remaining_data_length = 0 ;
816815 u32 data_length = 0 ;
817816 struct smbdirect_send_batch _send_ctx ;
818817 u16 new_credits ;
819818
820819 if (iter ) {
821820 header_length = sizeof (struct smbdirect_data_transfer );
821+ if (WARN_ON_ONCE (remaining_data_length == 0 ||
822+ iov_iter_count (iter ) > remaining_data_length ))
823+ return - EINVAL ;
822824 } else {
823825 /* If this is a packet without payload, don't send padding */
824826 header_length = offsetof(struct smbdirect_data_transfer , padding );
827+ if (WARN_ON_ONCE (remaining_data_length ))
828+ return - EINVAL ;
825829 }
826830
827831 if (!send_ctx ) {
@@ -858,14 +862,6 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,
858862 new_credits = smbdirect_connection_grant_recv_credits (sc );
859863 }
860864
861- if (iter )
862- data_length = iov_iter_count (iter );
863-
864- if (_remaining_data_length ) {
865- * _remaining_data_length -= data_length ;
866- remaining_data_length = * _remaining_data_length ;
867- }
868-
869865 msg = smbdirect_connection_alloc_send_io (sc );
870866 if (IS_ERR (msg )) {
871867 ret = PTR_ERR (msg );
@@ -894,14 +890,14 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,
894890 .local_dma_lkey = sc -> ib .pd -> local_dma_lkey ,
895891 .direction = DMA_TO_DEVICE ,
896892 };
893+ size_t payload_len = umin (iov_iter_count (iter ),
894+ sp -> max_send_size - sizeof (* packet ));
897895
898- ret = smbdirect_map_sges_from_iter (iter , data_length , & extract );
896+ ret = smbdirect_map_sges_from_iter (iter , payload_len , & extract );
899897 if (ret < 0 )
900898 goto err ;
901- if (WARN_ON_ONCE (ret != data_length )) {
902- ret = - EIO ;
903- goto err ;
904- }
899+ data_length = ret ;
900+ remaining_data_length -= data_length ;
905901 msg -> num_sge = extract .num_sge ;
906902 }
907903
@@ -970,13 +966,9 @@ static int smb_direct_writev(struct ksmbd_transport *t,
970966 struct smb_direct_transport * st = SMBD_TRANS (t );
971967 struct smbdirect_socket * sc = & st -> socket ;
972968 struct smbdirect_socket_parameters * sp = & sc -> parameters ;
973- size_t remaining_data_length ;
974- size_t iov_idx ;
975- size_t iov_ofs ;
976- size_t max_iov_size = sp -> max_send_size -
977- sizeof (struct smbdirect_data_transfer );
978969 int ret ;
979970 struct smbdirect_send_batch send_ctx ;
971+ struct iov_iter iter ;
980972 int error = 0 ;
981973
982974 if (sc -> status != SMBDIRECT_SOCKET_CONNECTED )
@@ -985,112 +977,31 @@ static int smb_direct_writev(struct ksmbd_transport *t,
985977 //FIXME: skip RFC1002 header..
986978 if (WARN_ON_ONCE (niovs <= 1 || iov [0 ].iov_len != 4 ))
987979 return - EINVAL ;
988- buflen -= 4 ;
989- iov_idx = 1 ;
990- iov_ofs = 0 ;
991-
992- remaining_data_length = buflen ;
993- ksmbd_debug (RDMA , "Sending smb (RDMA): smb_len=%u\n" , buflen );
994-
995- smb_direct_send_ctx_init (& send_ctx , need_invalidate , remote_key );
996- while (remaining_data_length ) {
997- struct kvec vecs [SMBDIRECT_SEND_IO_MAX_SGE - 1 ]; /* minus smbdirect hdr */
998- size_t possible_bytes = max_iov_size ;
999- size_t possible_vecs ;
1000- size_t bytes = 0 ;
1001- size_t nvecs = 0 ;
1002- struct iov_iter iter ;
1003-
1004- /*
1005- * For the last message remaining_data_length should be
1006- * have been 0 already!
1007- */
1008- if (WARN_ON_ONCE (iov_idx >= niovs )) {
1009- error = - EINVAL ;
1010- goto done ;
1011- }
980+ iov_iter_kvec (& iter , ITER_SOURCE , iov , niovs , buflen );
981+ iov_iter_advance (& iter , 4 );
1012982
1013- /*
1014- * We have 2 factors which limit the arguments we pass
1015- * to smb_direct_post_send_data():
1016- *
1017- * 1. The number of supported sges for the send,
1018- * while one is reserved for the smbdirect header.
1019- * And we currently need one SGE per page.
1020- * 2. The number of negotiated payload bytes per send.
1021- */
1022- possible_vecs = min_t (size_t , ARRAY_SIZE (vecs ), niovs - iov_idx );
1023-
1024- while (iov_idx < niovs && possible_vecs && possible_bytes ) {
1025- struct kvec * v = & vecs [nvecs ];
1026- int page_count ;
1027-
1028- v -> iov_base = ((u8 * )iov [iov_idx ].iov_base ) + iov_ofs ;
1029- v -> iov_len = min_t (size_t ,
1030- iov [iov_idx ].iov_len - iov_ofs ,
1031- possible_bytes );
1032- page_count = smbdirect_get_buf_page_count (v -> iov_base , v -> iov_len );
1033- if (page_count > possible_vecs ) {
1034- /*
1035- * If the number of pages in the buffer
1036- * is to much (because we currently require
1037- * one SGE per page), we need to limit the
1038- * length.
1039- *
1040- * We know possible_vecs is at least 1,
1041- * so we always keep the first page.
1042- *
1043- * We need to calculate the number extra
1044- * pages (epages) we can also keep.
1045- *
1046- * We calculate the number of bytes in the
1047- * first page (fplen), this should never be
1048- * larger than v->iov_len because page_count is
1049- * at least 2, but adding a limitation feels
1050- * better.
1051- *
1052- * Then we calculate the number of bytes (elen)
1053- * we can keep for the extra pages.
1054- */
1055- size_t epages = possible_vecs - 1 ;
1056- size_t fpofs = offset_in_page (v -> iov_base );
1057- size_t fplen = min_t (size_t , PAGE_SIZE - fpofs , v -> iov_len );
1058- size_t elen = min_t (size_t , v -> iov_len - fplen , epages * PAGE_SIZE );
1059-
1060- v -> iov_len = fplen + elen ;
1061- page_count = smbdirect_get_buf_page_count (v -> iov_base , v -> iov_len );
1062- if (WARN_ON_ONCE (page_count > possible_vecs )) {
1063- /*
1064- * Something went wrong in the above
1065- * logic...
1066- */
1067- error = - EINVAL ;
1068- goto done ;
1069- }
1070- }
1071- possible_vecs -= page_count ;
1072- nvecs += 1 ;
1073- possible_bytes -= v -> iov_len ;
1074- bytes += v -> iov_len ;
1075-
1076- iov_ofs += v -> iov_len ;
1077- if (iov_ofs >= iov [iov_idx ].iov_len ) {
1078- iov_idx += 1 ;
1079- iov_ofs = 0 ;
1080- }
1081- }
983+ /*
984+ * The size must fit into the negotiated
985+ * fragmented send size.
986+ */
987+ if (iov_iter_count (& iter ) > sp -> max_fragmented_send_size )
988+ return - EMSGSIZE ;
1082989
1083- iov_iter_kvec (& iter , ITER_SOURCE , vecs , nvecs , bytes );
990+ ksmbd_debug (RDMA , "Sending smb (RDMA): smb_len=%zu\n" ,
991+ iov_iter_count (& iter ));
1084992
1085- ret = smb_direct_post_send_data (sc , & send_ctx ,
1086- & iter , & remaining_data_length );
993+ smb_direct_send_ctx_init (& send_ctx , need_invalidate , remote_key );
994+ while (iov_iter_count (& iter )) {
995+ ret = smb_direct_post_send_data (sc ,
996+ & send_ctx ,
997+ & iter ,
998+ iov_iter_count (& iter ));
1087999 if (unlikely (ret )) {
10881000 error = ret ;
1089- goto done ;
1001+ break ;
10901002 }
10911003 }
10921004
1093- done :
10941005 ret = smb_direct_flush_send_list (sc , & send_ctx , true);
10951006 if (unlikely (!ret && error ))
10961007 ret = error ;
0 commit comments