Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hostcfgd] Move hostcfgd back to ConfigDBConnector for subscribing to updates #10168

Merged
merged 13 commits into from
Apr 7, 2022
Merged
175 changes: 85 additions & 90 deletions src/sonic-host-services/scripts/hostcfgd
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import signal

import jinja2
from sonic_py_common import device_info
from swsscommon.swsscommon import SubscriberStateTable, DBConnector, Select
from swsscommon.swsscommon import ConfigDBConnector, TableConsumable, Table
from swsscommon.swsscommon import ConfigDBConnector, DBConnector, Table

# FILE
PAM_AUTH_CONF = "/etc/pam.d/common-auth-sonic"
Expand Down Expand Up @@ -207,21 +206,23 @@ class FeatureHandler(object):
else:
self.resync_feature_state(self._cached_config[feature_name])

def sync_state_field(self):
def sync_state_field(self, feature_table):
"""
Summary:
Updates the state field in the FEATURE|* tables as the state field
might have to be rendered based on DEVICE_METADATA table
"""
feature_table = self._config_db.get_table('FEATURE')
for feature_name in feature_table.keys():
if not feature_name:
syslog.syslog(syslog.LOG_WARNING, "Feature is None")
continue

feature = Feature(feature_name, feature_table[feature_name], self._device_config)
if not feature.compare_state(feature_name, feature_table.get(feature_name, {})):
self.resync_feature_state(feature)

self._cached_config.setdefault(feature_name, feature)
self.update_feature_auto_restart(feature, feature_name)
self.update_feature_state(feature)
self.resync_feature_state(feature)

