Mini Shell
'''
Base netlink socket and marshal
===============================
All the netlink providers are derived from the socket
class, so they provide normal socket API, including
`getsockopt()`, `setsockopt()`, they can be used in
poll/select I/O loops etc.
asynchronous I/O
----------------
To run async reader thread, one should call
`NetlinkSocket.bind(async_cache=True)`. In that case
a background thread will be launched. The thread will
automatically collect all the messages and store
into a userspace buffer.
.. note::
There is no need to turn on async I/O, if you
don't plan to receive broadcast messages.
ENOBUF and async I/O
--------------------
When Netlink messages arrive faster than a program
reads then from the socket, the messages overflow
the socket buffer and one gets ENOBUF on `recv()`::
... self.recv(bufsize)
error: [Errno 105] No buffer space available
One way to avoid ENOBUF, is to use async I/O. Then the
library not only reads and buffers all the messages, but
also re-prioritizes threads. Suppressing the parser
activity, the library increases the response delay, but
spares CPU to read and enqueue arriving messages as
fast, as it is possible.
With logging level DEBUG you can notice messages, that
the library started to calm down the parser thread::
DEBUG:root:Packet burst: the reader thread priority
is increased, beware of delays on netlink calls
Counters: delta=25 qsize=25 delay=0.1
This state requires no immediate action, but just some
more attention. When the delay between messages on the
parser thread exceeds 1 second, DEBUG messages become
WARNING ones::
WARNING:root:Packet burst: the reader thread priority
is increased, beware of delays on netlink calls
Counters: delta=2525 qsize=213536 delay=3
This state means, that almost all the CPU resources are
dedicated to the reader thread. It doesn't mean, that
the reader thread consumes 100% CPU -- it means, that the
CPU is reserved for the case of more intensive bursts. The
library will return to the normal state only when the
broadcast storm will be over, and then the CPU will be
100% loaded with the parser for some time, when it will
process all the messages queued so far.
when async I/O doesn't help
---------------------------
Sometimes, even turning async I/O doesn't fix ENOBUF.
Mostly it means, that in this particular case the Python
performance is not enough even to read and store the raw
data from the socket. There is no workaround for such
cases, except of using something *not* Python-based.
One can still play around with SO_RCVBUF socket option,
but it doesn't help much. So keep it in mind, and if you
expect massive broadcast Netlink storms, perform stress
testing prior to deploy a solution in the production.
classes
-------
'''
import collections
import errno
import logging
import os
import random
import select
import struct
import threading
import time
import traceback
import warnings
from functools import partial
from socket import (
MSG_DONTWAIT,
MSG_PEEK,
MSG_TRUNC,
SO_RCVBUF,
SO_SNDBUF,
SOCK_DGRAM,
SOL_SOCKET,
)
from pyroute2 import config
from pyroute2.common import DEFAULT_RCVBUF, AddrPool
from pyroute2.config import AF_NETLINK
from pyroute2.netlink import (
NETLINK_ADD_MEMBERSHIP,
NETLINK_DROP_MEMBERSHIP,
NETLINK_EXT_ACK,
NETLINK_GENERIC,
NETLINK_GET_STRICT_CHK,
NETLINK_LISTEN_ALL_NSID,
NLM_F_ACK,
NLM_F_ACK_TLVS,
NLM_F_DUMP,
NLM_F_DUMP_INTR,
NLM_F_MULTI,
NLM_F_REQUEST,
NLMSG_DONE,
NLMSG_ERROR,
SOL_NETLINK,
mtypes,
nlmsg,
nlmsgerr,
)
from pyroute2.netlink.exceptions import (
ChaoticException,
NetlinkDecodeError,
NetlinkDumpInterrupted,
NetlinkError,
NetlinkHeaderDecodeError,
)
try:
from Queue import Queue
except ImportError:
from queue import Queue
log = logging.getLogger(__name__)
Stats = collections.namedtuple('Stats', ('qsize', 'delta', 'delay'))
NL_BUFSIZE = 32768
class CompileContext:
def __init__(self, netlink_socket):
self.netlink_socket = netlink_socket
self.netlink_socket.compiled = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def close(self):
self.netlink_socket.compiled = None
class Marshal:
'''
Generic marshalling class
'''
msg_map = {}
seq_map = None
key_offset = None
key_format = None
key_mask = None
debug = False
default_message_class = nlmsg
error_type = NLMSG_ERROR
def __init__(self):
self.lock = threading.Lock()
self.msg_map = self.msg_map.copy()
self.seq_map = {}
self.defragmentation = {}
def parse_one_message(
self, key, flags, sequence_number, data, offset, length
):
msg = None
error = None
msg_class = self.msg_map.get(key, self.default_message_class)
# ignore length for a while
# get the message
if (key == self.error_type) or (
key == NLMSG_DONE and flags & NLM_F_ACK_TLVS
):
msg = nlmsgerr(data, offset=offset)
else:
msg = msg_class(data, offset=offset)
try:
msg.decode()
except NetlinkHeaderDecodeError as e:
msg = nlmsg()
msg['header']['error'] = e
except NetlinkDecodeError as e:
msg['header']['error'] = e
if isinstance(msg, nlmsgerr) and msg['error'] != 0:
error = NetlinkError(
abs(msg['error']), msg.get_attr('NLMSGERR_ATTR_MSG')
)
enc_type = struct.unpack_from('H', data, offset + 24)[0]
enc_class = self.msg_map.get(enc_type, nlmsg)
enc = enc_class(data, offset=offset + 20)
enc.decode()
msg['header']['errmsg'] = enc
msg['header']['error'] = error
return msg
def get_parser(self, key, flags, sequence_number):
return self.seq_map.get(
sequence_number,
partial(self.parse_one_message, key, flags, sequence_number),
)
def parse(self, data, seq=None, callback=None, skip_alien_seq=False):
'''
Parse string data.
At this moment all transport, except of the native
Netlink is deprecated in this library, so we should
not support any defragmentation on that level
'''
offset = 0
# there must be at least one header in the buffer,
# 'IHHII' == 16 bytes
while offset <= len(data) - 16:
# pick type and length
(length, key, flags, sequence_number) = struct.unpack_from(
'IHHI', data, offset
)
if skip_alien_seq and sequence_number != seq:
continue
if not 0 < length <= len(data):
break
# support custom parser keys
# see also: pyroute2.netlink.diag.MarshalDiag
if self.key_format is not None:
(key,) = struct.unpack_from(
self.key_format, data, offset + self.key_offset
)
if self.key_mask is not None:
key &= self.key_mask
parser = self.get_parser(key, flags, sequence_number)
msg = parser(data, offset, length)
offset += length
if msg is None:
continue
if callable(callback) and seq == sequence_number:
try:
if callback(msg):
continue
except Exception:
pass
mtype = msg['header'].get('type', None)
if mtype in (1, 2, 3, 4) and 'event' not in msg:
msg['event'] = mtypes.get(mtype, 'none')
self.fix_message(msg)
yield msg
def fix_message(self, msg):
pass
# 8<-----------------------------------------------------------
# Singleton, containing possible modifiers to the NetlinkSocket
# bind() call.
#
# Normally, you can open only one netlink connection for one
# process, but there is a hack. Current PID_MAX_LIMIT is 2^22,
# so we can use the rest to modify the pid field.
#
# See also libnl library, lib/socket.c:generate_local_port()
sockets = AddrPool(minaddr=0x0, maxaddr=0x3FF, reverse=True)
# 8<-----------------------------------------------------------
class LockProxy:
def __init__(self, factory, key):
self.factory = factory
self.refcount = 0
self.key = key
self.internal = threading.Lock()
self.lock = factory.klass()
def acquire(self, *argv, **kwarg):
with self.internal:
self.refcount += 1
return self.lock.acquire()
def release(self):
with self.internal:
self.refcount -= 1
if (self.refcount == 0) and (self.key != 0):
try:
del self.factory.locks[self.key]
except KeyError:
pass
return self.lock.release()
def __enter__(self):
self.acquire()
def __exit__(self, exc_type, exc_value, traceback):
self.release()
class LockFactory:
def __init__(self, klass=threading.RLock):
self.klass = klass
self.locks = {0: LockProxy(self, 0)}
def __enter__(self):
self.locks[0].acquire()
def __exit__(self, exc_type, exc_value, traceback):
self.locks[0].release()
def __getitem__(self, key):
if key is None:
key = 0
if key not in self.locks:
self.locks[key] = LockProxy(self, key)
return self.locks[key]
def __delitem__(self, key):
del self.locks[key]
class EngineBase:
def __init__(self, socket):
self.socket = socket
self.get_timeout = 30
self.get_timeout_exception = None
self.change_master = threading.Event()
self.read_lock = threading.Lock()
self.qsize = 0
@property
def marshal(self):
return self.socket.marshal
@property
def backlog(self):
return self.socket.backlog
@property
def backlog_lock(self):
return self.socket.backlog_lock
@property
def error_deque(self):
return self.socket.error_deque
@property
def lock(self):
return self.socket.lock
@property
def buffer_queue(self):
return self.socket.buffer_queue
@property
def epid(self):
return self.socket.epid
@property
def target(self):
return self.socket.target
@property
def callbacks(self):
return self.socket.callbacks
class EngineThreadSafe(EngineBase):
'''
Thread-safe engine for netlink sockets. It buffers all
incoming messages regardless sequence numbers, and returns
only messages with requested numbers. This is done using
synchronization primitives in a quite complicated manner.
'''
def put(
self,
msg,
msg_type,
msg_flags=NLM_F_REQUEST,
addr=(0, 0),
msg_seq=0,
msg_pid=None,
):
'''
Construct a message from a dictionary and send it to
the socket. Parameters:
- msg -- the message in the dictionary format
- msg_type -- the message type
- msg_flags -- the message flags to use in the request
- addr -- `sendto()` addr, default `(0, 0)`
- msg_seq -- sequence number to use
- msg_pid -- pid to use, if `None` -- use os.getpid()
Example::
s = IPRSocket()
s.bind()
s.put({'index': 1}, RTM_GETLINK)
s.get()
s.close()
Please notice, that the return value of `s.get()` can be
not the result of `s.put()`, but any broadcast message.
To fix that, use `msg_seq` -- the response must contain the
same `msg['header']['sequence_number']` value.
'''
if msg_seq != 0:
self.lock[msg_seq].acquire()
try:
if msg_seq not in self.backlog:
self.backlog[msg_seq] = []
if not isinstance(msg, nlmsg):
msg_class = self.marshal.msg_map[msg_type]
msg = msg_class(msg)
if msg_pid is None:
msg_pid = self.epid or os.getpid()
msg['header']['type'] = msg_type
msg['header']['flags'] = msg_flags
msg['header']['sequence_number'] = msg_seq
msg['header']['pid'] = msg_pid
self.socket.sendto_gate(msg, addr)
except:
raise
finally:
if msg_seq != 0:
self.lock[msg_seq].release()
def get(
self,
bufsize=DEFAULT_RCVBUF,
msg_seq=0,
terminate=None,
callback=None,
noraise=False,
):
'''
Get parsed messages list. If `msg_seq` is given, return
only messages with that `msg['header']['sequence_number']`,
saving all other messages into `self.backlog`.
The routine is thread-safe.
The `bufsize` parameter can be:
- -1: bufsize will be calculated from the first 4 bytes of
the network data
- 0: bufsize will be calculated from SO_RCVBUF sockopt
- int >= 0: just a bufsize
If `noraise` is true, error messages will be treated as any
other message.
'''
ctime = time.time()
with self.lock[msg_seq]:
if bufsize == -1:
# get bufsize from the network data
bufsize = struct.unpack("I", self.recv(4, MSG_PEEK))[0]
elif bufsize == 0:
# get bufsize from SO_RCVBUF
bufsize = self.getsockopt(SOL_SOCKET, SO_RCVBUF) // 2
tmsg = None
enough = False
backlog_acquired = False
try:
while not enough:
# 8<-----------------------------------------------------------
#
# This stage changes the backlog, so use mutex to
# prevent side changes
self.backlog_lock.acquire()
backlog_acquired = True
##
# Stage 1. BEGIN
#
# 8<-----------------------------------------------------------
#
# Check backlog and return already collected
# messages.
#
if msg_seq == -1 and any(self.backlog.values()):
for seq, backlog in self.backlog.items():
if backlog:
for msg in backlog:
yield msg
self.backlog[seq] = []
enough = True
break
elif msg_seq == 0 and self.backlog[0]:
# Zero queue.
#
# Load the backlog, if there is valid
# content in it
for msg in self.backlog[0]:
yield msg
self.backlog[0] = []
# And just exit
break
elif msg_seq > 0 and len(self.backlog.get(msg_seq, [])):
# Any other msg_seq.
#
# Collect messages up to the terminator.
# Terminator conditions:
# * NLMSG_ERROR != 0
# * NLMSG_DONE
# * terminate() function (if defined)
# * not NLM_F_MULTI
#
# Please note, that if terminator not occured,
# more `recv()` rounds CAN be required.
for msg in tuple(self.backlog[msg_seq]):
# Drop the message from the backlog, if any
self.backlog[msg_seq].remove(msg)
# If there is an error, raise exception
if (
msg['header']['error'] is not None
and not noraise
):
# reschedule all the remaining messages,
# including errors and acks, into a
# separate deque
self.error_deque.extend(self.backlog[msg_seq])
# flush the backlog for this msg_seq
del self.backlog[msg_seq]
# The loop is done
raise msg['header']['error']
# If it is the terminator message, say "enough"
# and requeue all the rest into Zero queue
if terminate is not None:
tmsg = terminate(msg)
if isinstance(tmsg, nlmsg):
yield msg
if (msg['header']['type'] == NLMSG_DONE) or tmsg:
# The loop is done
enough = True
# If it is just a normal message, append it to
# the response
if not enough:
# finish the loop on single messages
if not msg['header']['flags'] & NLM_F_MULTI:
enough = True
yield msg
# Enough is enough, requeue the rest and delete
# our backlog
if enough:
self.backlog[0].extend(self.backlog[msg_seq])
del self.backlog[msg_seq]
break
# Next iteration
self.backlog_lock.release()
backlog_acquired = False
else:
# Stage 1. END
#
# 8<-------------------------------------------------------
#
# Stage 2. BEGIN
#
# 8<-------------------------------------------------------
#
# Receive the data from the socket and put the messages
# into the backlog
#
self.backlog_lock.release()
backlog_acquired = False
##
#
# Control the timeout. We should not be within the
# function more than TIMEOUT seconds. All the locks
# MUST be released here.
#
if (msg_seq != 0) and (
time.time() - ctime > self.get_timeout
):
# requeue already received for that msg_seq
self.backlog[0].extend(self.backlog[msg_seq])
del self.backlog[msg_seq]
# throw an exception
if self.get_timeout_exception:
raise self.get_timeout_exception()
else:
return
#
if self.read_lock.acquire(False):
try:
self.change_master.clear()
# If the socket is free to read from, occupy
# it and wait for the data
#
# This is a time consuming process, so all the
# locks, except the read lock must be released
data = self.socket.recv(bufsize)
# Parse data
msgs = tuple(
self.socket.marshal.parse(
data, msg_seq, callback
)
)
# Reset ctime -- timeout should be measured
# for every turn separately
ctime = time.time()
#
current = self.buffer_queue.qsize()
delta = current - self.qsize
delay = 0
if delta > 10:
delay = min(
3, max(0.01, float(current) / 60000)
)
message = (
"Packet burst: "
"delta=%s qsize=%s delay=%s"
% (delta, current, delay)
)
if delay < 1:
log.debug(message)
else:
log.warning(message)
time.sleep(delay)
self.qsize = current
# We've got the data, lock the backlog again
with self.backlog_lock:
for msg in msgs:
msg['header']['target'] = self.target
msg['header']['stats'] = Stats(
current, delta, delay
)
seq = msg['header']['sequence_number']
if seq not in self.backlog:
if (
msg['header']['type']
== NLMSG_ERROR
):
# Drop orphaned NLMSG_ERROR
# messages
continue
seq = 0
# 8<-----------------------------------
# Callbacks section
for cr in self.callbacks:
try:
if cr[0](msg):
cr[1](msg, *cr[2])
except:
# FIXME
#
# Usually such code formatting
# means that the method should
# be refactored to avoid such
# indentation.
#
# Plz do something with it.
#
lw = log.warning
lw("Callback fail: %s" % (cr))
lw(traceback.format_exc())
# 8<-----------------------------------
self.backlog[seq].append(msg)
# Now wake up other threads
self.change_master.set()
finally:
# Finally, release the read lock: all data
# processed
self.read_lock.release()
else:
# If the socket is occupied and there is still no
# data for us, wait for the next master change or
# for a timeout
self.change_master.wait(1)
# 8<-------------------------------------------------------
#
# Stage 2. END
#
# 8<-------------------------------------------------------
finally:
if backlog_acquired:
self.backlog_lock.release()
class EngineThreadUnsafe(EngineBase):
'''
Thread unsafe nlsocket base class. Does not implement any locks
on message processing. Discards any message if the sequence number
does not match.
'''
def put(
self,
msg,
msg_type,
msg_flags=NLM_F_REQUEST,
addr=(0, 0),
msg_seq=0,
msg_pid=None,
):
if not isinstance(msg, nlmsg):
msg_class = self.marshal.msg_map[msg_type]
msg = msg_class(msg)
if msg_pid is None:
msg_pid = self.epid or os.getpid()
msg['header']['type'] = msg_type
msg['header']['flags'] = msg_flags
msg['header']['sequence_number'] = msg_seq
msg['header']['pid'] = msg_pid
self.sendto_gate(msg, addr)
def get(
self,
bufsize=DEFAULT_RCVBUF,
msg_seq=0,
terminate=None,
callback=None,
noraise=False,
):
if bufsize == -1:
# get bufsize from the network data
bufsize = struct.unpack("I", self.recv(4, MSG_PEEK))[0]
elif bufsize == 0:
# get bufsize from SO_RCVBUF
bufsize = self.getsockopt(SOL_SOCKET, SO_RCVBUF) // 2
enough = False
while not enough:
data = self.recv(bufsize)
*messages, last = tuple(
self.marshal.parse(data, msg_seq, callback)
)
for msg in messages:
msg['header']['target'] = self.target
msg['header']['stats'] = Stats(0, 0, 0)
yield msg
if last['header']['type'] == NLMSG_DONE:
break
if (
(msg_seq == 0)
or (not last['header']['flags'] & NLM_F_MULTI)
or (callable(terminate) and terminate(last))
):
enough = True
yield last
class NetlinkSocketBase:
'''
Generic netlink socket.
'''
input_from_buffer_queue = False
def __init__(
self,
family=NETLINK_GENERIC,
port=None,
pid=None,
fileno=None,
sndbuf=1048576,
rcvbuf=1048576,
all_ns=False,
async_qsize=None,
nlm_generator=None,
target='localhost',
ext_ack=False,
strict_check=False,
groups=0,
nlm_echo=False,
):
# 8<-----------------------------------------
self.config = {
'family': family,
'port': port,
'pid': pid,
'fileno': fileno,
'sndbuf': sndbuf,
'rcvbuf': rcvbuf,
'all_ns': all_ns,
'async_qsize': async_qsize,
'target': target,
'nlm_generator': nlm_generator,
'ext_ack': ext_ack,
'strict_check': strict_check,
'groups': groups,
'nlm_echo': nlm_echo,
}
# 8<-----------------------------------------
self.addr_pool = AddrPool(minaddr=0x000000FF, maxaddr=0x0000FFFF)
self.epid = None
self.port = 0
self.fixed = True
self.family = family
self._fileno = fileno
self._sndbuf = sndbuf
self._rcvbuf = rcvbuf
self._use_peek = True
self.backlog = {0: []}
self.error_deque = collections.deque(maxlen=1000)
self.callbacks = [] # [(predicate, callback, args), ...]
self.buffer_thread = None
self.closed = False
self.compiled = None
self.uname = config.uname
self.target = target
self.groups = groups
self.capabilities = {
'create_bridge': config.kernel > [3, 2, 0],
'create_bond': config.kernel > [3, 2, 0],
'create_dummy': True,
'provide_master': config.kernel[0] > 2,
}
self.backlog_lock = threading.Lock()
self.sys_lock = threading.RLock()
self.lock = LockFactory()
self._sock = None
self._ctrl_read, self._ctrl_write = os.pipe()
if async_qsize is None:
async_qsize = config.async_qsize
self.async_qsize = async_qsize
if nlm_generator is None:
nlm_generator = config.nlm_generator
self.nlm_generator = nlm_generator
self.buffer_queue = Queue(maxsize=async_qsize)
self.log = []
self.all_ns = all_ns
self.ext_ack = ext_ack
self.strict_check = strict_check
if pid is None:
self.pid = os.getpid() & 0x3FFFFF
self.port = port
self.fixed = self.port is not None
elif pid == 0:
self.pid = os.getpid()
else:
self.pid = pid
# 8<-----------------------------------------
self.marshal = Marshal()
# 8<-----------------------------------------
if not nlm_generator:
def nlm_request(*argv, **kwarg):
return tuple(self._genlm_request(*argv, **kwarg))
def get(*argv, **kwarg):
return tuple(self._genlm_get(*argv, **kwarg))
self._genlm_request = self.nlm_request
self._genlm_get = self.get
self.nlm_request = nlm_request
self.get = get
def nlm_request_batch(*argv, **kwarg):
return tuple(self._genlm_request_batch(*argv, **kwarg))
self._genlm_request_batch = self.nlm_request_batch
self.nlm_request_batch = nlm_request_batch
# Set defaults
self.post_init()
self.engine = EngineThreadSafe(self)
def post_init(self):
pass
def clone(self):
return type(self)(**self.config)
def put(
self,
msg,
msg_type,
msg_flags=NLM_F_REQUEST,
addr=(0, 0),
msg_seq=0,
msg_pid=None,
):
return self.engine.put(
msg, msg_type, msg_flags, addr, msg_seq, msg_pid
)
def get(
self,
bufsize=DEFAULT_RCVBUF,
msg_seq=0,
terminate=None,
callback=None,
noraise=False,
):
return self.engine.get(bufsize, msg_seq, terminate, callback, noraise)
def close(self, code=errno.ECONNRESET):
if code > 0 and self.input_from_buffer_queue:
self.buffer_queue.put(
struct.pack('IHHQIQQ', 28, 2, 0, 0, code, 0, 0)
)
try:
os.close(self._ctrl_write)
os.close(self._ctrl_read)
except OSError:
# ignore the case when it is closed already
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def release(self):
warnings.warn('deprecated, use close() instead', DeprecationWarning)
self.close()
def register_callback(self, callback, predicate=lambda x: True, args=None):
'''
Register a callback to run on a message arrival.
Callback is the function that will be called with the
message as the first argument. Predicate is the optional
callable object, that returns True or False. Upon True,
the callback will be called. Upon False it will not.
Args is a list or tuple of arguments.
Simplest example, assume ipr is the IPRoute() instance::
# create a simplest callback that will print messages
def cb(msg):
print(msg)
# register callback for any message:
ipr.register_callback(cb)
More complex example, with filtering::
# Set object's attribute after the message key
def cb(msg, obj):
obj.some_attr = msg["some key"]
# Register the callback only for the loopback device, index 1:
ipr.register_callback(cb,
lambda x: x.get('index', None) == 1,
(self, ))
Please note: you do **not** need to register the default 0 queue
to invoke callbacks on broadcast messages. Callbacks are
iterated **before** messages get enqueued.
'''
if args is None:
args = []
self.callbacks.append((predicate, callback, args))
def unregister_callback(self, callback):
'''
Remove the first reference to the function from the callback
register
'''
cb = tuple(self.callbacks)
for cr in cb:
if cr[1] == callback:
self.callbacks.pop(cb.index(cr))
return
def register_policy(self, policy, msg_class=None):
'''
Register netlink encoding/decoding policy. Can
be specified in two ways:
`nlsocket.register_policy(MSG_ID, msg_class)`
to register one particular rule, or
`nlsocket.register_policy({MSG_ID1: msg_class})`
to register several rules at once.
E.g.::
policy = {RTM_NEWLINK: ifinfmsg,
RTM_DELLINK: ifinfmsg,
RTM_NEWADDR: ifaddrmsg,
RTM_DELADDR: ifaddrmsg}
nlsocket.register_policy(policy)
One can call `register_policy()` as many times,
as one want to -- it will just extend the current
policy scheme, not replace it.
'''
if isinstance(policy, int) and msg_class is not None:
policy = {policy: msg_class}
if not isinstance(policy, dict):
raise TypeError('wrong policy type')
for key in policy:
self.marshal.msg_map[key] = policy[key]
return self.marshal.msg_map
def unregister_policy(self, policy):
'''
Unregister policy. Policy can be:
- int -- then it will just remove one policy
- list or tuple of ints -- remove all given
- dict -- remove policies by keys from dict
In the last case the routine will ignore dict values,
it is implemented so just to make it compatible with
`get_policy_map()` return value.
'''
if isinstance(policy, int):
policy = [policy]
elif isinstance(policy, dict):
policy = list(policy)
if not isinstance(policy, (tuple, list, set)):
raise TypeError('wrong policy type')
for key in policy:
del self.marshal.msg_map[key]
return self.marshal.msg_map
def get_policy_map(self, policy=None):
'''
Return policy for a given message type or for all
message types. Policy parameter can be either int,
or a list of ints. Always return dictionary.
'''
if policy is None:
return self.marshal.msg_map
if isinstance(policy, int):
policy = [policy]
if not isinstance(policy, (list, tuple, set)):
raise TypeError('wrong policy type')
ret = {}
for key in policy:
ret[key] = self.marshal.msg_map[key]
return ret
def _peek_bufsize(self, socket_descriptor):
data = bytearray()
try:
bufsize, _ = socket_descriptor.recvfrom_into(
data, 0, MSG_DONTWAIT | MSG_PEEK | MSG_TRUNC
)
except BlockingIOError:
self._use_peek = False
bufsize = socket_descriptor.getsockopt(SOL_SOCKET, SO_RCVBUF) // 2
return bufsize
def sendto(self, *argv, **kwarg):
return self._sendto(*argv, **kwarg)
def recv(self, bufsize, flags=0):
if self.input_from_buffer_queue:
data_in = self.buffer_queue.get()
if isinstance(data_in, Exception):
raise data_in
return data_in
return self._sock.recv(
self._peek_bufsize(self._sock) if self._use_peek else bufsize,
flags,
)
def recv_into(self, data, *argv, **kwarg):
if self.input_from_buffer_queue:
data_in = self.buffer_queue.get()
if isinstance(data, Exception):
raise data_in
data[:] = data_in
return len(data_in)
return self._sock.recv_into(data, *argv, **kwarg)
def buffer_thread_routine(self):
poll = select.poll()
poll.register(self._sock, select.POLLIN | select.POLLPRI)
poll.register(self._ctrl_read, select.POLLIN | select.POLLPRI)
sockfd = self._sock.fileno()
while True:
events = poll.poll()
for fd, event in events:
if fd == sockfd:
try:
data = bytearray(64000)
self._sock.recv_into(data, 64000)
self.buffer_queue.put_nowait(data)
except Exception as e:
self.buffer_queue.put(e)
return
else:
return
def compile(self):
return CompileContext(self)
def _send_batch(self, msgs, addr=(0, 0)):
with self.backlog_lock:
for msg in msgs:
self.backlog[msg['header']['sequence_number']] = []
# We have locked the message locks in the caller already.
data = bytearray()
for msg in msgs:
if not isinstance(msg, nlmsg):
msg_class = self.marshal.msg_map[msg['header']['type']]
msg = msg_class(msg)
msg.reset()
msg.encode()
data += msg.data
if self.compiled is not None:
return self.compiled.append(data)
self._sock.sendto(data, addr)
def sendto_gate(self, msg, addr):
msg.reset()
msg.encode()
if self.compiled is not None:
return self.compiled.append(msg.data)
return self._sock.sendto(msg.data, addr)
def nlm_request_batch(self, msgs, noraise=False):
"""
This function is for messages which are expected to have side effects.
Do not blindly retry in case of errors as this might duplicate them.
"""
expected_responses = []
acquired = 0
seqs = self.addr_pool.alloc_multi(len(msgs))
try:
for seq in seqs:
self.lock[seq].acquire()
acquired += 1
for seq, msg in zip(seqs, msgs):
msg['header']['sequence_number'] = seq
if 'pid' not in msg['header']:
msg['header']['pid'] = self.epid or os.getpid()
if (msg['header']['flags'] & NLM_F_ACK) or (
msg['header']['flags'] & NLM_F_DUMP
):
expected_responses.append(seq)
self._send_batch(msgs)
if self.compiled is not None:
for data in self.compiled:
yield data
else:
for seq in expected_responses:
for msg in self.get(msg_seq=seq, noraise=noraise):
if msg['header']['flags'] & NLM_F_DUMP_INTR:
# Leave error handling to the caller
raise NetlinkDumpInterrupted()
yield msg
finally:
# Release locks in reverse order.
for seq in seqs[acquired - 1 :: -1]:
self.lock[seq].release()
with self.backlog_lock:
for seq in seqs:
# Clear the backlog. We may have raised an error
# causing the backlog to not be consumed entirely.
if seq in self.backlog:
del self.backlog[seq]
self.addr_pool.free(seq, ban=0xFF)
def nlm_request(
self,
msg,
msg_type,
msg_flags=NLM_F_REQUEST | NLM_F_DUMP,
terminate=None,
callback=None,
parser=None,
):
msg_seq = self.addr_pool.alloc()
defer = None
if callable(parser):
self.marshal.seq_map[msg_seq] = parser
with self.lock[msg_seq]:
retry_count = 0
try:
while True:
try:
self.put(msg, msg_type, msg_flags, msg_seq=msg_seq)
if self.compiled is not None:
for data in self.compiled:
yield data
else:
for msg in self.get(
msg_seq=msg_seq,
terminate=terminate,
callback=callback,
):
# analyze the response for effects to be
# deferred
if (
defer is None
and msg['header']['flags']
& NLM_F_DUMP_INTR
):
defer = NetlinkDumpInterrupted()
yield msg
break
except NetlinkError as e:
if e.code != errno.EBUSY:
raise
if retry_count >= 30:
raise
log.warning('Error 16, retry {}.'.format(retry_count))
time.sleep(0.3)
retry_count += 1
continue
except Exception:
raise
finally:
# Ban this msg_seq for 0xff rounds
#
# It's a long story. Modern kernels for RTM_SET.*
# operations always return NLMSG_ERROR(0) == success,
# even not setting NLM_F_MULTI flag on other response
# messages and thus w/o any NLMSG_DONE. So, how to detect
# the response end? One can not rely on NLMSG_ERROR on
# old kernels, but we have to support them too. Ty, we
# just ban msg_seq for several rounds, and NLMSG_ERROR,
# being received, will become orphaned and just dropped.
#
# Hack, but true.
self.addr_pool.free(msg_seq, ban=0xFF)
if msg_seq in self.marshal.seq_map:
self.marshal.seq_map.pop(msg_seq)
if defer is not None:
raise defer
class BatchAddrPool:
def alloc(self, *argv, **kwarg):
return 0
def free(self, *argv, **kwarg):
pass
class BatchBacklogQueue(list):
def append(self, *argv, **kwarg):
pass
def pop(self, *argv, **kwarg):
pass
class BatchBacklog(dict):
def __getitem__(self, key):
return BatchBacklogQueue()
def __setitem__(self, key, value):
pass
def __delitem__(self, key):
pass
class BatchSocket(NetlinkSocketBase):
def post_init(self):
self.backlog = BatchBacklog()
self.addr_pool = BatchAddrPool()
self._sock = None
self.reset()
def reset(self):
self.batch = bytearray()
def nlm_request(
self,
msg,
msg_type,
msg_flags=NLM_F_REQUEST | NLM_F_DUMP,
terminate=None,
callback=None,
):
msg_seq = self.addr_pool.alloc()
msg_pid = self.epid or os.getpid()
msg['header']['type'] = msg_type
msg['header']['flags'] = msg_flags
msg['header']['sequence_number'] = msg_seq
msg['header']['pid'] = msg_pid
msg.data = self.batch
msg.offset = len(self.batch)
msg.encode()
return []
def get(self, *argv, **kwarg):
pass
class NetlinkSocket(NetlinkSocketBase):
def post_init(self):
# recreate the underlying socket
with self.sys_lock:
if self._sock is not None:
self._sock.close()
self._sock = config.SocketBase(
AF_NETLINK, SOCK_DGRAM, self.family, self._fileno
)
self.setsockopt(SOL_SOCKET, SO_SNDBUF, self._sndbuf)
self.setsockopt(SOL_SOCKET, SO_RCVBUF, self._rcvbuf)
if self.ext_ack:
self.setsockopt(SOL_NETLINK, NETLINK_EXT_ACK, 1)
if self.all_ns:
self.setsockopt(SOL_NETLINK, NETLINK_LISTEN_ALL_NSID, 1)
if self.strict_check:
self.setsockopt(SOL_NETLINK, NETLINK_GET_STRICT_CHK, 1)
def __getattr__(self, attr):
if attr in (
'getsockname',
'getsockopt',
'makefile',
'setsockopt',
'setblocking',
'settimeout',
'gettimeout',
'shutdown',
'recvfrom',
'recvfrom_into',
'fileno',
):
return getattr(self._sock, attr)
elif attr in ('_sendto', '_recv', '_recv_into'):
return getattr(self._sock, attr.lstrip("_"))
raise AttributeError(attr)
def bind(self, groups=0, pid=None, **kwarg):
'''
Bind the socket to given multicast groups, using
given pid.
- If pid is None, use automatic port allocation
- If pid == 0, use process' pid
- If pid == <int>, use the value instead of pid
'''
if pid is not None:
self.port = 0
self.fixed = True
self.pid = pid or os.getpid()
if 'async' in kwarg:
# FIXME
# raise deprecation error after 0.5.3
#
log.warning(
'use "async_cache" instead of "async", '
'"async" is a keyword from Python 3.7'
)
async_cache = kwarg.get('async_cache') or kwarg.get('async')
self.groups = groups
# if we have pre-defined port, use it strictly
if self.fixed:
self.epid = self.pid + (self.port << 22)
self._sock.bind((self.epid, self.groups))
else:
for port in range(1024):
try:
self.port = port
self.epid = self.pid + (self.port << 22)
self._sock.bind((self.epid, self.groups))
break
except Exception:
# create a new underlying socket -- on kernel 4
# one failed bind() makes the socket useless
self.post_init()
else:
raise KeyError('no free address available')
# all is OK till now, so start async recv, if we need
if async_cache:
self.buffer_thread = threading.Thread(
name="Netlink async cache", target=self.buffer_thread_routine
)
self.input_from_buffer_queue = True
self.buffer_thread.daemon = True
self.buffer_thread.start()
def add_membership(self, group):
self.setsockopt(SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, group)
def drop_membership(self, group):
self.setsockopt(SOL_NETLINK, NETLINK_DROP_MEMBERSHIP, group)
def close(self, code=errno.ECONNRESET):
'''
Correctly close the socket and free all resources.
'''
with self.sys_lock:
if self.closed:
return
self.closed = True
if self.buffer_thread:
os.write(self._ctrl_write, b'exit')
self.buffer_thread.join()
super(NetlinkSocket, self).close(code=code)
# Common shutdown procedure
self._sock.close()
class ChaoticNetlinkSocket(NetlinkSocket):
success_rate = 1
def __init__(self, *argv, **kwarg):
self.success_rate = kwarg.pop('success_rate', 0.7)
super(ChaoticNetlinkSocket, self).__init__(*argv, **kwarg)
def get(self, *argv, **kwarg):
if random.random() > self.success_rate:
raise ChaoticException()
return super(ChaoticNetlinkSocket, self).get(*argv, **kwarg)
Zerion Mini Shell 1.0