ipv6: fix a potential deadlock in do_ipv6_setsockopt()
[cascardo/linux.git] / net / rxrpc / output.c
index 0d47db8..5dab1ff 100644 (file)
 #include <net/af_rxrpc.h>
 #include "ar-internal.h"
 
-struct rxrpc_pkt_buffer {
+struct rxrpc_ack_buffer {
        struct rxrpc_wire_header whdr;
-       union {
-               struct {
-                       struct rxrpc_ackpacket ack;
-                       u8 acks[255];
-                       u8 pad[3];
-               };
-               __be32 abort_code;
-       };
+       struct rxrpc_ackpacket ack;
+       u8 acks[255];
+       u8 pad[3];
        struct rxrpc_ackinfo ackinfo;
 };
 
+struct rxrpc_abort_buffer {
+       struct rxrpc_wire_header whdr;
+       __be32 abort_code;
+};
+
 /*
  * Fill out an ACK packet.
  */
 static size_t rxrpc_fill_out_ack(struct rxrpc_call *call,
-                                struct rxrpc_pkt_buffer *pkt,
+                                struct rxrpc_ack_buffer *pkt,
                                 rxrpc_seq_t *_hard_ack,
-                                rxrpc_seq_t *_top)
+                                rxrpc_seq_t *_top,
+                                u8 reason)
 {
        rxrpc_serial_t serial;
        rxrpc_seq_t hard_ack, top, seq;
@@ -58,10 +59,10 @@ static size_t rxrpc_fill_out_ack(struct rxrpc_call *call,
        pkt->ack.firstPacket    = htonl(hard_ack + 1);
        pkt->ack.previousPacket = htonl(call->ackr_prev_seq);
        pkt->ack.serial         = htonl(serial);
-       pkt->ack.reason         = call->ackr_reason;
+       pkt->ack.reason         = reason;
        pkt->ack.nAcks          = top - hard_ack;
 
-       if (pkt->ack.reason == RXRPC_ACK_PING)
+       if (reason == RXRPC_ACK_PING)
                pkt->whdr.flags |= RXRPC_REQUEST_ACK;
 
        if (after(top, hard_ack)) {
@@ -91,22 +92,19 @@ static size_t rxrpc_fill_out_ack(struct rxrpc_call *call,
 }
 
 /*
- * Send an ACK or ABORT call packet.
+ * Send an ACK call packet.
  */
-int rxrpc_send_call_packet(struct rxrpc_call *call, u8 type)
+int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping)
 {
        struct rxrpc_connection *conn = NULL;
-       struct rxrpc_pkt_buffer *pkt;
+       struct rxrpc_ack_buffer *pkt;
        struct msghdr msg;
        struct kvec iov[2];
        rxrpc_serial_t serial;
        rxrpc_seq_t hard_ack, top;
        size_t len, n;
-       bool ping = false;
-       int ioc, ret;
-       u32 abort_code;
-
-       _enter("%u,%s", call->debug_id, rxrpc_pkts[type]);
+       int ret;
+       u8 reason;
 
        spin_lock_bh(&call->lock);
        if (call->conn)
@@ -131,68 +129,44 @@ int rxrpc_send_call_packet(struct rxrpc_call *call, u8 type)
        pkt->whdr.cid           = htonl(call->cid);
        pkt->whdr.callNumber    = htonl(call->call_id);
        pkt->whdr.seq           = 0;
-       pkt->whdr.type          = type;
-       pkt->whdr.flags         = conn->out_clientflag;
+       pkt->whdr.type          = RXRPC_PACKET_TYPE_ACK;
+       pkt->whdr.flags         = RXRPC_SLOW_START_OK | conn->out_clientflag;
        pkt->whdr.userStatus    = 0;
        pkt->whdr.securityIndex = call->security_ix;
        pkt->whdr._rsvd         = 0;
        pkt->whdr.serviceId     = htons(call->service_id);
 
-       iov[0].iov_base = pkt;
-       iov[0].iov_len  = sizeof(pkt->whdr);
-       len = sizeof(pkt->whdr);
-
-       switch (type) {
-       case RXRPC_PACKET_TYPE_ACK:
-               spin_lock_bh(&call->lock);
+       spin_lock_bh(&call->lock);
+       if (ping) {
+               reason = RXRPC_ACK_PING;
+       } else {
+               reason = call->ackr_reason;
                if (!call->ackr_reason) {
                        spin_unlock_bh(&call->lock);
                        ret = 0;
                        goto out;
                }
-               ping = (call->ackr_reason == RXRPC_ACK_PING);
-               n = rxrpc_fill_out_ack(call, pkt, &hard_ack, &top);
                call->ackr_reason = 0;
+       }
+       n = rxrpc_fill_out_ack(call, pkt, &hard_ack, &top, reason);
 
-               spin_unlock_bh(&call->lock);
-
-
-               pkt->whdr.flags |= RXRPC_SLOW_START_OK;
-
-               iov[0].iov_len += sizeof(pkt->ack) + n;
-               iov[1].iov_base = &pkt->ackinfo;
-               iov[1].iov_len  = sizeof(pkt->ackinfo);
-               len += sizeof(pkt->ack) + n + sizeof(pkt->ackinfo);
-               ioc = 2;
-               break;
-
-       case RXRPC_PACKET_TYPE_ABORT:
-               abort_code = call->abort_code;
-               pkt->abort_code = htonl(abort_code);
-               iov[0].iov_len += sizeof(pkt->abort_code);
-               len += sizeof(pkt->abort_code);
-               ioc = 1;
-               break;
+       spin_unlock_bh(&call->lock);
 
-       default:
-               BUG();
-               ret = -ENOANO;
-               goto out;
-       }
+       iov[0].iov_base = pkt;
+       iov[0].iov_len  = sizeof(pkt->whdr) + sizeof(pkt->ack) + n;
+       iov[1].iov_base = &pkt->ackinfo;
+       iov[1].iov_len  = sizeof(pkt->ackinfo);
+       len = iov[0].iov_len + iov[1].iov_len;
 
        serial = atomic_inc_return(&conn->serial);
        pkt->whdr.serial = htonl(serial);
-       switch (type) {
-       case RXRPC_PACKET_TYPE_ACK:
-               trace_rxrpc_tx_ack(call, serial,
-                                  ntohl(pkt->ack.firstPacket),
-                                  ntohl(pkt->ack.serial),
-                                  pkt->ack.reason, pkt->ack.nAcks);
-               break;
-       }
+       trace_rxrpc_tx_ack(call, serial,
+                          ntohl(pkt->ack.firstPacket),
+                          ntohl(pkt->ack.serial),
+                          pkt->ack.reason, pkt->ack.nAcks);
 
        if (ping) {
-               call->ackr_ping = serial;
+               call->ping_serial = serial;
                smp_wmb();
                /* We need to stick a time in before we send the packet in case
                 * the reply gets back before kernel_sendmsg() completes - but
@@ -201,19 +175,19 @@ int rxrpc_send_call_packet(struct rxrpc_call *call, u8 type)
                 * the packet transmission is more likely to happen towards the
                 * end of the kernel_sendmsg() call.
                 */
-               call->ackr_ping_time = ktime_get_real();
+               call->ping_time = ktime_get_real();
                set_bit(RXRPC_CALL_PINGING, &call->flags);
                trace_rxrpc_rtt_tx(call, rxrpc_rtt_tx_ping, serial);
        }
-       ret = kernel_sendmsg(conn->params.local->socket,
-                            &msg, iov, ioc, len);
+
+       ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len);
        if (ping)
-               call->ackr_ping_time = ktime_get_real();
+               call->ping_time = ktime_get_real();
 
-       if (type == RXRPC_PACKET_TYPE_ACK &&
-           call->state < RXRPC_CALL_COMPLETE) {
+       if (call->state < RXRPC_CALL_COMPLETE) {
                if (ret < 0) {
-                       clear_bit(RXRPC_CALL_PINGING, &call->flags);
+                       if (ping)
+                               clear_bit(RXRPC_CALL_PINGING, &call->flags);
                        rxrpc_propose_ACK(call, pkt->ack.reason,
                                          ntohs(pkt->ack.maxSkew),
                                          ntohl(pkt->ack.serial),
@@ -235,6 +209,56 @@ out:
        return ret;
 }
 
+/*
+ * Send an ABORT call packet.
+ */
+int rxrpc_send_abort_packet(struct rxrpc_call *call)
+{
+       struct rxrpc_connection *conn = NULL;
+       struct rxrpc_abort_buffer pkt;
+       struct msghdr msg;
+       struct kvec iov[1];
+       rxrpc_serial_t serial;
+       int ret;
+
+       spin_lock_bh(&call->lock);
+       if (call->conn)
+               conn = rxrpc_get_connection_maybe(call->conn);
+       spin_unlock_bh(&call->lock);
+       if (!conn)
+               return -ECONNRESET;
+
+       msg.msg_name    = &call->peer->srx.transport;
+       msg.msg_namelen = call->peer->srx.transport_len;
+       msg.msg_control = NULL;
+       msg.msg_controllen = 0;
+       msg.msg_flags   = 0;
+
+       pkt.whdr.epoch          = htonl(conn->proto.epoch);
+       pkt.whdr.cid            = htonl(call->cid);
+       pkt.whdr.callNumber     = htonl(call->call_id);
+       pkt.whdr.seq            = 0;
+       pkt.whdr.type           = RXRPC_PACKET_TYPE_ABORT;
+       pkt.whdr.flags          = conn->out_clientflag;
+       pkt.whdr.userStatus     = 0;
+       pkt.whdr.securityIndex  = call->security_ix;
+       pkt.whdr._rsvd          = 0;
+       pkt.whdr.serviceId      = htons(call->service_id);
+       pkt.abort_code          = htonl(call->abort_code);
+
+       iov[0].iov_base = &pkt;
+       iov[0].iov_len  = sizeof(pkt);
+
+       serial = atomic_inc_return(&conn->serial);
+       pkt.whdr.serial = htonl(serial);
+
+       ret = kernel_sendmsg(conn->params.local->socket,
+                            &msg, iov, 1, sizeof(pkt));
+
+       rxrpc_put_connection(conn);
+       return ret;
+}
+
 /*
  * send a packet through the transport endpoint
  */
@@ -283,11 +307,12 @@ int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb,
        /* If our RTT cache needs working on, request an ACK.  Also request
         * ACKs if a DATA packet appears to have been lost.
         */
-       if (retrans ||
-           call->cong_mode == RXRPC_CALL_SLOW_START ||
-           (call->peer->rtt_usage < 3 && sp->hdr.seq & 1) ||
-           ktime_before(ktime_add_ms(call->peer->rtt_last_req, 1000),
-                        ktime_get_real()))
+       if (!(sp->hdr.flags & RXRPC_LAST_PACKET) &&
+           (retrans ||
+            call->cong_mode == RXRPC_CALL_SLOW_START ||
+            (call->peer->rtt_usage < 3 && sp->hdr.seq & 1) ||
+            ktime_before(ktime_add_ms(call->peer->rtt_last_req, 1000),
+                         ktime_get_real())))
                whdr.flags |= RXRPC_REQUEST_ACK;
 
        if (IS_ENABLED(CONFIG_AF_RXRPC_INJECT_LOSS)) {