Init: mediaserver

This commit is contained in:
2023-02-08 12:13:28 +01:00
parent 848bc9739c
commit f7c23d4ba9
31914 changed files with 6175775 additions and 0 deletions

View File

@@ -0,0 +1,884 @@
# Based on the ssh connection plugin by Michael DeHaan
#
# Copyright: (c) 2018, Pat Sharkey <psharkey@cleo.com>
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
DOCUMENTATION = '''
author:
- Pat Sharkey (@psharkey) <psharkey@cleo.com>
- HanumanthaRao MVL (@hanumantharaomvl) <hanumanth@flux7.com>
- Gaurav Ashtikar (@gau1991) <gaurav.ashtikar@flux7.com>
name: aws_ssm
short_description: execute via AWS Systems Manager
description:
- This connection plugin allows ansible to execute tasks on an EC2 instance via the aws ssm CLI.
requirements:
- The remote EC2 instance must be running the AWS Systems Manager Agent (SSM Agent).
- The control machine must have the aws session manager plugin installed.
- The remote EC2 linux instance must have the curl installed.
options:
access_key_id:
description: The STS access key to use when connecting via session-manager.
vars:
- name: ansible_aws_ssm_access_key_id
version_added: 1.3.0
secret_access_key:
description: The STS secret key to use when connecting via session-manager.
vars:
- name: ansible_aws_ssm_secret_access_key
version_added: 1.3.0
session_token:
description: The STS session token to use when connecting via session-manager.
vars:
- name: ansible_aws_ssm_session_token
version_added: 1.3.0
instance_id:
description: The EC2 instance ID.
vars:
- name: ansible_aws_ssm_instance_id
region:
description: The region the EC2 instance is located.
vars:
- name: ansible_aws_ssm_region
default: 'us-east-1'
bucket_name:
description: The name of the S3 bucket used for file transfers.
vars:
- name: ansible_aws_ssm_bucket_name
plugin:
description: This defines the location of the session-manager-plugin binary.
vars:
- name: ansible_aws_ssm_plugin
default: '/usr/local/bin/session-manager-plugin'
profile:
description: Sets AWS profile to use.
vars:
- name: ansible_aws_ssm_profile
version_added: 1.5.0
reconnection_retries:
description: Number of attempts to connect.
default: 3
type: integer
vars:
- name: ansible_aws_ssm_retries
ssm_timeout:
description: Connection timeout seconds.
default: 60
type: integer
vars:
- name: ansible_aws_ssm_timeout
bucket_sse_mode:
description: Server-side encryption mode to use for uploads on the S3 bucket used for file transfer.
choices: [ 'AES256', 'aws:kms' ]
required: false
version_added: 2.2.0
vars:
- name: ansible_aws_ssm_bucket_sse_mode
bucket_sse_kms_key_id:
description: KMS key id to use when encrypting objects using C(bucket_sse_mode=aws:kms). Ignored otherwise.
version_added: 2.2.0
vars:
- name: ansible_aws_ssm_bucket_sse_kms_key_id
ssm_document:
description: SSM document to use when connecting.
vars:
- name: ansible_aws_ssm_document
version_added: 5.2.0
s3_addressing_style:
description:
- The addressing style to use when using S3 URLs.
- When the S3 bucket isn't in the same region as the Instance
explicitly setting the addressing style to 'virtual' may be necessary
U(https://repost.aws/knowledge-center/s3-http-307-response) as this forces
the use of a specific endpoint.
choices: [ 'path', 'virtual', 'auto' ]
default: 'auto'
version_added: 5.2.0
vars:
- name: ansible_aws_ssm_s3_addressing_style
'''
EXAMPLES = r'''
# Wait for SSM Agent to be available on the Instance
- name: Wait for connection to be available
vars:
ansible_connection: aws_ssm
ansible_aws_ssm_bucket_name: nameofthebucket
ansible_aws_ssm_region: us-west-2
# When the S3 bucket isn't in the same region as the Instance
# Explicitly setting the addressing style to 'virtual' may be necessary
# https://repost.aws/knowledge-center/s3-http-307-response
ansible_aws_ssm_s3_addressing_style: virtual
tasks:
- name: Wait for connection
wait_for_connection:
# Stop Spooler Process on Windows Instances
- name: Stop Spooler Service on Windows Instances
vars:
ansible_connection: aws_ssm
ansible_shell_type: powershell
ansible_aws_ssm_bucket_name: nameofthebucket
ansible_aws_ssm_region: us-east-1
tasks:
- name: Stop spooler service
win_service:
name: spooler
state: stopped
# Install a Nginx Package on Linux Instance
- name: Install a Nginx Package
vars:
ansible_connection: aws_ssm
ansible_aws_ssm_bucket_name: nameofthebucket
ansible_aws_ssm_region: us-west-2
tasks:
- name: Install a Nginx Package
yum:
name: nginx
state: present
# Create a directory in Windows Instances
- name: Create a directory in Windows Instance
vars:
ansible_connection: aws_ssm
ansible_shell_type: powershell
ansible_aws_ssm_bucket_name: nameofthebucket
ansible_aws_ssm_region: us-east-1
tasks:
- name: Create a Directory
win_file:
path: C:\Windows\temp
state: directory
# Making use of Dynamic Inventory Plugin
# =======================================
# aws_ec2.yml (Dynamic Inventory - Linux)
# This will return the Instance IDs matching the filter
#plugin: aws_ec2
#regions:
# - us-east-1
#hostnames:
# - instance-id
#filters:
# tag:SSMTag: ssmlinux
# -----------------------
- name: install aws-cli
hosts: all
gather_facts: false
vars:
ansible_connection: aws_ssm
ansible_aws_ssm_bucket_name: nameofthebucket
ansible_aws_ssm_region: us-east-1
tasks:
- name: aws-cli
raw: yum install -y awscli
tags: aws-cli
# Execution: ansible-playbook linux.yaml -i aws_ec2.yml
# The playbook tasks will get executed on the instance ids returned from the dynamic inventory plugin using ssm connection.
# =====================================================
# aws_ec2.yml (Dynamic Inventory - Windows)
#plugin: aws_ec2
#regions:
# - us-east-1
#hostnames:
# - instance-id
#filters:
# tag:SSMTag: ssmwindows
# -----------------------
- name: Create a dir.
hosts: all
gather_facts: false
vars:
ansible_connection: aws_ssm
ansible_shell_type: powershell
ansible_aws_ssm_bucket_name: nameofthebucket
ansible_aws_ssm_region: us-east-1
tasks:
- name: Create the directory
win_file:
path: C:\Temp\SSM_Testing5
state: directory
# Execution: ansible-playbook win_file.yaml -i aws_ec2.yml
# The playbook tasks will get executed on the instance ids returned from the dynamic inventory plugin using ssm connection.
# Install a Nginx Package on Linux Instance; with specific SSE for file transfer
- name: Install a Nginx Package
vars:
ansible_connection: aws_ssm
ansible_aws_ssm_bucket_name: nameofthebucket
ansible_aws_ssm_region: us-west-2
ansible_aws_ssm_bucket_sse_mode: 'aws:kms'
ansible_aws_ssm_bucket_sse_kms_key_id: alias/kms-key-alias
tasks:
- name: Install a Nginx Package
yum:
name: nginx
state: present
# Install a Nginx Package on Linux Instance; with dedicated SSM document
- name: Install a Nginx Package
vars:
ansible_connection: aws_ssm
ansible_aws_ssm_bucket_name: nameofthebucket
ansible_aws_ssm_region: us-west-2
ansible_aws_ssm_document: nameofthecustomdocument
tasks:
- name: Install a Nginx Package
yum:
name: nginx
state: present
'''
import os
import getpass
import json
import pty
import random
import re
import select
import string
import subprocess
import time
try:
import boto3
from botocore.client import Config
except ImportError as e:
pass
from functools import wraps
from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3
from ansible.errors import AnsibleConnectionFailure, AnsibleError, AnsibleFileNotFound
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.six.moves import xrange
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.plugins.connection import ConnectionBase
from ansible.plugins.shell.powershell import _common_args
from ansible.utils.display import Display
display = Display()
def _ssm_retry(func):
"""
Decorator to retry in the case of a connection failure
Will retry if:
* an exception is caught
Will not retry if
* remaining_tries is <2
* retries limit reached
"""
@wraps(func)
def wrapped(self, *args, **kwargs):
remaining_tries = int(self.get_option('reconnection_retries')) + 1
cmd_summary = f"{args[0]}..."
for attempt in range(remaining_tries):
try:
return_tuple = func(self, *args, **kwargs)
self._vvvv(f"ssm_retry: (success) {to_text(return_tuple)}")
break
except (AnsibleConnectionFailure, Exception) as e:
if attempt == remaining_tries - 1:
raise
pause = 2 ** attempt - 1
pause = min(pause, 30)
if isinstance(e, AnsibleConnectionFailure):
msg = f"ssm_retry: attempt: {attempt}, cmd ({cmd_summary}), pausing for {pause} seconds"
else:
msg = f"ssm_retry: attempt: {attempt}, caught exception({e}) from cmd ({cmd_summary}), pausing for {pause} seconds"
self._vv(msg)
time.sleep(pause)
# Do not attempt to reuse the existing session on retries
# This will cause the SSM session to be completely restarted,
# as well as reinitializing the boto3 clients
self.close()
continue
return return_tuple
return wrapped
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
class Connection(ConnectionBase):
''' AWS SSM based connections '''
transport = 'community.aws.aws_ssm'
allow_executable = False
allow_extras = True
has_pipelining = False
is_windows = False
_client = None
_s3_client = None
_session = None
_stdout = None
_session_id = ''
_timeout = False
MARK_LENGTH = 26
def _display(self, f, message):
if self.host:
host_args = {"host": self.host}
else:
host_args = {}
f(to_text(message), **host_args)
def _v(self, message):
self._display(display.v, message)
def _vv(self, message):
self._display(display.vv, message)
def _vvv(self, message):
self._display(display.vvv, message)
def _vvvv(self, message):
self._display(display.vvvv, message)
def _get_bucket_endpoint(self):
# Fetch the correct S3 endpoint for use with our bucket.
# If we don't explicitly set the endpoint then some commands will use the global
# endpoint and fail
# (new AWS regions and new buckets in a region other than the one we're running in)
region_name = self.get_option('region') or 'us-east-1'
profile_name = self.get_option('profile') or ''
self._vvvv("_get_bucket_endpoint: S3 (global)")
tmp_s3_client = self._get_boto_client(
's3', region_name=region_name, profile_name=profile_name,
)
# Fetch the location of the bucket so we can open a client against the 'right' endpoint
# This /should/ always work
bucket_location = tmp_s3_client.get_bucket_location(
Bucket=(self.get_option('bucket_name')),
)
bucket_region = bucket_location['LocationConstraint']
# Create another client for the region the bucket lives in, so we can nab the endpoint URL
self._vvvv(f"_get_bucket_endpoint: S3 (bucket region) - {bucket_region}")
s3_bucket_client = self._get_boto_client(
's3', region_name=bucket_region, profile_name=profile_name,
)
return s3_bucket_client.meta.endpoint_url, s3_bucket_client.meta.region_name
def _init_clients(self):
self._vvvv("INITIALIZE BOTO3 CLIENTS")
profile_name = self.get_option('profile') or ''
region_name = self.get_option('region')
# The SSM Boto client, currently used to initiate and manage the session
# Note: does not handle the actual SSM session traffic
self._vvvv("SETUP BOTO3 CLIENTS: SSM")
ssm_client = self._get_boto_client(
'ssm', region_name=region_name, profile_name=profile_name,
)
self._client = ssm_client
s3_endpoint_url, s3_region_name = self._get_bucket_endpoint()
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
s3_bucket_client = self._get_boto_client(
's3', region_name=s3_region_name, endpoint_url=s3_endpoint_url, profile_name=profile_name,
)
self._s3_client = s3_bucket_client
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not HAS_BOTO3:
raise AnsibleError(missing_required_lib("boto3"))
self.host = self._play_context.remote_addr
if getattr(self._shell, "SHELL_FAMILY", '') == 'powershell':
self.delegate = None
self.has_native_async = True
self.always_pipeline_modules = True
self.module_implementation_preferences = ('.ps1', '.exe', '')
self.protocol = None
self.shell_id = None
self._shell_type = 'powershell'
self.is_windows = True
def __del__(self):
self.close()
def _connect(self):
''' connect to the host via ssm '''
self._play_context.remote_user = getpass.getuser()
if not self._session_id:
self.start_session()
return self
def reset(self):
''' start a fresh ssm session '''
self._vvvv('reset called on ssm connection')
return self.start_session()
def start_session(self):
''' start ssm session '''
if self.get_option('instance_id') is None:
self.instance_id = self.host
else:
self.instance_id = self.get_option('instance_id')
self._vvv(f"ESTABLISH SSM CONNECTION TO: {self.instance_id}")
executable = self.get_option('plugin')
if not os.path.exists(to_bytes(executable, errors='surrogate_or_strict')):
raise AnsibleError(f"failed to find the executable specified {executable}.")
self._init_clients()
self._vvvv(f"START SSM SESSION: {self.instance_id}")
start_session_args = dict(Target=self.instance_id, Parameters={})
document_name = self.get_option('ssm_document')
if document_name is not None:
start_session_args['DocumentName'] = document_name
response = self._client.start_session(**start_session_args)
self._session_id = response['SessionId']
region_name = self.get_option('region')
profile_name = self.get_option('profile') or ''
cmd = [
executable,
json.dumps(response),
region_name,
"StartSession",
profile_name,
json.dumps({"Target": self.instance_id}),
self._client.meta.endpoint_url,
]
self._vvvv(f"SSM COMMAND: {to_text(cmd)}")
stdout_r, stdout_w = pty.openpty()
session = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=stdout_w,
stderr=subprocess.PIPE,
close_fds=True,
bufsize=0,
)
os.close(stdout_w)
self._stdout = os.fdopen(stdout_r, 'rb', 0)
self._session = session
self._poll_stdout = select.poll()
self._poll_stdout.register(self._stdout, select.POLLIN)
# Disable command echo and prompt.
self._prepare_terminal()
self._vvvv(f"SSM CONNECTION ID: {self._session_id}")
return session
@_ssm_retry
def exec_command(self, cmd, in_data=None, sudoable=True):
''' run a command on the ssm host '''
super().exec_command(cmd, in_data=in_data, sudoable=sudoable)
self._vvv(f"EXEC: {to_text(cmd)}")
session = self._session
mark_begin = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
if self.is_windows:
mark_start = mark_begin + " $LASTEXITCODE"
else:
mark_start = mark_begin
mark_end = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
# Wrap command in markers accordingly for the shell used
cmd = self._wrap_command(cmd, sudoable, mark_start, mark_end)
self._flush_stderr(session)
for chunk in chunks(cmd, 1024):
session.stdin.write(to_bytes(chunk, errors='surrogate_or_strict'))
# Read stdout between the markers
stdout = ''
win_line = ''
begin = False
stop_time = int(round(time.time())) + self.get_option('ssm_timeout')
while session.poll() is None:
remaining = stop_time - int(round(time.time()))
if remaining < 1:
self._timeout = True
self._vvvv(f"EXEC timeout stdout: \n{to_text(stdout)}")
raise AnsibleConnectionFailure(
f"SSM exec_command timeout on host: {self.instance_id}")
if self._poll_stdout.poll(1000):
line = self._filter_ansi(self._stdout.readline())
self._vvvv(f"EXEC stdout line: \n{to_text(line)}")
else:
self._vvvv(f"EXEC remaining: {remaining}")
continue
if not begin and self.is_windows:
win_line = win_line + line
line = win_line
if mark_start in line:
begin = True
if not line.startswith(mark_start):
stdout = ''
continue
if begin:
if mark_end in line:
self._vvvv(f"POST_PROCESS: \n{to_text(stdout)}")
returncode, stdout = self._post_process(stdout, mark_begin)
self._vvvv(f"POST_PROCESSED: \n{to_text(stdout)}")
break
stdout = stdout + line
stderr = self._flush_stderr(session)
return (returncode, stdout, stderr)
def _prepare_terminal(self):
''' perform any one-time terminal settings '''
# No windows setup for now
if self.is_windows:
return
# *_complete variables are 3 valued:
# - None: not started
# - False: started
# - True: complete
startup_complete = False
disable_echo_complete = None
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")
disable_prompt_complete = None
end_mark = "".join(
[random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)]
)
disable_prompt_cmd = to_bytes(
"PS1='' ; printf '\\n%s\\n' '" + end_mark + "'\n",
errors="surrogate_or_strict",
)
disable_prompt_reply = re.compile(
r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE
)
stdout = ""
# Custom command execution for when we're waiting for startup
stop_time = int(round(time.time())) + self.get_option("ssm_timeout")
while (not disable_prompt_complete) and (self._session.poll() is None):
remaining = stop_time - int(round(time.time()))
if remaining < 1:
self._timeout = True
self._vvvv(f"PRE timeout stdout: \n{to_bytes(stdout)}")
raise AnsibleConnectionFailure(
f"SSM start_session timeout on host: {self.instance_id}"
)
if self._poll_stdout.poll(1000):
stdout += to_text(self._stdout.read(1024))
self._vvvv(f"PRE stdout line: \n{to_bytes(stdout)}")
else:
self._vvvv(f"PRE remaining: {remaining}")
# wait til prompt is ready
if startup_complete is False:
match = str(stdout).find("Starting session with SessionId")
if match != -1:
self._vvvv("PRE startup output received")
startup_complete = True
# disable echo
if startup_complete and (disable_echo_complete is None):
self._vvvv(f"PRE Disabling Echo: {disable_echo_cmd}")
self._session.stdin.write(disable_echo_cmd)
disable_echo_complete = False
if disable_echo_complete is False:
match = str(stdout).find("stty -echo")
if match != -1:
disable_echo_complete = True
# disable prompt
if disable_echo_complete and disable_prompt_complete is None:
self._vvvv(f"PRE Disabling Prompt: \n{disable_prompt_cmd}")
self._session.stdin.write(disable_prompt_cmd)
disable_prompt_complete = False
if disable_prompt_complete is False:
match = disable_prompt_reply.search(stdout)
if match:
stdout = stdout[match.end():]
disable_prompt_complete = True
if not disable_prompt_complete:
raise AnsibleConnectionFailure(
f"SSM process closed during _prepare_terminal on host: {self.instance_id}"
)
self._vvvv("PRE Terminal configured")
def _wrap_command(self, cmd, sudoable, mark_start, mark_end):
''' wrap command so stdout and status can be extracted '''
if self.is_windows:
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
cmd = self._shell._encode_script(cmd, preserve_rc=True)
cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n"
else:
if sudoable:
cmd = "sudo " + cmd
cmd = (
f"printf '%s\\n' '{mark_start}';\n"
f"echo | {cmd};\n"
f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n"
)
self._vvvv(f"_wrap_command: \n'{to_text(cmd)}'")
return cmd
def _post_process(self, stdout, mark_begin):
''' extract command status and strip unwanted lines '''
if not self.is_windows:
# Get command return code
returncode = int(stdout.splitlines()[-2])
# Throw away final lines
for _x in range(0, 3):
stdout = stdout[:stdout.rfind('\n')]
return (returncode, stdout)
# Windows is a little more complex
# Value of $LASTEXITCODE will be the line after the mark
trailer = stdout[stdout.rfind(mark_begin):]
last_exit_code = trailer.splitlines()[1]
if last_exit_code.isdigit:
returncode = int(last_exit_code)
else:
returncode = -1
# output to keep will be before the mark
stdout = stdout[:stdout.rfind(mark_begin)]
# If it looks like JSON remove any newlines
if stdout.startswith('{'):
stdout = stdout.replace('\n', '')
return (returncode, stdout)
def _filter_ansi(self, line):
''' remove any ANSI terminal control codes '''
line = to_text(line)
if self.is_windows:
osc_filter = re.compile(r'\x1b\][^\x07]*\x07')
line = osc_filter.sub('', line)
ansi_filter = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]')
line = ansi_filter.sub('', line)
# Replace or strip sequence (at terminal width)
line = line.replace('\r\r\n', '\n')
if len(line) == 201:
line = line[:-1]
return line
def _flush_stderr(self, session_process):
''' read and return stderr with minimal blocking '''
poll_stderr = select.poll()
poll_stderr.register(session_process.stderr, select.POLLIN)
stderr = ''
while session_process.poll() is None:
if not poll_stderr.poll(1):
break
line = session_process.stderr.readline()
self._vvvv(f"stderr line: {to_text(line)}")
stderr = stderr + line
return stderr
def _get_url(self, client_method, bucket_name, out_path, http_method, extra_args=None):
''' Generate URL for get_object / put_object '''
client = self._s3_client
params = {'Bucket': bucket_name, 'Key': out_path}
if extra_args is not None:
params.update(extra_args)
return client.generate_presigned_url(client_method, Params=params, ExpiresIn=3600, HttpMethod=http_method)
def _get_boto_client(self, service, region_name=None, profile_name=None, endpoint_url=None):
''' Gets a boto3 client based on the STS token '''
aws_access_key_id = self.get_option('access_key_id')
aws_secret_access_key = self.get_option('secret_access_key')
aws_session_token = self.get_option('session_token')
if aws_access_key_id is None:
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID", None)
if aws_secret_access_key is None:
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
if aws_session_token is None:
aws_session_token = os.environ.get("AWS_SESSION_TOKEN", None)
if not profile_name:
profile_name = os.environ.get("AWS_PROFILE", None)
session_args = dict(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=region_name,
)
if profile_name:
session_args['profile_name'] = profile_name
session = boto3.session.Session(**session_args)
client = session.client(
service,
endpoint_url=endpoint_url,
config=Config(
signature_version="s3v4",
s3={'addressing_style': self.get_option('s3_addressing_style')}
)
)
return client
def _escape_path(self, path):
return path.replace("\\", "/")
def _generate_encryption_settings(self):
put_args = {}
put_headers = {}
if not self.get_option('bucket_sse_mode'):
return put_args, put_headers
put_args['ServerSideEncryption'] = self.get_option('bucket_sse_mode')
put_headers['x-amz-server-side-encryption'] = self.get_option('bucket_sse_mode')
if self.get_option('bucket_sse_mode') == 'aws:kms' and self.get_option('bucket_sse_kms_key_id'):
put_args['SSEKMSKeyId'] = self.get_option('bucket_sse_kms_key_id')
put_headers['x-amz-server-side-encryption-aws-kms-key-id'] = self.get_option('bucket_sse_kms_key_id')
return put_args, put_headers
def _generate_commands(self, bucket_name, s3_path, in_path, out_path):
put_args, put_headers = self._generate_encryption_settings()
put_url = self._get_url('put_object', bucket_name, s3_path, 'PUT', extra_args=put_args)
get_url = self._get_url('get_object', bucket_name, s3_path, 'GET')
if self.is_windows:
put_command_headers = "; ".join([f"'{h}' = '{v}'" for h, v in put_headers.items()])
put_command = (
"Invoke-WebRequest -Method PUT "
f"-Headers @{{{put_command_headers}}} " # @{'key' = 'value'; 'key2' = 'value2'}
f"-InFile '{in_path}' "
f"-Uri '{put_url}' "
f"-UseBasicParsing"
)
get_command = (
"Invoke-WebRequest "
f"'{get_url}' "
f"-OutFile '{out_path}'"
)
else:
put_command_headers = " ".join([f"-H '{h}: {v}'" for h, v in put_headers.items()])
put_command = (
"curl --request PUT "
f"{put_command_headers} "
f"--upload-file '{in_path}' "
f"'{put_url}'"
)
get_command = (
"curl "
f"-o '{out_path}' "
f"'{get_url}'"
)
return get_command, put_command, put_args
@_ssm_retry
def _file_transport_command(self, in_path, out_path, ssm_action):
''' transfer a file to/from host using an intermediate S3 bucket '''
bucket_name = self.get_option("bucket_name")
s3_path = self._escape_path(f"{self.instance_id}/{out_path}")
get_command, put_command, put_args = self._generate_commands(
bucket_name, s3_path, in_path, out_path,
)
client = self._s3_client
if ssm_action == 'get':
(returncode, stdout, stderr) = self.exec_command(put_command, in_data=None, sudoable=False)
with open(to_bytes(out_path, errors='surrogate_or_strict'), 'wb') as data:
client.download_fileobj(bucket_name, s3_path, data)
else:
with open(to_bytes(in_path, errors='surrogate_or_strict'), 'rb') as data:
client.upload_fileobj(data, bucket_name, s3_path, ExtraArgs=put_args)
(returncode, stdout, stderr) = self.exec_command(get_command, in_data=None, sudoable=False)
# Remove the files from the bucket after they've been transferred
client.delete_object(Bucket=bucket_name, Key=s3_path)
# Check the return code
if returncode == 0:
return (returncode, stdout, stderr)
raise AnsibleError(
f"failed to transfer file to {in_path} {out_path}:\n"
f"{stdout}\n{stderr}")
def put_file(self, in_path, out_path):
''' transfer a file from local to remote '''
super().put_file(in_path, out_path)
self._vvv(f"PUT {in_path} TO {out_path}")
if not os.path.exists(to_bytes(in_path, errors='surrogate_or_strict')):
raise AnsibleFileNotFound(f"file or module does not exist: {in_path}")
return self._file_transport_command(in_path, out_path, 'put')
def fetch_file(self, in_path, out_path):
''' fetch a file from remote to local '''
super().fetch_file(in_path, out_path)
self._vvv(f"FETCH {in_path} TO {out_path}")
return self._file_transport_command(in_path, out_path, 'get')
def close(self):
''' terminate the connection '''
if self._session_id:
self._vvv(f"CLOSING SSM CONNECTION TO: {self.instance_id}")
if self._timeout:
self._session.terminate()
else:
cmd = b"\nexit\n"
self._session.communicate(cmd)
self._vvvv(f"TERMINATE SSM SESSION: {self._session_id}")
self._client.terminate_session(SessionId=self._session_id)
self._session_id = ''

