Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions awswrangler/athena/_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_QUERY_FINAL_STATES,
_QUERY_WAIT_POLLING_DELAY,
_apply_formatter,
_get_workgroup_config,
_get_default_workgroup_config,
_start_query_execution,
_WorkGroupConfig,
)
Expand Down Expand Up @@ -149,7 +149,7 @@ def start_query_execution(
query_execution_id = cache_info.query_execution_id
_logger.debug("Valid cache found. Retrieving...")
else:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
wg_config: _WorkGroupConfig = _get_default_workgroup_config()
query_execution_id = _start_query_execution(
sql=sql,
wg_config=wg_config,
Expand Down
20 changes: 10 additions & 10 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_apply_formatter,
_apply_query_metadata,
_empty_dataframe_response,
_get_default_workgroup_config,
_get_query_metadata,
_get_s3_output,
_get_workgroup_config,
Expand Down Expand Up @@ -431,9 +432,8 @@ def _resolve_query_without_cache_regular(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
client_request_token: str | None = None,
) -> pd.DataFrame | Iterator[pd.DataFrame]:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
s3_output = _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
wg_config: _WorkGroupConfig = _get_default_workgroup_config()
s3_output = s3_output[:-1] if s3_output and s3_output[-1] == "/" else s3_output
_logger.debug("Executing sql: %s", sql)
query_id: str = _start_query_execution(
sql=sql,
Expand Down Expand Up @@ -597,13 +597,13 @@ def _unload(
athena_query_wait_polling_delay: float,
execution_params: list[str] | None,
) -> _QueryMetadata:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
wg_config: _WorkGroupConfig = _get_workgroup_config(workgroup=workgroup, session=boto3_session)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unload requires a TO location, so keeping the previous implementation here, but added an exception in case we don't have S3 output from either user or workgroup

s3_output: str = _get_s3_output(s3_output=path, wg_config=wg_config, boto3_session=boto3_session)
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
# Athena does not enforce a Query Result Location for UNLOAD. Thus, the workgroup output location
# is only used if no path is supplied.
if not path:
path = s3_output
s3_output = s3_output[:-1] if s3_output and s3_output[-1] == "/" else s3_output
if not s3_output:
raise exceptions.InvalidArgumentValue(
"Output S3 location is required for UNLOAD, either as the path argument or as a workgroup configuration"
)

# Set UNLOAD parameters
unload_parameters = f" format='{file_format}'"
Expand All @@ -614,7 +614,7 @@ def _unload(
if partitioned_by:
unload_parameters += f" , partitioned_by=ARRAY{partitioned_by}"

sql = f"UNLOAD ({sql}) TO '{path}' WITH ({unload_parameters})"
sql = f"UNLOAD ({sql}) TO '{s3_output}' WITH ({unload_parameters})"
_logger.debug("Executing unload query: %s", sql)
try:
query_id: str = _start_query_execution(
Expand Down
20 changes: 14 additions & 6 deletions awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,20 @@ def _start_query_execution(
args: dict[str, Any] = {"QueryString": sql}

# s3_output
args["ResultConfiguration"] = {
"OutputLocation": _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
}
if s3_output:
args["ResultConfiguration"] = {"OutputLocation": s3_output}

# encryption
if wg_config.enforced is True:
if "ResultConfiguration" not in args:
args["ResultConfiguration"] = {}
if wg_config.encryption is not None:
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": wg_config.encryption}
if wg_config.kms_key is not None:
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = wg_config.kms_key
elif encryption is not None:
if "ResultConfiguration" not in args:
args["ResultConfiguration"] = {}
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": encryption}
if kms_key is not None:
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key
Expand Down Expand Up @@ -140,6 +143,12 @@ def _start_query_execution(
return response["QueryExecutionId"]


def _get_default_workgroup_config() -> _WorkGroupConfig:
wg_config: _WorkGroupConfig = _WorkGroupConfig(enforced=False, s3_output=None, encryption=None, kms_key=None)
_logger.debug("Default workgroup config:\n%s", wg_config)
return wg_config


def _get_workgroup_config(session: boto3.Session | None = None, workgroup: str = "primary") -> _WorkGroupConfig:
enforced: bool
wg_s3_output: str | None
Expand Down Expand Up @@ -783,9 +792,8 @@ def create_ctas_table(

fully_qualified_name = f'"{ctas_database}"."{ctas_table}"'

wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
s3_output = _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
wg_config: _WorkGroupConfig = _get_default_workgroup_config()
s3_output = s3_output[:-1] if s3_output and s3_output[-1] == "/" else s3_output
# If the workgroup enforces an external location, then it overrides the user supplied argument
external_location_str: str = (
f" external_location = '{s3_output}/{ctas_table}',\n" if (not wg_config.enforced) and (s3_output) else ""
Expand Down
8 changes: 4 additions & 4 deletions awswrangler/athena/_write_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from awswrangler._config import apply_configs
from awswrangler.athena._executions import wait_query
from awswrangler.athena._utils import (
_get_workgroup_config,
_get_default_workgroup_config,
_start_query_execution,
_WorkGroupConfig,
)
Expand Down Expand Up @@ -303,7 +303,7 @@ def _merge_iceberg(
None

"""
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
wg_config: _WorkGroupConfig = _get_default_workgroup_config()

sql_statement: str
if merge_cols:
Expand Down Expand Up @@ -488,7 +488,7 @@ def to_iceberg( # noqa: PLR0913
... )

"""
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
wg_config: _WorkGroupConfig = _get_default_workgroup_config()
temp_table: str = f"temp_table_{uuid.uuid4().hex}"

_validate_args(
Expand Down Expand Up @@ -743,7 +743,7 @@ def delete_from_iceberg_table(
if not merge_cols:
raise exceptions.InvalidArgumentValue("Merge columns must be specified.")

wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
wg_config: _WorkGroupConfig = _get_default_workgroup_config()
temp_table: str = f"temp_table_{uuid.uuid4().hex}"

if not temp_path and not wg_config.s3_output:
Expand Down
Loading