mirror of https://github.com/F-Stack/f-stack.git
487 lines
14 KiB
Python
487 lines
14 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||
|
# Copyright (c) 2014 6WIND S.A.
|
||
|
# Copyright (c) 2023 Robin Jarry
|
||
|
|
||
|
"""
|
||
|
Craft IP{v6}/{TCP/UDP} traffic flows that will evenly spread over a given
|
||
|
number of RX queues according to the RSS algorithm.
|
||
|
"""
|
||
|
|
||
|
import argparse
|
||
|
import binascii
|
||
|
import ctypes
|
||
|
import ipaddress
|
||
|
import json
|
||
|
import struct
|
||
|
import typing
|
||
|
|
||
|
|
||
|
Address = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||
|
Network = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
|
||
|
PortList = typing.Iterable[int]
|
||
|
|
||
|
|
||
|
class Packet:
|
||
|
def __init__(self, ip_src: Address, ip_dst: Address, l4_sport: int, l4_dport: int):
|
||
|
self.ip_src = ip_src
|
||
|
self.ip_dst = ip_dst
|
||
|
self.l4_sport = l4_sport
|
||
|
self.l4_dport = l4_dport
|
||
|
|
||
|
def reverse(self):
|
||
|
return Packet(
|
||
|
ip_src=self.ip_dst,
|
||
|
l4_sport=self.l4_dport,
|
||
|
ip_dst=self.ip_src,
|
||
|
l4_dport=self.l4_sport,
|
||
|
)
|
||
|
|
||
|
def hash_data(self, use_l4_port: bool = False) -> bytes:
|
||
|
data = self.ip_src.packed + self.ip_dst.packed
|
||
|
if use_l4_port:
|
||
|
data += struct.pack(">H", self.l4_sport)
|
||
|
data += struct.pack(">H", self.l4_dport)
|
||
|
return data
|
||
|
|
||
|
|
||
|
class TrafficTemplate:
|
||
|
def __init__(
|
||
|
self,
|
||
|
ip_src: Network,
|
||
|
ip_dst: Network,
|
||
|
l4_sport_range: PortList,
|
||
|
l4_dport_range: PortList,
|
||
|
):
|
||
|
self.ip_src = ip_src
|
||
|
self.ip_dst = ip_dst
|
||
|
self.l4_sport_range = l4_sport_range
|
||
|
self.l4_dport_range = l4_dport_range
|
||
|
|
||
|
def __iter__(self) -> typing.Iterator[Packet]:
|
||
|
for ip_src in self.ip_src.hosts():
|
||
|
for ip_dst in self.ip_dst.hosts():
|
||
|
if ip_src == ip_dst:
|
||
|
continue
|
||
|
for sport in self.l4_sport_range:
|
||
|
for dport in self.l4_dport_range:
|
||
|
yield Packet(ip_src, ip_dst, sport, dport)
|
||
|
|
||
|
|
||
|
class RSSAlgo:
|
||
|
def __init__(
|
||
|
self,
|
||
|
queues_count: int,
|
||
|
key: bytes,
|
||
|
reta_size: int,
|
||
|
use_l4_port: bool,
|
||
|
):
|
||
|
self.queues_count = queues_count
|
||
|
self.reta = tuple(i % queues_count for i in range(reta_size))
|
||
|
self.key = key
|
||
|
self.use_l4_port = use_l4_port
|
||
|
|
||
|
def toeplitz_hash(self, data: bytes) -> int:
|
||
|
# see rte_softrss_* in lib/hash/rte_thash.h
|
||
|
hash_value = ctypes.c_uint32(0)
|
||
|
|
||
|
for i, byte in enumerate(data):
|
||
|
for j in range(8):
|
||
|
bit = (byte >> (7 - j)) & 0x01
|
||
|
|
||
|
if bit == 1:
|
||
|
keyword = ctypes.c_uint32(0)
|
||
|
keyword.value |= self.key[i] << 24
|
||
|
keyword.value |= self.key[i + 1] << 16
|
||
|
keyword.value |= self.key[i + 2] << 8
|
||
|
keyword.value |= self.key[i + 3]
|
||
|
|
||
|
if j > 0:
|
||
|
keyword.value <<= j
|
||
|
keyword.value |= self.key[i + 4] >> (8 - j)
|
||
|
|
||
|
hash_value.value ^= keyword.value
|
||
|
|
||
|
return hash_value.value
|
||
|
|
||
|
def get_queue_index(self, packet: Packet) -> int:
|
||
|
bytes_to_hash = packet.hash_data(self.use_l4_port)
|
||
|
|
||
|
# get the 32bit hash of the packet
|
||
|
hash_value = self.toeplitz_hash(bytes_to_hash)
|
||
|
|
||
|
# determine the offset in the redirection table
|
||
|
offset = hash_value & (len(self.reta) - 1)
|
||
|
|
||
|
return self.reta[offset]
|
||
|
|
||
|
|
||
|
def balanced_traffic(
|
||
|
algo: RSSAlgo,
|
||
|
traffic_template: TrafficTemplate,
|
||
|
check_reverse_traffic: bool = False,
|
||
|
all_flows: bool = False,
|
||
|
) -> typing.Iterator[typing.Tuple[int, int, Packet]]:
|
||
|
queues = set()
|
||
|
if check_reverse_traffic:
|
||
|
queues_reverse = set()
|
||
|
|
||
|
for pkt in traffic_template:
|
||
|
q = algo.get_queue_index(pkt)
|
||
|
|
||
|
# check if q is already filled
|
||
|
if not all_flows and q in queues:
|
||
|
continue
|
||
|
|
||
|
qr = algo.get_queue_index(pkt.reverse())
|
||
|
|
||
|
if check_reverse_traffic:
|
||
|
# check if q is already filled
|
||
|
if not all_flows and qr in queues_reverse:
|
||
|
continue
|
||
|
# mark this queue as matched
|
||
|
queues_reverse.add(qr)
|
||
|
|
||
|
# mark this queue as filled
|
||
|
queues.add(q)
|
||
|
|
||
|
yield q, qr, pkt
|
||
|
|
||
|
# stop when all queues have been filled
|
||
|
if not all_flows and len(queues) == algo.queues_count:
|
||
|
break
|
||
|
|
||
|
|
||
|
NO_PORT = (0,)
|
||
|
|
||
|
|
||
|
class DriverInfo:
|
||
|
def __init__(self, key: bytes = None, reta_size: int = None):
|
||
|
self.__key = key
|
||
|
self.__reta_size = reta_size
|
||
|
|
||
|
def rss_key(self) -> bytes:
|
||
|
return self.__key
|
||
|
|
||
|
def reta_size(self, num_queues: int) -> int:
|
||
|
return self.__reta_size
|
||
|
|
||
|
|
||
|
class MlxDriverInfo(DriverInfo):
|
||
|
def rss_key(self) -> bytes:
|
||
|
return bytes(
|
||
|
(
|
||
|
# fmt: off
|
||
|
# rss_hash_default_key, see drivers/net/mlx5/mlx5_rxq.c
|
||
|
0x2c, 0xc6, 0x81, 0xd1, 0x5b, 0xdb, 0xf4, 0xf7,
|
||
|
0xfc, 0xa2, 0x83, 0x19, 0xdb, 0x1a, 0x3e, 0x94,
|
||
|
0x6b, 0x9e, 0x38, 0xd9, 0x2c, 0x9c, 0x03, 0xd1,
|
||
|
0xad, 0x99, 0x44, 0xa7, 0xd9, 0x56, 0x3d, 0x59,
|
||
|
0x06, 0x3c, 0x25, 0xf3, 0xfc, 0x1f, 0xdc, 0x2a,
|
||
|
# fmt: on
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def reta_size(self, num_queues: int) -> int:
|
||
|
if num_queues & (num_queues - 1) == 0:
|
||
|
# If the requested number of RX queues is power of two,
|
||
|
# use a table of this size.
|
||
|
return num_queues
|
||
|
# otherwise, use the maximum table size
|
||
|
return 512
|
||
|
|
||
|
|
||
|
DEFAULT_DRIVERS = {
|
||
|
"cnxk": DriverInfo(
|
||
|
key=bytes(
|
||
|
(
|
||
|
# fmt: off
|
||
|
# roc_nix_rss_key_default_fill, see drivers/common/cnxk/roc_nix_rss.c
|
||
|
# Marvell cnxk NICs take 48 bytes keys
|
||
|
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
|
||
|
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
|
||
|
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
|
||
|
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
|
||
|
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
|
||
|
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
|
||
|
# fmt: on
|
||
|
)
|
||
|
),
|
||
|
reta_size=64,
|
||
|
),
|
||
|
"intel": DriverInfo(
|
||
|
key=bytes(
|
||
|
(
|
||
|
# fmt: off
|
||
|
# rss_intel_key, see drivers/net/ixgbe/ixgbe_rxtx.c
|
||
|
0x6d, 0x5a, 0x56, 0xda, 0x25, 0x5b, 0x0e, 0xc2,
|
||
|
0x41, 0x67, 0x25, 0x3d, 0x43, 0xa3, 0x8f, 0xb0,
|
||
|
0xd0, 0xca, 0x2b, 0xcb, 0xae, 0x7b, 0x30, 0xb4,
|
||
|
0x77, 0xcb, 0x2d, 0xa3, 0x80, 0x30, 0xf2, 0x0c,
|
||
|
0x6a, 0x42, 0xb7, 0x3b, 0xbe, 0xac, 0x01, 0xfa,
|
||
|
# fmt: on
|
||
|
)
|
||
|
),
|
||
|
reta_size=128,
|
||
|
),
|
||
|
"i40e": DriverInfo(
|
||
|
key=bytes(
|
||
|
(
|
||
|
# fmt: off
|
||
|
# rss_key_default, see drivers/net/i40e/i40e_ethdev.c
|
||
|
# i40e is the only driver that takes 52 bytes keys
|
||
|
0x44, 0x39, 0x79, 0x6b, 0xb5, 0x4c, 0x50, 0x23,
|
||
|
0xb6, 0x75, 0xea, 0x5b, 0x12, 0x4f, 0x9f, 0x30,
|
||
|
0xb8, 0xa2, 0xc0, 0x3d, 0xdf, 0xdc, 0x4d, 0x02,
|
||
|
0xa0, 0x8c, 0x9b, 0x33, 0x4a, 0xf6, 0x4a, 0x4c,
|
||
|
0x05, 0xc6, 0xfa, 0x34, 0x39, 0x58, 0xd8, 0x55,
|
||
|
0x7d, 0x99, 0x58, 0x3a, 0xe1, 0x38, 0xc9, 0x2e,
|
||
|
0x81, 0x15, 0x03, 0x66,
|
||
|
# fmt: on
|
||
|
)
|
||
|
),
|
||
|
reta_size=512,
|
||
|
),
|
||
|
"mlx": MlxDriverInfo(),
|
||
|
}
|
||
|
|
||
|
|
||
|
def port_range(value):
|
||
|
try:
|
||
|
if "-" in value:
|
||
|
start, stop = value.split("-")
|
||
|
res = tuple(range(int(start), int(stop)))
|
||
|
else:
|
||
|
res = (int(value),)
|
||
|
return res or NO_PORT
|
||
|
except ValueError as e:
|
||
|
raise argparse.ArgumentTypeError(str(e)) from e
|
||
|
|
||
|
|
||
|
def positive_int(value):
|
||
|
try:
|
||
|
i = int(value)
|
||
|
if i <= 0:
|
||
|
raise argparse.ArgumentTypeError("must be strictly positive")
|
||
|
return i
|
||
|
except ValueError as e:
|
||
|
raise argparse.ArgumentTypeError(str(e)) from e
|
||
|
|
||
|
|
||
|
def power_of_two(value):
|
||
|
i = positive_int(value)
|
||
|
if i & (i - 1) != 0:
|
||
|
raise argparse.ArgumentTypeError("must be a power of two")
|
||
|
return i
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"rx_queues",
|
||
|
metavar="RX_QUEUES",
|
||
|
type=positive_int,
|
||
|
help="""
|
||
|
The number of RX queues to fill.
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"ip_src",
|
||
|
metavar="SRC",
|
||
|
type=ipaddress.ip_network,
|
||
|
help="""
|
||
|
The source IP network/address.
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"ip_dst",
|
||
|
metavar="DST",
|
||
|
type=ipaddress.ip_network,
|
||
|
help="""
|
||
|
The destination IP network/address.
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-s",
|
||
|
"--sport-range",
|
||
|
type=port_range,
|
||
|
default=NO_PORT,
|
||
|
help="""
|
||
|
The layer 4 (TCP/UDP) source port range.
|
||
|
Can be a single fixed value or a range <start>-<end>.
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-d",
|
||
|
"--dport-range",
|
||
|
type=port_range,
|
||
|
default=NO_PORT,
|
||
|
help="""
|
||
|
The layer 4 (TCP/UDP) destination port range.
|
||
|
Can be a single fixed value or a range <start>-<end>.
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-r",
|
||
|
"--check-reverse-traffic",
|
||
|
action="store_true",
|
||
|
help="""
|
||
|
The reversed traffic (source <-> dest) should also be evenly balanced
|
||
|
in the queues.
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-k",
|
||
|
"--rss-key",
|
||
|
default="intel",
|
||
|
help=f"""
|
||
|
The random key used to compute the RSS hash. This option
|
||
|
supports either a well-known name or the hex value of the key
|
||
|
(well-known names: {', '.join(DEFAULT_DRIVERS)}, default: intel).
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-t",
|
||
|
"--reta-size",
|
||
|
type=power_of_two,
|
||
|
help="""
|
||
|
Size of the redirection table or "RETA" (default: depends on driver if
|
||
|
using a well-known driver name, otherwise 128).
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-a",
|
||
|
"--all-flows",
|
||
|
action="store_true",
|
||
|
help="""
|
||
|
Output ALL flows that can be created based on source and destination
|
||
|
address/port ranges along their matched queue number. ATTENTION: this
|
||
|
option can produce very long outputs depending on the address and port
|
||
|
range sizes.
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-j",
|
||
|
"--json",
|
||
|
action="store_true",
|
||
|
help="""
|
||
|
Output in parseable JSON format.
|
||
|
""",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-i",
|
||
|
"--info",
|
||
|
action="store_true",
|
||
|
help="""
|
||
|
Print RETA size and RSS key above the results. Not available with --json.
|
||
|
""",
|
||
|
)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
if args.ip_src.version != args.ip_dst.version:
|
||
|
parser.error(
|
||
|
f"{args.ip_src} and {args.ip_dst} don't have the same protocol version"
|
||
|
)
|
||
|
|
||
|
if args.json and args.info:
|
||
|
parser.error("--json and --info are mutually exclusive")
|
||
|
|
||
|
if args.rss_key in DEFAULT_DRIVERS:
|
||
|
driver_info = DEFAULT_DRIVERS[args.rss_key]
|
||
|
else:
|
||
|
try:
|
||
|
key = binascii.unhexlify(args.rss_key)
|
||
|
except (TypeError, ValueError) as e:
|
||
|
parser.error(f"RSS_KEY: {e}")
|
||
|
driver_info = DriverInfo(key=key, reta_size=128)
|
||
|
|
||
|
if args.reta_size is None:
|
||
|
args.reta_size = driver_info.reta_size(args.rx_queues)
|
||
|
|
||
|
if args.reta_size < args.rx_queues:
|
||
|
parser.error("RETA_SIZE must be greater than or equal to RX_QUEUES")
|
||
|
|
||
|
args.rss_key = driver_info.rss_key()
|
||
|
|
||
|
return args
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parse_args()
|
||
|
use_l4_port = args.sport_range != NO_PORT or args.dport_range != NO_PORT
|
||
|
|
||
|
algo = RSSAlgo(
|
||
|
queues_count=args.rx_queues,
|
||
|
key=args.rss_key,
|
||
|
reta_size=args.reta_size,
|
||
|
use_l4_port=use_l4_port,
|
||
|
)
|
||
|
template = TrafficTemplate(
|
||
|
args.ip_src,
|
||
|
args.ip_dst,
|
||
|
args.sport_range,
|
||
|
args.dport_range,
|
||
|
)
|
||
|
|
||
|
results = balanced_traffic(
|
||
|
algo, template, args.check_reverse_traffic, args.all_flows
|
||
|
)
|
||
|
|
||
|
if args.json:
|
||
|
flows = []
|
||
|
for q, qr, pkt in results:
|
||
|
flows.append(
|
||
|
{
|
||
|
"queue": q,
|
||
|
"queue_reverse": qr,
|
||
|
"src_ip": str(pkt.ip_src),
|
||
|
"dst_ip": str(pkt.ip_dst),
|
||
|
"src_port": pkt.l4_sport,
|
||
|
"dst_port": pkt.l4_dport,
|
||
|
}
|
||
|
)
|
||
|
print(json.dumps(flows, indent=2))
|
||
|
return
|
||
|
|
||
|
if use_l4_port:
|
||
|
header = ["SRC_IP", "SPORT", "DST_IP", "DPORT", "QUEUE"]
|
||
|
else:
|
||
|
header = ["SRC_IP", "DST_IP", "QUEUE"]
|
||
|
if args.check_reverse_traffic:
|
||
|
header.append("QUEUE_REVERSE")
|
||
|
|
||
|
rows = [tuple(header)]
|
||
|
widths = [len(h) for h in header]
|
||
|
|
||
|
for q, qr, pkt in results:
|
||
|
if use_l4_port:
|
||
|
row = [pkt.ip_src, pkt.l4_sport, pkt.ip_dst, pkt.l4_dport, q]
|
||
|
else:
|
||
|
row = [pkt.ip_src, pkt.ip_dst, q]
|
||
|
if args.check_reverse_traffic:
|
||
|
row.append(qr)
|
||
|
cells = []
|
||
|
for i, r in enumerate(row):
|
||
|
r = str(r)
|
||
|
if len(r) > widths[i]:
|
||
|
widths[i] = len(r)
|
||
|
cells.append(r)
|
||
|
rows.append(tuple(cells))
|
||
|
|
||
|
if args.info:
|
||
|
print(f"RSS key: {binascii.hexlify(args.rss_key).decode()}")
|
||
|
print(f"RETA size: {args.reta_size}")
|
||
|
print()
|
||
|
|
||
|
fmt = [f"%-{w}s" for w in widths]
|
||
|
fmt[-1] = "%s" # avoid trailing whitespace
|
||
|
fmt = " ".join(fmt)
|
||
|
for row in rows:
|
||
|
print(fmt % row)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|