View File

@@ -0,0 +1,376 @@
# Copyright: Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
#
# Note: This code should probably live in amazon.aws rather than community.aws.
# However, for the sake of getting something into a useful shape first, it makes
# sense for it to start life in community.aws.
#
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from copy import deepcopy
from functools import wraps
try:
import botocore
except ImportError:
pass # Handled by AnsibleAWSModule
from ansible.module_utils.common.dict_transformations import camel_dict_to_snake_dict
from ansible_collections.amazon.aws.plugins.module_utils.tagging import boto3_tag_list_to_ansible_dict
class BaseWaiterFactory():
"""
A helper class used for creating additional waiters.
Unlike the waiters available directly from botocore these waiters will
automatically retry on common (temporary) AWS failures.
This class should be treated as an abstract class and subclassed before use.
A subclass should:
- create the necessary client to pass to BaseWaiterFactory.__init__
- override _BaseWaiterFactory._waiter_model_data to return the data defining
the waiter
Usage:
waiter_factory = BaseWaiterFactory(module, client)
waiter = waiters.get_waiter('my_waiter_name')
waiter.wait(**params)
"""
module = None
client = None
def __init__(self, module, client):
self.module = module
self.client = client
# While it would be nice to supliment this with the upstream data,
# unfortunately client doesn't have a public method for getting the
# waiter configs.
data = self._inject_ratelimit_retries(self._waiter_model_data)
self._model = botocore.waiter.WaiterModel(
waiter_config=dict(version=2, waiters=data),
)
@property
def _waiter_model_data(self):
r"""
Subclasses should override this method to return a dictionary mapping
waiter names to the waiter definition.
This data is similar to the data found in botocore's waiters-2.json
files (for example: botocore/botocore/data/ec2/2016-11-15/waiters-2.json)
with two differences:
1) Waiter names do not have transformations applied during lookup
2) Only the 'waiters' data is required, the data is assumed to be
version 2
for example:
@property
def _waiter_model_data(self):
return dict(
tgw_attachment_deleted=dict(
operation='DescribeTransitGatewayAttachments',
delay=5, maxAttempts=120,
acceptors=[
dict(state='retry', matcher='pathAll', expected='deleting', argument='TransitGatewayAttachments[].State'),
dict(state='success', matcher='pathAll', expected='deleted', argument='TransitGatewayAttachments[].State'),
dict(state='success', matcher='path', expected=True, argument='length(TransitGatewayAttachments[]) == `0`'),
dict(state='success', matcher='error', expected='InvalidRouteTableID.NotFound'),
]
),
)
or
@property
def _waiter_model_data(self):
return {
"instance_exists": {
"delay": 5,
"maxAttempts": 40,
"operation": "DescribeInstances",
"acceptors": [
{
"matcher": "path",
"expected": true,
"argument": "length(Reservations[]) > `0`",
"state": "success"
},
{
"matcher": "error",
"expected": "InvalidInstanceID.NotFound",
"state": "retry"
}
]
},
}
"""
return dict()
def _inject_ratelimit_retries(self, model):
extra_retries = [
'RequestLimitExceeded', 'Unavailable', 'ServiceUnavailable',
'InternalFailure', 'InternalError', 'TooManyRequestsException',
'Throttling']
acceptors = []
for error in extra_retries:
acceptors.append(dict(state="retry", matcher="error", expected=error))
_model = deepcopy(model)
for waiter in _model:
_model[waiter]["acceptors"].extend(acceptors)
return _model
def get_waiter(self, waiter_name):
waiters = self._model.waiter_names
if waiter_name not in waiters:
self.module.fail_json(
'Unable to find waiter {0}. Available_waiters: {1}'
.format(waiter_name, waiters))
return botocore.waiter.create_waiter_with_client(
waiter_name, self._model, self.client,
)
class Boto3Mixin():
@staticmethod
def aws_error_handler(description):
r"""
A simple wrapper that handles the usual botocore exceptions and exits
with module.fail_json_aws. Designed to be used with BaseResourceManager.
Assumptions:
1) First argument (usually `self` of method being wrapped will have a
'module' attribute which is an AnsibleAWSModule
2) First argument of method being wrapped will have an
_extra_error_output() method which takes no arguments and returns a
dictionary of extra parameters to be returned in the event of a
botocore exception.
Parameters:
description (string): In the event of a botocore exception the error
message will be 'Failed to {DESCRIPTION}'.
Example Usage:
class ExampleClass(Boto3Mixin):
def __init__(self, module)
self.module = module
self._get_client()
@Boto3Mixin.aws_error_handler("connect to AWS")
def _get_client(self):
self.client = self.module.client('ec2')
@Boto3Mixin.aws_error_handler("describe EC2 instances")
def _do_something(**params):
return self.client.describe_instances(**params)
"""
def wrapper(func):
@wraps(func)
def handler(_self, *args, **kwargs):
extra_ouput = _self._extra_error_output()
try:
return func(_self, *args, **kwargs)
except (botocore.exceptions.WaiterError) as e:
_self.module.fail_json_aws(e, msg='Failed waiting for {DESC}'.format(DESC=description), **extra_ouput)
except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
_self.module.fail_json_aws(e, msg='Failed to {DESC}'.format(DESC=description), **extra_ouput)
return handler
return wrapper
def _normalize_boto3_resource(self, resource, add_tags=False):
r"""
Performs common boto3 resource to Ansible resource conversion.
`resource['Tags']` will by default be converted from the boto3 tag list
format to a simple dictionary.
Parameters:
resource (dict): The boto3 style resource to convert to the normal Ansible
format (snake_case).
add_tags (bool): When `true`, if a resource does not have 'Tags' property
the returned resource will have tags set to an empty
dictionary.
"""
if resource is None:
return None
tags = resource.get('Tags', None)
if tags:
tags = boto3_tag_list_to_ansible_dict(tags)
elif add_tags or tags is not None:
tags = {}
normalized_resource = camel_dict_to_snake_dict(resource)
if tags is not None:
normalized_resource['tags'] = tags
return normalized_resource
def _extra_error_output(self):
# In the event of an error it can be helpful to ouput things like the
# 'name'/'arn' of a resource.
return dict()
class BaseResourceManager(Boto3Mixin):
def __init__(self, module):
r"""
Parameters:
module (AnsibleAWSModule): An Ansible module.
"""
self.module = module
self.changed = False
self.original_resource = dict()
self.updated_resource = dict()
self._resource_updates = dict()
self._preupdate_resource = dict()
self._wait = True
self._wait_timeout = None
super(BaseResourceManager, self).__init__()
def _merge_resource_changes(self, filter_immutable=True, creation=False):
"""
Merges the contents of the 'pre_update' resource and metadata variables
with the pending updates
"""
resource = deepcopy(self._preupdate_resource)
resource.update(self._resource_updates)
if filter_immutable:
resource = self._filter_immutable_resource_attributes(resource)
return resource
def _filter_immutable_resource_attributes(self, resource):
return deepcopy(resource)
def _do_creation_wait(self, **params):
pass
def _do_deletion_wait(self, **params):
pass
def _do_update_wait(self, **params):
pass
@property
def _waiter_config(self):
params = dict()
if self._wait_timeout:
delay = min(5, self._wait_timeout)
max_attempts = (self._wait_timeout // delay)
config = dict(Delay=delay, MaxAttempts=max_attempts)
params['WaiterConfig'] = config
return params
def _wait_for_deletion(self):
if not self._wait:
return
params = self._waiter_config
self._do_deletion_wait(**params)
def _wait_for_creation(self):
if not self._wait:
return
params = self._waiter_config
self._do_creation_wait(**params)
def _wait_for_update(self):
if not self._wait:
return
params = self._waiter_config
self._do_update_wait(**params)
def _generate_updated_resource(self):
"""
Merges all pending changes into self.updated_resource
Used during check mode where it's not possible to get and
refresh the resource
"""
return self._merge_resource_changes(filter_immutable=False)
# If you override _flush_update you're responsible for handling check_mode
# If you override _do_update_resource you'll only be called if check_mode == False
def _flush_create(self):
changed = True
if not self.module.check_mode:
changed = self._do_create_resource()
self._wait_for_creation()
self._do_creation_wait()
self.updated_resource = self.get_resource()
else: # (CHECK MODE)
self.updated_resource = self._normalize_resource(self._generate_updated_resource())
self._resource_updates = dict()
self.changed = changed
return True
def _check_updates_pending(self):
if self._resource_updates:
return True
return False
# If you override _flush_update you're responsible for handling check_mode
# If you override _do_update_resource you'll only be called if there are
# updated pending and check_mode == False
def _flush_update(self):
if not self._check_updates_pending():
self.updated_resource = self.original_resource
return False
if not self.module.check_mode:
self._do_update_resource()
response = self._wait_for_update()
self.updated_resource = self.get_resource()
else: # (CHECK_MODE)
self.updated_resource = self._normalize_resource(self._generate_updated_resource())
self._resource_updates = dict()
return True
def flush_changes(self):
if self.original_resource:
return self._flush_update()
else:
return self._flush_create()
def _set_resource_value(self, key, value, description=None, immutable=False):
if value is None:
return False
if value == self._get_resource_value(key):
return False
if immutable and self.original_resource:
if description is None:
description = key
self.module.fail_json(msg='{0} can not be updated after creation'
.format(description))
self._resource_updates[key] = value
self.changed = True
return True
def _get_resource_value(self, key, default=None):
default_value = self._preupdate_resource.get(key, default)
return self._resource_updates.get(key, default_value)
def set_wait(self, wait):
if wait is None:
return False
if wait == self._wait:
return False
self._wait = wait
return True
def set_wait_timeout(self, timeout):
if timeout is None:
return False
if timeout == self._wait_timeout:
return False
self._wait_timeout = timeout
return True

View File

@@ -0,0 +1,189 @@
# Copyright: Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from copy import deepcopy
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import ansible_dict_to_boto3_filter_list
from ansible_collections.amazon.aws.plugins.module_utils.tagging import boto3_tag_list_to_ansible_dict
from ansible_collections.amazon.aws.plugins.module_utils.tagging import ansible_dict_to_boto3_tag_list
from ansible_collections.amazon.aws.plugins.module_utils.tagging import compare_aws_tags
from ansible_collections.amazon.aws.plugins.module_utils.tagging import boto3_tag_specifications
from ansible_collections.community.aws.plugins.module_utils.base import BaseResourceManager
from ansible_collections.community.aws.plugins.module_utils.base import BaseWaiterFactory
from ansible_collections.community.aws.plugins.module_utils.base import Boto3Mixin
class Ec2WaiterFactory(BaseWaiterFactory):
def __init__(self, module):
# the AWSRetry wrapper doesn't support the wait functions (there's no
# public call we can cleanly wrap)
client = module.client('ec2')
super(Ec2WaiterFactory, self).__init__(module, client)
@property
def _waiter_model_data(self):
data = super(Ec2WaiterFactory, self)._waiter_model_data
return data
class Ec2Boto3Mixin(Boto3Mixin):
@AWSRetry.jittered_backoff()
def _paginated_describe_subnets(self, **params):
paginator = self.client.get_paginator('describe_subnets')
return paginator.paginate(**params).build_full_result()
@Boto3Mixin.aws_error_handler('describe subnets')
def _describe_subnets(self, **params):
try:
result = self._paginated_describe_subnets(**params)
except is_boto3_error_code('SubnetID.NotFound'):
return None
return result.get('Subnets', None)
class BaseEc2Manager(Ec2Boto3Mixin, BaseResourceManager):
resource_id = None
TAG_RESOURCE_TYPE = None
# This can be overridden by a subclass *if* 'Tags' isn't returned as a part of
# the standard Resource description
TAGS_ON_RESOURCE = True
# If the resource supports using "TagSpecifications" on creation we can
TAGS_ON_CREATE = 'TagSpecifications'
def __init__(self, module, id=None):
r"""
Parameters:
module (AnsibleAWSModule): An Ansible module.
"""
super(BaseEc2Manager, self).__init__(module)
self.client = self._create_client()
self._tagging_updates = dict()
self.resource_id = id
# Name parameter is unique (by region) and can not be modified.
if self.resource_id:
resource = deepcopy(self.get_resource())
self.original_resource = resource
def _flush_update(self):
changed = False
changed |= self._do_tagging()
changed |= super(BaseEc2Manager, self)._flush_update()
return changed
@Boto3Mixin.aws_error_handler('connect to AWS')
def _create_client(self, client_name='ec2'):
client = self.module.client(client_name, retry_decorator=AWSRetry.jittered_backoff())
return client
@Boto3Mixin.aws_error_handler('set tags on resource')
def _add_tags(self, **params):
self.client.create_tags(aws_retry=True, **params)
return True
@Boto3Mixin.aws_error_handler('unset tags on resource')
def _remove_tags(self, **params):
self.client.delete_tags(aws_retry=True, **params)
return True
@AWSRetry.jittered_backoff()
def _paginated_describe_tags(self, **params):
paginator = self.client.get_paginator('describe_tags')
return paginator.paginate(**params).build_full_result()
@Boto3Mixin.aws_error_handler('list tags on resource')
def _describe_tags(self, id=None):
if not id:
id = self.resource_id
filters = ansible_dict_to_boto3_filter_list({"resource-id": id})
tags = self._paginated_describe_tags(Filters=filters)
return tags
def _get_tags(self, id=None):
if id is None:
id = self.resource_id
# If the Tags are available from the resource, then use them
if self.TAGS_ON_RESOURCE:
tags = self._preupdate_resource.get('Tags', [])
# Otherwise we'll have to look them up
else:
tags = self._describe_tags(id=id)
return boto3_tag_list_to_ansible_dict(tags)
def _do_tagging(self):
changed = False
tags_to_add = self._tagging_updates.get('add')
tags_to_remove = self._tagging_updates.get('remove')
if tags_to_add:
changed = True
tags = ansible_dict_to_boto3_tag_list(tags_to_add)
if not self.module.check_mode:
self._add_tags(Resources=[self.resource_id], Tags=tags)
if tags_to_remove:
changed = True
if not self.module.check_mode:
tag_list = [dict(Key=tagkey) for tagkey in tags_to_remove]
self._remove_tags(Resources=[self.resource_id], Tags=tag_list)
return changed
def _merge_resource_changes(self, filter_immutable=True, creation=False):
resource = super(BaseEc2Manager, self)._merge_resource_changes(
filter_immutable=filter_immutable,
creation=creation
)
if creation:
if not self.TAGS_ON_CREATE:
resource.pop('Tags', None)
elif self.TAGS_ON_CREATE == 'TagSpecifications':
tags = boto3_tag_list_to_ansible_dict(resource.pop('Tags', []))
tag_specs = boto3_tag_specifications(tags, types=[self.TAG_RESOURCE_TYPE])
if tag_specs:
resource['TagSpecifications'] = tag_specs
return resource
def set_tags(self, tags, purge_tags):
if tags is None:
return False
changed = False
# Tags are returned as a part of the resource, but have to be updated
# via dedicated tagging methods
current_tags = self._get_tags()
# So that diff works in check mode we need to know the full target state
if purge_tags:
desired_tags = deepcopy(tags)
else:
desired_tags = deepcopy(current_tags)
desired_tags.update(tags)
tags_to_add, tags_to_remove = compare_aws_tags(current_tags, tags, purge_tags)
if tags_to_add:
self._tagging_updates['add'] = tags_to_add
changed = True
if tags_to_remove:
self._tagging_updates['remove'] = tags_to_remove
changed = True
if changed:
# Tags are a stored as a list, but treated like a list, the
# simplisic '==' in _set_resource_value doesn't do the comparison
# properly
return self._set_resource_value('Tags', ansible_dict_to_boto3_tag_list(desired_tags))
return False

View File

@@ -0,0 +1,62 @@
# source: https://github.com/tlastowka/calculate_multipart_etag/blob/master/calculate_multipart_etag.py
#
# calculate_multipart_etag Copyright (C) 2015
# Tony Lastowka <tlastowka at gmail dot com>
# https://github.com/tlastowka
#
#
# calculate_multipart_etag is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# calculate_multipart_etag is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with calculate_multipart_etag. If not, see <http://www.gnu.org/licenses/>.
import hashlib
try:
from boto3.s3.transfer import TransferConfig
DEFAULT_CHUNK_SIZE = TransferConfig().multipart_chunksize
except ImportError:
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024
pass # Handled by AnsibleAWSModule
def calculate_multipart_etag(source_path, chunk_size=DEFAULT_CHUNK_SIZE):
"""
calculates a multipart upload etag for amazon s3
Arguments:
source_path -- The file to calculate the etag for
chunk_size -- The chunk size to calculate for.
"""
md5s = []
with open(source_path, 'rb') as fp:
while True:
data = fp.read(chunk_size)
if not data:
break
md5 = hashlib.new('md5', usedforsecurity=False)
md5.update(data)
md5s.append(md5)
if len(md5s) == 1:
new_etag = '"{0}"'.format(md5s[0].hexdigest())
else: # > 1
digests = b"".join(m.digest() for m in md5s)
new_md5 = hashlib.md5(digests)
new_etag = '"{0}-{1}"'.format(new_md5.hexdigest(), len(md5s))
return new_etag

View File

@@ -0,0 +1,280 @@
# This file is part of Ansible
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from copy import deepcopy
import datetime
import functools
import time
try:
import botocore
except ImportError:
pass # caught by AnsibleAWSModule
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import (
ansible_dict_to_boto3_tag_list,
camel_dict_to_snake_dict,
compare_aws_tags,
)
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code
from ansible_collections.amazon.aws.plugins.module_utils.tagging import (
boto3_tag_list_to_ansible_dict,
)
from ansible.module_utils.six import string_types
def get_domain_status(client, module, domain_name):
"""
Get the status of an existing OpenSearch cluster.
"""
try:
response = client.describe_domain(DomainName=domain_name)
except is_boto3_error_code("ResourceNotFoundException"):
return None
except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: # pylint: disable=duplicate-except
module.fail_json_aws(e, msg="Couldn't get domain {0}".format(domain_name))
return response["DomainStatus"]
def get_domain_config(client, module, domain_name):
"""
Get the configuration of an existing OpenSearch cluster, convert the data
such that it can be used as input parameter to client.update_domain().
The status info is removed.
The returned config includes the 'EngineVersion' property, it needs to be removed
from the dict before invoking client.update_domain().
Return (domain_config, domain_arn) or (None, None) if the domain does not exist.
"""
try:
response = client.describe_domain_config(DomainName=domain_name)
except is_boto3_error_code("ResourceNotFoundException"):
return (None, None)
except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: # pylint: disable=duplicate-except
module.fail_json_aws(e, msg="Couldn't get domain {0}".format(domain_name))
domain_config = {}
arn = None
if response is not None:
for k in response["DomainConfig"]:
domain_config[k] = response["DomainConfig"][k]["Options"]
domain_config["DomainName"] = domain_name
# If ES cluster is attached to the Internet, the "VPCOptions" property is not present.
if "VPCOptions" in domain_config:
# The "VPCOptions" returned by the describe_domain_config API has
# additional attributes that would cause an error if sent in the HTTP POST body.
dc = {}
if "SubnetIds" in domain_config["VPCOptions"]:
dc["SubnetIds"] = deepcopy(domain_config["VPCOptions"]["SubnetIds"])
if "SecurityGroupIds" in domain_config["VPCOptions"]:
dc["SecurityGroupIds"] = deepcopy(domain_config["VPCOptions"]["SecurityGroupIds"])
domain_config["VPCOptions"] = dc
# The "StartAt" property is converted to datetime, but when doing comparisons it should
# be in the string format "YYYY-MM-DD".
for s in domain_config["AutoTuneOptions"]["MaintenanceSchedules"]:
if isinstance(s["StartAt"], datetime.datetime):
s["StartAt"] = s["StartAt"].strftime("%Y-%m-%d")
# Provisioning of "AdvancedOptions" is not supported by this module yet.
domain_config.pop("AdvancedOptions", None)
# Get the ARN of the OpenSearch cluster.
domain = get_domain_status(client, module, domain_name)
if domain is not None:
arn = domain["ARN"]
return (domain_config, arn)
def normalize_opensearch(client, module, domain):
"""
Merge the input domain object with tags associated with the domain,
convert the attributes from camel case to snake case, and return the object.
"""
try:
domain["Tags"] = boto3_tag_list_to_ansible_dict(
client.list_tags(ARN=domain["ARN"], aws_retry=True)["TagList"]
)
except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
module.fail_json_aws(
e, "Couldn't get tags for domain %s" % domain["domain_name"]
)
except KeyError:
module.fail_json(msg=str(domain))
return camel_dict_to_snake_dict(domain, ignore_list=["Tags"])
def wait_for_domain_status(client, module, domain_name, waiter_name):
if not module.params["wait"]:
return
timeout = module.params["wait_timeout"]
deadline = time.time() + timeout
status_msg = ""
while time.time() < deadline:
status = get_domain_status(client, module, domain_name)
if status is None:
status_msg = "Not Found"
if waiter_name == "domain_deleted":
return
else:
status_msg = "Created: {0}. Processing: {1}. UpgradeProcessing: {2}".format(
status["Created"],
status["Processing"],
status["UpgradeProcessing"],
)
if (
waiter_name == "domain_available"
and status["Created"]
and not status["Processing"]
and not status["UpgradeProcessing"]
):
return
time.sleep(15)
# Timeout occured.
module.fail_json(
msg=f"Timeout waiting for wait state '{waiter_name}'. {status_msg}"
)
def parse_version(engine_version):
'''
Parse the engine version, which should be Elasticsearch_X.Y or OpenSearch_X.Y
Return dict { 'engine_type': engine_type, 'major': major, 'minor': minor }
'''
version = engine_version.split("_")
if len(version) != 2:
return None
semver = version[1].split(".")
if len(semver) != 2:
return None
engine_type = version[0]
if engine_type not in ['Elasticsearch', 'OpenSearch']:
return None
if not (semver[0].isdigit() and semver[1].isdigit()):
return None
major = int(semver[0])
minor = int(semver[1])
return {'engine_type': engine_type, 'major': major, 'minor': minor}
def compare_domain_versions(version1, version2):
supported_engines = {
'Elasticsearch': 1,
'OpenSearch': 2,
}
if isinstance(version1, string_types):
version1 = parse_version(version1)
if isinstance(version2, string_types):
version2 = parse_version(version2)
if version1 is None and version2 is not None:
return -1
elif version1 is not None and version2 is None:
return 1
elif version1 is None and version2 is None:
return 0
e1 = supported_engines.get(version1.get('engine_type'))
e2 = supported_engines.get(version2.get('engine_type'))
if e1 < e2:
return -1
elif e1 > e2:
return 1
else:
if version1.get('major') < version2.get('major'):
return -1
elif version1.get('major') > version2.get('major'):
return 1
else:
if version1.get('minor') < version2.get('minor'):
return -1
elif version1.get('minor') > version2.get('minor'):
return 1
else:
return 0
def get_target_increment_version(client, module, domain_name, target_version):
"""
Returns the highest compatible version which is less than or equal to target_version.
When upgrading a domain from version V1 to V2, it may not be possible to upgrade
directly from V1 to V2. The domain may have to be upgraded through intermediate versions.
Return None if there is no such version.
For example, it's not possible to upgrade directly from Elasticsearch 5.5 to 7.10.
"""
api_compatible_versions = None
try:
api_compatible_versions = client.get_compatible_versions(DomainName=domain_name)
except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e:
module.fail_json_aws(
e,
msg="Couldn't get compatible versions for domain {0}".format(
domain_name),
)
compat = api_compatible_versions.get('CompatibleVersions')
if compat is None:
module.fail_json(
"Unable to determine list of compatible versions",
compatible_versions=api_compatible_versions)
if len(compat) == 0:
module.fail_json(
"Unable to determine list of compatible versions",
compatible_versions=api_compatible_versions)
if compat[0].get("TargetVersions") is None:
module.fail_json(
"No compatible versions found",
compatible_versions=api_compatible_versions)
compatible_versions = []
for v in compat[0].get("TargetVersions"):
if target_version == v:
# It's possible to upgrade directly to the target version.
return target_version
semver = parse_version(v)
if semver is not None:
compatible_versions.append(semver)
# No direct upgrade is possible. Upgrade to the highest version available.
compatible_versions = sorted(compatible_versions, key=functools.cmp_to_key(compare_domain_versions))
# Return the highest compatible version which is lower than target_version
for v in reversed(compatible_versions):
if compare_domain_versions(v, target_version) <= 0:
return v
return None
def ensure_tags(client, module, resource_arn, existing_tags, tags, purge_tags):
if tags is None:
return False
tags_to_add, tags_to_remove = compare_aws_tags(existing_tags, tags, purge_tags)
changed = bool(tags_to_add or tags_to_remove)
if tags_to_add:
if module.check_mode:
module.exit_json(
changed=True, msg="Would have added tags to domain if not in check mode"
)
try:
client.add_tags(
ARN=resource_arn,
TagList=ansible_dict_to_boto3_tag_list(tags_to_add),
)
except (
botocore.exceptions.ClientError,
botocore.exceptions.BotoCoreError,
) as e:
module.fail_json_aws(
e, "Couldn't add tags to domain {0}".format(resource_arn)
)
if tags_to_remove:
if module.check_mode:
module.exit_json(
changed=True, msg="Would have removed tags if not in check mode"
)
try:
client.remove_tags(ARN=resource_arn, TagKeys=tags_to_remove)
except (
botocore.exceptions.ClientError,
botocore.exceptions.BotoCoreError,
) as e:
module.fail_json_aws(
e, "Couldn't remove tags from domain {0}".format(resource_arn)
)
return changed

View File

@@ -0,0 +1,125 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import re
import copy
try:
import botocore
except ImportError:
pass # handled by AnsibleAWSModule
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import camel_dict_to_snake_dict
@AWSRetry.jittered_backoff()
def _list_topics_with_backoff(client):
paginator = client.get_paginator('list_topics')
return paginator.paginate().build_full_result()['Topics']
@AWSRetry.jittered_backoff(catch_extra_error_codes=['NotFound'])
def _list_topic_subscriptions_with_backoff(client, topic_arn):
paginator = client.get_paginator('list_subscriptions_by_topic')
return paginator.paginate(TopicArn=topic_arn).build_full_result()['Subscriptions']
@AWSRetry.jittered_backoff(catch_extra_error_codes=['NotFound'])
def _list_subscriptions_with_backoff(client):
paginator = client.get_paginator('list_subscriptions')
return paginator.paginate().build_full_result()['Subscriptions']
def list_topic_subscriptions(client, module, topic_arn):
try:
return _list_topic_subscriptions_with_backoff(client, topic_arn)
except is_boto3_error_code('AuthorizationError'):
try:
# potentially AuthorizationError when listing subscriptions for third party topic
return [sub for sub in _list_subscriptions_with_backoff(client)
if sub['TopicArn'] == topic_arn]
except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
module.fail_json_aws(e, msg="Couldn't get subscriptions list for topic %s" % topic_arn)
except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: # pylint: disable=duplicate-except
module.fail_json_aws(e, msg="Couldn't get subscriptions list for topic %s" % topic_arn)
def list_topics(client, module):
try:
topics = _list_topics_with_backoff(client)
except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e:
module.fail_json_aws(e, msg="Couldn't get topic list")
return [t['TopicArn'] for t in topics]
def topic_arn_lookup(client, module, name):
# topic names cannot have colons, so this captures the full topic name
all_topics = list_topics(client, module)
lookup_topic = ':%s' % name
for topic in all_topics:
if topic.endswith(lookup_topic):
return topic
def compare_delivery_policies(policy_a, policy_b):
_policy_a = copy.deepcopy(policy_a)
_policy_b = copy.deepcopy(policy_b)
# AWS automatically injects disableSubscriptionOverrides if you set an
# http policy
if 'http' in policy_a:
if 'disableSubscriptionOverrides' not in policy_a['http']:
_policy_a['http']['disableSubscriptionOverrides'] = False
if 'http' in policy_b:
if 'disableSubscriptionOverrides' not in policy_b['http']:
_policy_b['http']['disableSubscriptionOverrides'] = False
comparison = (_policy_a != _policy_b)
return comparison
def canonicalize_endpoint(protocol, endpoint):
# AWS SNS expects phone numbers in
# and canonicalizes to E.164 format
# See <https://docs.aws.amazon.com/sns/latest/dg/sms_publish-to-phone.html>
if protocol == 'sms':
return re.sub('[^0-9+]*', '', endpoint)
return endpoint
def get_info(connection, module, topic_arn):
name = module.params.get('name')
topic_type = module.params.get('topic_type')
state = module.params.get('state')
subscriptions = module.params.get('subscriptions')
purge_subscriptions = module.params.get('purge_subscriptions')
subscriptions_existing = module.params.get('subscriptions_existing', [])
subscriptions_deleted = module.params.get('subscriptions_deleted', [])
subscriptions_added = module.params.get('subscriptions_added', [])
subscriptions_added = module.params.get('subscriptions_added', [])
topic_created = module.params.get('topic_created', False)
topic_deleted = module.params.get('topic_deleted', False)
attributes_set = module.params.get('attributes_set', [])
check_mode = module.check_mode
info = {
'name': name,
'topic_type': topic_type,
'state': state,
'subscriptions_new': subscriptions,
'subscriptions_existing': subscriptions_existing,
'subscriptions_deleted': subscriptions_deleted,
'subscriptions_added': subscriptions_added,
'subscriptions_purge': purge_subscriptions,
'check_mode': check_mode,
'topic_created': topic_created,
'topic_deleted': topic_deleted,
'attributes_set': attributes_set,
}
if state != 'absent':
if topic_arn in list_topics(connection, module):
info.update(camel_dict_to_snake_dict(connection.get_topic_attributes(TopicArn=topic_arn)['Attributes']))
info['delivery_policy'] = info.pop('effective_delivery_policy')
info['subscriptions'] = [camel_dict_to_snake_dict(sub) for sub in list_topic_subscriptions(connection, module, topic_arn)]
return info

View File

@@ -0,0 +1,345 @@
# Copyright: Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from copy import deepcopy
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import ansible_dict_to_boto3_filter_list
from ansible_collections.community.aws.plugins.module_utils.ec2 import BaseEc2Manager
from ansible_collections.community.aws.plugins.module_utils.ec2 import Boto3Mixin
from ansible_collections.community.aws.plugins.module_utils.ec2 import Ec2WaiterFactory
class TgwWaiterFactory(Ec2WaiterFactory):
@property
def _waiter_model_data(self):
data = super(TgwWaiterFactory, self)._waiter_model_data
# split the TGW waiters so we can keep them close to everything else.
tgw_data = dict(
tgw_attachment_available=dict(
operation='DescribeTransitGatewayAttachments',
delay=5, maxAttempts=120,
acceptors=[
dict(state='success', matcher='pathAll', expected='available', argument='TransitGatewayAttachments[].State'),
]
),
tgw_attachment_deleted=dict(
operation='DescribeTransitGatewayAttachments',
delay=5, maxAttempts=120,
acceptors=[
dict(state='retry', matcher='pathAll', expected='deleting', argument='TransitGatewayAttachments[].State'),
dict(state='success', matcher='pathAll', expected='deleted', argument='TransitGatewayAttachments[].State'),
dict(state='success', matcher='path', expected=True, argument='length(TransitGatewayAttachments[]) == `0`'),
dict(state='success', matcher='error', expected='InvalidRouteTableID.NotFound'),
]
),
)
data.update(tgw_data)
return data
class TGWAttachmentBoto3Mixin(Boto3Mixin):
def __init__(self, module, **kwargs):
self.tgw_waiter_factory = TgwWaiterFactory(module)
super(TGWAttachmentBoto3Mixin, self).__init__(module, **kwargs)
# Paginators can't be (easily) wrapped, so we wrap this method with the
# retry - retries the full fetch, but better than simply giving up.
@AWSRetry.jittered_backoff()
def _paginated_describe_transit_gateway_vpc_attachments(self, **params):
paginator = self.client.get_paginator('describe_transit_gateway_vpc_attachments')
return paginator.paginate(**params).build_full_result()
@Boto3Mixin.aws_error_handler('describe transit gateway attachments')
def _describe_vpc_attachments(self, **params):
result = self._paginated_describe_transit_gateway_vpc_attachments(**params)
return result.get('TransitGatewayVpcAttachments', None)
@Boto3Mixin.aws_error_handler('create transit gateway attachment')
def _create_vpc_attachment(self, **params):
result = self.client.create_transit_gateway_vpc_attachment(aws_retry=True, **params)
return result.get('TransitGatewayVpcAttachment', None)
@Boto3Mixin.aws_error_handler('modify transit gateway attachment')
def _modify_vpc_attachment(self, **params):
result = self.client.modify_transit_gateway_vpc_attachment(aws_retry=True, **params)
return result.get('TransitGatewayVpcAttachment', None)
@Boto3Mixin.aws_error_handler('delete transit gateway attachment')
def _delete_vpc_attachment(self, **params):
try:
result = self.client.delete_transit_gateway_vpc_attachment(aws_retry=True, **params)
except is_boto3_error_code('ResourceNotFoundException'):
return None
return result.get('TransitGatewayVpcAttachment', None)
@Boto3Mixin.aws_error_handler('transit gateway attachment to finish deleting')
def _wait_tgw_attachment_deleted(self, **params):
waiter = self.tgw_waiter_factory.get_waiter('tgw_attachment_deleted')
waiter.wait(**params)
@Boto3Mixin.aws_error_handler('transit gateway attachment to become available')
def _wait_tgw_attachment_available(self, **params):
waiter = self.tgw_waiter_factory.get_waiter('tgw_attachment_available')
waiter.wait(**params)
def _normalize_tgw_attachment(self, rtb):
return self._normalize_boto3_resource(rtb)
def _get_tgw_vpc_attachment(self, **params):
# Only for use with a single attachment, use _describe_vpc_attachments for
# multiple tables.
attachments = self._describe_vpc_attachments(**params)
if not attachments:
return None
attachment = attachments[0]
return attachment
class BaseTGWManager(BaseEc2Manager):
@Boto3Mixin.aws_error_handler('connect to AWS')
def _create_client(self, client_name='ec2'):
if client_name == 'ec2':
error_codes = ['IncorrectState']
else:
error_codes = []
retry_decorator = AWSRetry.jittered_backoff(
catch_extra_error_codes=error_codes,
)
client = self.module.client(client_name, retry_decorator=retry_decorator)
return client
class TransitGatewayVpcAttachmentManager(TGWAttachmentBoto3Mixin, BaseTGWManager):
TAG_RESOURCE_TYPE = 'transit-gateway-attachment'
def __init__(self, module, id=None):
self._subnet_updates = dict()
super(TransitGatewayVpcAttachmentManager, self).__init__(module=module, id=id)
def _get_id_params(self, id=None, id_list=False):
if not id:
id = self.resource_id
if not id:
# Users should never see this, but let's cover ourself
self.module.fail_json(msg='Attachment identifier parameter missing')
if id_list:
return dict(TransitGatewayAttachmentIds=[id])
return dict(TransitGatewayAttachmentId=id)
def _extra_error_output(self):
output = super(TransitGatewayVpcAttachmentManager, self)._extra_error_output()
if self.resource_id:
output['TransitGatewayAttachmentId'] = self.resource_id
return output
def _filter_immutable_resource_attributes(self, resource):
resource = super(TransitGatewayVpcAttachmentManager, self)._filter_immutable_resource_attributes(resource)
resource.pop('TransitGatewayId', None)
resource.pop('VpcId', None)
resource.pop('VpcOwnerId', None)
resource.pop('State', None)
resource.pop('SubnetIds', None)
resource.pop('CreationTime', None)
resource.pop('Tags', None)
return resource
def _set_option(self, name, value):
if value is None:
return False
# For now VPC Attachment options are all enable/disable
if value:
value = 'enable'
else:
value = 'disable'
options = deepcopy(self._preupdate_resource.get('Options', dict()))
options.update(self._resource_updates.get('Options', dict()))
options[name] = value
return self._set_resource_value('Options', options)
def set_dns_support(self, value):
return self._set_option('DnsSupport', value)
def set_ipv6_support(self, value):
return self._set_option('Ipv6Support', value)
def set_appliance_mode_support(self, value):
return self._set_option('ApplianceModeSupport', value)
def set_transit_gateway(self, tgw_id):
return self._set_resource_value('TransitGatewayId', tgw_id)
def set_vpc(self, vpc_id):
return self._set_resource_value('VpcId', vpc_id)
def set_subnets(self, subnets=None, purge=True):
if subnets is None:
return False
current_subnets = set(self._preupdate_resource.get('SubnetIds', []))
desired_subnets = set(subnets)
if not purge:
desired_subnets = desired_subnets.union(current_subnets)
# We'll pull the VPC ID from the subnets, no point asking for
# information we 'know'.
subnet_details = self._describe_subnets(SubnetIds=list(desired_subnets))
vpc_id = self.subnets_to_vpc(desired_subnets, subnet_details)
self._set_resource_value('VpcId', vpc_id, immutable=True)
# Only one subnet per-AZ is permitted
azs = [s.get('AvailabilityZoneId') for s in subnet_details]
if len(azs) != len(set(azs)):
self.module.fail_json(
msg='Only one attachment subnet per availability zone may be set.',
availability_zones=azs, subnets=subnet_details)
subnets_to_add = list(desired_subnets.difference(current_subnets))
subnets_to_remove = list(current_subnets.difference(desired_subnets))
if not subnets_to_remove and not subnets_to_add:
return False
self._subnet_updates = dict(add=subnets_to_add, remove=subnets_to_remove)
self._set_resource_value('SubnetIds', list(desired_subnets))
return True
def subnets_to_vpc(self, subnets, subnet_details=None):
if not subnets:
return None
if subnet_details is None:
subnet_details = self._describe_subnets(SubnetIds=list(subnets))
vpcs = [s.get('VpcId') for s in subnet_details]
if len(set(vpcs)) > 1:
self.module.fail_json(
msg='Attachment subnets may only be in one VPC, multiple VPCs found',
vpcs=list(set(vpcs)), subnets=subnet_details)
return vpcs[0]
def _do_deletion_wait(self, id=None, **params):
all_params = self._get_id_params(id=id, id_list=True)
all_params.update(**params)
return self._wait_tgw_attachment_deleted(**all_params)
def _do_creation_wait(self, id=None, **params):
all_params = self._get_id_params(id=id, id_list=True)
all_params.update(**params)
return self._wait_tgw_attachment_available(**all_params)
def _do_update_wait(self, id=None, **params):
all_params = self._get_id_params(id=id, id_list=True)
all_params.update(**params)
return self._wait_tgw_attachment_available(**all_params)
def _do_create_resource(self):
params = self._merge_resource_changes(filter_immutable=False, creation=True)
response = self._create_vpc_attachment(**params)
if response:
self.resource_id = response.get('TransitGatewayAttachmentId', None)
return response
def _do_update_resource(self):
if self._preupdate_resource.get('State', None) == 'pending':
# Resources generally don't like it if you try to update before creation
# is complete. If things are in a 'pending' state they'll often throw
# exceptions.
self._wait_for_creation()
elif self._preupdate_resource.get('State', None) == 'deleting':
self.module.fail_json(msg='Deletion in progress, unable to update',
route_tables=[self.original_resource])
updates = self._filter_immutable_resource_attributes(self._resource_updates)
subnets_to_add = self._subnet_updates.get('add', [])
subnets_to_remove = self._subnet_updates.get('remove', [])
if subnets_to_add:
updates['AddSubnetIds'] = subnets_to_add
if subnets_to_remove:
updates['RemoveSubnetIds'] = subnets_to_remove
if not updates:
return False
if self.module.check_mode:
return True
updates.update(self._get_id_params(id_list=False))
self._modify_vpc_attachment(**updates)
return True
def get_resource(self):
return self.get_attachment()
def delete(self, id=None):
if id:
id_params = self._get_id_params(id=id, id_list=True)
result = self._get_tgw_vpc_attachment(**id_params)
else:
result = self._preupdate_resource
self.updated_resource = dict()
if not result:
return False
if result.get('State') == 'deleting':
self._wait_for_deletion()
return False
if self.module.check_mode:
self.changed = True
return True
id_params = self._get_id_params(id=id, id_list=False)
result = self._delete_vpc_attachment(**id_params)
self.changed |= bool(result)
self._wait_for_deletion()
return bool(result)
def list(self, filters=None, id=None):
params = dict()
if id:
params['TransitGatewayAttachmentIds'] = [id]
if filters:
params['Filters'] = ansible_dict_to_boto3_filter_list(filters)
attachments = self._describe_vpc_attachments(**params)
if not attachments:
return list()
return [self._normalize_tgw_attachment(a) for a in attachments]
def get_attachment(self, id=None):
# RouteTable needs a list, Association/Propagation needs a single ID
id_params = self._get_id_params(id=id, id_list=True)
id_param = self._get_id_params(id=id, id_list=False)
result = self._get_tgw_vpc_attachment(**id_params)
if not result:
return None
if not id:
self._preupdate_resource = deepcopy(result)
attachment = self._normalize_tgw_attachment(result)
return attachment
def _normalize_resource(self, resource):
return self._normalize_tgw_attachment(resource)

View File

@@ -0,0 +1,206 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
try:
from botocore.exceptions import ClientError, BotoCoreError
except ImportError:
pass # caught by AnsibleAWSModule
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry
from ansible_collections.amazon.aws.plugins.module_utils.tagging import ansible_dict_to_boto3_tag_list
from ansible_collections.amazon.aws.plugins.module_utils.tagging import boto3_tag_list_to_ansible_dict
from ansible_collections.amazon.aws.plugins.module_utils.tagging import compare_aws_tags
@AWSRetry.jittered_backoff()
def _list_tags(wafv2, arn, fail_json_aws, next_marker=None):
params = dict(ResourceARN=arn)
if next_marker:
params['NextMarker'] = next_marker
try:
return wafv2.list_tags_for_resource(**params)
except (BotoCoreError, ClientError) as e:
fail_json_aws(e, msg="Failed to list wafv2 tags")
def describe_wafv2_tags(wafv2, arn, fail_json_aws):
next_marker = None
tag_list = []
# there is currently no paginator for wafv2
while True:
responce = _list_tags(wafv2, arn, fail_json_aws)
next_marker = responce.get('NextMarker', None)
tag_info = responce.get('TagInfoForResource', {})
tag_list.extend(tag_info.get('TagList', []))
if not next_marker:
break
return boto3_tag_list_to_ansible_dict(tag_list)
def ensure_wafv2_tags(wafv2, arn, tags, purge_tags, fail_json_aws, check_mode):
if tags is None:
return False
current_tags = describe_wafv2_tags(wafv2, arn, fail_json_aws)
tags_to_add, tags_to_remove = compare_aws_tags(current_tags, tags, purge_tags)
if not tags_to_add and not tags_to_remove:
return False
if check_mode:
return True
if tags_to_add:
try:
boto3_tags = ansible_dict_to_boto3_tag_list(tags_to_add)
wafv2.tag_resource(ResourceARN=arn, Tags=boto3_tags)
except (BotoCoreError, ClientError) as e:
fail_json_aws(e, msg="Failed to add wafv2 tags")
if tags_to_remove:
try:
wafv2.untag_resource(ResourceARN=arn, TagKeys=tags_to_remove)
except (BotoCoreError, ClientError) as e:
fail_json_aws(e, msg="Failed to remove wafv2 tags")
return True
def wafv2_list_web_acls(wafv2, scope, fail_json_aws, nextmarker=None):
# there is currently no paginator for wafv2
req_obj = {
'Scope': scope,
'Limit': 100
}
if nextmarker:
req_obj['NextMarker'] = nextmarker
try:
response = wafv2.list_web_acls(**req_obj)
except (BotoCoreError, ClientError) as e:
fail_json_aws(e, msg="Failed to list wafv2 web acl")
if response.get('NextMarker'):
response['WebACLs'] += wafv2_list_web_acls(wafv2, scope, fail_json_aws, nextmarker=response.get('NextMarker')).get('WebACLs')
return response
def wafv2_list_rule_groups(wafv2, scope, fail_json_aws, nextmarker=None):
# there is currently no paginator for wafv2
req_obj = {
'Scope': scope,
'Limit': 100
}
if nextmarker:
req_obj['NextMarker'] = nextmarker
try:
response = wafv2.list_rule_groups(**req_obj)
except (BotoCoreError, ClientError) as e:
fail_json_aws(e, msg="Failed to list wafv2 rule group")
if response.get('NextMarker'):
response['RuleGroups'] += wafv2_list_rule_groups(wafv2, scope, fail_json_aws, nextmarker=response.get('NextMarker')).get('RuleGroups')
return response
def wafv2_snake_dict_to_camel_dict(a):
if not isinstance(a, dict):
return a
retval = {}
for item in a.keys():
if isinstance(a.get(item), dict):
if 'Ip' in item:
retval[item.replace('Ip', 'IP')] = wafv2_snake_dict_to_camel_dict(a.get(item))
elif 'Arn' == item:
retval['ARN'] = wafv2_snake_dict_to_camel_dict(a.get(item))
else:
retval[item] = wafv2_snake_dict_to_camel_dict(a.get(item))
elif isinstance(a.get(item), list):
retval[item] = []
for idx in range(len(a.get(item))):
retval[item].append(wafv2_snake_dict_to_camel_dict(a.get(item)[idx]))
elif 'Ip' in item:
retval[item.replace('Ip', 'IP')] = a.get(item)
elif 'Arn' == item:
retval['ARN'] = a.get(item)
else:
retval[item] = a.get(item)
return retval
def nested_byte_values_to_strings(rule, keyname):
"""
currently valid nested byte values in statements array are
- OrStatement
- AndStatement
- NotStatement
"""
if rule.get('Statement', {}).get(keyname):
for idx in range(len(rule.get('Statement', {}).get(keyname, {}).get('Statements'))):
if rule['Statement'][keyname]['Statements'][idx].get('ByteMatchStatement'):
rule['Statement'][keyname]['Statements'][idx]['ByteMatchStatement']['SearchString'] = \
rule.get('Statement').get(keyname).get('Statements')[idx].get('ByteMatchStatement').get('SearchString').decode('utf-8')
return rule
def byte_values_to_strings_before_compare(rules):
for idx in range(len(rules)):
if rules[idx].get('Statement', {}).get('ByteMatchStatement', {}).get('SearchString'):
rules[idx]['Statement']['ByteMatchStatement']['SearchString'] = \
rules[idx].get('Statement').get('ByteMatchStatement').get('SearchString').decode('utf-8')
else:
for statement in ['AndStatement', 'OrStatement', 'NotStatement']:
if rules[idx].get('Statement', {}).get(statement):
rules[idx] = nested_byte_values_to_strings(rules[idx], statement)
return rules
def compare_priority_rules(existing_rules, requested_rules, purge_rules, state):
diff = False
existing_rules = sorted(existing_rules, key=lambda k: k['Priority'])
existing_rules = byte_values_to_strings_before_compare(existing_rules)
requested_rules = sorted(requested_rules, key=lambda k: k['Priority'])
if purge_rules and state == 'present':
merged_rules = requested_rules
if len(existing_rules) == len(requested_rules):
for idx in range(len(existing_rules)):
if existing_rules[idx] != requested_rules[idx]:
diff = True
break
else:
diff = True
else:
# find same priority rules
# * pop same priority rule from existing rule
# * compare existing rule
merged_rules = []
ex_idx_pop = []
for existing_idx in range(len(existing_rules)):
for requested_idx in range(len(requested_rules)):
if existing_rules[existing_idx].get('Priority') == requested_rules[requested_idx].get('Priority'):
if state == 'present':
ex_idx_pop.append(existing_idx)
if existing_rules[existing_idx] != requested_rules[requested_idx]:
diff = True
elif existing_rules[existing_idx] == requested_rules[requested_idx]:
ex_idx_pop.append(existing_idx)
diff = True
prev_count = len(existing_rules)
for idx in ex_idx_pop:
existing_rules.pop(idx)
if state == 'present':
merged_rules = existing_rules + requested_rules
if len(merged_rules) != prev_count:
diff = True
else:
merged_rules = existing_rules
return diff, merged_rules

Some files were not shown because too many files have changed in this diff Show More