sock: convert sk_peek_offset functions to WRITE_ONCE
[cascardo/linux.git] / include / net / sock.h
index e91b87f..09aec75 100644 (file)
@@ -178,7 +178,7 @@ struct sock_common {
        int                     skc_bound_dev_if;
        union {
                struct hlist_node       skc_bind_node;
-               struct hlist_nulls_node skc_portaddr_node;
+               struct hlist_node       skc_portaddr_node;
        };
        struct proto            *skc_prot;
        possible_net_t          skc_net;
@@ -438,6 +438,7 @@ struct sock {
                                                  struct sk_buff *skb);
        void                    (*sk_destruct)(struct sock *sk);
        struct sock_reuseport __rcu     *sk_reuseport_cb;
+       struct rcu_head         sk_rcu;
 };
 
 #define __sk_user_data(sk) ((*((void __rcu **)&(sk)->sk_user_data)))
@@ -458,26 +459,28 @@ struct sock {
 
 static inline int sk_peek_offset(struct sock *sk, int flags)
 {
-       if ((flags & MSG_PEEK) && (sk->sk_peek_off >= 0))
-               return sk->sk_peek_off;
-       else
-               return 0;
+       if (unlikely(flags & MSG_PEEK)) {
+               s32 off = READ_ONCE(sk->sk_peek_off);
+               if (off >= 0)
+                       return off;
+       }
+
+       return 0;
 }
 
 static inline void sk_peek_offset_bwd(struct sock *sk, int val)
 {
-       if (sk->sk_peek_off >= 0) {
-               if (sk->sk_peek_off >= val)
-                       sk->sk_peek_off -= val;
-               else
-                       sk->sk_peek_off = 0;
+       s32 off = READ_ONCE(sk->sk_peek_off);
+
+       if (unlikely(off >= 0)) {
+               off = max_t(s32, off - val, 0);
+               WRITE_ONCE(sk->sk_peek_off, off);
        }
 }
 
 static inline void sk_peek_offset_fwd(struct sock *sk, int val)
 {
-       if (sk->sk_peek_off >= 0)
-               sk->sk_peek_off += val;
+       sk_peek_offset_bwd(sk, -val);
 }
 
 /*
@@ -669,18 +672,18 @@ static inline void sk_add_bind_node(struct sock *sk,
        hlist_for_each_entry(__sk, list, sk_bind_node)
 
 /**
- * sk_nulls_for_each_entry_offset - iterate over a list at a given struct offset
+ * sk_for_each_entry_offset_rcu - iterate over a list at a given struct offset
  * @tpos:      the type * to use as a loop cursor.
  * @pos:       the &struct hlist_node to use as a loop cursor.
  * @head:      the head for your list.
  * @offset:    offset of hlist_node within the struct.
  *
  */
-#define sk_nulls_for_each_entry_offset(tpos, pos, head, offset)                       \
-       for (pos = (head)->first;                                              \
-            (!is_a_nulls(pos)) &&                                             \
+#define sk_for_each_entry_offset_rcu(tpos, pos, head, offset)                 \
+       for (pos = rcu_dereference((head)->first);                             \
+            pos != NULL &&                                                    \
                ({ tpos = (typeof(*tpos) *)((void *)pos - offset); 1;});       \
-            pos = pos->next)
+            pos = rcu_dereference(pos->next))
 
 static inline struct user_namespace *sk_user_ns(struct sock *sk)
 {
@@ -720,6 +723,7 @@ enum sock_flags {
                     */
        SOCK_FILTER_LOCKED, /* Filter cannot be changed anymore */
        SOCK_SELECT_ERR_QUEUE, /* Wake select on error queue */
+       SOCK_RCU_FREE, /* wait rcu grace period in sk_destruct() */
 };
 
 #define SK_FLAGS_TIMESTAMP ((1UL << SOCK_TIMESTAMP) | (1UL << SOCK_TIMESTAMPING_RX_SOFTWARE))
@@ -2010,6 +2014,13 @@ sock_skb_set_dropcount(const struct sock *sk, struct sk_buff *skb)
        SOCK_SKB_CB(skb)->dropcount = atomic_read(&sk->sk_drops);
 }
 
+static inline void sk_drops_add(struct sock *sk, const struct sk_buff *skb)
+{
+       int segs = max_t(u16, 1, skb_shinfo(skb)->gso_segs);
+
+       atomic_add(segs, &sk->sk_drops);
+}
+
 void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk,
                           struct sk_buff *skb);
 void __sock_recv_wifi_status(struct msghdr *msg, struct sock *sk,