f-stack/dpdk/usertools/dpdk-rss-flows.py

487 lines
14 KiB
Python
Raw Normal View History

2025-01-10 11:50:43 +00:00
#!/usr/bin/env python3
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2014 6WIND S.A.
# Copyright (c) 2023 Robin Jarry
"""
Craft IP{v6}/{TCP/UDP} traffic flows that will evenly spread over a given
number of RX queues according to the RSS algorithm.
"""
import argparse
import binascii
import ctypes
import ipaddress
import json
import struct
import typing
Address = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
Network = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
PortList = typing.Iterable[int]
class Packet:
def __init__(self, ip_src: Address, ip_dst: Address, l4_sport: int, l4_dport: int):
self.ip_src = ip_src
self.ip_dst = ip_dst
self.l4_sport = l4_sport
self.l4_dport = l4_dport
def reverse(self):
return Packet(
ip_src=self.ip_dst,
l4_sport=self.l4_dport,
ip_dst=self.ip_src,
l4_dport=self.l4_sport,
)
def hash_data(self, use_l4_port: bool = False) -> bytes:
data = self.ip_src.packed + self.ip_dst.packed
if use_l4_port:
data += struct.pack(">H", self.l4_sport)
data += struct.pack(">H", self.l4_dport)
return data
class TrafficTemplate:
def __init__(
self,
ip_src: Network,
ip_dst: Network,
l4_sport_range: PortList,
l4_dport_range: PortList,
):
self.ip_src = ip_src
self.ip_dst = ip_dst
self.l4_sport_range = l4_sport_range
self.l4_dport_range = l4_dport_range
def __iter__(self) -> typing.Iterator[Packet]:
for ip_src in self.ip_src.hosts():
for ip_dst in self.ip_dst.hosts():
if ip_src == ip_dst:
continue
for sport in self.l4_sport_range:
for dport in self.l4_dport_range:
yield Packet(ip_src, ip_dst, sport, dport)
class RSSAlgo:
def __init__(
self,
queues_count: int,
key: bytes,
reta_size: int,
use_l4_port: bool,
):
self.queues_count = queues_count
self.reta = tuple(i % queues_count for i in range(reta_size))
self.key = key
self.use_l4_port = use_l4_port
def toeplitz_hash(self, data: bytes) -> int:
# see rte_softrss_* in lib/hash/rte_thash.h
hash_value = ctypes.c_uint32(0)
for i, byte in enumerate(data):
for j in range(8):
bit = (byte >> (7 - j)) & 0x01
if bit == 1:
keyword = ctypes.c_uint32(0)
keyword.value |= self.key[i] << 24
keyword.value |= self.key[i + 1] << 16
keyword.value |= self.key[i + 2] << 8
keyword.value |= self.key[i + 3]
if j > 0:
keyword.value <<= j
keyword.value |= self.key[i + 4] >> (8 - j)
hash_value.value ^= keyword.value
return hash_value.value
def get_queue_index(self, packet: Packet) -> int:
bytes_to_hash = packet.hash_data(self.use_l4_port)
# get the 32bit hash of the packet
hash_value = self.toeplitz_hash(bytes_to_hash)
# determine the offset in the redirection table
offset = hash_value & (len(self.reta) - 1)
return self.reta[offset]
def balanced_traffic(
algo: RSSAlgo,
traffic_template: TrafficTemplate,
check_reverse_traffic: bool = False,
all_flows: bool = False,
) -> typing.Iterator[typing.Tuple[int, int, Packet]]:
queues = set()
if check_reverse_traffic:
queues_reverse = set()
for pkt in traffic_template:
q = algo.get_queue_index(pkt)
# check if q is already filled
if not all_flows and q in queues:
continue
qr = algo.get_queue_index(pkt.reverse())
if check_reverse_traffic:
# check if q is already filled
if not all_flows and qr in queues_reverse:
continue
# mark this queue as matched
queues_reverse.add(qr)
# mark this queue as filled
queues.add(q)
yield q, qr, pkt
# stop when all queues have been filled
if not all_flows and len(queues) == algo.queues_count:
break
NO_PORT = (0,)
class DriverInfo:
def __init__(self, key: bytes = None, reta_size: int = None):
self.__key = key
self.__reta_size = reta_size
def rss_key(self) -> bytes:
return self.__key
def reta_size(self, num_queues: int) -> int:
return self.__reta_size
class MlxDriverInfo(DriverInfo):
def rss_key(self) -> bytes:
return bytes(
(
# fmt: off
# rss_hash_default_key, see drivers/net/mlx5/mlx5_rxq.c
0x2c, 0xc6, 0x81, 0xd1, 0x5b, 0xdb, 0xf4, 0xf7,
0xfc, 0xa2, 0x83, 0x19, 0xdb, 0x1a, 0x3e, 0x94,
0x6b, 0x9e, 0x38, 0xd9, 0x2c, 0x9c, 0x03, 0xd1,
0xad, 0x99, 0x44, 0xa7, 0xd9, 0x56, 0x3d, 0x59,
0x06, 0x3c, 0x25, 0xf3, 0xfc, 0x1f, 0xdc, 0x2a,
# fmt: on
)
)
def reta_size(self, num_queues: int) -> int:
if num_queues & (num_queues - 1) == 0:
# If the requested number of RX queues is power of two,
# use a table of this size.
return num_queues
# otherwise, use the maximum table size
return 512
DEFAULT_DRIVERS = {
"cnxk": DriverInfo(
key=bytes(
(
# fmt: off
# roc_nix_rss_key_default_fill, see drivers/common/cnxk/roc_nix_rss.c
# Marvell cnxk NICs take 48 bytes keys
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
# fmt: on
)
),
reta_size=64,
),
"intel": DriverInfo(
key=bytes(
(
# fmt: off
# rss_intel_key, see drivers/net/ixgbe/ixgbe_rxtx.c
0x6d, 0x5a, 0x56, 0xda, 0x25, 0x5b, 0x0e, 0xc2,
0x41, 0x67, 0x25, 0x3d, 0x43, 0xa3, 0x8f, 0xb0,
0xd0, 0xca, 0x2b, 0xcb, 0xae, 0x7b, 0x30, 0xb4,
0x77, 0xcb, 0x2d, 0xa3, 0x80, 0x30, 0xf2, 0x0c,
0x6a, 0x42, 0xb7, 0x3b, 0xbe, 0xac, 0x01, 0xfa,
# fmt: on
)
),
reta_size=128,
),
"i40e": DriverInfo(
key=bytes(
(
# fmt: off
# rss_key_default, see drivers/net/i40e/i40e_ethdev.c
# i40e is the only driver that takes 52 bytes keys
0x44, 0x39, 0x79, 0x6b, 0xb5, 0x4c, 0x50, 0x23,
0xb6, 0x75, 0xea, 0x5b, 0x12, 0x4f, 0x9f, 0x30,
0xb8, 0xa2, 0xc0, 0x3d, 0xdf, 0xdc, 0x4d, 0x02,
0xa0, 0x8c, 0x9b, 0x33, 0x4a, 0xf6, 0x4a, 0x4c,
0x05, 0xc6, 0xfa, 0x34, 0x39, 0x58, 0xd8, 0x55,
0x7d, 0x99, 0x58, 0x3a, 0xe1, 0x38, 0xc9, 0x2e,
0x81, 0x15, 0x03, 0x66,
# fmt: on
)
),
reta_size=512,
),
"mlx": MlxDriverInfo(),
}
def port_range(value):
try:
if "-" in value:
start, stop = value.split("-")
res = tuple(range(int(start), int(stop)))
else:
res = (int(value),)
return res or NO_PORT
except ValueError as e:
raise argparse.ArgumentTypeError(str(e)) from e
def positive_int(value):
try:
i = int(value)
if i <= 0:
raise argparse.ArgumentTypeError("must be strictly positive")
return i
except ValueError as e:
raise argparse.ArgumentTypeError(str(e)) from e
def power_of_two(value):
i = positive_int(value)
if i & (i - 1) != 0:
raise argparse.ArgumentTypeError("must be a power of two")
return i
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"rx_queues",
metavar="RX_QUEUES",
type=positive_int,
help="""
The number of RX queues to fill.
""",
)
parser.add_argument(
"ip_src",
metavar="SRC",
type=ipaddress.ip_network,
help="""
The source IP network/address.
""",
)
parser.add_argument(
"ip_dst",
metavar="DST",
type=ipaddress.ip_network,
help="""
The destination IP network/address.
""",
)
parser.add_argument(
"-s",
"--sport-range",
type=port_range,
default=NO_PORT,
help="""
The layer 4 (TCP/UDP) source port range.
Can be a single fixed value or a range <start>-<end>.
""",
)
parser.add_argument(
"-d",
"--dport-range",
type=port_range,
default=NO_PORT,
help="""
The layer 4 (TCP/UDP) destination port range.
Can be a single fixed value or a range <start>-<end>.
""",
)
parser.add_argument(
"-r",
"--check-reverse-traffic",
action="store_true",
help="""
The reversed traffic (source <-> dest) should also be evenly balanced
in the queues.
""",
)
parser.add_argument(
"-k",
"--rss-key",
default="intel",
help=f"""
The random key used to compute the RSS hash. This option
supports either a well-known name or the hex value of the key
(well-known names: {', '.join(DEFAULT_DRIVERS)}, default: intel).
""",
)
parser.add_argument(
"-t",
"--reta-size",
type=power_of_two,
help="""
Size of the redirection table or "RETA" (default: depends on driver if
using a well-known driver name, otherwise 128).
""",
)
parser.add_argument(
"-a",
"--all-flows",
action="store_true",
help="""
Output ALL flows that can be created based on source and destination
address/port ranges along their matched queue number. ATTENTION: this
option can produce very long outputs depending on the address and port
range sizes.
""",
)
parser.add_argument(
"-j",
"--json",
action="store_true",
help="""
Output in parseable JSON format.
""",
)
parser.add_argument(
"-i",
"--info",
action="store_true",
help="""
Print RETA size and RSS key above the results. Not available with --json.
""",
)
args = parser.parse_args()
if args.ip_src.version != args.ip_dst.version:
parser.error(
f"{args.ip_src} and {args.ip_dst} don't have the same protocol version"
)
if args.json and args.info:
parser.error("--json and --info are mutually exclusive")
if args.rss_key in DEFAULT_DRIVERS:
driver_info = DEFAULT_DRIVERS[args.rss_key]
else:
try:
key = binascii.unhexlify(args.rss_key)
except (TypeError, ValueError) as e:
parser.error(f"RSS_KEY: {e}")
driver_info = DriverInfo(key=key, reta_size=128)
if args.reta_size is None:
args.reta_size = driver_info.reta_size(args.rx_queues)
if args.reta_size < args.rx_queues:
parser.error("RETA_SIZE must be greater than or equal to RX_QUEUES")
args.rss_key = driver_info.rss_key()
return args
def main():
args = parse_args()
use_l4_port = args.sport_range != NO_PORT or args.dport_range != NO_PORT
algo = RSSAlgo(
queues_count=args.rx_queues,
key=args.rss_key,
reta_size=args.reta_size,
use_l4_port=use_l4_port,
)
template = TrafficTemplate(
args.ip_src,
args.ip_dst,
args.sport_range,
args.dport_range,
)
results = balanced_traffic(
algo, template, args.check_reverse_traffic, args.all_flows
)
if args.json:
flows = []
for q, qr, pkt in results:
flows.append(
{
"queue": q,
"queue_reverse": qr,
"src_ip": str(pkt.ip_src),
"dst_ip": str(pkt.ip_dst),
"src_port": pkt.l4_sport,
"dst_port": pkt.l4_dport,
}
)
print(json.dumps(flows, indent=2))
return
if use_l4_port:
header = ["SRC_IP", "SPORT", "DST_IP", "DPORT", "QUEUE"]
else:
header = ["SRC_IP", "DST_IP", "QUEUE"]
if args.check_reverse_traffic:
header.append("QUEUE_REVERSE")
rows = [tuple(header)]
widths = [len(h) for h in header]
for q, qr, pkt in results:
if use_l4_port:
row = [pkt.ip_src, pkt.l4_sport, pkt.ip_dst, pkt.l4_dport, q]
else:
row = [pkt.ip_src, pkt.ip_dst, q]
if args.check_reverse_traffic:
row.append(qr)
cells = []
for i, r in enumerate(row):
r = str(r)
if len(r) > widths[i]:
widths[i] = len(r)
cells.append(r)
rows.append(tuple(cells))
if args.info:
print(f"RSS key: {binascii.hexlify(args.rss_key).decode()}")
print(f"RETA size: {args.reta_size}")
print()
fmt = [f"%-{w}s" for w in widths]
fmt[-1] = "%s" # avoid trailing whitespace
fmt = " ".join(fmt)
for row in rows:
print(fmt % row)
if __name__ == "__main__":
main()