diff --git a/docs/07-network-policy-database-overview.md b/docs/07-network-policy-database-overview.md
index 67b3285dc..00fee4427 100644
--- a/docs/07-network-policy-database-overview.md
+++ b/docs/07-network-policy-database-overview.md
@@ -26,10 +26,10 @@ This document is intended to help people who are poking around the `network_poli
## How to access an internal database
1. Bosh ssh onto the VM where the `policy-server` is running. You can figure out what machine by running `bosh is --ps | grep policy-server`.
-2. Grab the mysql config.
+2. Grab the mysql config.
```
$ cat /var/vcap/jobs/policy-server/config/policy-server.json | grep \"database\" -A 11
-
+
"database": {
"type": "mysql",
"user": "USER",
@@ -43,7 +43,7 @@ This document is intended to help people who are poking around the `network_poli
"skip_hostname_validation": false
},
```
-
+
3. Bosh ssh onto the database VM.
4. Connect to the mysql instance.
```
@@ -61,7 +61,11 @@ Below are all of the tables in the `network_policy` database.
| gorp_migrations | Record of which migrations have been run. |
| groups | List of all apps that are either the source or destination of a network policy. |
| policies | List of source apps and destination metadata for network policies. |
-
+| policies_info | A single row indicating the last update time of any network policy, used to save on DB queries from vxlan-policy-agent |
+| security_groups | Lists all security groups defined in CAPI. This is populated by policy-server-asg-syncer, and is not a source of truth. |
+| security_groups_info | A single row indicating the last update time of any security group info, used to save on DB queries from vxlan-policy-agent |
+| running_security_groups_spaces | A join table associating security groups to the spaces they are bound to for running lifecycle workloads. |
+| staging_security_groups_spaces | A join table associating security groups to the spaes they are bound to for staging lifecycle workloads. |
The following tables were related to dynamic egress, which has been removed
from the codebase. These tables should no longer present in your database as of
@@ -80,7 +84,7 @@ v3.6.0.
## Network Policy Related Tables
-There are three tables related to cf networking policies: policies, groups, and destinations.
+There are four tables related to cf networking policies: policies, groups, destinations, and policies_info.
### Groups
@@ -152,11 +156,25 @@ mysql> describe policies;
| group_id | This is the id for the group table entry that represents the source app. |
| destination_id | This is the id for the destinations table entry that represents the destination metadata. |
+### policies_info
+This table is a single row with a single value that represents the last updated timestamp of policy data,
+to allow VXLAN Policy Agent to short-circuit its sync loop if no changes were made.
+
+```
+mysql> describe policies_info;
++--------------+--------------+------+-----+----------------------+-------------------+
+| Field | Type | Null | Key | Default | Extra |
++--------------+--------------+------+-----+----------------------+-------------------+
+| id | int | NO | PRI | NULL | auto_increment |
+| last_updated | timestamp(6) | NO | | CURRENT_TIMESTAMP(6) | DEFAULT_GENERATED |
++--------------+--------------+------+-----+----------------------+-------------------+
+```
+
## Networking Policy Example
-In this example:
-* There is a network policy from AppA to AppB.
+In this example:
+* There is a network policy from AppA to AppB.
* AppA has guid `2ffe4b0f-b03c-48bb-a4fa-bf22657d34a2`
* AppB has guid `5346072e-7265-45f9-b70a-80c42e3f13ae`
@@ -181,14 +199,117 @@ mysql> select * from policies; mysql> select * from destinations;
| 2 | 5346072e-7265-45f9-b70a-80c42e3f13ae | app <--+
| 3 | NULL | app |
+----+--------------------------------------+------+
+```
+
+## Security Group Related Tables
+
+There are four tables storing information about security groups: security_groups, running_security_groups_spaces,
+staging_security_groups_spaces, and security_groups_info.
+
+
+### security_groups
+This table stores a copy of all security groups found in CAPI, so vxlan-policy-agent can query
+policy-server-internal for this information, rather than overwhelm CAPI with requests. Its data is
+synced and updated via the policy-server-asg-syncer process, and is not a source of truth for ASG data.
+
+```
+mysql> describe security_groups;
++-----------------+--------------+------+-----+---------+----------------+
+| Field | Type | Null | Key | Default | Extra |
++-----------------+--------------+------+-----+---------+----------------+
+| id | bigint | NO | PRI | NULL | auto_increment |
+| guid | varchar(36) | NO | UNI | NULL | |
+| name | varchar(255) | NO | | NULL | |
+| rules | mediumtext | YES | | NULL | |
+| staging_default | tinyint(1) | YES | MUL | 0 | |
+| running_default | tinyint(1) | YES | MUL | 0 | |
+| staging_spaces | json | YES | | NULL | |
+| running_spaces | json | YES | | NULL | |
++-----------------+--------------+------+-----+---------+----------------+
+```
+| Field | Note |
+|---|---|
+| id | An internal id for each record |
+| guid | The CAPI GUID of the security group |
+| name | The name of the security group as it appears in CAPI |
+| hash | A SHA256 hash of the ASG data, used to check whether records need updating during policy-server-asg-syncer polls |
+| rules | The rules (in JSON) associated with the ASG defined in CAPI |
+| staging_default | Whether or not this is a globally bound security group for `staging` lifecycles |
+| running_default | Whether or not this is a globally bound security group for `running` lifecycles |
+| staging_spaces | A json list of CAPI guids for all spaces this security group is bound to for the `staging` lifecycle. This column duplicates data in the `staging_security_groups_spaces` table, but is already in JSON format so we pull it out for faster data presentation when serving queries from VXLAN Policy Agent, while filtering via the `staging_security_groups_spaces` table. |
+| running_spaces | A json list of CAPI guids for all spaces this security group is bound to for the `running` lifecycle. This column duplicates data in the `running_security_groups_spaces` table, but is already in JSON format so we pull it out for faster data presentation when serving queries from VXLAN Policy Agent, while filtering via the `running_security_groups_spaces` table. |
+
+### running_security_groups_spaces
+This table is a join table to enable faster querying of security_groups when filtering by
+running_space guids. It is used by the BySpaceGuids() store function, when returning lists
+of ASGs for a given set of space guids. Querying the space associations directly in the security_groups
+table results in unindexed queries, and giant full-table scans which topple databases with thousands of
+ASGs. Adding this table enables indexed lookups of space guids to find the security group they're bound to,
+drasticly speeding up query times for VXLAN Policy Agent requests.
+
+It is synced and updated via the policy-server-asg-syncer process, and is not a source of
+truth for ASG data.
```
+mysql> describe running_security_groups_spaces;
++---------------------+--------------+------+-----+---------+-------+
+| Field | Type | Null | Key | Default | Extra |
++---------------------+--------------+------+-----+---------+-------+
+| space_guid | varchar(255) | NO | PRI | NULL | |
+| security_group_guid | varchar(255) | NO | PRI | NULL | |
++---------------------+--------------+------+-----+---------+-------+
+```
+
+| Field | Note|
+|---|---|
+| space_guid | This value is the CAPI guid for the space bound to a given security group via the `running` app lifecycle |
+| security_group_guid | This value is the CAPI guid for the security group bound to a given space via the `running` app lifecycle |
+### staging_security_groups_spaces
+This table is a join table to enable faster querying of security_groups when filtering by
+staging_space guids. It is used by the BySpaceGuids() store function, when returning lists
+of ASGs for a given set of space guids. Querying the space associations directly in the security_groups
+table results in unindexed queries, and giant full-table scans which topple databases with thousands of
+ASGs. Adding this table enables indexed lookups of space guids to find the security group they're bound to,
+drasticly speeding up query times for VXLAN Policy Agent requests.
+
+It is synced and updated via the policy-server-asg-syncer process, and is not a source of
+truth for ASG data.
+
+```
+mysql> describe staging_security_groups_spaces;
++---------------------+--------------+------+-----+---------+-------+
+| Field | Type | Null | Key | Default | Extra |
++---------------------+--------------+------+-----+---------+-------+
+| space_guid | varchar(255) | NO | PRI | NULL | |
+| security_group_guid | varchar(255) | NO | PRI | NULL | |
++---------------------+--------------+------+-----+---------+-------+
+```
+
+| Field | Note|
+|---|---|
+| space_guid | This value is the CAPI guid for the space bound to a given security group via the `staging` app lifecycle |
+| security_group_guid | This value is the CAPI guid for the security group bound to a given space via the `staging` app lifecycle |
+
+### security_groups_info
+This table is a single row with a single value that represents the last updated timestamp of security group data,
+to allow VXLAN Policy Agent to short-circuit its sync loop if no changes were made.
+
+```
+mysql> describe security_groups_info;
++--------------+--------------+------+-----+----------------------+-------------------+
+| Field | Type | Null | Key | Default | Extra |
++--------------+--------------+------+-----+----------------------+-------------------+
+| id | int | NO | PRI | NULL | auto_increment |
+| last_updated | timestamp(6) | NO | | CURRENT_TIMESTAMP(6) | DEFAULT_GENERATED |
++--------------+--------------+------+-----+----------------------+-------------------+
+```
+
## Migration Related Tables
-There are two tables related to migraitons: gorp_migrations and gorp_lock.
+There are two tables related to migrations: gorp_migrations and gorp_lock.
### gorp_migrations
This table tracks what database migrations have been applied.
diff --git a/src/code.cloudfoundry.org/cf-pusher/cmd/multispace-pusher/main.go b/src/code.cloudfoundry.org/cf-pusher/cmd/multispace-pusher/main.go
index cdc756e41..f1d8a4547 100644
--- a/src/code.cloudfoundry.org/cf-pusher/cmd/multispace-pusher/main.go
+++ b/src/code.cloudfoundry.org/cf-pusher/cmd/multispace-pusher/main.go
@@ -29,6 +29,7 @@ type Config struct {
TotalSpaces int `json:"total_spaces"`
AppsPerSpace int `json:"apps_per_space"`
SkipASGCreation bool `json:"skip_asg_creation"`
+ SkipSpaceCreation bool `json:"skip_space_creation"`
}
type ConcurrentSpaceSetup struct {
@@ -60,9 +61,9 @@ func main() {
if !config.SkipASGCreation {
// Create global asgs
createGlobalASGs(config)
- // Create a bunch of bindable ASGs
- asgs = createASGs(config.TotalASGs-config.GlobalASGs, config.ASGSize, config.Prefix, globalAdapter)
}
+ // Create a bunch of bindable ASGs
+ asgs = createASGs(config.TotalASGs-config.GlobalASGs, config.ASGSize, config.Prefix, globalAdapter, config.SkipASGCreation)
spaces = createSpacesConcurrently(config)
orgName := fmt.Sprintf("%s-org", config.Prefix)
@@ -94,30 +95,34 @@ func createSpacesConcurrently(config Config) []string {
sem := make(chan bool, config.Concurrency)
var spaceNames []string
for i := 0; i < config.TotalSpaces; i++ {
- sem <- true
setup := generateConcurrentSpaceSetup(i, config)
spaceNames = append(spaceNames, setup.OrgSpaceCreator.Space)
- go func(s *ConcurrentSpaceSetup, c Config, index int) {
- defer func() { <-sem }()
-
- // Connect to the api with this adapter
- if err := s.ApiConnector.Connect(); err != nil {
- log.Fatalf("connecting to api: %s", err)
- }
+ if !config.SkipSpaceCreation {
+ sem <- true
+ go func(s *ConcurrentSpaceSetup, c Config, index int) {
+ defer func() { <-sem }()
+
+ // Connect to the api with this adapter
+ if err := s.ApiConnector.Connect(); err != nil {
+ log.Fatalf("connecting to api: %s", err)
+ }
- // Create and target the space
- if err := s.OrgSpaceCreator.Create(); err != nil {
- log.Fatalf("creating org and space: %s", err)
- }
+ // Create and target the space
+ if err := s.OrgSpaceCreator.Create(); err != nil {
+ log.Fatalf("creating org and space: %s", err)
+ }
- // Push apps for this space
- if err := s.AppPusher.Push(); err != nil {
- log.Printf("Got an error while pushing proxy apps: %s", err)
- }
- }(setup, config, i)
+ // Push apps for this space
+ if err := s.AppPusher.Push(); err != nil {
+ log.Printf("Got an error while pushing proxy apps: %s", err)
+ }
+ }(setup, config, i)
+ }
}
- for i := 0; i < cap(sem); i++ {
- sem <- true
+ if !config.SkipSpaceCreation {
+ for i := 0; i < cap(sem); i++ {
+ sem <- true
+ }
}
return spaceNames
}
@@ -268,27 +273,29 @@ func bindASGToThisSpace(asg string, orgName, spaceName string, adapter *cf_cli_a
}
}
-func createASGs(howMany, asgSize int, prefix string, adapter *cf_cli_adapter.Adapter) []string {
+func createASGs(howMany, asgSize int, prefix string, adapter *cf_cli_adapter.Adapter, skipASGCreation bool) []string {
var asgNames []string
for i := 0; i < howMany; i++ {
asgName := fmt.Sprintf("%s-many-%d-asg", prefix, i)
asgNames = append(asgNames, asgName)
- asgContent := testsupport.BuildASG(asgSize)
- asgFile, err := testsupport.CreateTempFile(asgContent)
- if err != nil {
- log.Fatalf("creating asg file: %s", err)
- }
-
- // check ASG and create if not OK
- asgChecker := cf_command.ASGChecker{Adapter: adapter}
- asgErr := asgChecker.CheckASG(asgName, asgContent)
- if asgErr != nil {
- // install ASG
- if err := adapter.DeleteSecurityGroup(asgName); err != nil {
- log.Fatalf("deleting security group: %s", err)
+ if !skipASGCreation {
+ asgContent := testsupport.BuildASG(asgSize)
+ asgFile, err := testsupport.CreateTempFile(asgContent)
+ if err != nil {
+ log.Fatalf("creating asg file: %s", err)
}
- if err := adapter.CreateSecurityGroup(asgName, asgFile); err != nil {
- log.Fatalf("creating security group: %s", err)
+
+ // check ASG and create if not OK
+ asgChecker := cf_command.ASGChecker{Adapter: adapter}
+ asgErr := asgChecker.CheckASG(asgName, asgContent)
+ if asgErr != nil {
+ // install ASG
+ if err := adapter.DeleteSecurityGroup(asgName); err != nil {
+ log.Fatalf("deleting security group: %s", err)
+ }
+ if err := adapter.CreateSecurityGroup(asgName, asgFile); err != nil {
+ log.Fatalf("creating security group: %s", err)
+ }
}
}
}
diff --git a/src/code.cloudfoundry.org/policy-server/cmd/policy-server-asg-syncer/main.go b/src/code.cloudfoundry.org/policy-server/cmd/policy-server-asg-syncer/main.go
index 12314a6e9..f8e1c6cff 100644
--- a/src/code.cloudfoundry.org/policy-server/cmd/policy-server-asg-syncer/main.go
+++ b/src/code.cloudfoundry.org/policy-server/cmd/policy-server-asg-syncer/main.go
@@ -73,7 +73,8 @@ func main() {
}
securityGroupsStore := &store.SGStore{
- Conn: connectionPool,
+ Logger: logger.Session("security-groups-store"),
+ Conn: connectionPool,
}
metricsSender := &metrics.MetricsSender{
diff --git a/src/code.cloudfoundry.org/policy-server/cmd/policy-server-internal/main.go b/src/code.cloudfoundry.org/policy-server/cmd/policy-server-internal/main.go
index ee0e70c14..06f88d329 100644
--- a/src/code.cloudfoundry.org/policy-server/cmd/policy-server-internal/main.go
+++ b/src/code.cloudfoundry.org/policy-server/cmd/policy-server-internal/main.go
@@ -80,7 +80,8 @@ func main() {
)
securityGroupsStore := &store.SGStore{
- Conn: connectionPool,
+ Logger: logger.Session("security-groups-store"),
+ Conn: connectionPool,
}
tagDataStore := store.NewTagStore(connectionPool, &store.GroupTable{}, conf.TagLength)
diff --git a/src/code.cloudfoundry.org/policy-server/store/helpers/helpers.go b/src/code.cloudfoundry.org/policy-server/store/helpers/helpers.go
index 053f0fd28..4bcaf6a4e 100644
--- a/src/code.cloudfoundry.org/policy-server/store/helpers/helpers.go
+++ b/src/code.cloudfoundry.org/policy-server/store/helpers/helpers.go
@@ -35,19 +35,3 @@ func RebindForSQLDialect(query, dialect string) string {
}
return strings.Join(strParts, "")
}
-
-func RebindForSQLDialectAndMark(query, dialect, mark string) string {
- if dialect != Postgres && dialect != MySQL {
- panic(fmt.Sprintf("Unrecognized DB dialect '%s'", dialect))
- }
-
- if dialect == MySQL {
- return strings.ReplaceAll(query, mark, "?")
- }
-
- strParts := strings.Split(query, mark)
- for i := 1; i < len(strParts); i++ {
- strParts[i-1] = fmt.Sprintf("%s$%d", strParts[i-1], i)
- }
- return strings.Join(strParts, "")
-}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/migrations.go b/src/code.cloudfoundry.org/policy-server/store/migrations/migrations.go
index 70a24287e..d5d5f6f65 100644
--- a/src/code.cloudfoundry.org/policy-server/store/migrations/migrations.go
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/migrations.go
@@ -472,4 +472,40 @@ var MigrationsToPerform = PolicyServerMigrations{
Id: "91",
Up: migration_v0091,
},
+ PolicyServerMigration{
+ Id: "92",
+ Up: migration_v0092,
+ },
+ PolicyServerMigration{
+ Id: "93",
+ Up: migration_v0093,
+ },
+ PolicyServerMigration{
+ Id: "94",
+ Up: migration_v0094,
+ },
+ PolicyServerMigration{
+ Id: "95",
+ Up: migration_v0095,
+ },
+ PolicyServerMigration{
+ Id: "96",
+ Up: migration_v0096,
+ },
+ PolicyServerMigration{
+ Id: "97",
+ Up: migration_v0097,
+ },
+ PolicyServerMigration{
+ Id: "98",
+ Up: migration_v0098,
+ },
+ PolicyServerMigration{
+ Id: "99",
+ Up: migration_v0099,
+ },
+ PolicyServerMigration{
+ Id: "100",
+ Up: migration_v0100,
+ },
}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/migrator_test.go b/src/code.cloudfoundry.org/policy-server/store/migrations/migrator_test.go
index c0624a50f..ee46c43d2 100644
--- a/src/code.cloudfoundry.org/policy-server/store/migrations/migrator_test.go
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/migrator_test.go
@@ -836,13 +836,13 @@ var _ = Describe("migrations", func() {
By("inserting new data")
_, err = realDb.Exec(realDb.RawConnection().Rebind(`
- INSERT INTO egress_policies (source_id, destination_id)
+ INSERT INTO egress_policies (source_id, destination_id)
VALUES (?, ?)`), terminalId, terminalId)
Expect(err).NotTo(HaveOccurred())
By("verifying new row exists")
rows, err = realDb.Query(`
- SELECT id FROM egress_policies
+ SELECT id FROM egress_policies
WHERE source_id=1 AND destination_id=1`)
Expect(err).NotTo(HaveOccurred())
Expect(scanCountRow(rows)).To(Equal(1))
@@ -850,7 +850,7 @@ var _ = Describe("migrations", func() {
It("constrains the terminal id to existing rows", func() {
_, err := realDb.Exec(`
- INSERT INTO egress_policies (source_id, destination_id)
+ INSERT INTO egress_policies (source_id, destination_id)
VALUES (42, 23)`)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("violates foreign key constraint"))
@@ -882,13 +882,13 @@ var _ = Describe("migrations", func() {
By("inserting new data")
_, err = realDb.Exec(`
- INSERT INTO ip_ranges (protocol, start_ip, end_ip, terminal_id)
+ INSERT INTO ip_ranges (protocol, start_ip, end_ip, terminal_id)
VALUES ('tcp', '1.2.3.4', '2.3.4.5', ?)`, terminalId)
Expect(err).NotTo(HaveOccurred())
By("verifying new row exists")
rows, err = realDb.Query(`
- SELECT id FROM ip_ranges
+ SELECT id FROM ip_ranges
WHERE protocol='tcp' AND start_ip='1.2.3.4' AND end_ip='2.3.4.5' AND terminal_id=1`)
Expect(err).NotTo(HaveOccurred())
Expect(scanCountRow(rows)).To(Equal(1))
@@ -896,7 +896,7 @@ var _ = Describe("migrations", func() {
It("constrains the policy id to existing rows", func() {
_, err := realDb.Exec(`
- INSERT INTO ip_ranges (protocol, start_ip, end_ip, terminal_id)
+ INSERT INTO ip_ranges (protocol, start_ip, end_ip, terminal_id)
VALUES ('tcp', '1.2.3.4', '2.3.4.5', 42)`)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("foreign key constraint fails"))
@@ -926,13 +926,13 @@ var _ = Describe("migrations", func() {
By("inserting new data")
_, err = realDb.Exec(realDb.RawConnection().Rebind(`
- INSERT INTO ip_ranges (protocol, start_ip, end_ip, terminal_id)
+ INSERT INTO ip_ranges (protocol, start_ip, end_ip, terminal_id)
VALUES ('tcp', '1.2.3.4', '2.3.4.5', ?)`), terminalId)
Expect(err).NotTo(HaveOccurred())
By("verifying new row exists")
rows, err = realDb.Query(`
- SELECT id FROM ip_ranges
+ SELECT id FROM ip_ranges
WHERE protocol='tcp' AND start_ip='1.2.3.4' AND end_ip='2.3.4.5' AND terminal_id=1`)
Expect(err).NotTo(HaveOccurred())
Expect(scanCountRow(rows)).To(Equal(1))
@@ -940,7 +940,7 @@ var _ = Describe("migrations", func() {
It("constrains the policy id to existing rows", func() {
_, err := realDb.Exec(`
- INSERT INTO ip_ranges (protocol, start_ip, end_ip, terminal_id)
+ INSERT INTO ip_ranges (protocol, start_ip, end_ip, terminal_id)
VALUES ('tcp','1.2.3.4','2.3.4.5',42)`)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("violates foreign key constraint"))
@@ -1011,7 +1011,7 @@ var _ = Describe("migrations", func() {
By("inserting new data")
_, err = realDb.Exec(realDb.RawConnection().Rebind(`
- INSERT INTO apps (terminal_id, app_guid)
+ INSERT INTO apps (terminal_id, app_guid)
VALUES (?,'an-app-guid')`), terminalId)
Expect(err).NotTo(HaveOccurred())
@@ -2165,6 +2165,220 @@ var _ = Describe("migrations", func() {
Expect(numMigrations).To(Equal(2))
})
})
+
+ Describe("v92-93 - adding new association tables", func() {
+
+ It("succeeds", func() {
+ migrateTo("93")
+ })
+ })
+ Describe("v94", func() {
+ It("adds a `hash` column to security_groups", func() {
+ migrateTo("94")
+ _, err := realDb.Query("SELECT hash from security_groups")
+ Expect(err).NotTo(HaveOccurred())
+ })
+ })
+ Describe("v95-100 - migrating json running_spaces/staging_spaces to join tables", func() {
+ BeforeEach(func() {
+ migrateTo("94")
+
+ var rows int
+ err := realDb.QueryRow("SELECT COUNT(*) FROM security_groups").Scan(&rows)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(rows).To(Equal(0))
+ err = realDb.QueryRow("SELECT COUNT(*) FROM running_security_groups_spaces").Scan(&rows)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(rows).To(Equal(0))
+ err = realDb.QueryRow("SELECT COUNT(*) FROM staging_security_groups_spaces").Scan(&rows)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(rows).To(Equal(0))
+ })
+ AfterEach(func() {
+ _, err := realDb.Exec("DELETE FROM security_groups")
+ Expect(err).NotTo(HaveOccurred())
+
+ _, err = realDb.Exec("DELETE FROM running_security_groups_spaces")
+ Expect(err).NotTo(HaveOccurred())
+ _, err = realDb.Exec("DELETE FROM staging_security_groups_spaces")
+ Expect(err).NotTo(HaveOccurred())
+ })
+ Context("when no rows exist in the security_groups table", func() {
+ It("migrates without error", func() {
+ migrateTo("100")
+ })
+ })
+ Context("when a security group has running spaces but no staging spaces", func() {
+ BeforeEach(func() {
+ _, err := realDb.Exec(
+ `INSERT INTO security_groups (name, guid, running_spaces, staging_spaces) VALUES
+ ('guid-1', 'guid-1', JSON_ARRAY('space-1', 'space-2'), JSON_ARRAY())`)
+ Expect(err).NotTo(HaveOccurred())
+ })
+ It("adds entries to the running_space join table, but nothing to the staging space join table", func() {
+ migrateTo("100")
+
+ ExpectAssociatedSpacesToConsistOf(realDb, "running", []JoinRow{{
+ SecurityGroup: "guid-1",
+ Space: "space-1",
+ }, {
+ SecurityGroup: "guid-1",
+ Space: "space-2",
+ }})
+ ExpectAssociatedSpacesToConsistOf(realDb, "staging", []JoinRow{})
+ })
+ })
+ Context("when a security group has staging spaces but no running spaces", func() {
+ BeforeEach(func() {
+ _, err := realDb.Exec(
+ `INSERT INTO security_groups (name, guid, running_spaces, staging_spaces) VALUES
+ ('guid-1', 'guid-1', JSON_ARRAY(), JSON_ARRAY('space-1', 'space-2'))`)
+ Expect(err).NotTo(HaveOccurred())
+ })
+ It("adds entries to the staging_space join table, but nothing to the running space join table", func() {
+ migrateTo("100")
+
+ ExpectAssociatedSpacesToConsistOf(realDb, "staging", []JoinRow{{
+ SecurityGroup: "guid-1",
+ Space: "space-1",
+ }, {
+ SecurityGroup: "guid-1",
+ Space: "space-2",
+ }})
+ ExpectAssociatedSpacesToConsistOf(realDb, "running", []JoinRow{})
+ })
+ })
+ Context("when a security group has no staging spaces or no running spaces", func() {
+ BeforeEach(func() {
+ _, err := realDb.Exec(
+ `INSERT INTO security_groups (name, guid, running_spaces, staging_spaces) VALUES
+ ('guid-1', 'guid-1', JSON_ARRAY(), JSON_ARRAY())`)
+ Expect(err).NotTo(HaveOccurred())
+ })
+ It("adds nothing to the join tables", func() {
+ migrateTo("100")
+
+ ExpectAssociatedSpacesToConsistOf(realDb, "running", []JoinRow{})
+ ExpectAssociatedSpacesToConsistOf(realDb, "staging", []JoinRow{})
+ })
+ })
+ Context("when a security group has staging spaces and running spaces", func() {
+ BeforeEach(func() {
+ _, err := realDb.Exec(
+ `INSERT INTO security_groups (name, guid, running_spaces, staging_spaces) VALUES
+ ('guid-1', 'guid-1', JSON_ARRAY('space-1', 'space-2', 'common-space-1'), JSON_ARRAY('space-3', 'space-4', 'common-space-1'))`)
+ Expect(err).NotTo(HaveOccurred())
+ })
+ It("adds the security group to both tables with the correct running + staging spaces in each", func() {
+ migrateTo("100")
+
+ ExpectAssociatedSpacesToConsistOf(realDb, "staging", []JoinRow{{
+ SecurityGroup: "guid-1",
+ Space: "space-3",
+ }, {
+ SecurityGroup: "guid-1",
+ Space: "space-4",
+ }, {
+ SecurityGroup: "guid-1",
+ Space: "common-space-1",
+ }})
+ ExpectAssociatedSpacesToConsistOf(realDb, "running", []JoinRow{{
+ SecurityGroup: "guid-1",
+ Space: "space-1",
+ }, {
+ SecurityGroup: "guid-1",
+ Space: "space-2",
+ }, {
+ SecurityGroup: "guid-1",
+ Space: "common-space-1",
+ }})
+ })
+ })
+ Context("when the longest number of spaces associated with a security group is in the running_spaces column", func() {
+ BeforeEach(func() {
+ _, err := realDb.Exec(
+ `INSERT INTO security_groups (name, guid, running_spaces, staging_spaces) VALUES
+ ('guid-1', 'guid-1', JSON_ARRAY('space-1', 'space-2', 'space-3'), JSON_ARRAY('space-3', 'space-4')),
+ ('guid-2', 'guid-2', JSON_ARRAY('space-1', 'space-2'), JSON_ARRAY('space-3', 'space-4'))
+ `)
+ Expect(err).NotTo(HaveOccurred())
+ })
+ It("is able to migrate successfully", func() {
+ migrateTo("100")
+ })
+ })
+ Context("when the longest number of spaces associated with a security group is in the staging_spaces column", func() {
+ BeforeEach(func() {
+ _, err := realDb.Exec(
+ `INSERT INTO security_groups (name, guid, running_spaces, staging_spaces) VALUES
+ ('guid-1', 'guid-1', JSON_ARRAY('space-1', 'space-2', 'space-3'), JSON_ARRAY('space-3', 'space-4')),
+ ('guid-2', 'guid-2', JSON_ARRAY('space-1', 'space-2'), JSON_ARRAY('space-3', 'space-4', 'space-5', 'space-6'))
+ `)
+ Expect(err).NotTo(HaveOccurred())
+ })
+ It("is able to migrate successfully", func() {
+ migrateTo("100")
+ })
+ })
+ Context("when multiple rows have multiple spaces bound", func() {
+ BeforeEach(func() {
+ _, err := realDb.Exec(
+ `INSERT INTO security_groups (name, guid, running_spaces, staging_spaces) VALUES
+ ('guid-1', 'guid-1', JSON_ARRAY('space-1', 'space-2', 'space-3'), JSON_ARRAY('space-3', 'space-4')),
+ ('guid-2', 'guid-2', JSON_ARRAY('space-1', 'space-2'), JSON_ARRAY('space-3', 'space-4', 'space-5', 'space-6')),
+ ('guid-3', 'guid-3', JSON_ARRAY(), JSON_ARRAY()),
+ ('guid-4', 'guid-4', JSON_ARRAY(), JSON_ARRAY('space-10')),
+ ('guid-5', 'guid-5', JSON_ARRAY('space-11'), JSON_ARRAY())
+ `)
+ Expect(err).NotTo(HaveOccurred())
+ })
+ It("adds entries into all relevant join records", func() {
+ migrateTo("100")
+
+ ExpectAssociatedSpacesToConsistOf(realDb, "staging", []JoinRow{{
+ SecurityGroup: "guid-1",
+ Space: "space-3",
+ }, {
+ SecurityGroup: "guid-1",
+ Space: "space-4",
+ }, {
+ SecurityGroup: "guid-2",
+ Space: "space-3",
+ }, {
+ SecurityGroup: "guid-2",
+ Space: "space-4",
+ }, {
+ SecurityGroup: "guid-2",
+ Space: "space-5",
+ }, {
+ SecurityGroup: "guid-2",
+ Space: "space-6",
+ }, {
+ SecurityGroup: "guid-4",
+ Space: "space-10",
+ }})
+ ExpectAssociatedSpacesToConsistOf(realDb, "running", []JoinRow{{
+ SecurityGroup: "guid-1",
+ Space: "space-1",
+ }, {
+ SecurityGroup: "guid-1",
+ Space: "space-2",
+ }, {
+ SecurityGroup: "guid-1",
+ Space: "space-3",
+ }, {
+ SecurityGroup: "guid-2",
+ Space: "space-1",
+ }, {
+ SecurityGroup: "guid-2",
+ Space: "space-2",
+ }, {
+ SecurityGroup: "guid-5",
+ Space: "space-11",
+ }})
+ })
+ })
+ })
})
Describe("Down Migration", func() {
@@ -2350,3 +2564,28 @@ func isPostgresOrMySQL57(realDb *db.ConnWrapper) bool {
}
return true
}
+
+type JoinRow struct {
+ SecurityGroup string
+ Space string
+}
+
+func ExpectAssociatedSpacesToConsistOf(realDb *db.ConnWrapper, table string, expectedRows []JoinRow) {
+ rows, err := realDb.Query(fmt.Sprintf("SELECT security_group_guid, space_guid FROM %s_security_groups_spaces", table))
+ ExpectWithOffset(1, err).ToNot(HaveOccurred())
+ if len(expectedRows) == 0 {
+ ExpectWithOffset(1, scanCountRow(rows)).To(Equal(0))
+ } else {
+ var receivedRows []JoinRow
+ for rows.Next() {
+ var sg, space string
+ err := rows.Scan(&sg, &space)
+ ExpectWithOffset(1, err).ToNot(HaveOccurred())
+ receivedRows = append(receivedRows, JoinRow{
+ SecurityGroup: sg,
+ Space: space,
+ })
+ }
+ ExpectWithOffset(1, receivedRows).To(ConsistOf(expectedRows))
+ }
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/v0092.go b/src/code.cloudfoundry.org/policy-server/store/migrations/v0092.go
new file mode 100644
index 000000000..4d011f5a0
--- /dev/null
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/v0092.go
@@ -0,0 +1,24 @@
+package migrations
+
+var migration_v0092 = map[string][]string{
+ "mysql": {
+ `CREATE TABLE IF NOT EXISTS "running_security_groups_spaces" (
+ space_guid VARCHAR(36) NOT NULL,
+ security_group_guid VARCHAR(36) NOT NULL,
+ UNIQUE (space_guid, security_group_guid),
+ PRIMARY KEY (space_guid, security_group_guid),
+ CONSTRAINT "running_security_groups_guid_fkey" FOREIGN KEY (security_group_guid) REFERENCES "security_groups" (guid) ON DELETE CASCADE
+ )
+ `,
+ },
+ "postgres": {
+ `CREATE TABLE IF NOT EXISTS running_security_groups_spaces (
+ space_guid varchar(36) NOT NULL,
+ security_group_guid varchar(36) NOT NULL
+ CONSTRAINT running_sg_spaces_fk REFERENCES security_groups(guid) ON DELETE CASCADE,
+ CONSTRAINT running_sg_spaces_unique UNIQUE (space_guid, security_group_guid),
+ CONSTRAINT running_sg_spaces_pk PRIMARY KEY (space_guid, security_group_guid)
+ )
+ `,
+ },
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/v0093.go b/src/code.cloudfoundry.org/policy-server/store/migrations/v0093.go
new file mode 100644
index 000000000..94be11d1b
--- /dev/null
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/v0093.go
@@ -0,0 +1,24 @@
+package migrations
+
+var migration_v0093 = map[string][]string{
+ "mysql": {
+ `CREATE TABLE IF NOT EXISTS "staging_security_groups_spaces" (
+ space_guid VARCHAR(36) NOT NULL,
+ security_group_guid VARCHAR(36) NOT NULL,
+ UNIQUE (space_guid, security_group_guid),
+ PRIMARY KEY (space_guid, security_group_guid),
+ CONSTRAINT "staging_security_groups_guid_fkey" FOREIGN KEY (security_group_guid) REFERENCES "security_groups" (guid) ON DELETE CASCADE
+ )
+ `,
+ },
+ "postgres": {
+ `CREATE TABLE IF NOT EXISTS staging_security_groups_spaces (
+ space_guid varchar(36) NOT NULL,
+ security_group_guid varchar(36) NOT NULL
+ CONSTRAINT staging_sg_spaces_fk REFERENCES security_groups(guid) ON DELETE CASCADE,
+ CONSTRAINT staging_sg_spaces_unique UNIQUE (space_guid, security_group_guid),
+ CONSTRAINT staging_sg_spaces_pk PRIMARY KEY (space_guid, security_group_guid)
+ )
+ `,
+ },
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/v0094.go b/src/code.cloudfoundry.org/policy-server/store/migrations/v0094.go
new file mode 100644
index 000000000..5a71ec760
--- /dev/null
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/v0094.go
@@ -0,0 +1,10 @@
+package migrations
+
+var migration_v0094 = map[string][]string{
+ "mysql": {
+ `ALTER TABLE security_groups ADD COLUMN hash VARCHAR(255) DEFAULT ''`,
+ },
+ "postgres": {
+ `ALTER TABLE security_groups ADD COLUMN hash VARCHAR(255) DEFAULT ''`,
+ },
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/v0095.go b/src/code.cloudfoundry.org/policy-server/store/migrations/v0095.go
new file mode 100644
index 000000000..486add728
--- /dev/null
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/v0095.go
@@ -0,0 +1,24 @@
+package migrations
+
+var migration_v0095 = map[string][]string{
+ "mysql": {`
+
+CREATE PROCEDURE generate_sequence()
+BEGIN
+ DECLARE max_length INT;
+
+ SELECT GREATEST(MAX(JSON_LENGTH(running_spaces)), MAX(JSON_LENGTH(staging_spaces))) INTO max_length
+ FROM security_groups;
+
+ CREATE TABLE IF NOT EXISTS temp_sequence (id INT NOT NULL PRIMARY KEY);
+
+ SET @counter = 0;
+
+ WHILE @counter < max_length DO
+ INSERT INTO temp_sequence (id) VALUES (@counter);
+ SET @counter = @counter + 1;
+ END WHILE;
+END
+`},
+ "postgres": {},
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/v0096.go b/src/code.cloudfoundry.org/policy-server/store/migrations/v0096.go
new file mode 100644
index 000000000..6bcbc3e5b
--- /dev/null
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/v0096.go
@@ -0,0 +1,8 @@
+package migrations
+
+var migration_v0096 = map[string][]string{
+ "mysql": {
+ `CALL generate_sequence();`,
+ },
+ "postgres": {},
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/v0097.go b/src/code.cloudfoundry.org/policy-server/store/migrations/v0097.go
new file mode 100644
index 000000000..0755355a3
--- /dev/null
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/v0097.go
@@ -0,0 +1,20 @@
+package migrations
+
+var migration_v0097 = map[string][]string{
+ "mysql": {
+ `INSERT INTO running_security_groups_spaces (security_group_guid, space_guid) SELECT
+ guid AS security_group_guid,
+ JSON_UNQUOTE(JSON_EXTRACT(security_groups.running_spaces, CONCAT('$[', x.id, ']'))) AS space_guid
+ FROM
+ security_groups
+ JOIN
+ (SELECT id from temp_sequence) x
+ WHERE x.id < JSON_LENGTH(security_groups.running_spaces)`,
+ },
+ "postgres": {
+ `INSERT INTO running_security_groups_spaces (security_group_guid, space_guid) SELECT
+ guid AS security_group_guid,
+ jsonb_array_elements_text(running_spaces) AS space_id
+ FROM security_groups`,
+ },
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/v0098.go b/src/code.cloudfoundry.org/policy-server/store/migrations/v0098.go
new file mode 100644
index 000000000..9a90e3a2b
--- /dev/null
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/v0098.go
@@ -0,0 +1,20 @@
+package migrations
+
+var migration_v0098 = map[string][]string{
+ "mysql": {
+ `INSERT INTO staging_security_groups_spaces (security_group_guid, space_guid) SELECT
+ guid AS security_group_guid,
+ JSON_UNQUOTE(JSON_EXTRACT(security_groups.staging_spaces, CONCAT('$[', x.id, ']'))) AS space_guid
+ FROM
+ security_groups
+ JOIN
+ (SELECT id from temp_sequence) x
+ WHERE x.id < JSON_LENGTH(security_groups.staging_spaces)`,
+ },
+ "postgres": {
+ `INSERT INTO staging_security_groups_spaces (security_group_guid, space_guid) SELECT
+ guid AS security_group_guid,
+ jsonb_array_elements_text(staging_spaces) AS space_id
+ FROM security_groups`,
+ },
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/v0099.go b/src/code.cloudfoundry.org/policy-server/store/migrations/v0099.go
new file mode 100644
index 000000000..494d75f08
--- /dev/null
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/v0099.go
@@ -0,0 +1,8 @@
+package migrations
+
+var migration_v0099 = map[string][]string{
+ "mysql": {
+ `DROP TABLE IF EXISTS temp_sequence`,
+ },
+ "postgres": {},
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/migrations/v0100.go b/src/code.cloudfoundry.org/policy-server/store/migrations/v0100.go
new file mode 100644
index 000000000..1f3fdd7e0
--- /dev/null
+++ b/src/code.cloudfoundry.org/policy-server/store/migrations/v0100.go
@@ -0,0 +1,8 @@
+package migrations
+
+var migration_v0100 = map[string][]string{
+ "mysql": {
+ `DROP PROCEDURE IF EXISTS generate_sequence`,
+ },
+ "postgres": {},
+}
diff --git a/src/code.cloudfoundry.org/policy-server/store/security_groups_store.go b/src/code.cloudfoundry.org/policy-server/store/security_groups_store.go
index af497ed37..a28507952 100644
--- a/src/code.cloudfoundry.org/policy-server/store/security_groups_store.go
+++ b/src/code.cloudfoundry.org/policy-server/store/security_groups_store.go
@@ -1,11 +1,15 @@
package store
import (
+ "crypto/sha256"
+ "encoding/json"
"fmt"
+ "slices"
"strings"
"time"
"code.cloudfoundry.org/cf-networking-helpers/db"
+ "code.cloudfoundry.org/lager/v3"
"code.cloudfoundry.org/policy-server/store/helpers"
)
@@ -17,56 +21,65 @@ type SecurityGroupsStore interface {
}
type SGStore struct {
- Conn Database
+ Logger lager.Logger
+ Conn Database
}
-func (sgs *SGStore) BySpaceGuids(spaceGuids []string, page Page) ([]SecurityGroup, Pagination, error) {
- query := `
- SELECT
- id,
- guid,
- name,
- rules,
- staging_default,
- running_default,
- staging_spaces,
- running_spaces
- FROM security_groups`
-
- whereClause := `staging_default=true OR running_default=true`
+func buildBoundASGQuery(table string, spaceGuids []string) string {
+ return fmt.Sprintf("SELECT security_group_guid AS guid FROM %s_security_groups_spaces WHERE space_guid IN (%s)", table, helpers.QuestionMarks(len(spaceGuids)))
+}
+func (sgs *SGStore) BySpaceGuids(spaceGuids []string, page Page) ([]SecurityGroup, Pagination, error) {
+ var boundASGQuery string
if len(spaceGuids) > 0 {
- whereClause = fmt.Sprintf("%s OR %s OR %s",
- whereClause,
- sgs.jsonOverlapsSQL("staging_spaces", spaceGuids),
- sgs.jsonOverlapsSQL("running_spaces", spaceGuids),
- )
+ boundASGQuery = fmt.Sprintf(`
+ UNION
+ %s
+ UNION
+ %s`, buildBoundASGQuery("staging", spaceGuids), buildBoundASGQuery("running", spaceGuids))
}
+ query := fmt.Sprintf(`SELECT
+ sgs.id,
+ sgs.guid,
+ sgs.name,
+ sgs.rules,
+ sgs.staging_default,
+ sgs.running_default,
+ sgs.staging_spaces,
+ sgs.running_spaces
+FROM security_groups AS sgs WHERE guid in (
+ SELECT guid FROM (
+ SELECT guid FROM security_groups WHERE staging_default = true
+ UNION
+ SELECT guid FROM security_groups WHERE running_default = true
+%s
+ ) as bound
+)
+`, boundASGQuery)
- query = fmt.Sprintf("%s WHERE (%s)", query, whereClause)
-
- // one for running and one for staging
- whereBindings := make([]interface{}, len(spaceGuids)*2)
+ whereBindings := make([]any, len(spaceGuids))
for i, spaceGuid := range spaceGuids {
whereBindings[i] = spaceGuid
- whereBindings[i+len(spaceGuids)] = spaceGuid
}
+ // add a second set for running space guids
+ whereBindings = append(whereBindings, whereBindings...)
if page.From > 0 {
- query = query + " AND id >= %"
+ query = query + " AND sgs.id >= ?"
whereBindings = append(whereBindings, page.From)
}
- query = query + " ORDER BY id"
+ query = query + " ORDER BY sgs.id"
if page.Limit > 0 {
// we don't use a placeholder because limit is an integer and it is safe to interpolate it
query = fmt.Sprintf(`%s LIMIT %d`, query, page.Limit+1)
}
- rebindedQuery := helpers.RebindForSQLDialectAndMark(query, sgs.Conn.DriverName(), "%")
+ rebindedQuery := helpers.RebindForSQLDialect(query, sgs.Conn.DriverName())
rows, err := sgs.Conn.Query(rebindedQuery, whereBindings...)
if err != nil {
+ sgs.Logger.Error("selecting-security-groups", err, lager.Data{"query": rebindedQuery})
return nil, Pagination{}, fmt.Errorf("selecting security groups: %s", err)
}
defer rows.Close()
@@ -98,6 +111,18 @@ func (sgs *SGStore) BySpaceGuids(spaceGuids []string, page Page) ([]SecurityGrou
return result, Pagination{Next: nextId}, nil
}
+func calculateAsgHash(group SecurityGroup) (string, error) {
+ slices.Sort(group.RunningSpaceGuids)
+ slices.Sort(group.StagingSpaceGuids)
+ groupJson, err := json.Marshal(group)
+ if err != nil {
+ return "", fmt.Errorf("failed-marshaling-asg-as-json: %s", err)
+ }
+
+ hash := fmt.Sprintf("%x", sha256.Sum256(groupJson))
+ return hash, nil
+}
+
func (sgs *SGStore) Replace(newSecurityGroups []SecurityGroup) error {
tx, err := sgs.Conn.Beginx()
if err != nil {
@@ -105,61 +130,95 @@ func (sgs *SGStore) Replace(newSecurityGroups []SecurityGroup) error {
}
defer tx.Rollback()
- existingGuids := map[string]bool{}
- rows, err := tx.Queryx("SELECT guid FROM security_groups")
+ if len(newSecurityGroups) == 0 {
+ _, err = tx.Exec("DELETE FROM security_groups")
+ if err != nil {
+ return fmt.Errorf("deleting ALL security groups: %s", err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ return fmt.Errorf("committing transaction to delete ALL security groups: %s", err)
+ }
+ return nil
+ }
+
+ existingGuids := map[string]string{}
+ rows, err := tx.Queryx("SELECT guid, hash FROM security_groups")
if err != nil {
return fmt.Errorf("selecting security groups: %s", err)
}
if rows != nil {
defer rows.Close()
for rows.Next() {
- var guid string
- err := rows.Scan(&guid)
+ var guid, hash string
+ err := rows.Scan(&guid, &hash)
if err != nil {
return fmt.Errorf("scanning security group result: %s", err)
}
- existingGuids[guid] = true
+ existingGuids[guid] = hash
}
}
- upsertQuery := tx.Rebind(`
+ upsertQuery := `
INSERT INTO security_groups
- (guid, name, rules, staging_default, running_default, staging_spaces, running_spaces)
- VALUES(?, ?, ?, ?, ?, ?, ?) ` +
- sgs.onConflictUpdateSQL() +
- ` name=?, rules=?, staging_default=?, running_default=?, staging_spaces=?, running_spaces=?`)
+ (guid, name, hash, rules, staging_default, running_default, staging_spaces, running_spaces)
+ VALUES`
+ columnsPerRecord := 8
+ onConflictQuery := sgs.onConflictUpdateSQL("name", "hash", "rules", "staging_default", "running_default", "staging_spaces", "running_spaces")
+ stagingBindings := map[string][]string{}
+ runningBindings := map[string][]string{}
+ var insertValues []any
for _, group := range newSecurityGroups {
+ originalHash := existingGuids[group.Guid]
delete(existingGuids, group.Guid)
- _, err := tx.Exec(upsertQuery,
- group.Guid,
- group.Name,
- group.Rules,
- group.StagingDefault,
- group.RunningDefault,
- group.StagingSpaceGuids,
- group.RunningSpaceGuids,
- group.Name,
- group.Rules,
- group.StagingDefault,
- group.RunningDefault,
- group.StagingSpaceGuids,
- group.RunningSpaceGuids,
- )
+ newHash, err := calculateAsgHash(group)
if err != nil {
- return fmt.Errorf("saving security group %s (%s): %s", group.Guid, group.Name, err)
+ return fmt.Errorf("failed-calculating-asg-hash: %s", err)
+ }
+ if newHash != originalHash {
+ insertValues = append(insertValues,
+ group.Guid,
+ group.Name,
+ newHash,
+ group.Rules,
+ group.StagingDefault,
+ group.RunningDefault,
+ group.StagingSpaceGuids,
+ group.RunningSpaceGuids,
+ )
+
+ stagingBindings[group.Guid] = group.StagingSpaceGuids
+ runningBindings[group.Guid] = group.RunningSpaceGuids
+ }
+ }
+
+ if len(insertValues) > 0 {
+ sgs.Logger.Debug("updating-existing-security-groups", lager.Data{"": len(insertValues) / columnsPerRecord})
+ err = sgs.BatchPreparedStatement(tx, upsertQuery, onConflictQuery, insertValues, columnsPerRecord)
+ if err != nil {
+ return fmt.Errorf("upserting security groups: %s", err)
+ }
+
+ err = sgs.ReplaceSecurityGroupSpaceAssociations(tx, "staging_security_groups_spaces", stagingBindings)
+ if err != nil {
+ return fmt.Errorf("replacing staging space associations: %s", err)
+ }
+ err = sgs.ReplaceSecurityGroupSpaceAssociations(tx, "running_security_groups_spaces", runningBindings)
+ if err != nil {
+ return fmt.Errorf("replacing running space associations: %s", err)
}
}
if len(existingGuids) > 0 {
- guids := []interface{}{}
+ sgs.Logger.Debug("deleting-stale-security-groups", lager.Data{"num_records": len(insertValues) / columnsPerRecord})
+ guidsToDelete := []any{}
for guid := range existingGuids {
- guids = append(guids, guid)
+ guidsToDelete = append(guidsToDelete, guid)
}
- _, err = tx.Exec(tx.Rebind(`
- DELETE FROM security_groups WHERE guid IN (`+helpers.QuestionMarks(len(existingGuids))+`)`),
- guids...)
+
+ err = sgs.BatchPreparedStatement(tx, "DELETE FROM security_groups WHERE guid IN (", ")", guidsToDelete, 1)
if err != nil {
return fmt.Errorf("deleting security groups: %s", err)
}
@@ -169,6 +228,7 @@ func (sgs *SGStore) Replace(newSecurityGroups []SecurityGroup) error {
if err != nil {
return fmt.Errorf("updating security_groups_info.last_updated: %s", err)
}
+ sgs.Logger.Debug("committing-transaction")
err = tx.Commit()
if err != nil {
return fmt.Errorf("committing transaction: %s", err)
@@ -176,31 +236,94 @@ func (sgs *SGStore) Replace(newSecurityGroups []SecurityGroup) error {
return nil
}
-func (sgs *SGStore) jsonOverlapsSQL(columnName string, filterValues []string) string {
- switch sgs.Conn.DriverName() {
- case helpers.MySQL:
- clauses := []string{}
- for range filterValues {
- clauses = append(clauses, fmt.Sprintf(`json_contains(%s, json_quote(?))`, columnName))
+func (sgs *SGStore) BatchPreparedStatement(tx db.Transaction, statementStart, statementEnd string, parameterValues []any, parametersPerRecord int) error {
+ if len(parameterValues) == 0 {
+ return nil
+ }
+ // Both mysql + postgres claim to use 16bit integers in the protocol spec to identify how many
+ // parameters are being provided, 0 indicating no parameters, and a max of 65535.
+ parameterLimit := 65535
+
+ maxRecordCount := parameterLimit / parametersPerRecord
+ parametersPerBatch := maxRecordCount * parametersPerRecord
+
+ batchNumber := 1
+ for i := 0; i < len(parameterValues); i += parametersPerBatch {
+ lastIndex := min(i+parametersPerBatch, len(parameterValues))
+ recordCount := min((lastIndex-i)/parametersPerRecord, maxRecordCount)
+ values := parameterValues[i:lastIndex]
+
+ reboundStatement := tx.Rebind(
+ fmt.Sprintf("%s %s %s",
+ statementStart,
+ strings.TrimSuffix(
+ strings.Repeat(fmt.Sprintf("(%s), ", helpers.QuestionMarks(parametersPerRecord)), recordCount),
+ ", ",
+ ),
+ statementEnd,
+ ),
+ )
+
+ sgs.Logger.Debug("executing-batched-statement", lager.Data{"batch": batchNumber, "recordCount": recordCount, "paramCount": len(values)})
+ _, err := tx.Exec(reboundStatement, values...)
+ if err != nil {
+ sgs.Logger.Error("batch-prepared-statement-failed", err, lager.Data{"query": reboundStatement})
+ return fmt.Errorf("executing batched statement: %s", err)
}
- return strings.Join(clauses, " OR ")
- case helpers.Postgres:
- filterList := helpers.MarksWithSeparator(len(filterValues), "%", ", ")
- return fmt.Sprintf(`%s ?| array[%s]`, columnName, filterList)
- default:
- return ""
+ batchNumber++
}
+
+ return nil
}
-func (sgs *SGStore) onConflictUpdateSQL() string {
+func (sgs *SGStore) ReplaceSecurityGroupSpaceAssociations(tx db.Transaction, table string, bindings map[string][]string) error {
+ var deleteWhereBindings, insertValues []any
+ var insertTxSize, deleteTxSize int
+ for sgGuid, spaceGuids := range bindings {
+ deleteWhereBindings = append(deleteWhereBindings, sgGuid)
+ deleteTxSize += len(sgGuid)
+
+ for _, spaceGuid := range spaceGuids {
+ insertValues = append(insertValues, sgGuid, spaceGuid)
+ insertTxSize += len(sgGuid) + len(spaceGuid)
+ }
+ }
+
+ deleteQuery := fmt.Sprintf("DELETE FROM %s WHERE security_group_guid IN (", table)
+ err := sgs.BatchPreparedStatement(tx, deleteQuery, ")", deleteWhereBindings, 1)
+ if err != nil {
+ return fmt.Errorf("deleting previous associations: %s", err)
+ }
+
+ replaceQuery := fmt.Sprintf("INSERT INTO %s (security_group_guid, space_guid) VALUES", table)
+ columnsPerRecord := 2
+ err = sgs.BatchPreparedStatement(tx, replaceQuery, "", insertValues, columnsPerRecord)
+ if err != nil {
+ return fmt.Errorf("creating new associations: %s", err)
+ }
+
+ return nil
+}
+
+func (sgs *SGStore) onConflictUpdateSQL(columns ...string) string {
+ var conflictSql string
switch sgs.Conn.DriverName() {
case helpers.MySQL:
- return "ON DUPLICATE KEY UPDATE"
+ conflictSql = "ON DUPLICATE KEY UPDATE"
+ for _, column := range columns {
+ conflictSql = fmt.Sprintf("%s %s = VALUES(%s),", conflictSql, column, column)
+ }
+ conflictSql = strings.TrimRight(conflictSql, ",")
case helpers.Postgres:
- return "ON CONFLICT (guid) DO UPDATE SET"
+ conflictSql = "ON CONFLICT (guid) DO UPDATE SET "
+ for _, column := range columns {
+ conflictSql = fmt.Sprintf("%s %s = EXCLUDED.%s,", conflictSql, column, column)
+ }
+ conflictSql = strings.TrimRight(conflictSql, ",")
default:
return ""
}
+ return conflictSql
}
func (sgs *SGStore) LastUpdated() (int, error) {
diff --git a/src/code.cloudfoundry.org/policy-server/store/security_groups_store_test.go b/src/code.cloudfoundry.org/policy-server/store/security_groups_store_test.go
index b595abe56..665bfb97a 100644
--- a/src/code.cloudfoundry.org/policy-server/store/security_groups_store_test.go
+++ b/src/code.cloudfoundry.org/policy-server/store/security_groups_store_test.go
@@ -8,12 +8,14 @@ import (
dbHelper "code.cloudfoundry.org/cf-networking-helpers/db"
dbfakes "code.cloudfoundry.org/cf-networking-helpers/db/fakes"
"code.cloudfoundry.org/cf-networking-helpers/testsupport"
- "code.cloudfoundry.org/lager/v3"
+ "code.cloudfoundry.org/lager/v3/lagertest"
"code.cloudfoundry.org/policy-server/store"
"code.cloudfoundry.org/policy-server/store/fakes"
+ "code.cloudfoundry.org/policy-server/store/helpers"
testhelpers "code.cloudfoundry.org/test-helpers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
+ "github.com/onsi/gomega/gbytes"
)
var _ = Describe("SecurityGroupsStore", func() {
@@ -21,21 +23,30 @@ var _ = Describe("SecurityGroupsStore", func() {
securityGroupsStore *store.SGStore
dbConf dbHelper.Config
realDb *dbHelper.ConnWrapper
+ testLogger *lagertest.TestLogger
)
+ getNumRecords := func(table string) int {
+ var count int
+ ExpectWithOffset(1, realDb).ToNot(BeNil())
+ err := realDb.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count)
+ ExpectWithOffset(1, err).ToNot(HaveOccurred())
+ return count
+ }
+
BeforeEach(func() {
dbConf = testsupport.GetDBConfig()
dbConf.DatabaseName = fmt.Sprintf("security_groups_store_test_%d", time.Now().UnixNano())
dbConf.Timeout = 30
testhelpers.CreateDatabase(dbConf)
- logger := lager.NewLogger("Security Groups Store Test")
-
var err error
- realDb, err = dbHelper.NewConnectionPool(dbConf, 200, 200, 5*time.Minute, "Security Groups Store Test", "Security Groups Store Test", logger)
+ testLogger = lagertest.NewTestLogger("asg-syncer-test")
+ realDb, err = dbHelper.NewConnectionPool(dbConf, 200, 200, 5*time.Minute, "Security Groups Store Test", "Security Groups Store Test", testLogger)
Expect(err).NotTo(HaveOccurred())
securityGroupsStore = &store.SGStore{
- Conn: realDb,
+ Conn: realDb,
+ Logger: testLogger,
}
migrate(realDb)
@@ -89,7 +100,7 @@ var _ = Describe("SecurityGroupsStore", func() {
})
Context("search by staging space guid", func() {
- It("fetches global asgs and asgs attached to provided spaces", func() {
+ It("fetches asgs attached to provided spaces", func() {
securityGroups, pagination, err := securityGroupsStore.BySpaceGuids([]string{"space-b"}, store.Page{})
Expect(err).ToNot(HaveOccurred())
@@ -106,7 +117,7 @@ var _ = Describe("SecurityGroupsStore", func() {
})
Context("search by running space guid", func() {
- It("fetches global asgs and asgs attached to provided spaces", func() {
+ It("fetches attached to provided spaces", func() {
securityGroups, pagination, err := securityGroupsStore.BySpaceGuids([]string{"space-a"}, store.Page{})
Expect(err).ToNot(HaveOccurred())
@@ -328,7 +339,7 @@ var _ = Describe("SecurityGroupsStore", func() {
})
Describe("Replace", func() {
- var initialRules, newRules []store.SecurityGroup
+ var initialRules, newRules, emptyRules []store.SecurityGroup
BeforeEach(func() {
initialRules = []store.SecurityGroup{{
@@ -336,6 +347,7 @@ var _ = Describe("SecurityGroupsStore", func() {
Name: "first-asg",
Rules: "firstRules",
RunningSpaceGuids: []string{"first-space"},
+ StagingSpaceGuids: []string{"fourth-space"},
}, {
Guid: "second-guid",
Name: "second-name",
@@ -362,6 +374,8 @@ var _ = Describe("SecurityGroupsStore", func() {
RunningDefault: true,
}}
+ emptyRules = []store.SecurityGroup{}
+
err := securityGroupsStore.Replace(initialRules)
Expect(err).ToNot(HaveOccurred())
})
@@ -399,6 +413,260 @@ var _ = Describe("SecurityGroupsStore", func() {
Expect(securityGroups).To(ConsistOf(initialRules))
})
+ Context("when the only change is a deletion", func() {
+ BeforeEach(func() {
+ newRules = initialRules[0:1]
+ })
+ It("deletes the record", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ securityGroups, _, err := securityGroupsStore.BySpaceGuids([]string{"first-space", "second-space", "third-space"}, store.Page{})
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(securityGroups).To(ConsistOf(newRules))
+ })
+ })
+ Context("when two ASGs change but the third doesn't", func() {
+ var hashes []string
+ BeforeEach(func() {
+ newRules = initialRules
+ newRules[1].Guid = "new-guid"
+ newRules = append(newRules, store.SecurityGroup{
+ Guid: "third-guid",
+ Name: "third-name",
+ Rules: "thirdRules",
+ StagingSpaceGuids: []string{"third-space"},
+ StagingDefault: true,
+ RunningSpaceGuids: []string{},
+ })
+ Eventually(testLogger).Should(gbytes.Say("committing-transaction"))
+
+ rows, err := realDb.Query("SELECT hash FROM security_groups ORDER BY id")
+ Expect(err).NotTo(HaveOccurred())
+ for rows.Next() {
+ var hash string
+ err := rows.Scan(&hash)
+ Expect(err).NotTo(HaveOccurred())
+ hashes = append(hashes, hash)
+ }
+ })
+ It("only updates the two changing", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ securityGroups, _, err := securityGroupsStore.BySpaceGuids([]string{"first-space", "second-space", "third-space"}, store.Page{})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(securityGroups).To(ConsistOf(newRules))
+
+ Eventually(testLogger).Should(gbytes.Say(`"num_records":2`))
+
+ var newHashes []string
+ rows, err := realDb.Query("SELECT hash FROM security_groups ORDER BY id")
+ Expect(err).NotTo(HaveOccurred())
+ for rows.Next() {
+ var hash string
+ err := rows.Scan(&hash)
+ Expect(err).NotTo(HaveOccurred())
+ newHashes = append(newHashes, hash)
+ }
+ Expect(newHashes[0]).To(Equal(hashes[0]))
+ Expect(newHashes[1]).To(Equal("f842efd3fd4944dc6dd669b8f6817f914826485ebec5aa697632d8a10f3ddc30"))
+ Expect(newHashes[2]).To(Equal("2a3d6609af1920ccf8fa2c9d2e5d829592ee91f146c46fee2506c01a59aacd3f"))
+ })
+ })
+ Context("testing security-group-space associations", func() {
+ BeforeEach(func() {
+
+ newRules[0].StagingDefault = false
+ newRules[0].RunningDefault = false
+ newRules[1].StagingDefault = false
+ newRules[1].RunningDefault = false
+ })
+ Context("when a security group loses a running space binding", func() {
+ BeforeEach(func() {
+ newRules[1].RunningSpaceGuids = []string{"third-space", "first-space"}
+ })
+ It("removes the association for that space from that ASG", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ results := map[string][]string{}
+ query := helpers.RebindForSQLDialect("SELECT security_group_guid, space_guid FROM running_security_groups_spaces WHERE security_group_guid = ? ORDER BY space_guid", securityGroupsStore.Conn.DriverName())
+ rows, err := securityGroupsStore.Conn.Query(query, "second-guid")
+ Expect(err).NotTo(HaveOccurred())
+ for rows.Next() {
+ var sg, space string
+ err := rows.Scan(&sg, &space)
+ Expect(err).NotTo(HaveOccurred())
+
+ results[sg] = append(results[sg], space)
+ }
+ Expect(results).To(Equal(map[string][]string{
+ "second-guid": []string{"first-space", "third-space"},
+ }))
+ })
+ })
+ Context("when a security group loses a staging space binding", func() {
+ BeforeEach(func() {
+ newRules[1].StagingSpaceGuids = []string{"third-space", "first-space"}
+ })
+ It("removes the association for that space from that ASG", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ results := map[string][]string{}
+ query := helpers.RebindForSQLDialect("SELECT security_group_guid, space_guid FROM staging_security_groups_spaces WHERE security_group_guid = ? ORDER BY space_guid", securityGroupsStore.Conn.DriverName())
+ rows, err := securityGroupsStore.Conn.Query(query, "second-guid")
+ Expect(err).NotTo(HaveOccurred())
+ for rows.Next() {
+ var sg, space string
+ err := rows.Scan(&sg, &space)
+ Expect(err).NotTo(HaveOccurred())
+
+ results[sg] = append(results[sg], space)
+ }
+ Expect(results).To(Equal(map[string][]string{
+ "second-guid": []string{"first-space", "third-space"},
+ }))
+ })
+ })
+ Context("when a security group loses its staging space bindings", func() {
+ BeforeEach(func() {
+ newRules[1].StagingSpaceGuids = []string{}
+ })
+ It("removes all associations for that ASG", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ var associations int
+ query := helpers.RebindForSQLDialect("SELECT COUNT(*) FROM staging_security_groups_spaces WHERE security_group_guid = ?", securityGroupsStore.Conn.DriverName())
+ err = securityGroupsStore.Conn.QueryRow(query, "second-guid").Scan(&associations)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(associations).To(Equal(0))
+
+ err = securityGroupsStore.Conn.QueryRow("SELECT COUNT(*) FROM staging_security_groups_spaces").Scan(&associations)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(associations).ToNot(Equal(0))
+ })
+ })
+
+ Context("when a security group loses its running space bindings", func() {
+ BeforeEach(func() {
+ newRules[1].RunningSpaceGuids = []string{}
+ newRules[0].RunningSpaceGuids = []string{"fourth-space"}
+ })
+ It("removes all associations from that ASG", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ var associations int
+ query := helpers.RebindForSQLDialect("SELECT COUNT(*) FROM running_security_groups_spaces WHERE security_group_guid = ?", securityGroupsStore.Conn.DriverName())
+ err = securityGroupsStore.Conn.QueryRow(query, "second-guid").Scan(&associations)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(associations).To(Equal(0))
+
+ err = securityGroupsStore.Conn.QueryRow("SELECT COUNT(*) FROM running_security_groups_spaces").Scan(&associations)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(associations).ToNot(Equal(0))
+ })
+ })
+ Context("when no more running asgs are directly bound", func() {
+ BeforeEach(func() {
+ newRules[0].RunningSpaceGuids = []string{}
+ newRules[1].RunningSpaceGuids = []string{}
+ })
+ It("removes all space associations from from all ASGs", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ var associations int
+ err = securityGroupsStore.Conn.QueryRow("SELECT COUNT(*) FROM running_security_groups_spaces").Scan(&associations)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(associations).To(Equal(0))
+ })
+ })
+ Context("when no more staging asgs are directly bound", func() {
+ BeforeEach(func() {
+ newRules[0].StagingSpaceGuids = []string{}
+ newRules[1].StagingSpaceGuids = []string{}
+ })
+ It("removes all space associations from from all ASGs", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ var associations int
+ err = securityGroupsStore.Conn.QueryRow("SELECT COUNT(*) FROM staging_security_groups_spaces").Scan(&associations)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(associations).To(Equal(0))
+ })
+ })
+ Context("when the only updated ASGs are ones that don't have space bindings", func() {
+ BeforeEach(func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+ securityGroups, _, err := securityGroupsStore.BySpaceGuids([]string{"first-space", "second-space", "third-space"}, store.Page{})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(securityGroups).To(ConsistOf(newRules))
+
+ newRules = append(newRules, store.SecurityGroup{
+ Name: "only-global",
+ Guid: "only-global-guid",
+ StagingDefault: true,
+ RunningDefault: true,
+ })
+ })
+ It("doesn't affect any of the space bindings", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).NotTo(HaveOccurred())
+
+ securityGroups, _, err := securityGroupsStore.BySpaceGuids([]string{"first-space", "second-space", "third-space"}, store.Page{})
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(securityGroups).To(ConsistOf(newRules))
+
+ })
+ })
+ })
+ Context("when no more asgs exist", func() {
+ It("removes all ASGs", func() {
+ Expect(getNumRecords("security_groups")).ToNot(Equal(0))
+ Expect(getNumRecords("staging_security_groups_spaces")).ToNot(Equal(0))
+ Expect(getNumRecords("running_security_groups_spaces")).ToNot(Equal(0))
+
+ err := securityGroupsStore.Replace(emptyRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ securityGroups, _, err := securityGroupsStore.BySpaceGuids([]string{"first-space", "second-space", "third-space"}, store.Page{})
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(securityGroups).To(ConsistOf(emptyRules))
+ time.Sleep(1 * time.Second)
+ Expect(getNumRecords("security_groups")).To(Equal(0))
+ Expect(getNumRecords("staging_security_groups_spaces")).To(Equal(0))
+ Expect(getNumRecords("running_security_groups_spaces")).To(Equal(0))
+ })
+ })
+ Context("when no pre-existing data exists", func() {
+ BeforeEach(func() {
+ err := securityGroupsStore.Replace(emptyRules)
+ Expect(err).NotTo(HaveOccurred())
+
+ Expect(getNumRecords("security_groups")).To(Equal(0))
+ Expect(getNumRecords("staging_security_groups_spaces")).To(Equal(0))
+ Expect(getNumRecords("running_security_groups_spaces")).To(Equal(0))
+ })
+ It("creates new data", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).ToNot(HaveOccurred())
+
+ securityGroups, _, err := securityGroupsStore.BySpaceGuids([]string{"first-space", "second-space", "third-space"}, store.Page{})
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(securityGroups).To(ConsistOf(newRules))
+
+ })
+ })
Context("when errors occur", func() {
var mockDB *fakes.Db
@@ -408,6 +676,7 @@ var _ = Describe("SecurityGroupsStore", func() {
tx = new(dbfakes.Transaction)
mockDB.BeginxReturns(tx, nil)
securityGroupsStore.Conn = mockDB
+ mockDB.DriverNameReturns("mysql")
})
Context("beginning a transaction", func() {
@@ -421,6 +690,39 @@ var _ = Describe("SecurityGroupsStore", func() {
})
})
+ Context("when deleting all ASGs", func() {
+ Context("and the error is when deleting ASGs", func() {
+ BeforeEach(func() {
+ tx.ExecReturnsOnCall(0, nil, errors.New("can't exec SQL"))
+ })
+
+ It("returns an error", func() {
+ err := securityGroupsStore.Replace(emptyRules)
+ Expect(err).To(MatchError("deleting ALL security groups: can't exec SQL"))
+ })
+
+ It("rolls back the transaction", func() {
+ securityGroupsStore.Replace(newRules)
+ Expect(tx.RollbackCallCount()).To(Equal(1))
+ })
+ })
+ Context("and committing the transaction fails", func() {
+ BeforeEach(func() {
+ tx.CommitReturns(errors.New("can't commit transaction"))
+ })
+
+ It("returns an error", func() {
+ err := securityGroupsStore.Replace(emptyRules)
+ Expect(err).To(MatchError("committing transaction to delete ALL security groups: can't commit transaction"))
+ })
+
+ It("rolls back the transaction", func() {
+ securityGroupsStore.Replace(newRules)
+ Expect(tx.RollbackCallCount()).To(Equal(1))
+ })
+ })
+ })
+
Context("getting existing security groups", func() {
BeforeEach(func() {
tx.QueryxReturns(nil, errors.New("can't exec SQL"))
@@ -444,7 +746,38 @@ var _ = Describe("SecurityGroupsStore", func() {
It("returns an error", func() {
err := securityGroupsStore.Replace(newRules)
- Expect(err).To(MatchError("saving security group third-guid (third-name): can't exec SQL"))
+ Expect(err).To(MatchError("upserting security groups: executing batched statement: can't exec SQL"))
+ })
+
+ It("rolls back the transaction", func() {
+ securityGroupsStore.Replace(newRules)
+ Expect(tx.RollbackCallCount()).To(Equal(1))
+ })
+ })
+
+ Context("updating running security group bindings", func() {
+ BeforeEach(func() {
+ tx.ExecReturnsOnCall(3, nil, errors.New("can't exec SQL"))
+ })
+
+ It("returns an error", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).To(MatchError("replacing running space associations: deleting previous associations: executing batched statement: can't exec SQL"))
+ })
+
+ It("rolls back the transaction", func() {
+ securityGroupsStore.Replace(newRules)
+ Expect(tx.RollbackCallCount()).To(Equal(1))
+ })
+ })
+ Context("updating staging security group bindings", func() {
+ BeforeEach(func() {
+ tx.ExecReturnsOnCall(1, nil, errors.New("can't exec SQL"))
+ })
+
+ It("returns an error", func() {
+ err := securityGroupsStore.Replace(newRules)
+ Expect(err).To(MatchError("replacing staging space associations: deleting previous associations: executing batched statement: can't exec SQL"))
})
It("rolls back the transaction", func() {
@@ -471,6 +804,116 @@ var _ = Describe("SecurityGroupsStore", func() {
})
})
+ Describe("ReplaceSecurityGroupSpaceAssociations()", func() {
+ var tx *dbfakes.Transaction
+ var sgSpaceBindings map[string][]string
+ BeforeEach(func() {
+ mockDB := new(fakes.Db)
+ tx = new(dbfakes.Transaction)
+ mockDB.BeginxReturns(tx, nil)
+ securityGroupsStore.Conn = mockDB
+ mockDB.DriverNameReturns("mysql")
+
+ sgSpaceBindings = map[string][]string{
+ "sg-1": {"space-1", "space-2"},
+ "sg-2": {"space-2", "space-3"},
+ }
+ })
+ Context("when errors occur", func() {
+ Context("during the delete", func() {
+ BeforeEach(func() {
+ tx.ExecReturnsOnCall(0, nil, errors.New("injected failure"))
+ })
+ It("returns an error", func() {
+ err := securityGroupsStore.ReplaceSecurityGroupSpaceAssociations(tx, "staging_security_groups_spaces", sgSpaceBindings)
+ Expect(err).To(HaveOccurred())
+ Expect(err).To(MatchError("deleting previous associations: executing batched statement: injected failure"))
+ })
+ })
+ Context("during the insert", func() {
+ BeforeEach(func() {
+ tx.ExecReturnsOnCall(1, nil, errors.New("injected failure"))
+ })
+ It("returns an error", func() {
+ err := securityGroupsStore.ReplaceSecurityGroupSpaceAssociations(tx, "staging_security_groups_spaces", sgSpaceBindings)
+ Expect(err).To(HaveOccurred())
+ Expect(err).To(MatchError("creating new associations: executing batched statement: injected failure"))
+ })
+ })
+ })
+ //Happy path tests are all included in tests of Replace() since thats the functionality
+ //we actually care about succeeding
+ })
+
+ Describe("BatchPreparedStatement()", func() {
+ paramsPerRecord := 2
+ // we use this many to validate we adhere to the limits of postgres + mysql, rather
+ // than testing a smaller dataset against an arbitrarily lowered max to only test batching
+ var values []any
+ var tx dbHelper.Transaction
+ BeforeEach(func() {
+ values = []any{}
+ for i := range 35000 {
+ values = append(values, fmt.Sprintf("fake-asg-%d", i))
+ values = append(values, fmt.Sprintf("fake-name-%d", i))
+ }
+ })
+ JustBeforeEach(func() {
+ var err error
+ tx, err = securityGroupsStore.Conn.Beginx()
+ Expect(err).NotTo(HaveOccurred())
+ })
+ AfterEach(func() {
+ _, err := realDb.Exec("DELETE FROM staging_security_groups_spaces")
+ Expect(err).NotTo(HaveOccurred())
+ })
+
+ It("doesn't error and inserts every record", func() {
+ err := securityGroupsStore.BatchPreparedStatement(tx, "INSERT INTO security_groups (guid, name) VALUES", "", values, paramsPerRecord)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(tx.Commit()).To(Succeed())
+ rows, err := securityGroupsStore.Conn.Query("SELECT guid, name FROM security_groups ORDER BY guid")
+ Expect(err).NotTo(HaveOccurred())
+ results := map[string]string{}
+ for rows.Next() {
+ var sg, name string
+ err := rows.Scan(&sg, &name)
+ Expect(err).NotTo(HaveOccurred())
+ results[sg] = name
+ }
+ for i := range 35000 {
+ space, ok := results[fmt.Sprintf("fake-asg-%d", i)]
+ Expect(ok).To(BeTrue())
+ Expect(space).To(Equal(fmt.Sprintf("fake-name-%d", i)))
+ }
+ })
+ Context("when run against a mockDB", func() {
+ var fakeTx *dbfakes.Transaction
+ BeforeEach(func() {
+ mockDB := new(fakes.Db)
+ fakeTx = new(dbfakes.Transaction)
+ mockDB.BeginxReturns(fakeTx, nil)
+ securityGroupsStore.Conn = mockDB
+ mockDB.DriverNameReturns("mysql")
+ })
+ It("batches into two chunks", func() {
+ err := securityGroupsStore.BatchPreparedStatement(tx, "INSERT INTO security_groups (guid, name) VALUES", "", values, paramsPerRecord)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(fakeTx.ExecCallCount()).To(Equal(2))
+ })
+ Context("when execution fails", func() {
+ BeforeEach(func() {
+ fakeTx.ExecReturns(nil, fmt.Errorf("injected failure"))
+ })
+ It("returns an error", func() {
+ err := securityGroupsStore.BatchPreparedStatement(tx, "INSERT INTO security_groups (guid, name) VALUES", "", values, paramsPerRecord)
+ Expect(err).To(HaveOccurred())
+ Expect(err).To(MatchError("executing batched statement: injected failure"))
+ })
+ })
+ })
+ })
+
Describe("LastUpdated()", func() {
var currentTime int64
BeforeEach(func() {
@@ -478,7 +921,7 @@ var _ = Describe("SecurityGroupsStore", func() {
migrateAndPopulateTags(realDb, 1)
currentTime = time.Now().UnixNano()
- securityGroupsStore.Replace([]store.SecurityGroup{{
+ err := securityGroupsStore.Replace([]store.SecurityGroup{{
Guid: "third-guid",
Name: "third-name",
Rules: "thirdRules",
@@ -486,6 +929,7 @@ var _ = Describe("SecurityGroupsStore", func() {
StagingDefault: true,
RunningSpaceGuids: []string{},
}})
+ Expect(err).NotTo(HaveOccurred())
})
It("returns a timestamp in UnixNano format", func() {
updatedTime, err := securityGroupsStore.LastUpdated()