def update_feature_state(self, feature):
cached_feature = self._cached_config[feature.name]
Expand Down Expand Up @@ -406,6 +407,10 @@ class Iptables(object):
'''
return (isinstance(key, tuple))

def load(self, lpbk_table):
for row in lpbk_table:
self.iptables_handler(row, lpbk_table[row])

def command(self, chain, ip, ver, op):
cmd = 'iptables' if ver == '4' else 'ip6tables'
cmd += ' -t mangle --{} {} -p tcp --tcp-flags SYN SYN'.format(op, chain)
Expand Down Expand Up @@ -890,15 +895,13 @@ class KdumpCfg(object):
memory = self.kdump_defaults["memory"]
if data.get("memory") is not None:
memory = data.get("memory")
if data.get("memory") is not None:
run_cmd("sonic-kdump-config --memory " + memory)
run_cmd("sonic-kdump-config --memory " + memory)

# Num dumps
num_dumps = self.kdump_defaults["num_dumps"]
if data.get("num_dumps") is not None:
num_dumps = data.get("num_dumps")
if data.get("num_dumps") is not None:
run_cmd("sonic-kdump-config --num_dumps " + num_dumps)
run_cmd("sonic-kdump-config --num_dumps " + num_dumps)

class NtpCfg(object):
"""
Expand All @@ -912,6 +915,15 @@ class NtpCfg(object):
self.ntp_global = {}
self.ntp_servers = set()

def load(self, ntp_global_conf, ntp_server_conf):
syslog.syslog(syslog.LOG_INFO, "NtpCfg load ...")

for row in ntp_global_conf:
self.ntp_global_update(row, ntp_global_conf[row], is_load=True)

# Force reload on init
self.ntp_server_update(0, None, is_load=True)

def handle_ntp_source_intf_chg(self, intf_name):
# if no ntp server configured, do nothing
if not self.ntp_servers:
Expand All @@ -925,7 +937,7 @@ class NtpCfg(object):
cmd = 'systemctl restart ntp-config'
run_cmd(cmd)

def ntp_global_update(self, key, data):
def ntp_global_update(self, key, data, is_load=False):
syslog.syslog(syslog.LOG_INFO, 'NTP GLOBAL Update')
orig_src = self.ntp_global.get('src_intf', '')
orig_src_set = set(orig_src.split(";"))
Expand All @@ -938,6 +950,9 @@ class NtpCfg(object):
# Update the Local Cache
self.ntp_global = data

# If initial load don't restart daemon
if is_load: return

# check if ntp server configured, if not, do nothing
if not self.ntp_servers:
syslog.syslog(syslog.LOG_INFO, "No ntp server when global config change, do nothing")
Expand All @@ -954,16 +969,19 @@ class NtpCfg(object):
cmd = 'service ntp restart'
run_cmd(cmd)

def ntp_server_update(self, key, op):
def ntp_server_update(self, key, op, is_load=False):
syslog.syslog(syslog.LOG_INFO, 'ntp server update key {}'.format(key))

restart_config = False
if op == "SET" and key not in self.ntp_servers:
restart_config = True
self.ntp_servers.add(key)
elif op == "DEL" and key in self.ntp_servers:
if not is_load:
if op == "SET" and key not in self.ntp_servers:
restart_config = True
self.ntp_servers.add(key)
elif op == "DEL" and key in self.ntp_servers:
restart_config = True
self.ntp_servers.remove(key)
else:
restart_config = True
self.ntp_servers.remove(key)

if restart_config:
cmd = 'systemctl restart ntp-config'
Expand Down Expand Up @@ -1034,31 +1052,24 @@ class HostConfigDaemon:
# before moving forward
self.config_db = ConfigDBConnector()
self.config_db.connect(wait_for_init=True, retry_on=True)
self.dbconn = DBConnector(CFG_DB, 0)
self.state_db_conn = DBConnector(STATE_DB, 0)
self.selector = Select()
syslog.syslog(syslog.LOG_INFO, 'ConfigDB connect success')

self.select = Select()
self.callbacks = dict()
self.subscriber_map = dict()

feature_state_table = Table(self.state_db_conn, 'FEATURE')

# Load DEVICE metadata configurations
self.device_config = {}
self.device_config['DEVICE_METADATA'] = self.config_db.get_table('DEVICE_METADATA')

# Load feature state table
self.state_db_conn = DBConnector(STATE_DB, 0)
feature_state_table = Table(self.state_db_conn, 'FEATURE')

# Initialize KDump Config and set the config to default if nothing is provided
self.kdumpCfg = KdumpCfg(self.config_db)
self.kdumpCfg.load(self.config_db.get_table('KDUMP'))

# Initialize IpTables
self.iptables = Iptables()

# Intialize Feature Handler
self.feature_handler = FeatureHandler(self.config_db, feature_state_table, self.device_config)
self.feature_handler.sync_state_field()

# Initialize Ntp Config Handler
self.ntpcfg = NtpCfg()
Expand All @@ -1073,21 +1084,28 @@ class HostConfigDaemon:
self.pamLimitsCfg = PamLimitsCfg(self.config_db)
self.pamLimitsCfg.update_config_file()

def load(self):
aaa = self.config_db.get_table('AAA')
tacacs_global = self.config_db.get_table('TACPLUS')
tacacs_server = self.config_db.get_table('TACPLUS_SERVER')
radius_global = self.config_db.get_table('RADIUS')
radius_server = self.config_db.get_table('RADIUS_SERVER')
def load(self, init_data):
features = init_data['FEATURE']
aaa = init_data['AAA']
tacacs_global = init_data['TACPLUS']
tacacs_server = init_data['TACPLUS_SERVER']
radius_global = init_data['RADIUS']
radius_server = init_data['RADIUS_SERVER']
lpbk_table = init_data['LOOPBACK_INTERFACE']
ntp_server = init_data['NTP_SERVER']
ntp_global = init_data['NTP']
kdump = init_data['KDUMP']

self.feature_handler.sync_state_field(features)
self.aaacfg.load(aaa, tacacs_global, tacacs_server, radius_global, radius_server)
self.iptables.load(lpbk_table)
self.ntpcfg.load(ntp_global, ntp_server)
self.kdumpCfg.load(kdump)

try:
dev_meta = self.config_db.get_table('DEVICE_METADATA')
if 'localhost' in dev_meta:
if 'hostname' in dev_meta['localhost']:
self.hostname_cache = dev_meta['localhost']['hostname']
except Exception as e:
pass
dev_meta = self.config_db.get_table('DEVICE_METADATA')
if 'localhost' in dev_meta:
if 'hostname' in dev_meta['localhost']:
self.hostname_cache = dev_meta['localhost']['hostname']

# Update AAA with the hostname
self.aaacfg.hostname_update(self.hostname_cache)
Expand Down Expand Up @@ -1181,68 +1199,46 @@ class HostConfigDaemon:
systemctl_cmd = "sudo systemctl is-system-running --wait --quiet"
subprocess.call(systemctl_cmd, shell=True)

def subscribe(self, table, callback, pri):
try:
if table not in self.callbacks:
self.callbacks[table] = []
subscriber = SubscriberStateTable(self.dbconn, table, TableConsumable.DEFAULT_POP_BATCH_SIZE, pri)
self.selector.addSelectable(subscriber) # Add to the Selector
self.subscriber_map[subscriber.getFd()] = (subscriber, table) # Maintain a mapping b/w subscriber & fd
def register_callbacks(self):

self.callbacks[table].append(callback)
except Exception as err:
syslog.syslog(syslog.LOG_ERR, "Subscribe to table {} failed with error {}".format(table, err))
def make_callback(func):
def callback(table, key, data):
if data is None:
op = "DEL"
else:
op = "SET"
return func(key, op, data)
return callback

def register_callbacks(self):
self.subscribe('KDUMP', lambda table, key, op, data: self.kdump_handler(key, op, data), HOSTCFGD_MAX_PRI)
self.config_db.subscribe('KDUMP', make_callback(self.kdump_handler))
# Handle FEATURE updates before other tables
self.subscribe('FEATURE', lambda table, key, op, data: self.feature_handler.handle(key, op, data), HOSTCFGD_MAX_PRI-1)
self.config_db.subscribe('FEATURE', make_callback(self.feature_handler.handle))
# Handle AAA, TACACS and RADIUS related tables
self.subscribe('AAA', lambda table, key, op, data: self.aaa_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.subscribe('TACPLUS', lambda table, key, op, data: self.tacacs_global_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.subscribe('TACPLUS_SERVER', lambda table, key, op, data: self.tacacs_server_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.subscribe('RADIUS', lambda table, key, op, data: self.radius_global_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.subscribe('RADIUS_SERVER', lambda table, key, op, data: self.radius_server_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.config_db.subscribe('AAA', make_callback(self.aaa_handler))
self.config_db.subscribe('TACPLUS', make_callback(self.tacacs_global_handler))
self.config_db.subscribe('TACPLUS_SERVER', make_callback(self.tacacs_server_handler))
self.config_db.subscribe('RADIUS', make_callback(self.radius_global_handler))
self.config_db.subscribe('RADIUS_SERVER', make_callback(self.radius_server_handler))
# Handle IPTables configuration
self.subscribe('LOOPBACK_INTERFACE', lambda table, key, op, data: self.lpbk_handler(key, op, data), HOSTCFGD_MAX_PRI-3)
self.config_db.subscribe('LOOPBACK_INTERFACE', make_callback(self.lpbk_handler))
# Handle NTP & NTP_SERVER updates
self.subscribe('NTP', lambda table, key, op, data: self.ntp_global_handler(key, op, data), HOSTCFGD_MAX_PRI-4)
self.subscribe('NTP_SERVER', lambda table, key, op, data: self.ntp_server_handler(key, op, data), HOSTCFGD_MAX_PRI-4)
self.config_db.subscribe('NTP', make_callback(self.ntp_global_handler))
self.config_db.subscribe('NTP_SERVER', make_callback(self.ntp_server_handler))
# Handle updates to src intf changes in radius
self.subscribe('MGMT_INTERFACE', lambda table, key, op, data: self.mgmt_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)
self.subscribe('VLAN_INTERFACE', lambda table, key, op, data: self.vlan_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)
self.subscribe('VLAN_SUB_INTERFACE', lambda table, key, op, data: self.vlan_sub_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)
self.subscribe('PORTCHANNEL_INTERFACE', lambda table, key, op, data: self.portchannel_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)
self.subscribe('INTERFACE', lambda table, key, op, data: self.phy_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)

self.config_db.subscribe('MGMT_INTERFACE', make_callback(self.mgmt_intf_handler))
self.config_db.subscribe('VLAN_INTERFACE', make_callback(self.vlan_intf_handler))
self.config_db.subscribe('VLAN_SUB_INTERFACE', make_callback(self.vlan_sub_intf_handler))
self.config_db.subscribe('PORTCHANNEL_INTERFACE', make_callback(self.portchannel_intf_handler))
self.config_db.subscribe('INTERFACE', make_callback(self.phy_intf_handler))
syslog.syslog(syslog.LOG_INFO,
"Waiting for systemctl to finish initialization")
self.wait_till_system_init_done()
syslog.syslog(syslog.LOG_INFO,
"systemctl has finished initialization -- proceeding ...")

def start(self):
while True:
state, selectable_ = self.selector.select(DEFAULT_SELECT_TIMEOUT)
if state == self.selector.TIMEOUT:
continue
elif state == self.selector.ERROR:
syslog.syslog(syslog.LOG_ERR,
"error returned by select")
continue

fd = selectable_.getFd()
# Get the Corresponding subscriber & table
subscriber, table = self.subscriber_map.get(fd, (None, ""))
if not subscriber:
syslog.syslog(syslog.LOG_ERR,
"No Subscriber object found for fd: {}, subscriber map: {}".format(fd, subscriber_map))
continue
key, op, fvs = subscriber.pop()
# Get the registered callback
cbs = self.callbacks.get(table, None)
for callback in cbs:
callback(table, key, op, dict(fvs))
self.config_db.listen(init_data_handler=self.load)


def main():
Expand All @@ -1251,7 +1247,6 @@ def main():
signal.signal(signal.SIGHUP, signal_handler)
daemon = HostConfigDaemon()
daemon.register_callbacks()
daemon.load()
daemon.start()

if __name__ == "__main__":
Expand Down
74 changes: 7 additions & 67 deletions src/sonic-host-services/tests/common/mock_configdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ class MockConfigDb(object):
"""
STATE_DB = None
CONFIG_DB = None
event_queue = []

def __init__(self, **kwargs):
pass
self.handlers = {}

@staticmethod
def set_config_db(test_config_db):
Expand Down Expand Up @@ -44,73 +45,12 @@ def set_entry(self, key, field, data):
def get_table(self, table_name):
return MockConfigDb.CONFIG_DB[table_name]

def subscribe(self, table_name, callback):
self.handlers[table_name] = callback

class MockSelect():

event_queue = []

@staticmethod
def set_event_queue(Q):
MockSelect.event_queue = Q

@staticmethod
def get_event_queue():
return MockSelect.event_queue

@staticmethod
def reset_event_queue():
MockSelect.event_queue = []

def __init__(self):
self.sub_map = {}
self.TIMEOUT = "TIMEOUT"
self.ERROR = "ERROR"

def addSelectable(self, subscriber):
self.sub_map[subscriber.table] = subscriber

def select(self, TIMEOUT):
if not MockSelect.get_event_queue():
raise TimeoutError
table, key = MockSelect.get_event_queue().pop(0)
self.sub_map[table].nextKey(key)
return "OBJECT", self.sub_map[table]


class MockSubscriberStateTable():

FD_INIT = 0

@staticmethod
def generate_fd():
curr = MockSubscriberStateTable.FD_INIT
MockSubscriberStateTable.FD_INIT = curr + 1
return curr

@staticmethod
def reset_fd():
MockSubscriberStateTable.FD_INIT = 0

def __init__(self, conn, table, pop, pri):
self.fd = MockSubscriberStateTable.generate_fd()
self.next_key = ''
self.table = table

def getFd(self):
return self.fd

def nextKey(self, key):
self.next_key = key

def pop(self):
table = MockConfigDb.CONFIG_DB.get(self.table, {})
if self.next_key not in table:
op = "DEL"
fvs = {}
else:
op = "SET"
fvs = table.get(self.next_key, {})
return self.next_key, op, fvs
def listen(self, init_data_handler=None):
for e in MockConfigDb.event_queue:
self.handlers[e[0]](e[0], e[1], self.get_entry(e[0], e[1]))


class MockDBConnector():
Expand Down
Loading