Merge tag 'for-linus' of git://git.kernel.org/pub/scm/virt/kvm/kvm
[cascardo/linux.git] / net / vmw_vsock / virtio_transport_common.c
1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio.h>
15 #include <linux/virtio_ids.h>
16 #include <linux/virtio_config.h>
17 #include <linux/virtio_vsock.h>
18
19 #include <net/sock.h>
20 #include <net/af_vsock.h>
21
22 #define CREATE_TRACE_POINTS
23 #include <trace/events/vsock_virtio_transport_common.h>
24
25 /* How long to wait for graceful shutdown of a connection */
26 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
27
28 static const struct virtio_transport *virtio_transport_get_ops(void)
29 {
30         const struct vsock_transport *t = vsock_core_get_transport();
31
32         return container_of(t, struct virtio_transport, transport);
33 }
34
35 struct virtio_vsock_pkt *
36 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
37                            size_t len,
38                            u32 src_cid,
39                            u32 src_port,
40                            u32 dst_cid,
41                            u32 dst_port)
42 {
43         struct virtio_vsock_pkt *pkt;
44         int err;
45
46         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
47         if (!pkt)
48                 return NULL;
49
50         pkt->hdr.type           = cpu_to_le16(info->type);
51         pkt->hdr.op             = cpu_to_le16(info->op);
52         pkt->hdr.src_cid        = cpu_to_le64(src_cid);
53         pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
54         pkt->hdr.src_port       = cpu_to_le32(src_port);
55         pkt->hdr.dst_port       = cpu_to_le32(dst_port);
56         pkt->hdr.flags          = cpu_to_le32(info->flags);
57         pkt->len                = len;
58         pkt->hdr.len            = cpu_to_le32(len);
59         pkt->reply              = info->reply;
60
61         if (info->msg && len > 0) {
62                 pkt->buf = kmalloc(len, GFP_KERNEL);
63                 if (!pkt->buf)
64                         goto out_pkt;
65                 err = memcpy_from_msg(pkt->buf, info->msg, len);
66                 if (err)
67                         goto out;
68         }
69
70         trace_virtio_transport_alloc_pkt(src_cid, src_port,
71                                          dst_cid, dst_port,
72                                          len,
73                                          info->type,
74                                          info->op,
75                                          info->flags);
76
77         return pkt;
78
79 out:
80         kfree(pkt->buf);
81 out_pkt:
82         kfree(pkt);
83         return NULL;
84 }
85 EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
86
87 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
88                                           struct virtio_vsock_pkt_info *info)
89 {
90         u32 src_cid, src_port, dst_cid, dst_port;
91         struct virtio_vsock_sock *vvs;
92         struct virtio_vsock_pkt *pkt;
93         u32 pkt_len = info->pkt_len;
94
95         src_cid = vm_sockets_get_local_cid();
96         src_port = vsk->local_addr.svm_port;
97         if (!info->remote_cid) {
98                 dst_cid = vsk->remote_addr.svm_cid;
99                 dst_port = vsk->remote_addr.svm_port;
100         } else {
101                 dst_cid = info->remote_cid;
102                 dst_port = info->remote_port;
103         }
104
105         vvs = vsk->trans;
106
107         /* we can send less than pkt_len bytes */
108         if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
109                 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
110
111         /* virtio_transport_get_credit might return less than pkt_len credit */
112         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
113
114         /* Do not send zero length OP_RW pkt */
115         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
116                 return pkt_len;
117
118         pkt = virtio_transport_alloc_pkt(info, pkt_len,
119                                          src_cid, src_port,
120                                          dst_cid, dst_port);
121         if (!pkt) {
122                 virtio_transport_put_credit(vvs, pkt_len);
123                 return -ENOMEM;
124         }
125
126         virtio_transport_inc_tx_pkt(vvs, pkt);
127
128         return virtio_transport_get_ops()->send_pkt(pkt);
129 }
130
131 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
132                                         struct virtio_vsock_pkt *pkt)
133 {
134         vvs->rx_bytes += pkt->len;
135 }
136
137 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
138                                         struct virtio_vsock_pkt *pkt)
139 {
140         vvs->rx_bytes -= pkt->len;
141         vvs->fwd_cnt += pkt->len;
142 }
143
144 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
145 {
146         spin_lock_bh(&vvs->tx_lock);
147         pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
148         pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
149         spin_unlock_bh(&vvs->tx_lock);
150 }
151 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
152
153 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
154 {
155         u32 ret;
156
157         spin_lock_bh(&vvs->tx_lock);
158         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
159         if (ret > credit)
160                 ret = credit;
161         vvs->tx_cnt += ret;
162         spin_unlock_bh(&vvs->tx_lock);
163
164         return ret;
165 }
166 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
167
168 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
169 {
170         spin_lock_bh(&vvs->tx_lock);
171         vvs->tx_cnt -= credit;
172         spin_unlock_bh(&vvs->tx_lock);
173 }
174 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
175
176 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
177                                                int type,
178                                                struct virtio_vsock_hdr *hdr)
179 {
180         struct virtio_vsock_pkt_info info = {
181                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
182                 .type = type,
183         };
184
185         return virtio_transport_send_pkt_info(vsk, &info);
186 }
187
188 static ssize_t
189 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
190                                    struct msghdr *msg,
191                                    size_t len)
192 {
193         struct virtio_vsock_sock *vvs = vsk->trans;
194         struct virtio_vsock_pkt *pkt;
195         size_t bytes, total = 0;
196         int err = -EFAULT;
197
198         spin_lock_bh(&vvs->rx_lock);
199         while (total < len && !list_empty(&vvs->rx_queue)) {
200                 pkt = list_first_entry(&vvs->rx_queue,
201                                        struct virtio_vsock_pkt, list);
202
203                 bytes = len - total;
204                 if (bytes > pkt->len - pkt->off)
205                         bytes = pkt->len - pkt->off;
206
207                 /* sk_lock is held by caller so no one else can dequeue.
208                  * Unlock rx_lock since memcpy_to_msg() may sleep.
209                  */
210                 spin_unlock_bh(&vvs->rx_lock);
211
212                 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
213                 if (err)
214                         goto out;
215
216                 spin_lock_bh(&vvs->rx_lock);
217
218                 total += bytes;
219                 pkt->off += bytes;
220                 if (pkt->off == pkt->len) {
221                         virtio_transport_dec_rx_pkt(vvs, pkt);
222                         list_del(&pkt->list);
223                         virtio_transport_free_pkt(pkt);
224                 }
225         }
226         spin_unlock_bh(&vvs->rx_lock);
227
228         /* Send a credit pkt to peer */
229         virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
230                                             NULL);
231
232         return total;
233
234 out:
235         if (total)
236                 err = total;
237         return err;
238 }
239
240 ssize_t
241 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
242                                 struct msghdr *msg,
243                                 size_t len, int flags)
244 {
245         if (flags & MSG_PEEK)
246                 return -EOPNOTSUPP;
247
248         return virtio_transport_stream_do_dequeue(vsk, msg, len);
249 }
250 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
251
252 int
253 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
254                                struct msghdr *msg,
255                                size_t len, int flags)
256 {
257         return -EOPNOTSUPP;
258 }
259 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
260
261 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
262 {
263         struct virtio_vsock_sock *vvs = vsk->trans;
264         s64 bytes;
265
266         spin_lock_bh(&vvs->rx_lock);
267         bytes = vvs->rx_bytes;
268         spin_unlock_bh(&vvs->rx_lock);
269
270         return bytes;
271 }
272 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
273
274 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
275 {
276         struct virtio_vsock_sock *vvs = vsk->trans;
277         s64 bytes;
278
279         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
280         if (bytes < 0)
281                 bytes = 0;
282
283         return bytes;
284 }
285
286 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
287 {
288         struct virtio_vsock_sock *vvs = vsk->trans;
289         s64 bytes;
290
291         spin_lock_bh(&vvs->tx_lock);
292         bytes = virtio_transport_has_space(vsk);
293         spin_unlock_bh(&vvs->tx_lock);
294
295         return bytes;
296 }
297 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
298
299 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
300                                     struct vsock_sock *psk)
301 {
302         struct virtio_vsock_sock *vvs;
303
304         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
305         if (!vvs)
306                 return -ENOMEM;
307
308         vsk->trans = vvs;
309         vvs->vsk = vsk;
310         if (psk) {
311                 struct virtio_vsock_sock *ptrans = psk->trans;
312
313                 vvs->buf_size   = ptrans->buf_size;
314                 vvs->buf_size_min = ptrans->buf_size_min;
315                 vvs->buf_size_max = ptrans->buf_size_max;
316                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
317         } else {
318                 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
319                 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
320                 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
321         }
322
323         vvs->buf_alloc = vvs->buf_size;
324
325         spin_lock_init(&vvs->rx_lock);
326         spin_lock_init(&vvs->tx_lock);
327         INIT_LIST_HEAD(&vvs->rx_queue);
328
329         return 0;
330 }
331 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
332
333 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
334 {
335         struct virtio_vsock_sock *vvs = vsk->trans;
336
337         return vvs->buf_size;
338 }
339 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
340
341 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
342 {
343         struct virtio_vsock_sock *vvs = vsk->trans;
344
345         return vvs->buf_size_min;
346 }
347 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
348
349 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
350 {
351         struct virtio_vsock_sock *vvs = vsk->trans;
352
353         return vvs->buf_size_max;
354 }
355 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
356
357 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
358 {
359         struct virtio_vsock_sock *vvs = vsk->trans;
360
361         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
362                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
363         if (val < vvs->buf_size_min)
364                 vvs->buf_size_min = val;
365         if (val > vvs->buf_size_max)
366                 vvs->buf_size_max = val;
367         vvs->buf_size = val;
368         vvs->buf_alloc = val;
369 }
370 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
371
372 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
373 {
374         struct virtio_vsock_sock *vvs = vsk->trans;
375
376         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
377                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
378         if (val > vvs->buf_size)
379                 vvs->buf_size = val;
380         vvs->buf_size_min = val;
381 }
382 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
383
384 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
385 {
386         struct virtio_vsock_sock *vvs = vsk->trans;
387
388         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
389                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
390         if (val < vvs->buf_size)
391                 vvs->buf_size = val;
392         vvs->buf_size_max = val;
393 }
394 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
395
396 int
397 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
398                                 size_t target,
399                                 bool *data_ready_now)
400 {
401         if (vsock_stream_has_data(vsk))
402                 *data_ready_now = true;
403         else
404                 *data_ready_now = false;
405
406         return 0;
407 }
408 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
409
410 int
411 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
412                                  size_t target,
413                                  bool *space_avail_now)
414 {
415         s64 free_space;
416
417         free_space = vsock_stream_has_space(vsk);
418         if (free_space > 0)
419                 *space_avail_now = true;
420         else if (free_space == 0)
421                 *space_avail_now = false;
422
423         return 0;
424 }
425 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
426
427 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
428         size_t target, struct vsock_transport_recv_notify_data *data)
429 {
430         return 0;
431 }
432 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
433
434 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
435         size_t target, struct vsock_transport_recv_notify_data *data)
436 {
437         return 0;
438 }
439 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
440
441 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
442         size_t target, struct vsock_transport_recv_notify_data *data)
443 {
444         return 0;
445 }
446 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
447
448 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
449         size_t target, ssize_t copied, bool data_read,
450         struct vsock_transport_recv_notify_data *data)
451 {
452         return 0;
453 }
454 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
455
456 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
457         struct vsock_transport_send_notify_data *data)
458 {
459         return 0;
460 }
461 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
462
463 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
464         struct vsock_transport_send_notify_data *data)
465 {
466         return 0;
467 }
468 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
469
470 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
471         struct vsock_transport_send_notify_data *data)
472 {
473         return 0;
474 }
475 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
476
477 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
478         ssize_t written, struct vsock_transport_send_notify_data *data)
479 {
480         return 0;
481 }
482 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
483
484 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
485 {
486         struct virtio_vsock_sock *vvs = vsk->trans;
487
488         return vvs->buf_size;
489 }
490 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
491
492 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
493 {
494         return true;
495 }
496 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
497
498 bool virtio_transport_stream_allow(u32 cid, u32 port)
499 {
500         return true;
501 }
502 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
503
504 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
505                                 struct sockaddr_vm *addr)
506 {
507         return -EOPNOTSUPP;
508 }
509 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
510
511 bool virtio_transport_dgram_allow(u32 cid, u32 port)
512 {
513         return false;
514 }
515 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
516
517 int virtio_transport_connect(struct vsock_sock *vsk)
518 {
519         struct virtio_vsock_pkt_info info = {
520                 .op = VIRTIO_VSOCK_OP_REQUEST,
521                 .type = VIRTIO_VSOCK_TYPE_STREAM,
522         };
523
524         return virtio_transport_send_pkt_info(vsk, &info);
525 }
526 EXPORT_SYMBOL_GPL(virtio_transport_connect);
527
528 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
529 {
530         struct virtio_vsock_pkt_info info = {
531                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
532                 .type = VIRTIO_VSOCK_TYPE_STREAM,
533                 .flags = (mode & RCV_SHUTDOWN ?
534                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
535                          (mode & SEND_SHUTDOWN ?
536                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
537         };
538
539         return virtio_transport_send_pkt_info(vsk, &info);
540 }
541 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
542
543 int
544 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
545                                struct sockaddr_vm *remote_addr,
546                                struct msghdr *msg,
547                                size_t dgram_len)
548 {
549         return -EOPNOTSUPP;
550 }
551 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
552
553 ssize_t
554 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
555                                 struct msghdr *msg,
556                                 size_t len)
557 {
558         struct virtio_vsock_pkt_info info = {
559                 .op = VIRTIO_VSOCK_OP_RW,
560                 .type = VIRTIO_VSOCK_TYPE_STREAM,
561                 .msg = msg,
562                 .pkt_len = len,
563         };
564
565         return virtio_transport_send_pkt_info(vsk, &info);
566 }
567 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
568
569 void virtio_transport_destruct(struct vsock_sock *vsk)
570 {
571         struct virtio_vsock_sock *vvs = vsk->trans;
572
573         kfree(vvs);
574 }
575 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
576
577 static int virtio_transport_reset(struct vsock_sock *vsk,
578                                   struct virtio_vsock_pkt *pkt)
579 {
580         struct virtio_vsock_pkt_info info = {
581                 .op = VIRTIO_VSOCK_OP_RST,
582                 .type = VIRTIO_VSOCK_TYPE_STREAM,
583                 .reply = !!pkt,
584         };
585
586         /* Send RST only if the original pkt is not a RST pkt */
587         if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
588                 return 0;
589
590         return virtio_transport_send_pkt_info(vsk, &info);
591 }
592
593 /* Normally packets are associated with a socket.  There may be no socket if an
594  * attempt was made to connect to a socket that does not exist.
595  */
596 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
597 {
598         struct virtio_vsock_pkt_info info = {
599                 .op = VIRTIO_VSOCK_OP_RST,
600                 .type = le16_to_cpu(pkt->hdr.type),
601                 .reply = true,
602         };
603
604         /* Send RST only if the original pkt is not a RST pkt */
605         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
606                 return 0;
607
608         pkt = virtio_transport_alloc_pkt(&info, 0,
609                                          le32_to_cpu(pkt->hdr.dst_cid),
610                                          le32_to_cpu(pkt->hdr.dst_port),
611                                          le32_to_cpu(pkt->hdr.src_cid),
612                                          le32_to_cpu(pkt->hdr.src_port));
613         if (!pkt)
614                 return -ENOMEM;
615
616         return virtio_transport_get_ops()->send_pkt(pkt);
617 }
618
619 static void virtio_transport_wait_close(struct sock *sk, long timeout)
620 {
621         if (timeout) {
622                 DEFINE_WAIT(wait);
623
624                 do {
625                         prepare_to_wait(sk_sleep(sk), &wait,
626                                         TASK_INTERRUPTIBLE);
627                         if (sk_wait_event(sk, &timeout,
628                                           sock_flag(sk, SOCK_DONE)))
629                                 break;
630                 } while (!signal_pending(current) && timeout);
631
632                 finish_wait(sk_sleep(sk), &wait);
633         }
634 }
635
636 static void virtio_transport_do_close(struct vsock_sock *vsk,
637                                       bool cancel_timeout)
638 {
639         struct sock *sk = sk_vsock(vsk);
640
641         sock_set_flag(sk, SOCK_DONE);
642         vsk->peer_shutdown = SHUTDOWN_MASK;
643         if (vsock_stream_has_data(vsk) <= 0)
644                 sk->sk_state = SS_DISCONNECTING;
645         sk->sk_state_change(sk);
646
647         if (vsk->close_work_scheduled &&
648             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
649                 vsk->close_work_scheduled = false;
650
651                 vsock_remove_sock(vsk);
652
653                 /* Release refcnt obtained when we scheduled the timeout */
654                 sock_put(sk);
655         }
656 }
657
658 static void virtio_transport_close_timeout(struct work_struct *work)
659 {
660         struct vsock_sock *vsk =
661                 container_of(work, struct vsock_sock, close_work.work);
662         struct sock *sk = sk_vsock(vsk);
663
664         sock_hold(sk);
665         lock_sock(sk);
666
667         if (!sock_flag(sk, SOCK_DONE)) {
668                 (void)virtio_transport_reset(vsk, NULL);
669
670                 virtio_transport_do_close(vsk, false);
671         }
672
673         vsk->close_work_scheduled = false;
674
675         release_sock(sk);
676         sock_put(sk);
677 }
678
679 /* User context, vsk->sk is locked */
680 static bool virtio_transport_close(struct vsock_sock *vsk)
681 {
682         struct sock *sk = &vsk->sk;
683
684         if (!(sk->sk_state == SS_CONNECTED ||
685               sk->sk_state == SS_DISCONNECTING))
686                 return true;
687
688         /* Already received SHUTDOWN from peer, reply with RST */
689         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
690                 (void)virtio_transport_reset(vsk, NULL);
691                 return true;
692         }
693
694         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
695                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
696
697         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
698                 virtio_transport_wait_close(sk, sk->sk_lingertime);
699
700         if (sock_flag(sk, SOCK_DONE)) {
701                 return true;
702         }
703
704         sock_hold(sk);
705         INIT_DELAYED_WORK(&vsk->close_work,
706                           virtio_transport_close_timeout);
707         vsk->close_work_scheduled = true;
708         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
709         return false;
710 }
711
712 void virtio_transport_release(struct vsock_sock *vsk)
713 {
714         struct sock *sk = &vsk->sk;
715         bool remove_sock = true;
716
717         lock_sock(sk);
718         if (sk->sk_type == SOCK_STREAM)
719                 remove_sock = virtio_transport_close(vsk);
720         release_sock(sk);
721
722         if (remove_sock)
723                 vsock_remove_sock(vsk);
724 }
725 EXPORT_SYMBOL_GPL(virtio_transport_release);
726
727 static int
728 virtio_transport_recv_connecting(struct sock *sk,
729                                  struct virtio_vsock_pkt *pkt)
730 {
731         struct vsock_sock *vsk = vsock_sk(sk);
732         int err;
733         int skerr;
734
735         switch (le16_to_cpu(pkt->hdr.op)) {
736         case VIRTIO_VSOCK_OP_RESPONSE:
737                 sk->sk_state = SS_CONNECTED;
738                 sk->sk_socket->state = SS_CONNECTED;
739                 vsock_insert_connected(vsk);
740                 sk->sk_state_change(sk);
741                 break;
742         case VIRTIO_VSOCK_OP_INVALID:
743                 break;
744         case VIRTIO_VSOCK_OP_RST:
745                 skerr = ECONNRESET;
746                 err = 0;
747                 goto destroy;
748         default:
749                 skerr = EPROTO;
750                 err = -EINVAL;
751                 goto destroy;
752         }
753         return 0;
754
755 destroy:
756         virtio_transport_reset(vsk, pkt);
757         sk->sk_state = SS_UNCONNECTED;
758         sk->sk_err = skerr;
759         sk->sk_error_report(sk);
760         return err;
761 }
762
763 static int
764 virtio_transport_recv_connected(struct sock *sk,
765                                 struct virtio_vsock_pkt *pkt)
766 {
767         struct vsock_sock *vsk = vsock_sk(sk);
768         struct virtio_vsock_sock *vvs = vsk->trans;
769         int err = 0;
770
771         switch (le16_to_cpu(pkt->hdr.op)) {
772         case VIRTIO_VSOCK_OP_RW:
773                 pkt->len = le32_to_cpu(pkt->hdr.len);
774                 pkt->off = 0;
775
776                 spin_lock_bh(&vvs->rx_lock);
777                 virtio_transport_inc_rx_pkt(vvs, pkt);
778                 list_add_tail(&pkt->list, &vvs->rx_queue);
779                 spin_unlock_bh(&vvs->rx_lock);
780
781                 sk->sk_data_ready(sk);
782                 return err;
783         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
784                 sk->sk_write_space(sk);
785                 break;
786         case VIRTIO_VSOCK_OP_SHUTDOWN:
787                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
788                         vsk->peer_shutdown |= RCV_SHUTDOWN;
789                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
790                         vsk->peer_shutdown |= SEND_SHUTDOWN;
791                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
792                     vsock_stream_has_data(vsk) <= 0)
793                         sk->sk_state = SS_DISCONNECTING;
794                 if (le32_to_cpu(pkt->hdr.flags))
795                         sk->sk_state_change(sk);
796                 break;
797         case VIRTIO_VSOCK_OP_RST:
798                 virtio_transport_do_close(vsk, true);
799                 break;
800         default:
801                 err = -EINVAL;
802                 break;
803         }
804
805         virtio_transport_free_pkt(pkt);
806         return err;
807 }
808
809 static void
810 virtio_transport_recv_disconnecting(struct sock *sk,
811                                     struct virtio_vsock_pkt *pkt)
812 {
813         struct vsock_sock *vsk = vsock_sk(sk);
814
815         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
816                 virtio_transport_do_close(vsk, true);
817 }
818
819 static int
820 virtio_transport_send_response(struct vsock_sock *vsk,
821                                struct virtio_vsock_pkt *pkt)
822 {
823         struct virtio_vsock_pkt_info info = {
824                 .op = VIRTIO_VSOCK_OP_RESPONSE,
825                 .type = VIRTIO_VSOCK_TYPE_STREAM,
826                 .remote_cid = le32_to_cpu(pkt->hdr.src_cid),
827                 .remote_port = le32_to_cpu(pkt->hdr.src_port),
828                 .reply = true,
829         };
830
831         return virtio_transport_send_pkt_info(vsk, &info);
832 }
833
834 /* Handle server socket */
835 static int
836 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
837 {
838         struct vsock_sock *vsk = vsock_sk(sk);
839         struct vsock_sock *vchild;
840         struct sock *child;
841
842         if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
843                 virtio_transport_reset(vsk, pkt);
844                 return -EINVAL;
845         }
846
847         if (sk_acceptq_is_full(sk)) {
848                 virtio_transport_reset(vsk, pkt);
849                 return -ENOMEM;
850         }
851
852         child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
853                                sk->sk_type, 0);
854         if (!child) {
855                 virtio_transport_reset(vsk, pkt);
856                 return -ENOMEM;
857         }
858
859         sk->sk_ack_backlog++;
860
861         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
862
863         child->sk_state = SS_CONNECTED;
864
865         vchild = vsock_sk(child);
866         vsock_addr_init(&vchild->local_addr, le32_to_cpu(pkt->hdr.dst_cid),
867                         le32_to_cpu(pkt->hdr.dst_port));
868         vsock_addr_init(&vchild->remote_addr, le32_to_cpu(pkt->hdr.src_cid),
869                         le32_to_cpu(pkt->hdr.src_port));
870
871         vsock_insert_connected(vchild);
872         vsock_enqueue_accept(sk, child);
873         virtio_transport_send_response(vchild, pkt);
874
875         release_sock(child);
876
877         sk->sk_data_ready(sk);
878         return 0;
879 }
880
881 static bool virtio_transport_space_update(struct sock *sk,
882                                           struct virtio_vsock_pkt *pkt)
883 {
884         struct vsock_sock *vsk = vsock_sk(sk);
885         struct virtio_vsock_sock *vvs = vsk->trans;
886         bool space_available;
887
888         /* buf_alloc and fwd_cnt is always included in the hdr */
889         spin_lock_bh(&vvs->tx_lock);
890         vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
891         vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
892         space_available = virtio_transport_has_space(vsk);
893         spin_unlock_bh(&vvs->tx_lock);
894         return space_available;
895 }
896
897 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
898  * lock.
899  */
900 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
901 {
902         struct sockaddr_vm src, dst;
903         struct vsock_sock *vsk;
904         struct sock *sk;
905         bool space_available;
906
907         vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid),
908                         le32_to_cpu(pkt->hdr.src_port));
909         vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid),
910                         le32_to_cpu(pkt->hdr.dst_port));
911
912         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
913                                         dst.svm_cid, dst.svm_port,
914                                         le32_to_cpu(pkt->hdr.len),
915                                         le16_to_cpu(pkt->hdr.type),
916                                         le16_to_cpu(pkt->hdr.op),
917                                         le32_to_cpu(pkt->hdr.flags),
918                                         le32_to_cpu(pkt->hdr.buf_alloc),
919                                         le32_to_cpu(pkt->hdr.fwd_cnt));
920
921         if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
922                 (void)virtio_transport_reset_no_sock(pkt);
923                 goto free_pkt;
924         }
925
926         /* The socket must be in connected or bound table
927          * otherwise send reset back
928          */
929         sk = vsock_find_connected_socket(&src, &dst);
930         if (!sk) {
931                 sk = vsock_find_bound_socket(&dst);
932                 if (!sk) {
933                         (void)virtio_transport_reset_no_sock(pkt);
934                         goto free_pkt;
935                 }
936         }
937
938         vsk = vsock_sk(sk);
939
940         space_available = virtio_transport_space_update(sk, pkt);
941
942         lock_sock(sk);
943
944         /* Update CID in case it has changed after a transport reset event */
945         vsk->local_addr.svm_cid = dst.svm_cid;
946
947         if (space_available)
948                 sk->sk_write_space(sk);
949
950         switch (sk->sk_state) {
951         case VSOCK_SS_LISTEN:
952                 virtio_transport_recv_listen(sk, pkt);
953                 virtio_transport_free_pkt(pkt);
954                 break;
955         case SS_CONNECTING:
956                 virtio_transport_recv_connecting(sk, pkt);
957                 virtio_transport_free_pkt(pkt);
958                 break;
959         case SS_CONNECTED:
960                 virtio_transport_recv_connected(sk, pkt);
961                 break;
962         case SS_DISCONNECTING:
963                 virtio_transport_recv_disconnecting(sk, pkt);
964                 virtio_transport_free_pkt(pkt);
965                 break;
966         default:
967                 virtio_transport_free_pkt(pkt);
968                 break;
969         }
970         release_sock(sk);
971
972         /* Release refcnt obtained when we fetched this socket out of the
973          * bound or connected list.
974          */
975         sock_put(sk);
976         return;
977
978 free_pkt:
979         virtio_transport_free_pkt(pkt);
980 }
981 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
982
983 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
984 {
985         kfree(pkt->buf);
986         kfree(pkt);
987 }
988 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
989
990 MODULE_LICENSE("GPL v2");
991 MODULE_AUTHOR("Asias He");
992 MODULE_DESCRIPTION("common code for virtio vsock");