Mongo go driver beta integration (#1042)

* Changed mongo.go to use new driver

* Modified mongo cataloger

* More new driver related changes

* Change lister.go

* Change saver.go

* Change imports

* Remove unnecessary Count query

* Use IndexView for indexing

* Rename ModuleStore fields

* Use map of key:sorting-order for creating the index

* Minor changes

* Use client options to configure mongo client

* Use method chaining

* gofmt changes

* Change imports

* Fix some build errors

* Use new GridFS API

* Fix more build errors

* Add Go Mongo driver to dependency modules

* Use multierror

* Leave download stream open

* Remove mgo error handling

* Copy zip instead of loading all in memory

* Use context.WithTimeout() wherever possible

* Raise KindNotFound when mod@ver isn't found

* NopCloser not needed

* Fix IndexView error

* Fix build errors

* Remove another mgo error usage

* Fix build error

* Changes according to review

* Formatting changes as per gofmt

* Modify gofmt argument to show the expected formatting (diff)

* Handle ErrNoDocument error and error arising from query execution

* Fix kind of returned error

* Minor changes

* Bug fixes

* gofmt related changes

* Minor change

* Use Insecure from MongoConfig, remove Insecure from global Config

* Remove stray print statement
This commit is contained in:
Arpit Gogia
2019-04-17 23:29:01 +05:30
committed by marpio
parent c11212ba16
commit 974077e73b
311 changed files with 52520 additions and 112 deletions
+8 -1
View File
@@ -9,14 +9,17 @@ require (
github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20180917103902-e6c7f767dc57
github.com/aws/aws-sdk-go v1.15.24
github.com/bsm/redis-lock v8.0.0+incompatible
github.com/codegangsta/negroni v1.0.0 // indirect
github.com/fatih/color v1.7.0
github.com/globalsign/mgo v0.0.0-20180828104044-6f9f54af1356
github.com/go-playground/locales v0.12.1 // indirect
github.com/go-playground/universal-translator v0.16.0 // indirect
github.com/go-redis/redis v6.15.2+incompatible
github.com/go-stack/stack v1.8.0 // indirect
github.com/gobuffalo/envy v1.6.7
github.com/gobuffalo/httptest v1.0.4
github.com/gogo/protobuf v1.2.0 // indirect
github.com/golang/snappy v0.0.1 // indirect
github.com/google/go-cmp v0.2.0
github.com/google/martian v2.1.0+incompatible // indirect
github.com/google/uuid v1.1.1
@@ -38,9 +41,13 @@ require (
github.com/spf13/afero v1.1.2
github.com/stretchr/testify v1.3.0
github.com/technosophos/moniker v0.0.0-20180509230615-a5dbd03a2245
github.com/tidwall/pretty v0.0.0-20180105212114-65a9db5fad51 // indirect
github.com/tinylib/msgp v1.0.2 // indirect
github.com/unrolled/secure v0.0.0-20181221173256-0d6b5bb13069
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c // indirect
github.com/xdg/stringprep v1.0.0 // indirect
go.etcd.io/etcd v0.0.0-20190215181705-784daa04988c
go.mongodb.org/mongo-driver v1.0.0
go.opencensus.io v0.17.0
golang.org/x/crypto v0.0.0-20181029103014-dab2b1051b5d // indirect
golang.org/x/net v0.0.0-20181029044818-c44066c5c816 // indirect
@@ -49,7 +56,7 @@ require (
golang.org/x/sys v0.0.0-20181031143558-9b800f95dbbc // indirect
google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf
google.golang.org/appengine v1.3.0 // indirect
gopkg.in/DataDog/dd-trace-go.v1 v1.3.0 // indirect
gopkg.in/DataDog/dd-trace-go.v1 v1.10.0 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/go-playground/assert.v1 v1.2.1 // indirect
gopkg.in/go-playground/validator.v9 v9.20.2
+41 -2
View File
@@ -23,6 +23,9 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
github.com/bsm/redis-lock v8.0.0+incompatible h1:QgB0J2pNG8hUfndTIvpPh38F5XsUTTvO7x8Sls++9Mk=
github.com/bsm/redis-lock v8.0.0+incompatible/go.mod h1:8dGkQ5GimBCahwF2R67tqGCJbyDZSp0gzO7wq3pDrik=
github.com/codegangsta/negroni v1.0.0 h1:+aYywywx4bnKXWvoWtRfJ91vC59NbEhEY03sZjQhbVY=
github.com/codegangsta/negroni v1.0.0/go.mod h1:v0y3T5G7Y1UlFfyxFn/QLRU4a2EuNau2iZY63YTKWo0=
github.com/coreos/go-semver v0.2.0 h1:3Jm3tLmsgAYcjC+4Up7hJrFBPr+n7rAqYeSw/SZazuY=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7 h1:u9SHYsPQNyt5tgDm3YN7+9dYrpK96E5wFilTFWIDZOM=
github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
@@ -31,11 +34,15 @@ github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfc
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4 h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/globalsign/mgo v0.0.0-20180828104044-6f9f54af1356 h1:5bNaeqHyuxTGYlx42mevVN+R0TGdOrwj8MQl0yo1260=
github.com/globalsign/mgo v0.0.0-20180828104044-6f9f54af1356/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
@@ -47,6 +54,8 @@ github.com/go-playground/universal-translator v0.16.0 h1:X++omBR/4cE2MNg91AoC3rm
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-redis/redis v6.15.2+incompatible h1:9SpNVG76gr6InJGxoZ6IuuxaCOQwDAhzyXg+Bs+0Sb4=
github.com/go-redis/redis v6.15.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gobuffalo/envy v1.6.7 h1:XMZGuFqTupAXhZTriQ+qO38QvNOSU/0rl3hEPCFci/4=
github.com/gobuffalo/envy v1.6.7/go.mod h1:N+GkhhZ/93bGZc6ZKhJLP6+m+tCNPKwgSpH9kaifseQ=
github.com/gobuffalo/httptest v1.0.4 h1:P0uKaPEjti1bbJmuBILE3QQ7iU1cS7oIkxVba5HbcVE=
@@ -56,9 +65,13 @@ github.com/gogo/protobuf v1.2.0 h1:xU6/SpYbvkNYiptHJYEDRseDLvYE7wSqhYYNy0QSUzI=
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903 h1:LbsanbbD6LieFkXbj9YNNBupiGHJgFeLpO0j0Fza1h8=
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/btree v0.0.0-20180124185431-e89373fe6b4a h1:ZJu5NB1Bk5ms4vw0Xu4i+jD32SE9jQXyfnOvwhHqlT0=
github.com/google/btree v0.0.0-20180124185431-e89373fe6b4a/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
@@ -80,20 +93,26 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9RU=
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c h1:Lh2aW+HnU2Nbe1gqD9SOJLJxW1jBMmQOktN2acDyJk8=
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 h1:Iju5GlWwrvL6UBg4zJJt3btmonfrMlCDdsejg4CZE7c=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
github.com/grpc-ecosystem/grpc-gateway v1.4.1 h1:pX7cnDwSSmG0dR9yNjCQSSpmsJOqFdT7SzVp5Yl9uVw=
github.com/grpc-ecosystem/grpc-gateway v1.4.1/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o=
github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 h1:12VvqtR6Aowv3l/EQUlocDHW2Cp4G9WJVH7uyH8QFJE=
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc=
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpRVWLVmUEE=
github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
@@ -123,7 +142,9 @@ github.com/minio/minio-go v6.0.5+incompatible/go.mod h1:7guKYtitv8dktvNUGrhzmNlA
github.com/mitchellh/go-homedir v1.0.0 h1:vKb8ShqSby24Yrqr/yDYkuFz8d0WUjys40rvnGC8aR0=
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo=
github.com/onsi/ginkgo v1.6.0 h1:Ix8l273rp3QzYgXSR+c8d1fTG7UPgYkOSELPhiY/YGw=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.4.2 h1:3mYCb7aPxS/RU7TI1y4rkEn1oKmPRjNJLNEXgw7MH2I=
github.com/onsi/gomega v1.4.2/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ=
@@ -152,10 +173,12 @@ github.com/smartystreets/assertions v0.0.0-20180820201707-7c9eb446e3cf h1:6V1qxN
github.com/smartystreets/assertions v0.0.0-20180820201707-7c9eb446e3cf/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20180222194500-ef6db91d284a h1:JSvGDIbmil4Ui/dDdFBExb7/cmkNjyX5F97oglmvCDo=
github.com/smartystreets/goconvey v0.0.0-20180222194500-ef6db91d284a/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/soheilhy/cmux v0.1.4 h1:0HKaf1o97UwFjHH9o5XsHUOF+tqmdA7KEzXLpiyaw0E=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
github.com/spf13/afero v1.1.2 h1:m8/z1t7/fwjysjQRYbP0RD+bUIF/8tJwPdEZsI83ACI=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
github.com/spf13/pflag v1.0.1 h1:aCvUg6QPl3ibpQUxyLkrEkCHtPqYJL4x9AuhqVqFis4=
github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
@@ -164,17 +187,29 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/technosophos/moniker v0.0.0-20180509230615-a5dbd03a2245 h1:DNVk+NIkGS0RbLkjQOLCJb/759yfCysThkMbl7EXxyY=
github.com/technosophos/moniker v0.0.0-20180509230615-a5dbd03a2245/go.mod h1:O1c8HleITsZqzNZDjSNzirUGsMT0oGu9LhHKoJrqO+A=
github.com/tidwall/pretty v0.0.0-20180105212114-65a9db5fad51 h1:BP2bjP495BBPaBcS5rmqviTfrOkN5rO5ceKAMRZCRFc=
github.com/tidwall/pretty v0.0.0-20180105212114-65a9db5fad51/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tinylib/msgp v1.0.2 h1:DfdQrzQa7Yh2es9SuLkixqxuXS2SxsdYn0KbdrOGWD8=
github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8 h1:ndzgwNDnKIqyCvHTXaCqh9KlOWKvBry6nuXMJmonVsE=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ugorji/go v1.1.1 h1:gmervu+jDMvXTbcHQ0pd2wee85nEoE0BsVyEuzkfK8w=
github.com/ugorji/go v1.1.1/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ=
github.com/unrolled/secure v0.0.0-20181221173256-0d6b5bb13069 h1:RKeYksgIwGE8zFJTvXI1WWx09QPrGyaVFMy0vpU7j/o=
github.com/unrolled/secure v0.0.0-20181221173256-0d6b5bb13069/go.mod h1:mnPT77IAdsi/kV7+Es7y+pXALeV3h7G6dQF6mNYjcLA=
github.com/urfave/cli v1.18.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk=
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I=
github.com/xdg/stringprep v1.0.0 h1:d9X0esnoa3dFsV0FG35rAT0RIhYFlPq7MiP+DW89La0=
github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
go.etcd.io/bbolt v1.3.2 h1:Z/90sZLPOeCy2PwprqkFa25PdkusRzaj9P8zm/KNyvk=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.etcd.io/etcd v0.0.0-20190215181705-784daa04988c h1:pkiZ418C7QN/HIps1lDF1+lzZhdgMpvFN4kDcxrYhD0=
go.etcd.io/etcd v0.0.0-20190215181705-784daa04988c/go.mod h1:RutfZdQAP913VY0GI8/Mjwf50+IZ7Mpg2zt3SDs17/g=
go.mongodb.org/mongo-driver v1.0.0 h1:KxPRDyfB2xXnDE2My8acoOWBQkfv3tz0SaWTRZjJR0c=
go.mongodb.org/mongo-driver v1.0.0/go.mod h1:u7ryQJ+DOzQmeO7zB6MHyr8jkEQvC8vH7qLUO4lqsUM=
go.opencensus.io v0.17.0 h1:2Cu88MYg+1LU+WVD+NWwYhyP0kKgRlN9QjWGaX0jKTE=
go.opencensus.io v0.17.0/go.mod h1:mp1VrMQxhlqqDpKvH4UcQUa4YwlzNmymAjPrDdfxNpI=
go.uber.org/atomic v1.3.2 h1:2Oa65PReHzfn29GpvgsYwloV9AVFHPDk8tYxt2c2tr4=
@@ -203,6 +238,7 @@ golang.org/x/sys v0.0.0-20181031143558-9b800f95dbbc h1:SdCq5U4J+PpbSDIl9bM0V1e1U
golang.org/x/sys v0.0.0-20181031143558-9b800f95dbbc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2 h1:+DCIGbF/swA92ohVg0//6X2IVY3KZs6p9mix0ziNYJM=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf h1:rjxqQmxjyqerRKEj+tZW+MCm4LgpFXu18bsEoCMgDsk=
google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
@@ -213,20 +249,23 @@ google.golang.org/genproto v0.0.0-20180831171423-11092d34479b h1:lohp5blsw53GBXt
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/grpc v1.14.0 h1:ArxJuB1NWfPY6r9Gp9gqwplT0Ge7nqv9msgu03lHLmo=
google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
gopkg.in/DataDog/dd-trace-go.v1 v1.3.0 h1:5FIqJszYWD+FWV/fLSySU/XafqYVCJwiffzA3AZc1/4=
gopkg.in/DataDog/dd-trace-go.v1 v1.3.0/go.mod h1:DVp8HmDh8PuTu2Z0fVVlBsyWaC++fzwVCaGWylTe3tg=
gopkg.in/DataDog/dd-trace-go.v1 v1.10.0 h1:aKIe93NsKAn5Gm/A4nNO4hlPuKTnhaf+khqu0OgdzpQ=
gopkg.in/DataDog/dd-trace-go.v1 v1.10.0/go.mod h1:DVp8HmDh8PuTu2Z0fVVlBsyWaC++fzwVCaGWylTe3tg=
gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo=
gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM=
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
gopkg.in/go-playground/validator.v9 v9.20.2 h1:6AVDyt8bk0FDiSYSeWivUfzqEjHyVSCMRkpTr6ZCIgk=
gopkg.in/go-playground/validator.v9 v9.20.2/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+2 -2
View File
@@ -1,12 +1,12 @@
package storage
import (
"github.com/globalsign/mgo/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Module represents a vgo module saved in a storage backend.
type Module struct {
ID bson.ObjectId `bson:"_id,omitempty"`
ID primitive.ObjectID `bson:"_id,omitempty"`
Module string `bson:"module"`
Version string `bson:"version"`
Mod []byte `bson:"mod"`
+26 -10
View File
@@ -2,11 +2,13 @@ package mongo
import (
"context"
"github.com/globalsign/mgo/bson"
"github.com/gomods/athens/pkg/errors"
"github.com/gomods/athens/pkg/paths"
"github.com/gomods/athens/pkg/storage"
"github.com/hashicorp/go-multierror"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/options"
)
// Catalog implements the (./pkg/storage).Cataloger interface
@@ -15,23 +17,37 @@ func (s *ModuleStore) Catalog(ctx context.Context, token string, pageSize int) (
const op errors.Op = "mongo.Catalog"
q := bson.M{}
if token != "" {
q = bson.M{"_id": bson.M{"$gt": bson.ObjectIdHex(token)}}
t, err := primitive.ObjectIDFromHex(token)
if err == nil {
q = bson.M{"_id": bson.M{"$gt": t}}
}
}
fields := bson.M{"module": 1, "version": 1}
projection := bson.M{"module": 1, "version": 1}
sort := bson.M{"_id": 1}
c := s.s.DB(s.d).C(s.c)
c := s.client.Database(s.db).Collection(s.coll)
tctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
modules := make([]storage.Module, 0)
err := c.Find(q).
Select(fields).
Sort("_id").
Limit(pageSize).
All(&modules)
findOptions := options.Find().SetProjection(projection).SetSort(sort).SetLimit(int64(pageSize))
cursor, err := c.Find(tctx, q, findOptions)
if err != nil {
return nil, "", errors.E(op, err)
}
var errs error
for cursor.Next(ctx) {
var module storage.Module
if err := cursor.Decode(&module); err != nil {
errs = multierror.Append(errs, err)
} else {
modules = append(modules, module)
}
}
// If there are 0 results, return empty results without an error
if len(modules) == 0 {
return nil, "", nil
+6 -3
View File
@@ -3,9 +3,9 @@ package mongo
import (
"context"
"github.com/globalsign/mgo/bson"
"github.com/gomods/athens/pkg/errors"
"github.com/gomods/athens/pkg/observ"
"go.mongodb.org/mongo-driver/bson"
)
// Exists checks for a specific version of a module
@@ -13,8 +13,11 @@ func (s *ModuleStore) Exists(ctx context.Context, module, vsn string) (bool, err
var op errors.Op = "mongo.Exists"
ctx, span := observ.StartSpan(ctx, op.String())
defer span.End()
c := s.s.DB(s.d).C(s.c)
count, err := c.Find(bson.M{"module": module, "version": vsn}).Count()
c := s.client.Database(s.db).Collection(s.coll)
tctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
count, err := c.CountDocuments(tctx, bson.M{"module": module, "version": vsn})
if err != nil {
return false, errors.E(op, errors.M(module), errors.V(vsn), err)
}
+30 -8
View File
@@ -3,9 +3,12 @@ package mongo
import (
"context"
"github.com/globalsign/mgo/bson"
"github.com/gomods/athens/pkg/errors"
"github.com/gomods/athens/pkg/observ"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/gridfs"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/bsonx"
)
// Delete removes a specific version of a module
@@ -15,21 +18,40 @@ func (s *ModuleStore) Delete(ctx context.Context, module, version string) error
defer span.End()
exists, err := s.Exists(ctx, module, version)
if err != nil {
return errors.E(op, err, errors.M(module), errors.V(version))
return errors.E(op, errors.M(module), errors.V(version), errors.KindNotFound)
}
if !exists {
return errors.E(op, errors.M(module), errors.V(version), errors.KindNotFound)
}
db := s.s.DB(s.d)
c := db.C(s.c)
err = db.GridFS("fs").Remove(s.gridFileName(module, version))
db := s.client.Database(s.db)
c := db.Collection(s.coll)
bucket, err := gridfs.NewBucket(db, &options.BucketOptions{})
if err != nil {
return errors.E(op, err, errors.M(module), errors.V(version))
return errors.E(op, errors.M(module), errors.V(version), errors.KindNotFound)
}
err = c.Remove(bson.M{"module": module, "version": version})
filter := bsonx.Doc{}
filter = filter.Set("filename", bsonx.String(s.gridFileName(module, version)))
cursor, err := bucket.Find(filter)
var x bsonx.Doc
for cursor.Next(ctx) {
cursor.Decode(&x)
}
if err = bucket.Delete(x.Lookup("_id").ObjectID()); err != nil {
kind := errors.KindUnexpected
if err == gridfs.ErrFileNotFound {
kind = errors.KindNotFound
}
return errors.E(op, err, kind, errors.M(module), errors.V(version))
}
tctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
_, err = c.DeleteOne(tctx, bson.M{"module": module, "version": version})
if err != nil {
return errors.E(op, err, errors.M(module), errors.V(version))
return errors.E(op, err, errors.KindNotFound, errors.M(module), errors.V(version))
}
return nil
}
+39 -16
View File
@@ -4,11 +4,13 @@ import (
"context"
"io"
"github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson"
"github.com/gomods/athens/pkg/errors"
"github.com/gomods/athens/pkg/observ"
"github.com/gomods/athens/pkg/storage"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/gridfs"
"go.mongodb.org/mongo-driver/mongo/options"
)
// Info implements storage.Getter
@@ -16,15 +18,24 @@ func (s *ModuleStore) Info(ctx context.Context, module, vsn string) ([]byte, err
const op errors.Op = "mongo.Info"
ctx, span := observ.StartSpan(ctx, op.String())
defer span.End()
c := s.s.DB(s.d).C(s.c)
c := s.client.Database(s.db).Collection(s.coll)
result := &storage.Module{}
err := c.Find(bson.M{"module": module, "version": vsn}).One(result)
if err != nil {
tctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
queryResult := c.FindOne(tctx, bson.M{"module": module, "version": vsn})
if queryErr := queryResult.Err(); queryErr != nil {
return nil, errors.E(op, queryErr, errors.M(module), errors.V(vsn))
}
if err := queryResult.Decode(&result); err != nil {
kind := errors.KindUnexpected
if err == mgo.ErrNotFound {
if err == mongo.ErrNoDocuments {
kind = errors.KindNotFound
}
return nil, errors.E(op, kind, errors.M(module), errors.V(vsn), err)
return nil, errors.E(op, err, kind, errors.M(module), errors.V(vsn))
}
return result.Info, nil
@@ -35,15 +46,22 @@ func (s *ModuleStore) GoMod(ctx context.Context, module, vsn string) ([]byte, er
const op errors.Op = "mongo.GoMod"
ctx, span := observ.StartSpan(ctx, op.String())
defer span.End()
c := s.s.DB(s.d).C(s.c)
c := s.client.Database(s.db).Collection(s.coll)
result := &storage.Module{}
err := c.Find(bson.M{"module": module, "version": vsn}).One(result)
if err != nil {
tctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
queryResult := c.FindOne(tctx, bson.M{"module": module, "version": vsn})
if queryErr := queryResult.Err(); queryErr != nil {
return nil, errors.E(op, queryErr, errors.M(module), errors.V(vsn))
}
if err := queryResult.Decode(result); err != nil {
kind := errors.KindUnexpected
if err == mgo.ErrNotFound {
if err == mongo.ErrNoDocuments {
kind = errors.KindNotFound
}
return nil, errors.E(op, kind, errors.M(module), errors.V(vsn), err)
return nil, errors.E(op, err, kind, errors.M(module), errors.V(vsn))
}
return result.Mod, nil
@@ -56,15 +74,20 @@ func (s *ModuleStore) Zip(ctx context.Context, module, vsn string) (io.ReadClose
defer span.End()
zipName := s.gridFileName(module, vsn)
fs := s.s.DB(s.d).GridFS("fs")
f, err := fs.Open(zipName)
db := s.client.Database(s.db)
bucket, err := gridfs.NewBucket(db, &options.BucketOptions{})
if err != nil {
return nil, errors.E(op, err, errors.M(module), errors.V(vsn))
}
dStream, err := bucket.OpenDownloadStreamByName(zipName, options.GridFSName())
if err != nil {
kind := errors.KindUnexpected
if err == mgo.ErrNotFound {
if err == gridfs.ErrFileNotFound {
kind = errors.KindNotFound
}
return nil, errors.E(op, err, kind, errors.M(module), errors.V(vsn))
}
return f, nil
return dStream, nil
}
+24 -9
View File
@@ -3,28 +3,43 @@ package mongo
import (
"context"
"github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson"
"github.com/gomods/athens/pkg/errors"
"github.com/gomods/athens/pkg/observ"
"github.com/gomods/athens/pkg/storage"
multierror "github.com/hashicorp/go-multierror"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// List lists all versions of a module
func (s *ModuleStore) List(ctx context.Context, module string) ([]string, error) {
func (s *ModuleStore) List(ctx context.Context, moduleName string) ([]string, error) {
const op errors.Op = "mongo.List"
ctx, span := observ.StartSpan(ctx, op.String())
defer span.End()
c := s.s.DB(s.d).C(s.c)
fields := bson.M{"version": 1}
result := make([]storage.Module, 0)
err := c.Find(bson.M{"module": module}).Select(fields).All(&result)
c := s.client.Database(s.db).Collection(s.coll)
projection := bson.M{"version": 1, "_id": 0}
query := bson.M{"module": moduleName}
tctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
cursor, err := c.Find(tctx, query, &options.FindOptions{Projection: projection})
if err != nil {
return nil, errors.E(op, err, errors.M(moduleName))
}
result := make([]storage.Module, 0)
var errs error
for cursor.Next(ctx) {
var module storage.Module
if err := cursor.Decode(&module); err != nil {
kind := errors.KindUnexpected
if err == mgo.ErrNotFound {
if err == mongo.ErrNoDocuments {
kind = errors.KindNotFound
}
return nil, errors.E(op, kind, errors.M(module), err)
errs = multierror.Append(errs, errors.E(op, err, kind))
} else {
result = append(result, module)
}
}
versions := make([]string, len(result))
+45 -41
View File
@@ -1,24 +1,25 @@
package mongo
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"strings"
"time"
"github.com/globalsign/mgo"
"github.com/gomods/athens/pkg/config"
"github.com/gomods/athens/pkg/errors"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// ModuleStore represents a mongo backed storage backend.
type ModuleStore struct {
s *mgo.Session
d string // database
c string // collection
client *mongo.Client
db string // database
coll string // collection
url string
certPath string
insecure bool // Only to be used for development instances
@@ -32,9 +33,16 @@ func NewStorage(conf *config.MongoConfig, timeout time.Duration) (*ModuleStore,
if conf == nil {
return nil, errors.E(op, "No Mongo Configuration provided")
}
ms := &ModuleStore{url: conf.URL, certPath: conf.CertPath, timeout: timeout}
ms := &ModuleStore{url: conf.URL, certPath: conf.CertPath, timeout: timeout, insecure: conf.InsecureConn}
client, err := ms.newClient(conf)
ms.client = client
if err != nil {
return nil, errors.E(op, err)
}
_, err = ms.connect(conf)
err := ms.connect(conf)
if err != nil {
return nil, errors.E(op, err)
}
@@ -42,54 +50,48 @@ func NewStorage(conf *config.MongoConfig, timeout time.Duration) (*ModuleStore,
return ms, nil
}
func (m *ModuleStore) connect(conf *config.MongoConfig) error {
func (m *ModuleStore) connect(conf *config.MongoConfig) (*mongo.Collection, error) {
const op errors.Op = "mongo.connect"
var err error
m.s, err = m.newSession(m.timeout, m.insecure, conf)
err = m.client.Connect(context.Background())
if err != nil {
return errors.E(op, err)
return nil, errors.E(op, err)
}
return m.initDatabase()
return m.initDatabase(), nil
}
func (m *ModuleStore) initDatabase() error {
m.c = "modules"
func (m *ModuleStore) initDatabase() *mongo.Collection {
// TODO: database and collection as env vars, or params to New()? together with user/mongo
m.db = "athens"
m.coll = "modules"
index := mgo.Index{
Key: []string{"base_url", "module", "version"},
Unique: true,
Background: true,
Sparse: true,
}
c := m.s.DB(m.d).C(m.c)
return c.EnsureIndex(index)
c := m.client.Database(m.db).Collection(m.coll)
indexView := c.Indexes()
keys := make(map[string]int)
keys["base_url"] = 1
keys["module"] = 1
keys["version"] = 1
indexOptions := options.Index().SetBackground(true).SetSparse(true).SetUnique(true)
indexView.CreateOne(context.Background(), mongo.IndexModel{Keys: keys, Options: indexOptions}, options.CreateIndexes())
return c
}
func (m *ModuleStore) newSession(timeout time.Duration, insecure bool, conf *config.MongoConfig) (*mgo.Session, error) {
func (m *ModuleStore) newClient(conf *config.MongoConfig) (*mongo.Client, error) {
tlsConfig := &tls.Config{}
dialInfo, err := mgo.ParseURL(m.url)
if err != nil {
return nil, err
}
dialInfo.Timeout = timeout
if dialInfo.Database != "" {
m.d = dialInfo.Database
} else {
m.d = conf.DefaultDBName
}
clientOptions := options.Client()
// Maybe check for error using Validate()?
clientOptions = clientOptions.ApplyURI(m.url)
if m.certPath != "" {
// Sets only when the env var is setup in config.dev.toml
tlsConfig.InsecureSkipVerify = insecure
tlsConfig.InsecureSkipVerify = m.insecure
var roots *x509.CertPool
// See if there is a system cert pool
roots, err = x509.SystemCertPool()
roots, err := x509.SystemCertPool()
if err != nil {
// If there is no system cert pool, create a new one
roots = x509.NewCertPool()
@@ -105,13 +107,15 @@ func (m *ModuleStore) newSession(timeout time.Duration, insecure bool, conf *con
}
tlsConfig.ClientCAs = roots
dialInfo.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
return tls.Dial("tcp", addr.String(), tlsConfig)
clientOptions = clientOptions.SetTLSConfig(tlsConfig)
}
clientOptions = clientOptions.SetConnectTimeout(m.timeout)
client, err := mongo.NewClient(clientOptions)
if err != nil {
return nil, err
}
return mgo.DialWithInfo(dialInfo)
return client, nil
}
func (m *ModuleStore) gridFileName(mod, ver string) string {
+3 -2
View File
@@ -1,6 +1,7 @@
package mongo
import (
"context"
"os"
"testing"
@@ -15,8 +16,8 @@ func TestBackend(t *testing.T) {
}
func (m *ModuleStore) clear() error {
m.s.DB(m.d).DropDatabase()
return m.initDatabase()
m.client.Database(m.db).Drop(context.Background())
return nil
}
func BenchmarkBackend(b *testing.B) {
+17 -6
View File
@@ -8,6 +8,8 @@ import (
"github.com/gomods/athens/pkg/errors"
"github.com/gomods/athens/pkg/observ"
"github.com/gomods/athens/pkg/storage"
"go.mongodb.org/mongo-driver/mongo/gridfs"
"go.mongodb.org/mongo-driver/mongo/options"
)
// Save stores a module in mongo storage.
@@ -25,14 +27,20 @@ func (s *ModuleStore) Save(ctx context.Context, module, version string, mod []by
}
zipName := s.gridFileName(module, version)
fs := s.s.DB(s.d).GridFS("fs")
f, err := fs.Create(zipName)
db := s.client.Database(s.db)
bucket, err := gridfs.NewBucket(db, options.GridFSBucket())
if err != nil {
return errors.E(op, err, errors.M(module), errors.V(version))
}
defer f.Close()
numBytesWritten, err := io.Copy(f, zip)
uStream, err := bucket.OpenUploadStream(zipName, options.GridFSUpload())
if err != nil {
return errors.E(op, err, errors.M(module), errors.V(version))
}
defer uStream.Close()
numBytesWritten, err := io.Copy(uStream, zip)
if err != nil {
return errors.E(op, err, errors.M(module), errors.V(version))
}
@@ -48,8 +56,11 @@ func (s *ModuleStore) Save(ctx context.Context, module, version string, mod []by
Info: info,
}
c := s.s.DB(s.d).C(s.c)
err = c.Insert(m)
c := s.client.Database(s.db).Collection(s.coll)
tctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
_, err = c.InsertOne(tctx, m, options.InsertOne().SetBypassDocumentValidation(false))
if err != nil {
return errors.E(op, err, errors.M(module), errors.V(version))
}
+1 -1
View File
@@ -5,4 +5,4 @@
set -euo pipefail
GO_FILES=$(find . -iname '*.go' -type f | grep -v /vendor/) # All the .go files, excluding vendor/
test -z $(gofmt -s -l $GO_FILES | tee /dev/stderr)
test -z $(gofmt -s -d $GO_FILES | tee /dev/stderr)
+15
View File
@@ -0,0 +1,15 @@
language: go
sudo: false
go:
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x
- tip
before_install:
- go get github.com/mattn/goveralls
script:
- goveralls -service=travis-ci
+21
View File
@@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Chris Hines
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+38
View File
@@ -0,0 +1,38 @@
[![GoDoc](https://godoc.org/github.com/go-stack/stack?status.svg)](https://godoc.org/github.com/go-stack/stack)
[![Go Report Card](https://goreportcard.com/badge/go-stack/stack)](https://goreportcard.com/report/go-stack/stack)
[![TravisCI](https://travis-ci.org/go-stack/stack.svg?branch=master)](https://travis-ci.org/go-stack/stack)
[![Coverage Status](https://coveralls.io/repos/github/go-stack/stack/badge.svg?branch=master)](https://coveralls.io/github/go-stack/stack?branch=master)
# stack
Package stack implements utilities to capture, manipulate, and format call
stacks. It provides a simpler API than package runtime.
The implementation takes care of the minutia and special cases of interpreting
the program counter (pc) values returned by runtime.Callers.
## Versioning
Package stack publishes releases via [semver](http://semver.org/) compatible Git
tags prefixed with a single 'v'. The master branch always contains the latest
release. The develop branch contains unreleased commits.
## Formatting
Package stack's types implement fmt.Formatter, which provides a simple and
flexible way to declaratively configure formatting when used with logging or
error tracking packages.
```go
func DoTheThing() {
c := stack.Caller(0)
log.Print(c) // "source.go:10"
log.Printf("%+v", c) // "pkg/path/source.go:10"
log.Printf("%n", c) // "DoTheThing"
s := stack.Trace().TrimRuntime()
log.Print(s) // "[source.go:15 caller.go:42 main.go:14]"
}
```
See the docs for all of the supported formatting options.
+1
View File
@@ -0,0 +1 @@
module github.com/go-stack/stack
+400
View File
@@ -0,0 +1,400 @@
// +build go1.7
// Package stack implements utilities to capture, manipulate, and format call
// stacks. It provides a simpler API than package runtime.
//
// The implementation takes care of the minutia and special cases of
// interpreting the program counter (pc) values returned by runtime.Callers.
//
// Package stack's types implement fmt.Formatter, which provides a simple and
// flexible way to declaratively configure formatting when used with logging
// or error tracking packages.
package stack
import (
"bytes"
"errors"
"fmt"
"io"
"runtime"
"strconv"
"strings"
)
// Call records a single function invocation from a goroutine stack.
type Call struct {
frame runtime.Frame
}
// Caller returns a Call from the stack of the current goroutine. The argument
// skip is the number of stack frames to ascend, with 0 identifying the
// calling function.
func Caller(skip int) Call {
// As of Go 1.9 we need room for up to three PC entries.
//
// 0. An entry for the stack frame prior to the target to check for
// special handling needed if that prior entry is runtime.sigpanic.
// 1. A possible second entry to hold metadata about skipped inlined
// functions. If inline functions were not skipped the target frame
// PC will be here.
// 2. A third entry for the target frame PC when the second entry
// is used for skipped inline functions.
var pcs [3]uintptr
n := runtime.Callers(skip+1, pcs[:])
frames := runtime.CallersFrames(pcs[:n])
frame, _ := frames.Next()
frame, _ = frames.Next()
return Call{
frame: frame,
}
}
// String implements fmt.Stinger. It is equivalent to fmt.Sprintf("%v", c).
func (c Call) String() string {
return fmt.Sprint(c)
}
// MarshalText implements encoding.TextMarshaler. It formats the Call the same
// as fmt.Sprintf("%v", c).
func (c Call) MarshalText() ([]byte, error) {
if c.frame == (runtime.Frame{}) {
return nil, ErrNoFunc
}
buf := bytes.Buffer{}
fmt.Fprint(&buf, c)
return buf.Bytes(), nil
}
// ErrNoFunc means that the Call has a nil *runtime.Func. The most likely
// cause is a Call with the zero value.
var ErrNoFunc = errors.New("no call stack information")
// Format implements fmt.Formatter with support for the following verbs.
//
// %s source file
// %d line number
// %n function name
// %k last segment of the package path
// %v equivalent to %s:%d
//
// It accepts the '+' and '#' flags for most of the verbs as follows.
//
// %+s path of source file relative to the compile time GOPATH,
// or the module path joined to the path of source file relative
// to module root
// %#s full path of source file
// %+n import path qualified function name
// %+k full package path
// %+v equivalent to %+s:%d
// %#v equivalent to %#s:%d
func (c Call) Format(s fmt.State, verb rune) {
if c.frame == (runtime.Frame{}) {
fmt.Fprintf(s, "%%!%c(NOFUNC)", verb)
return
}
switch verb {
case 's', 'v':
file := c.frame.File
switch {
case s.Flag('#'):
// done
case s.Flag('+'):
file = pkgFilePath(&c.frame)
default:
const sep = "/"
if i := strings.LastIndex(file, sep); i != -1 {
file = file[i+len(sep):]
}
}
io.WriteString(s, file)
if verb == 'v' {
buf := [7]byte{':'}
s.Write(strconv.AppendInt(buf[:1], int64(c.frame.Line), 10))
}
case 'd':
buf := [6]byte{}
s.Write(strconv.AppendInt(buf[:0], int64(c.frame.Line), 10))
case 'k':
name := c.frame.Function
const pathSep = "/"
start, end := 0, len(name)
if i := strings.LastIndex(name, pathSep); i != -1 {
start = i + len(pathSep)
}
const pkgSep = "."
if i := strings.Index(name[start:], pkgSep); i != -1 {
end = start + i
}
if s.Flag('+') {
start = 0
}
io.WriteString(s, name[start:end])
case 'n':
name := c.frame.Function
if !s.Flag('+') {
const pathSep = "/"
if i := strings.LastIndex(name, pathSep); i != -1 {
name = name[i+len(pathSep):]
}
const pkgSep = "."
if i := strings.Index(name, pkgSep); i != -1 {
name = name[i+len(pkgSep):]
}
}
io.WriteString(s, name)
}
}
// Frame returns the call frame infomation for the Call.
func (c Call) Frame() runtime.Frame {
return c.frame
}
// PC returns the program counter for this call frame; multiple frames may
// have the same PC value.
//
// Deprecated: Use Call.Frame instead.
func (c Call) PC() uintptr {
return c.frame.PC
}
// CallStack records a sequence of function invocations from a goroutine
// stack.
type CallStack []Call
// String implements fmt.Stinger. It is equivalent to fmt.Sprintf("%v", cs).
func (cs CallStack) String() string {
return fmt.Sprint(cs)
}
var (
openBracketBytes = []byte("[")
closeBracketBytes = []byte("]")
spaceBytes = []byte(" ")
)
// MarshalText implements encoding.TextMarshaler. It formats the CallStack the
// same as fmt.Sprintf("%v", cs).
func (cs CallStack) MarshalText() ([]byte, error) {
buf := bytes.Buffer{}
buf.Write(openBracketBytes)
for i, pc := range cs {
if i > 0 {
buf.Write(spaceBytes)
}
fmt.Fprint(&buf, pc)
}
buf.Write(closeBracketBytes)
return buf.Bytes(), nil
}
// Format implements fmt.Formatter by printing the CallStack as square brackets
// ([, ]) surrounding a space separated list of Calls each formatted with the
// supplied verb and options.
func (cs CallStack) Format(s fmt.State, verb rune) {
s.Write(openBracketBytes)
for i, pc := range cs {
if i > 0 {
s.Write(spaceBytes)
}
pc.Format(s, verb)
}
s.Write(closeBracketBytes)
}
// Trace returns a CallStack for the current goroutine with element 0
// identifying the calling function.
func Trace() CallStack {
var pcs [512]uintptr
n := runtime.Callers(1, pcs[:])
frames := runtime.CallersFrames(pcs[:n])
cs := make(CallStack, 0, n)
// Skip extra frame retrieved just to make sure the runtime.sigpanic
// special case is handled.
frame, more := frames.Next()
for more {
frame, more = frames.Next()
cs = append(cs, Call{frame: frame})
}
return cs
}
// TrimBelow returns a slice of the CallStack with all entries below c
// removed.
func (cs CallStack) TrimBelow(c Call) CallStack {
for len(cs) > 0 && cs[0] != c {
cs = cs[1:]
}
return cs
}
// TrimAbove returns a slice of the CallStack with all entries above c
// removed.
func (cs CallStack) TrimAbove(c Call) CallStack {
for len(cs) > 0 && cs[len(cs)-1] != c {
cs = cs[:len(cs)-1]
}
return cs
}
// pkgIndex returns the index that results in file[index:] being the path of
// file relative to the compile time GOPATH, and file[:index] being the
// $GOPATH/src/ portion of file. funcName must be the name of a function in
// file as returned by runtime.Func.Name.
func pkgIndex(file, funcName string) int {
// As of Go 1.6.2 there is no direct way to know the compile time GOPATH
// at runtime, but we can infer the number of path segments in the GOPATH.
// We note that runtime.Func.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// file[:idx] == /home/user/src/
// file[idx:] == pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired result for file[idx:]. We count separators from the
// end of the file path until it finds two more than in the function name
// and then move one character forward to preserve the initial path
// segment without a leading separator.
const sep = "/"
i := len(file)
for n := strings.Count(funcName, sep) + 2; n > 0; n-- {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
return i + len(sep)
}
// pkgFilePath returns the frame's filepath relative to the compile-time GOPATH,
// or its module path joined to its path relative to the module root.
//
// As of Go 1.11 there is no direct way to know the compile time GOPATH or
// module paths at runtime, but we can piece together the desired information
// from available information. We note that runtime.Frame.Function contains the
// function name qualified by the package path, which includes the module path
// but not the GOPATH. We can extract the package path from that and append the
// last segments of the file path to arrive at the desired package qualified
// file path. For example, given:
//
// GOPATH /home/user
// import path pkg/sub
// frame.File /home/user/src/pkg/sub/file.go
// frame.Function pkg/sub.Type.Method
// Desired return pkg/sub/file.go
//
// It appears that we simply need to trim ".Type.Method" from frame.Function and
// append "/" + path.Base(file).
//
// But there are other wrinkles. Although it is idiomatic to do so, the internal
// name of a package is not required to match the last segment of its import
// path. In addition, the introduction of modules in Go 1.11 allows working
// without a GOPATH. So we also must make these work right:
//
// GOPATH /home/user
// import path pkg/go-sub
// package name sub
// frame.File /home/user/src/pkg/go-sub/file.go
// frame.Function pkg/sub.Type.Method
// Desired return pkg/go-sub/file.go
//
// Module path pkg/v2
// import path pkg/v2/go-sub
// package name sub
// frame.File /home/user/cloned-pkg/go-sub/file.go
// frame.Function pkg/v2/sub.Type.Method
// Desired return pkg/v2/go-sub/file.go
//
// We can handle all of these situations by using the package path extracted
// from frame.Function up to, but not including, the last segment as the prefix
// and the last two segments of frame.File as the suffix of the returned path.
// This preserves the existing behavior when working in a GOPATH without modules
// and a semantically equivalent behavior when used in module aware project.
func pkgFilePath(frame *runtime.Frame) string {
pre := pkgPrefix(frame.Function)
post := pathSuffix(frame.File)
if pre == "" {
return post
}
return pre + "/" + post
}
// pkgPrefix returns the import path of the function's package with the final
// segment removed.
func pkgPrefix(funcName string) string {
const pathSep = "/"
end := strings.LastIndex(funcName, pathSep)
if end == -1 {
return ""
}
return funcName[:end]
}
// pathSuffix returns the last two segments of path.
func pathSuffix(path string) string {
const pathSep = "/"
lastSep := strings.LastIndex(path, pathSep)
if lastSep == -1 {
return path
}
return path[strings.LastIndex(path[:lastSep], pathSep)+1:]
}
var runtimePath string
func init() {
var pcs [3]uintptr
runtime.Callers(0, pcs[:])
frames := runtime.CallersFrames(pcs[:])
frame, _ := frames.Next()
file := frame.File
idx := pkgIndex(frame.File, frame.Function)
runtimePath = file[:idx]
if runtime.GOOS == "windows" {
runtimePath = strings.ToLower(runtimePath)
}
}
func inGoroot(c Call) bool {
file := c.frame.File
if len(file) == 0 || file[0] == '?' {
return true
}
if runtime.GOOS == "windows" {
file = strings.ToLower(file)
}
return strings.HasPrefix(file, runtimePath) || strings.HasSuffix(file, "/_testmain.go")
}
// TrimRuntime returns a slice of the CallStack with the topmost entries from
// the go runtime removed. It considers any calls originating from unknown
// files, files under GOROOT, or _testmain.go as part of the runtime.
func (cs CallStack) TrimRuntime() CallStack {
for len(cs) > 0 && inGoroot(cs[len(cs)-1]) {
cs = cs[:len(cs)-1]
}
return cs
}
+16
View File
@@ -0,0 +1,16 @@
cmd/snappytool/snappytool
testdata/bench
# These explicitly listed benchmark data files are for an obsolete version of
# snappy_test.go.
testdata/alice29.txt
testdata/asyoulik.txt
testdata/fireworks.jpeg
testdata/geo.protodata
testdata/html
testdata/html_x_4
testdata/kppkn.gtb
testdata/lcet10.txt
testdata/paper-100k.pdf
testdata/plrabn12.txt
testdata/urls.10K
+15
View File
@@ -0,0 +1,15 @@
# This is the official list of Snappy-Go authors for copyright purposes.
# This file is distinct from the CONTRIBUTORS files.
# See the latter for an explanation.
# Names should be added to this file as
# Name or Organization <email address>
# The email address is not required for organizations.
# Please keep the list sorted.
Damian Gryski <dgryski@gmail.com>
Google Inc.
Jan Mercl <0xjnml@gmail.com>
Rodolfo Carvalho <rhcarvalho@gmail.com>
Sebastien Binet <seb.binet@gmail.com>
+37
View File
@@ -0,0 +1,37 @@
# This is the official list of people who can contribute
# (and typically have contributed) code to the Snappy-Go repository.
# The AUTHORS file lists the copyright holders; this file
# lists people. For example, Google employees are listed here
# but not in AUTHORS, because Google holds the copyright.
#
# The submission process automatically checks to make sure
# that people submitting code are listed in this file (by email address).
#
# Names should be added to this file only after verifying that
# the individual or the individual's organization has agreed to
# the appropriate Contributor License Agreement, found here:
#
# http://code.google.com/legal/individual-cla-v1.0.html
# http://code.google.com/legal/corporate-cla-v1.0.html
#
# The agreement for individuals can be filled out on the web.
#
# When adding J Random Contributor's name to this file,
# either J's name or J's organization's name should be
# added to the AUTHORS file, depending on whether the
# individual or corporate CLA was used.
# Names should be added to this file like so:
# Name <email address>
# Please keep the list sorted.
Damian Gryski <dgryski@gmail.com>
Jan Mercl <0xjnml@gmail.com>
Kai Backman <kaib@golang.org>
Marc-Antoine Ruel <maruel@chromium.org>
Nigel Tao <nigeltao@golang.org>
Rob Pike <r@golang.org>
Rodolfo Carvalho <rhcarvalho@gmail.com>
Russ Cox <rsc@golang.org>
Sebastien Binet <seb.binet@gmail.com>
+27
View File
@@ -0,0 +1,27 @@
Copyright (c) 2011 The Snappy-Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+107
View File
@@ -0,0 +1,107 @@
The Snappy compression format in the Go programming language.
To download and install from source:
$ go get github.com/golang/snappy
Unless otherwise noted, the Snappy-Go source files are distributed
under the BSD-style license found in the LICENSE file.
Benchmarks.
The golang/snappy benchmarks include compressing (Z) and decompressing (U) ten
or so files, the same set used by the C++ Snappy code (github.com/google/snappy
and note the "google", not "golang"). On an "Intel(R) Core(TM) i7-3770 CPU @
3.40GHz", Go's GOARCH=amd64 numbers as of 2016-05-29:
"go test -test.bench=."
_UFlat0-8 2.19GB/s ± 0% html
_UFlat1-8 1.41GB/s ± 0% urls
_UFlat2-8 23.5GB/s ± 2% jpg
_UFlat3-8 1.91GB/s ± 0% jpg_200
_UFlat4-8 14.0GB/s ± 1% pdf
_UFlat5-8 1.97GB/s ± 0% html4
_UFlat6-8 814MB/s ± 0% txt1
_UFlat7-8 785MB/s ± 0% txt2
_UFlat8-8 857MB/s ± 0% txt3
_UFlat9-8 719MB/s ± 1% txt4
_UFlat10-8 2.84GB/s ± 0% pb
_UFlat11-8 1.05GB/s ± 0% gaviota
_ZFlat0-8 1.04GB/s ± 0% html
_ZFlat1-8 534MB/s ± 0% urls
_ZFlat2-8 15.7GB/s ± 1% jpg
_ZFlat3-8 740MB/s ± 3% jpg_200
_ZFlat4-8 9.20GB/s ± 1% pdf
_ZFlat5-8 991MB/s ± 0% html4
_ZFlat6-8 379MB/s ± 0% txt1
_ZFlat7-8 352MB/s ± 0% txt2
_ZFlat8-8 396MB/s ± 1% txt3
_ZFlat9-8 327MB/s ± 1% txt4
_ZFlat10-8 1.33GB/s ± 1% pb
_ZFlat11-8 605MB/s ± 1% gaviota
"go test -test.bench=. -tags=noasm"
_UFlat0-8 621MB/s ± 2% html
_UFlat1-8 494MB/s ± 1% urls
_UFlat2-8 23.2GB/s ± 1% jpg
_UFlat3-8 1.12GB/s ± 1% jpg_200
_UFlat4-8 4.35GB/s ± 1% pdf
_UFlat5-8 609MB/s ± 0% html4
_UFlat6-8 296MB/s ± 0% txt1
_UFlat7-8 288MB/s ± 0% txt2
_UFlat8-8 309MB/s ± 1% txt3
_UFlat9-8 280MB/s ± 1% txt4
_UFlat10-8 753MB/s ± 0% pb
_UFlat11-8 400MB/s ± 0% gaviota
_ZFlat0-8 409MB/s ± 1% html
_ZFlat1-8 250MB/s ± 1% urls
_ZFlat2-8 12.3GB/s ± 1% jpg
_ZFlat3-8 132MB/s ± 0% jpg_200
_ZFlat4-8 2.92GB/s ± 0% pdf
_ZFlat5-8 405MB/s ± 1% html4
_ZFlat6-8 179MB/s ± 1% txt1
_ZFlat7-8 170MB/s ± 1% txt2
_ZFlat8-8 189MB/s ± 1% txt3
_ZFlat9-8 164MB/s ± 1% txt4
_ZFlat10-8 479MB/s ± 1% pb
_ZFlat11-8 270MB/s ± 1% gaviota
For comparison (Go's encoded output is byte-for-byte identical to C++'s), here
are the numbers from C++ Snappy's
make CXXFLAGS="-O2 -DNDEBUG -g" clean snappy_unittest.log && cat snappy_unittest.log
BM_UFlat/0 2.4GB/s html
BM_UFlat/1 1.4GB/s urls
BM_UFlat/2 21.8GB/s jpg
BM_UFlat/3 1.5GB/s jpg_200
BM_UFlat/4 13.3GB/s pdf
BM_UFlat/5 2.1GB/s html4
BM_UFlat/6 1.0GB/s txt1
BM_UFlat/7 959.4MB/s txt2
BM_UFlat/8 1.0GB/s txt3
BM_UFlat/9 864.5MB/s txt4
BM_UFlat/10 2.9GB/s pb
BM_UFlat/11 1.2GB/s gaviota
BM_ZFlat/0 944.3MB/s html (22.31 %)
BM_ZFlat/1 501.6MB/s urls (47.78 %)
BM_ZFlat/2 14.3GB/s jpg (99.95 %)
BM_ZFlat/3 538.3MB/s jpg_200 (73.00 %)
BM_ZFlat/4 8.3GB/s pdf (83.30 %)
BM_ZFlat/5 903.5MB/s html4 (22.52 %)
BM_ZFlat/6 336.0MB/s txt1 (57.88 %)
BM_ZFlat/7 312.3MB/s txt2 (61.91 %)
BM_ZFlat/8 353.1MB/s txt3 (54.99 %)
BM_ZFlat/9 289.9MB/s txt4 (66.26 %)
BM_ZFlat/10 1.2GB/s pb (19.68 %)
BM_ZFlat/11 527.4MB/s gaviota (37.72 %)
+237
View File
@@ -0,0 +1,237 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
"errors"
"io"
)
var (
// ErrCorrupt reports that the input is invalid.
ErrCorrupt = errors.New("snappy: corrupt input")
// ErrTooLarge reports that the uncompressed length is too large.
ErrTooLarge = errors.New("snappy: decoded block is too large")
// ErrUnsupported reports that the input isn't supported.
ErrUnsupported = errors.New("snappy: unsupported input")
errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length")
)
// DecodedLen returns the length of the decoded block.
func DecodedLen(src []byte) (int, error) {
v, _, err := decodedLen(src)
return v, err
}
// decodedLen returns the length of the decoded block and the number of bytes
// that the length header occupied.
func decodedLen(src []byte) (blockLen, headerLen int, err error) {
v, n := binary.Uvarint(src)
if n <= 0 || v > 0xffffffff {
return 0, 0, ErrCorrupt
}
const wordSize = 32 << (^uint(0) >> 32 & 1)
if wordSize == 32 && v > 0x7fffffff {
return 0, 0, ErrTooLarge
}
return int(v), n, nil
}
const (
decodeErrCodeCorrupt = 1
decodeErrCodeUnsupportedLiteralLength = 2
)
// Decode returns the decoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire decoded block.
// Otherwise, a newly allocated slice will be returned.
//
// The dst and src must not overlap. It is valid to pass a nil dst.
func Decode(dst, src []byte) ([]byte, error) {
dLen, s, err := decodedLen(src)
if err != nil {
return nil, err
}
if dLen <= len(dst) {
dst = dst[:dLen]
} else {
dst = make([]byte, dLen)
}
switch decode(dst, src[s:]) {
case 0:
return dst, nil
case decodeErrCodeUnsupportedLiteralLength:
return nil, errUnsupportedLiteralLength
}
return nil, ErrCorrupt
}
// NewReader returns a new Reader that decompresses from r, using the framing
// format described at
// https://github.com/google/snappy/blob/master/framing_format.txt
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
decoded: make([]byte, maxBlockSize),
buf: make([]byte, maxEncodedLenOfMaxBlockSize+checksumSize),
}
}
// Reader is an io.Reader that can read Snappy-compressed bytes.
type Reader struct {
r io.Reader
err error
decoded []byte
buf []byte
// decoded[i:j] contains decoded bytes that have not yet been passed on.
i, j int
readHeader bool
}
// Reset discards any buffered data, resets all state, and switches the Snappy
// reader to read from r. This permits reusing a Reader rather than allocating
// a new one.
func (r *Reader) Reset(reader io.Reader) {
r.r = reader
r.err = nil
r.i = 0
r.j = 0
r.readHeader = false
}
func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
if _, r.err = io.ReadFull(r.r, p); r.err != nil {
if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
r.err = ErrCorrupt
}
return false
}
return true
}
// Read satisfies the io.Reader interface.
func (r *Reader) Read(p []byte) (int, error) {
if r.err != nil {
return 0, r.err
}
for {
if r.i < r.j {
n := copy(p, r.decoded[r.i:r.j])
r.i += n
return n, nil
}
if !r.readFull(r.buf[:4], true) {
return 0, r.err
}
chunkType := r.buf[0]
if !r.readHeader {
if chunkType != chunkTypeStreamIdentifier {
r.err = ErrCorrupt
return 0, r.err
}
r.readHeader = true
}
chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
if chunkLen > len(r.buf) {
r.err = ErrUnsupported
return 0, r.err
}
// The chunk types are specified at
// https://github.com/google/snappy/blob/master/framing_format.txt
switch chunkType {
case chunkTypeCompressedData:
// Section 4.2. Compressed data (chunk type 0x00).
if chunkLen < checksumSize {
r.err = ErrCorrupt
return 0, r.err
}
buf := r.buf[:chunkLen]
if !r.readFull(buf, false) {
return 0, r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
buf = buf[checksumSize:]
n, err := DecodedLen(buf)
if err != nil {
r.err = err
return 0, r.err
}
if n > len(r.decoded) {
r.err = ErrCorrupt
return 0, r.err
}
if _, err := Decode(r.decoded, buf); err != nil {
r.err = err
return 0, r.err
}
if crc(r.decoded[:n]) != checksum {
r.err = ErrCorrupt
return 0, r.err
}
r.i, r.j = 0, n
continue
case chunkTypeUncompressedData:
// Section 4.3. Uncompressed data (chunk type 0x01).
if chunkLen < checksumSize {
r.err = ErrCorrupt
return 0, r.err
}
buf := r.buf[:checksumSize]
if !r.readFull(buf, false) {
return 0, r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
// Read directly into r.decoded instead of via r.buf.
n := chunkLen - checksumSize
if n > len(r.decoded) {
r.err = ErrCorrupt
return 0, r.err
}
if !r.readFull(r.decoded[:n], false) {
return 0, r.err
}
if crc(r.decoded[:n]) != checksum {
r.err = ErrCorrupt
return 0, r.err
}
r.i, r.j = 0, n
continue
case chunkTypeStreamIdentifier:
// Section 4.1. Stream identifier (chunk type 0xff).
if chunkLen != len(magicBody) {
r.err = ErrCorrupt
return 0, r.err
}
if !r.readFull(r.buf[:len(magicBody)], false) {
return 0, r.err
}
for i := 0; i < len(magicBody); i++ {
if r.buf[i] != magicBody[i] {
r.err = ErrCorrupt
return 0, r.err
}
}
continue
}
if chunkType <= 0x7f {
// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
r.err = ErrUnsupported
return 0, r.err
}
// Section 4.4 Padding (chunk type 0xfe).
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
if !r.readFull(r.buf[:chunkLen], false) {
return 0, r.err
}
}
}
+14
View File
@@ -0,0 +1,14 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
package snappy
// decode has the same semantics as in decode_other.go.
//
//go:noescape
func decode(dst, src []byte) int
+490
View File
@@ -0,0 +1,490 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
#include "textflag.h"
// The asm code generally follows the pure Go code in decode_other.go, except
// where marked with a "!!!".
// func decode(dst, src []byte) int
//
// All local variables fit into registers. The non-zero stack size is only to
// spill registers and push args when issuing a CALL. The register allocation:
// - AX scratch
// - BX scratch
// - CX length or x
// - DX offset
// - SI &src[s]
// - DI &dst[d]
// + R8 dst_base
// + R9 dst_len
// + R10 dst_base + dst_len
// + R11 src_base
// + R12 src_len
// + R13 src_base + src_len
// - R14 used by doCopy
// - R15 used by doCopy
//
// The registers R8-R13 (marked with a "+") are set at the start of the
// function, and after a CALL returns, and are not otherwise modified.
//
// The d variable is implicitly DI - R8, and len(dst)-d is R10 - DI.
// The s variable is implicitly SI - R11, and len(src)-s is R13 - SI.
TEXT ·decode(SB), NOSPLIT, $48-56
// Initialize SI, DI and R8-R13.
MOVQ dst_base+0(FP), R8
MOVQ dst_len+8(FP), R9
MOVQ R8, DI
MOVQ R8, R10
ADDQ R9, R10
MOVQ src_base+24(FP), R11
MOVQ src_len+32(FP), R12
MOVQ R11, SI
MOVQ R11, R13
ADDQ R12, R13
loop:
// for s < len(src)
CMPQ SI, R13
JEQ end
// CX = uint32(src[s])
//
// switch src[s] & 0x03
MOVBLZX (SI), CX
MOVL CX, BX
ANDL $3, BX
CMPL BX, $1
JAE tagCopy
// ----------------------------------------
// The code below handles literal tags.
// case tagLiteral:
// x := uint32(src[s] >> 2)
// switch
SHRL $2, CX
CMPL CX, $60
JAE tagLit60Plus
// case x < 60:
// s++
INCQ SI
doLit:
// This is the end of the inner "switch", when we have a literal tag.
//
// We assume that CX == x and x fits in a uint32, where x is the variable
// used in the pure Go decode_other.go code.
// length = int(x) + 1
//
// Unlike the pure Go code, we don't need to check if length <= 0 because
// CX can hold 64 bits, so the increment cannot overflow.
INCQ CX
// Prepare to check if copying length bytes will run past the end of dst or
// src.
//
// AX = len(dst) - d
// BX = len(src) - s
MOVQ R10, AX
SUBQ DI, AX
MOVQ R13, BX
SUBQ SI, BX
// !!! Try a faster technique for short (16 or fewer bytes) copies.
//
// if length > 16 || len(dst)-d < 16 || len(src)-s < 16 {
// goto callMemmove // Fall back on calling runtime·memmove.
// }
//
// The C++ snappy code calls this TryFastAppend. It also checks len(src)-s
// against 21 instead of 16, because it cannot assume that all of its input
// is contiguous in memory and so it needs to leave enough source bytes to
// read the next tag without refilling buffers, but Go's Decode assumes
// contiguousness (the src argument is a []byte).
CMPQ CX, $16
JGT callMemmove
CMPQ AX, $16
JLT callMemmove
CMPQ BX, $16
JLT callMemmove
// !!! Implement the copy from src to dst as a 16-byte load and store.
// (Decode's documentation says that dst and src must not overlap.)
//
// This always copies 16 bytes, instead of only length bytes, but that's
// OK. If the input is a valid Snappy encoding then subsequent iterations
// will fix up the overrun. Otherwise, Decode returns a nil []byte (and a
// non-nil error), so the overrun will be ignored.
//
// Note that on amd64, it is legal and cheap to issue unaligned 8-byte or
// 16-byte loads and stores. This technique probably wouldn't be as
// effective on architectures that are fussier about alignment.
MOVOU 0(SI), X0
MOVOU X0, 0(DI)
// d += length
// s += length
ADDQ CX, DI
ADDQ CX, SI
JMP loop
callMemmove:
// if length > len(dst)-d || length > len(src)-s { etc }
CMPQ CX, AX
JGT errCorrupt
CMPQ CX, BX
JGT errCorrupt
// copy(dst[d:], src[s:s+length])
//
// This means calling runtime·memmove(&dst[d], &src[s], length), so we push
// DI, SI and CX as arguments. Coincidentally, we also need to spill those
// three registers to the stack, to save local variables across the CALL.
MOVQ DI, 0(SP)
MOVQ SI, 8(SP)
MOVQ CX, 16(SP)
MOVQ DI, 24(SP)
MOVQ SI, 32(SP)
MOVQ CX, 40(SP)
CALL runtime·memmove(SB)
// Restore local variables: unspill registers from the stack and
// re-calculate R8-R13.
MOVQ 24(SP), DI
MOVQ 32(SP), SI
MOVQ 40(SP), CX
MOVQ dst_base+0(FP), R8
MOVQ dst_len+8(FP), R9
MOVQ R8, R10
ADDQ R9, R10
MOVQ src_base+24(FP), R11
MOVQ src_len+32(FP), R12
MOVQ R11, R13
ADDQ R12, R13
// d += length
// s += length
ADDQ CX, DI
ADDQ CX, SI
JMP loop
tagLit60Plus:
// !!! This fragment does the
//
// s += x - 58; if uint(s) > uint(len(src)) { etc }
//
// checks. In the asm version, we code it once instead of once per switch case.
ADDQ CX, SI
SUBQ $58, SI
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// case x == 60:
CMPL CX, $61
JEQ tagLit61
JA tagLit62Plus
// x = uint32(src[s-1])
MOVBLZX -1(SI), CX
JMP doLit
tagLit61:
// case x == 61:
// x = uint32(src[s-2]) | uint32(src[s-1])<<8
MOVWLZX -2(SI), CX
JMP doLit
tagLit62Plus:
CMPL CX, $62
JA tagLit63
// case x == 62:
// x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
MOVWLZX -3(SI), CX
MOVBLZX -1(SI), BX
SHLL $16, BX
ORL BX, CX
JMP doLit
tagLit63:
// case x == 63:
// x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
MOVL -4(SI), CX
JMP doLit
// The code above handles literal tags.
// ----------------------------------------
// The code below handles copy tags.
tagCopy4:
// case tagCopy4:
// s += 5
ADDQ $5, SI
// if uint(s) > uint(len(src)) { etc }
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// length = 1 + int(src[s-5])>>2
SHRQ $2, CX
INCQ CX
// offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
MOVLQZX -4(SI), DX
JMP doCopy
tagCopy2:
// case tagCopy2:
// s += 3
ADDQ $3, SI
// if uint(s) > uint(len(src)) { etc }
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// length = 1 + int(src[s-3])>>2
SHRQ $2, CX
INCQ CX
// offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
MOVWQZX -2(SI), DX
JMP doCopy
tagCopy:
// We have a copy tag. We assume that:
// - BX == src[s] & 0x03
// - CX == src[s]
CMPQ BX, $2
JEQ tagCopy2
JA tagCopy4
// case tagCopy1:
// s += 2
ADDQ $2, SI
// if uint(s) > uint(len(src)) { etc }
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
MOVQ CX, DX
ANDQ $0xe0, DX
SHLQ $3, DX
MOVBQZX -1(SI), BX
ORQ BX, DX
// length = 4 + int(src[s-2])>>2&0x7
SHRQ $2, CX
ANDQ $7, CX
ADDQ $4, CX
doCopy:
// This is the end of the outer "switch", when we have a copy tag.
//
// We assume that:
// - CX == length && CX > 0
// - DX == offset
// if offset <= 0 { etc }
CMPQ DX, $0
JLE errCorrupt
// if d < offset { etc }
MOVQ DI, BX
SUBQ R8, BX
CMPQ BX, DX
JLT errCorrupt
// if length > len(dst)-d { etc }
MOVQ R10, BX
SUBQ DI, BX
CMPQ CX, BX
JGT errCorrupt
// forwardCopy(dst[d:d+length], dst[d-offset:]); d += length
//
// Set:
// - R14 = len(dst)-d
// - R15 = &dst[d-offset]
MOVQ R10, R14
SUBQ DI, R14
MOVQ DI, R15
SUBQ DX, R15
// !!! Try a faster technique for short (16 or fewer bytes) forward copies.
//
// First, try using two 8-byte load/stores, similar to the doLit technique
// above. Even if dst[d:d+length] and dst[d-offset:] can overlap, this is
// still OK if offset >= 8. Note that this has to be two 8-byte load/stores
// and not one 16-byte load/store, and the first store has to be before the
// second load, due to the overlap if offset is in the range [8, 16).
//
// if length > 16 || offset < 8 || len(dst)-d < 16 {
// goto slowForwardCopy
// }
// copy 16 bytes
// d += length
CMPQ CX, $16
JGT slowForwardCopy
CMPQ DX, $8
JLT slowForwardCopy
CMPQ R14, $16
JLT slowForwardCopy
MOVQ 0(R15), AX
MOVQ AX, 0(DI)
MOVQ 8(R15), BX
MOVQ BX, 8(DI)
ADDQ CX, DI
JMP loop
slowForwardCopy:
// !!! If the forward copy is longer than 16 bytes, or if offset < 8, we
// can still try 8-byte load stores, provided we can overrun up to 10 extra
// bytes. As above, the overrun will be fixed up by subsequent iterations
// of the outermost loop.
//
// The C++ snappy code calls this technique IncrementalCopyFastPath. Its
// commentary says:
//
// ----
//
// The main part of this loop is a simple copy of eight bytes at a time
// until we've copied (at least) the requested amount of bytes. However,
// if d and d-offset are less than eight bytes apart (indicating a
// repeating pattern of length < 8), we first need to expand the pattern in
// order to get the correct results. For instance, if the buffer looks like
// this, with the eight-byte <d-offset> and <d> patterns marked as
// intervals:
//
// abxxxxxxxxxxxx
// [------] d-offset
// [------] d
//
// a single eight-byte copy from <d-offset> to <d> will repeat the pattern
// once, after which we can move <d> two bytes without moving <d-offset>:
//
// ababxxxxxxxxxx
// [------] d-offset
// [------] d
//
// and repeat the exercise until the two no longer overlap.
//
// This allows us to do very well in the special case of one single byte
// repeated many times, without taking a big hit for more general cases.
//
// The worst case of extra writing past the end of the match occurs when
// offset == 1 and length == 1; the last copy will read from byte positions
// [0..7] and write to [4..11], whereas it was only supposed to write to
// position 1. Thus, ten excess bytes.
//
// ----
//
// That "10 byte overrun" worst case is confirmed by Go's
// TestSlowForwardCopyOverrun, which also tests the fixUpSlowForwardCopy
// and finishSlowForwardCopy algorithm.
//
// if length > len(dst)-d-10 {
// goto verySlowForwardCopy
// }
SUBQ $10, R14
CMPQ CX, R14
JGT verySlowForwardCopy
makeOffsetAtLeast8:
// !!! As above, expand the pattern so that offset >= 8 and we can use
// 8-byte load/stores.
//
// for offset < 8 {
// copy 8 bytes from dst[d-offset:] to dst[d:]
// length -= offset
// d += offset
// offset += offset
// // The two previous lines together means that d-offset, and therefore
// // R15, is unchanged.
// }
CMPQ DX, $8
JGE fixUpSlowForwardCopy
MOVQ (R15), BX
MOVQ BX, (DI)
SUBQ DX, CX
ADDQ DX, DI
ADDQ DX, DX
JMP makeOffsetAtLeast8
fixUpSlowForwardCopy:
// !!! Add length (which might be negative now) to d (implied by DI being
// &dst[d]) so that d ends up at the right place when we jump back to the
// top of the loop. Before we do that, though, we save DI to AX so that, if
// length is positive, copying the remaining length bytes will write to the
// right place.
MOVQ DI, AX
ADDQ CX, DI
finishSlowForwardCopy:
// !!! Repeat 8-byte load/stores until length <= 0. Ending with a negative
// length means that we overrun, but as above, that will be fixed up by
// subsequent iterations of the outermost loop.
CMPQ CX, $0
JLE loop
MOVQ (R15), BX
MOVQ BX, (AX)
ADDQ $8, R15
ADDQ $8, AX
SUBQ $8, CX
JMP finishSlowForwardCopy
verySlowForwardCopy:
// verySlowForwardCopy is a simple implementation of forward copy. In C
// parlance, this is a do/while loop instead of a while loop, since we know
// that length > 0. In Go syntax:
//
// for {
// dst[d] = dst[d - offset]
// d++
// length--
// if length == 0 {
// break
// }
// }
MOVB (R15), BX
MOVB BX, (DI)
INCQ R15
INCQ DI
DECQ CX
JNZ verySlowForwardCopy
JMP loop
// The code above handles copy tags.
// ----------------------------------------
end:
// This is the end of the "for s < len(src)".
//
// if d != len(dst) { etc }
CMPQ DI, R10
JNE errCorrupt
// return 0
MOVQ $0, ret+48(FP)
RET
errCorrupt:
// return decodeErrCodeCorrupt
MOVQ $1, ret+48(FP)
RET
+101
View File
@@ -0,0 +1,101 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64 appengine !gc noasm
package snappy
// decode writes the decoding of src to dst. It assumes that the varint-encoded
// length of the decompressed bytes has already been read, and that len(dst)
// equals that length.
//
// It returns 0 on success or a decodeErrCodeXxx error code on failure.
func decode(dst, src []byte) int {
var d, s, offset, length int
for s < len(src) {
switch src[s] & 0x03 {
case tagLiteral:
x := uint32(src[s] >> 2)
switch {
case x < 60:
s++
case x == 60:
s += 2
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-1])
case x == 61:
s += 3
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-2]) | uint32(src[s-1])<<8
case x == 62:
s += 4
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
case x == 63:
s += 5
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
}
length = int(x) + 1
if length <= 0 {
return decodeErrCodeUnsupportedLiteralLength
}
if length > len(dst)-d || length > len(src)-s {
return decodeErrCodeCorrupt
}
copy(dst[d:], src[s:s+length])
d += length
s += length
continue
case tagCopy1:
s += 2
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
length = 4 + int(src[s-2])>>2&0x7
offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
case tagCopy2:
s += 3
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
length = 1 + int(src[s-3])>>2
offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
case tagCopy4:
s += 5
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
length = 1 + int(src[s-5])>>2
offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
}
if offset <= 0 || d < offset || length > len(dst)-d {
return decodeErrCodeCorrupt
}
// Copy from an earlier sub-slice of dst to a later sub-slice. Unlike
// the built-in copy function, this byte-by-byte copy always runs
// forwards, even if the slices overlap. Conceptually, this is:
//
// d += forwardCopy(dst[d:d+length], dst[d-offset:])
for end := d + length; d != end; d++ {
dst[d] = dst[d-offset]
}
}
if d != len(dst) {
return decodeErrCodeCorrupt
}
return 0
}
+285
View File
@@ -0,0 +1,285 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
"errors"
"io"
)
// Encode returns the encoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire encoded block.
// Otherwise, a newly allocated slice will be returned.
//
// The dst and src must not overlap. It is valid to pass a nil dst.
func Encode(dst, src []byte) []byte {
if n := MaxEncodedLen(len(src)); n < 0 {
panic(ErrTooLarge)
} else if len(dst) < n {
dst = make([]byte, n)
}
// The block starts with the varint-encoded length of the decompressed bytes.
d := binary.PutUvarint(dst, uint64(len(src)))
for len(src) > 0 {
p := src
src = nil
if len(p) > maxBlockSize {
p, src = p[:maxBlockSize], p[maxBlockSize:]
}
if len(p) < minNonLiteralBlockSize {
d += emitLiteral(dst[d:], p)
} else {
d += encodeBlock(dst[d:], p)
}
}
return dst[:d]
}
// inputMargin is the minimum number of extra input bytes to keep, inside
// encodeBlock's inner loop. On some architectures, this margin lets us
// implement a fast path for emitLiteral, where the copy of short (<= 16 byte)
// literals can be implemented as a single load to and store from a 16-byte
// register. That literal's actual length can be as short as 1 byte, so this
// can copy up to 15 bytes too much, but that's OK as subsequent iterations of
// the encoding loop will fix up the copy overrun, and this inputMargin ensures
// that we don't overrun the dst and src buffers.
const inputMargin = 16 - 1
// minNonLiteralBlockSize is the minimum size of the input to encodeBlock that
// could be encoded with a copy tag. This is the minimum with respect to the
// algorithm used by encodeBlock, not a minimum enforced by the file format.
//
// The encoded output must start with at least a 1 byte literal, as there are
// no previous bytes to copy. A minimal (1 byte) copy after that, generated
// from an emitCopy call in encodeBlock's main loop, would require at least
// another inputMargin bytes, for the reason above: we want any emitLiteral
// calls inside encodeBlock's main loop to use the fast path if possible, which
// requires being able to overrun by inputMargin bytes. Thus,
// minNonLiteralBlockSize equals 1 + 1 + inputMargin.
//
// The C++ code doesn't use this exact threshold, but it could, as discussed at
// https://groups.google.com/d/topic/snappy-compression/oGbhsdIJSJ8/discussion
// The difference between Go (2+inputMargin) and C++ (inputMargin) is purely an
// optimization. It should not affect the encoded form. This is tested by
// TestSameEncodingAsCppShortCopies.
const minNonLiteralBlockSize = 1 + 1 + inputMargin
// MaxEncodedLen returns the maximum length of a snappy block, given its
// uncompressed length.
//
// It will return a negative value if srcLen is too large to encode.
func MaxEncodedLen(srcLen int) int {
n := uint64(srcLen)
if n > 0xffffffff {
return -1
}
// Compressed data can be defined as:
// compressed := item* literal*
// item := literal* copy
//
// The trailing literal sequence has a space blowup of at most 62/60
// since a literal of length 60 needs one tag byte + one extra byte
// for length information.
//
// Item blowup is trickier to measure. Suppose the "copy" op copies
// 4 bytes of data. Because of a special check in the encoding code,
// we produce a 4-byte copy only if the offset is < 65536. Therefore
// the copy op takes 3 bytes to encode, and this type of item leads
// to at most the 62/60 blowup for representing literals.
//
// Suppose the "copy" op copies 5 bytes of data. If the offset is big
// enough, it will take 5 bytes to encode the copy op. Therefore the
// worst case here is a one-byte literal followed by a five-byte copy.
// That is, 6 bytes of input turn into 7 bytes of "compressed" data.
//
// This last factor dominates the blowup, so the final estimate is:
n = 32 + n + n/6
if n > 0xffffffff {
return -1
}
return int(n)
}
var errClosed = errors.New("snappy: Writer is closed")
// NewWriter returns a new Writer that compresses to w.
//
// The Writer returned does not buffer writes. There is no need to Flush or
// Close such a Writer.
//
// Deprecated: the Writer returned is not suitable for many small writes, only
// for few large writes. Use NewBufferedWriter instead, which is efficient
// regardless of the frequency and shape of the writes, and remember to Close
// that Writer when done.
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
obuf: make([]byte, obufLen),
}
}
// NewBufferedWriter returns a new Writer that compresses to w, using the
// framing format described at
// https://github.com/google/snappy/blob/master/framing_format.txt
//
// The Writer returned buffers writes. Users must call Close to guarantee all
// data has been forwarded to the underlying io.Writer. They may also call
// Flush zero or more times before calling Close.
func NewBufferedWriter(w io.Writer) *Writer {
return &Writer{
w: w,
ibuf: make([]byte, 0, maxBlockSize),
obuf: make([]byte, obufLen),
}
}
// Writer is an io.Writer that can write Snappy-compressed bytes.
type Writer struct {
w io.Writer
err error
// ibuf is a buffer for the incoming (uncompressed) bytes.
//
// Its use is optional. For backwards compatibility, Writers created by the
// NewWriter function have ibuf == nil, do not buffer incoming bytes, and
// therefore do not need to be Flush'ed or Close'd.
ibuf []byte
// obuf is a buffer for the outgoing (compressed) bytes.
obuf []byte
// wroteStreamHeader is whether we have written the stream header.
wroteStreamHeader bool
}
// Reset discards the writer's state and switches the Snappy writer to write to
// w. This permits reusing a Writer rather than allocating a new one.
func (w *Writer) Reset(writer io.Writer) {
w.w = writer
w.err = nil
if w.ibuf != nil {
w.ibuf = w.ibuf[:0]
}
w.wroteStreamHeader = false
}
// Write satisfies the io.Writer interface.
func (w *Writer) Write(p []byte) (nRet int, errRet error) {
if w.ibuf == nil {
// Do not buffer incoming bytes. This does not perform or compress well
// if the caller of Writer.Write writes many small slices. This
// behavior is therefore deprecated, but still supported for backwards
// compatibility with code that doesn't explicitly Flush or Close.
return w.write(p)
}
// The remainder of this method is based on bufio.Writer.Write from the
// standard library.
for len(p) > (cap(w.ibuf)-len(w.ibuf)) && w.err == nil {
var n int
if len(w.ibuf) == 0 {
// Large write, empty buffer.
// Write directly from p to avoid copy.
n, _ = w.write(p)
} else {
n = copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
w.ibuf = w.ibuf[:len(w.ibuf)+n]
w.Flush()
}
nRet += n
p = p[n:]
}
if w.err != nil {
return nRet, w.err
}
n := copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
w.ibuf = w.ibuf[:len(w.ibuf)+n]
nRet += n
return nRet, nil
}
func (w *Writer) write(p []byte) (nRet int, errRet error) {
if w.err != nil {
return 0, w.err
}
for len(p) > 0 {
obufStart := len(magicChunk)
if !w.wroteStreamHeader {
w.wroteStreamHeader = true
copy(w.obuf, magicChunk)
obufStart = 0
}
var uncompressed []byte
if len(p) > maxBlockSize {
uncompressed, p = p[:maxBlockSize], p[maxBlockSize:]
} else {
uncompressed, p = p, nil
}
checksum := crc(uncompressed)
// Compress the buffer, discarding the result if the improvement
// isn't at least 12.5%.
compressed := Encode(w.obuf[obufHeaderLen:], uncompressed)
chunkType := uint8(chunkTypeCompressedData)
chunkLen := 4 + len(compressed)
obufEnd := obufHeaderLen + len(compressed)
if len(compressed) >= len(uncompressed)-len(uncompressed)/8 {
chunkType = chunkTypeUncompressedData
chunkLen = 4 + len(uncompressed)
obufEnd = obufHeaderLen
}
// Fill in the per-chunk header that comes before the body.
w.obuf[len(magicChunk)+0] = chunkType
w.obuf[len(magicChunk)+1] = uint8(chunkLen >> 0)
w.obuf[len(magicChunk)+2] = uint8(chunkLen >> 8)
w.obuf[len(magicChunk)+3] = uint8(chunkLen >> 16)
w.obuf[len(magicChunk)+4] = uint8(checksum >> 0)
w.obuf[len(magicChunk)+5] = uint8(checksum >> 8)
w.obuf[len(magicChunk)+6] = uint8(checksum >> 16)
w.obuf[len(magicChunk)+7] = uint8(checksum >> 24)
if _, err := w.w.Write(w.obuf[obufStart:obufEnd]); err != nil {
w.err = err
return nRet, err
}
if chunkType == chunkTypeUncompressedData {
if _, err := w.w.Write(uncompressed); err != nil {
w.err = err
return nRet, err
}
}
nRet += len(uncompressed)
}
return nRet, nil
}
// Flush flushes the Writer to its underlying io.Writer.
func (w *Writer) Flush() error {
if w.err != nil {
return w.err
}
if len(w.ibuf) == 0 {
return nil
}
w.write(w.ibuf)
w.ibuf = w.ibuf[:0]
return w.err
}
// Close calls Flush and then closes the Writer.
func (w *Writer) Close() error {
w.Flush()
ret := w.err
if w.err == nil {
w.err = errClosed
}
return ret
}
+29
View File
@@ -0,0 +1,29 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
package snappy
// emitLiteral has the same semantics as in encode_other.go.
//
//go:noescape
func emitLiteral(dst, lit []byte) int
// emitCopy has the same semantics as in encode_other.go.
//
//go:noescape
func emitCopy(dst []byte, offset, length int) int
// extendMatch has the same semantics as in encode_other.go.
//
//go:noescape
func extendMatch(src []byte, i, j int) int
// encodeBlock has the same semantics as in encode_other.go.
//
//go:noescape
func encodeBlock(dst, src []byte) (d int)
+730
View File
@@ -0,0 +1,730 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
#include "textflag.h"
// The XXX lines assemble on Go 1.4, 1.5 and 1.7, but not 1.6, due to a
// Go toolchain regression. See https://github.com/golang/go/issues/15426 and
// https://github.com/golang/snappy/issues/29
//
// As a workaround, the package was built with a known good assembler, and
// those instructions were disassembled by "objdump -d" to yield the
// 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
// style comments, in AT&T asm syntax. Note that rsp here is a physical
// register, not Go/asm's SP pseudo-register (see https://golang.org/doc/asm).
// The instructions were then encoded as "BYTE $0x.." sequences, which assemble
// fine on Go 1.6.
// The asm code generally follows the pure Go code in encode_other.go, except
// where marked with a "!!!".
// ----------------------------------------------------------------------------
// func emitLiteral(dst, lit []byte) int
//
// All local variables fit into registers. The register allocation:
// - AX len(lit)
// - BX n
// - DX return value
// - DI &dst[i]
// - R10 &lit[0]
//
// The 24 bytes of stack space is to call runtime·memmove.
//
// The unusual register allocation of local variables, such as R10 for the
// source pointer, matches the allocation used at the call site in encodeBlock,
// which makes it easier to manually inline this function.
TEXT ·emitLiteral(SB), NOSPLIT, $24-56
MOVQ dst_base+0(FP), DI
MOVQ lit_base+24(FP), R10
MOVQ lit_len+32(FP), AX
MOVQ AX, DX
MOVL AX, BX
SUBL $1, BX
CMPL BX, $60
JLT oneByte
CMPL BX, $256
JLT twoBytes
threeBytes:
MOVB $0xf4, 0(DI)
MOVW BX, 1(DI)
ADDQ $3, DI
ADDQ $3, DX
JMP memmove
twoBytes:
MOVB $0xf0, 0(DI)
MOVB BX, 1(DI)
ADDQ $2, DI
ADDQ $2, DX
JMP memmove
oneByte:
SHLB $2, BX
MOVB BX, 0(DI)
ADDQ $1, DI
ADDQ $1, DX
memmove:
MOVQ DX, ret+48(FP)
// copy(dst[i:], lit)
//
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
// DI, R10 and AX as arguments.
MOVQ DI, 0(SP)
MOVQ R10, 8(SP)
MOVQ AX, 16(SP)
CALL runtime·memmove(SB)
RET
// ----------------------------------------------------------------------------
// func emitCopy(dst []byte, offset, length int) int
//
// All local variables fit into registers. The register allocation:
// - AX length
// - SI &dst[0]
// - DI &dst[i]
// - R11 offset
//
// The unusual register allocation of local variables, such as R11 for the
// offset, matches the allocation used at the call site in encodeBlock, which
// makes it easier to manually inline this function.
TEXT ·emitCopy(SB), NOSPLIT, $0-48
MOVQ dst_base+0(FP), DI
MOVQ DI, SI
MOVQ offset+24(FP), R11
MOVQ length+32(FP), AX
loop0:
// for length >= 68 { etc }
CMPL AX, $68
JLT step1
// Emit a length 64 copy, encoded as 3 bytes.
MOVB $0xfe, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $64, AX
JMP loop0
step1:
// if length > 64 { etc }
CMPL AX, $64
JLE step2
// Emit a length 60 copy, encoded as 3 bytes.
MOVB $0xee, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $60, AX
step2:
// if length >= 12 || offset >= 2048 { goto step3 }
CMPL AX, $12
JGE step3
CMPL R11, $2048
JGE step3
// Emit the remaining copy, encoded as 2 bytes.
MOVB R11, 1(DI)
SHRL $8, R11
SHLB $5, R11
SUBB $4, AX
SHLB $2, AX
ORB AX, R11
ORB $1, R11
MOVB R11, 0(DI)
ADDQ $2, DI
// Return the number of bytes written.
SUBQ SI, DI
MOVQ DI, ret+40(FP)
RET
step3:
// Emit the remaining copy, encoded as 3 bytes.
SUBL $1, AX
SHLB $2, AX
ORB $2, AX
MOVB AX, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
// Return the number of bytes written.
SUBQ SI, DI
MOVQ DI, ret+40(FP)
RET
// ----------------------------------------------------------------------------
// func extendMatch(src []byte, i, j int) int
//
// All local variables fit into registers. The register allocation:
// - DX &src[0]
// - SI &src[j]
// - R13 &src[len(src) - 8]
// - R14 &src[len(src)]
// - R15 &src[i]
//
// The unusual register allocation of local variables, such as R15 for a source
// pointer, matches the allocation used at the call site in encodeBlock, which
// makes it easier to manually inline this function.
TEXT ·extendMatch(SB), NOSPLIT, $0-48
MOVQ src_base+0(FP), DX
MOVQ src_len+8(FP), R14
MOVQ i+24(FP), R15
MOVQ j+32(FP), SI
ADDQ DX, R14
ADDQ DX, R15
ADDQ DX, SI
MOVQ R14, R13
SUBQ $8, R13
cmp8:
// As long as we are 8 or more bytes before the end of src, we can load and
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
CMPQ SI, R13
JA cmp1
MOVQ (R15), AX
MOVQ (SI), BX
CMPQ AX, BX
JNE bsf
ADDQ $8, R15
ADDQ $8, SI
JMP cmp8
bsf:
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
// the index of the first byte that differs. The BSF instruction finds the
// least significant 1 bit, the amd64 architecture is little-endian, and
// the shift by 3 converts a bit index to a byte index.
XORQ AX, BX
BSFQ BX, BX
SHRQ $3, BX
ADDQ BX, SI
// Convert from &src[ret] to ret.
SUBQ DX, SI
MOVQ SI, ret+40(FP)
RET
cmp1:
// In src's tail, compare 1 byte at a time.
CMPQ SI, R14
JAE extendMatchEnd
MOVB (R15), AX
MOVB (SI), BX
CMPB AX, BX
JNE extendMatchEnd
ADDQ $1, R15
ADDQ $1, SI
JMP cmp1
extendMatchEnd:
// Convert from &src[ret] to ret.
SUBQ DX, SI
MOVQ SI, ret+40(FP)
RET
// ----------------------------------------------------------------------------
// func encodeBlock(dst, src []byte) (d int)
//
// All local variables fit into registers, other than "var table". The register
// allocation:
// - AX . .
// - BX . .
// - CX 56 shift (note that amd64 shifts by non-immediates must use CX).
// - DX 64 &src[0], tableSize
// - SI 72 &src[s]
// - DI 80 &dst[d]
// - R9 88 sLimit
// - R10 . &src[nextEmit]
// - R11 96 prevHash, currHash, nextHash, offset
// - R12 104 &src[base], skip
// - R13 . &src[nextS], &src[len(src) - 8]
// - R14 . len(src), bytesBetweenHashLookups, &src[len(src)], x
// - R15 112 candidate
//
// The second column (56, 64, etc) is the stack offset to spill the registers
// when calling other functions. We could pack this slightly tighter, but it's
// simpler to have a dedicated spill map independent of the function called.
//
// "var table [maxTableSize]uint16" takes up 32768 bytes of stack space. An
// extra 56 bytes, to call other functions, and an extra 64 bytes, to spill
// local variables (registers) during calls gives 32768 + 56 + 64 = 32888.
TEXT ·encodeBlock(SB), 0, $32888-56
MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI
MOVQ src_len+32(FP), R14
// shift, tableSize := uint32(32-8), 1<<8
MOVQ $24, CX
MOVQ $256, DX
calcShift:
// for ; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 {
// shift--
// }
CMPQ DX, $16384
JGE varTable
CMPQ DX, R14
JGE varTable
SUBQ $1, CX
SHLQ $1, DX
JMP calcShift
varTable:
// var table [maxTableSize]uint16
//
// In the asm code, unlike the Go code, we can zero-initialize only the
// first tableSize elements. Each uint16 element is 2 bytes and each MOVOU
// writes 16 bytes, so we can do only tableSize/8 writes instead of the
// 2048 writes that would zero-initialize all of table's 32768 bytes.
SHRQ $3, DX
LEAQ table-32768(SP), BX
PXOR X0, X0
memclr:
MOVOU X0, 0(BX)
ADDQ $16, BX
SUBQ $1, DX
JNZ memclr
// !!! DX = &src[0]
MOVQ SI, DX
// sLimit := len(src) - inputMargin
MOVQ R14, R9
SUBQ $15, R9
// !!! Pre-emptively spill CX, DX and R9 to the stack. Their values don't
// change for the rest of the function.
MOVQ CX, 56(SP)
MOVQ DX, 64(SP)
MOVQ R9, 88(SP)
// nextEmit := 0
MOVQ DX, R10
// s := 1
ADDQ $1, SI
// nextHash := hash(load32(src, s), shift)
MOVL 0(SI), R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
outer:
// for { etc }
// skip := 32
MOVQ $32, R12
// nextS := s
MOVQ SI, R13
// candidate := 0
MOVQ $0, R15
inner0:
// for { etc }
// s := nextS
MOVQ R13, SI
// bytesBetweenHashLookups := skip >> 5
MOVQ R12, R14
SHRQ $5, R14
// nextS = s + bytesBetweenHashLookups
ADDQ R14, R13
// skip += bytesBetweenHashLookups
ADDQ R14, R12
// if nextS > sLimit { goto emitRemainder }
MOVQ R13, AX
SUBQ DX, AX
CMPQ AX, R9
JA emitRemainder
// candidate = int(table[nextHash])
// XXX: MOVWQZX table-32768(SP)(R11*2), R15
// XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
BYTE $0x4e
BYTE $0x0f
BYTE $0xb7
BYTE $0x7c
BYTE $0x5c
BYTE $0x78
// table[nextHash] = uint16(s)
MOVQ SI, AX
SUBQ DX, AX
// XXX: MOVW AX, table-32768(SP)(R11*2)
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
BYTE $0x66
BYTE $0x42
BYTE $0x89
BYTE $0x44
BYTE $0x5c
BYTE $0x78
// nextHash = hash(load32(src, nextS), shift)
MOVL 0(R13), R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// if load32(src, s) != load32(src, candidate) { continue } break
MOVL 0(SI), AX
MOVL (DX)(R15*1), BX
CMPL AX, BX
JNE inner0
fourByteMatch:
// As per the encode_other.go code:
//
// A 4-byte match has been found. We'll later see etc.
// !!! Jump to a fast path for short (<= 16 byte) literals. See the comment
// on inputMargin in encode.go.
MOVQ SI, AX
SUBQ R10, AX
CMPQ AX, $16
JLE emitLiteralFastPath
// ----------------------------------------
// Begin inline of the emitLiteral call.
//
// d += emitLiteral(dst[d:], src[nextEmit:s])
MOVL AX, BX
SUBL $1, BX
CMPL BX, $60
JLT inlineEmitLiteralOneByte
CMPL BX, $256
JLT inlineEmitLiteralTwoBytes
inlineEmitLiteralThreeBytes:
MOVB $0xf4, 0(DI)
MOVW BX, 1(DI)
ADDQ $3, DI
JMP inlineEmitLiteralMemmove
inlineEmitLiteralTwoBytes:
MOVB $0xf0, 0(DI)
MOVB BX, 1(DI)
ADDQ $2, DI
JMP inlineEmitLiteralMemmove
inlineEmitLiteralOneByte:
SHLB $2, BX
MOVB BX, 0(DI)
ADDQ $1, DI
inlineEmitLiteralMemmove:
// Spill local variables (registers) onto the stack; call; unspill.
//
// copy(dst[i:], lit)
//
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
// DI, R10 and AX as arguments.
MOVQ DI, 0(SP)
MOVQ R10, 8(SP)
MOVQ AX, 16(SP)
ADDQ AX, DI // Finish the "d +=" part of "d += emitLiteral(etc)".
MOVQ SI, 72(SP)
MOVQ DI, 80(SP)
MOVQ R15, 112(SP)
CALL runtime·memmove(SB)
MOVQ 56(SP), CX
MOVQ 64(SP), DX
MOVQ 72(SP), SI
MOVQ 80(SP), DI
MOVQ 88(SP), R9
MOVQ 112(SP), R15
JMP inner1
inlineEmitLiteralEnd:
// End inline of the emitLiteral call.
// ----------------------------------------
emitLiteralFastPath:
// !!! Emit the 1-byte encoding "uint8(len(lit)-1)<<2".
MOVB AX, BX
SUBB $1, BX
SHLB $2, BX
MOVB BX, (DI)
ADDQ $1, DI
// !!! Implement the copy from lit to dst as a 16-byte load and store.
// (Encode's documentation says that dst and src must not overlap.)
//
// This always copies 16 bytes, instead of only len(lit) bytes, but that's
// OK. Subsequent iterations will fix up the overrun.
//
// Note that on amd64, it is legal and cheap to issue unaligned 8-byte or
// 16-byte loads and stores. This technique probably wouldn't be as
// effective on architectures that are fussier about alignment.
MOVOU 0(R10), X0
MOVOU X0, 0(DI)
ADDQ AX, DI
inner1:
// for { etc }
// base := s
MOVQ SI, R12
// !!! offset := base - candidate
MOVQ R12, R11
SUBQ R15, R11
SUBQ DX, R11
// ----------------------------------------
// Begin inline of the extendMatch call.
//
// s = extendMatch(src, candidate+4, s+4)
// !!! R14 = &src[len(src)]
MOVQ src_len+32(FP), R14
ADDQ DX, R14
// !!! R13 = &src[len(src) - 8]
MOVQ R14, R13
SUBQ $8, R13
// !!! R15 = &src[candidate + 4]
ADDQ $4, R15
ADDQ DX, R15
// !!! s += 4
ADDQ $4, SI
inlineExtendMatchCmp8:
// As long as we are 8 or more bytes before the end of src, we can load and
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
CMPQ SI, R13
JA inlineExtendMatchCmp1
MOVQ (R15), AX
MOVQ (SI), BX
CMPQ AX, BX
JNE inlineExtendMatchBSF
ADDQ $8, R15
ADDQ $8, SI
JMP inlineExtendMatchCmp8
inlineExtendMatchBSF:
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
// the index of the first byte that differs. The BSF instruction finds the
// least significant 1 bit, the amd64 architecture is little-endian, and
// the shift by 3 converts a bit index to a byte index.
XORQ AX, BX
BSFQ BX, BX
SHRQ $3, BX
ADDQ BX, SI
JMP inlineExtendMatchEnd
inlineExtendMatchCmp1:
// In src's tail, compare 1 byte at a time.
CMPQ SI, R14
JAE inlineExtendMatchEnd
MOVB (R15), AX
MOVB (SI), BX
CMPB AX, BX
JNE inlineExtendMatchEnd
ADDQ $1, R15
ADDQ $1, SI
JMP inlineExtendMatchCmp1
inlineExtendMatchEnd:
// End inline of the extendMatch call.
// ----------------------------------------
// ----------------------------------------
// Begin inline of the emitCopy call.
//
// d += emitCopy(dst[d:], base-candidate, s-base)
// !!! length := s - base
MOVQ SI, AX
SUBQ R12, AX
inlineEmitCopyLoop0:
// for length >= 68 { etc }
CMPL AX, $68
JLT inlineEmitCopyStep1
// Emit a length 64 copy, encoded as 3 bytes.
MOVB $0xfe, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $64, AX
JMP inlineEmitCopyLoop0
inlineEmitCopyStep1:
// if length > 64 { etc }
CMPL AX, $64
JLE inlineEmitCopyStep2
// Emit a length 60 copy, encoded as 3 bytes.
MOVB $0xee, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $60, AX
inlineEmitCopyStep2:
// if length >= 12 || offset >= 2048 { goto inlineEmitCopyStep3 }
CMPL AX, $12
JGE inlineEmitCopyStep3
CMPL R11, $2048
JGE inlineEmitCopyStep3
// Emit the remaining copy, encoded as 2 bytes.
MOVB R11, 1(DI)
SHRL $8, R11
SHLB $5, R11
SUBB $4, AX
SHLB $2, AX
ORB AX, R11
ORB $1, R11
MOVB R11, 0(DI)
ADDQ $2, DI
JMP inlineEmitCopyEnd
inlineEmitCopyStep3:
// Emit the remaining copy, encoded as 3 bytes.
SUBL $1, AX
SHLB $2, AX
ORB $2, AX
MOVB AX, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
inlineEmitCopyEnd:
// End inline of the emitCopy call.
// ----------------------------------------
// nextEmit = s
MOVQ SI, R10
// if s >= sLimit { goto emitRemainder }
MOVQ SI, AX
SUBQ DX, AX
CMPQ AX, R9
JAE emitRemainder
// As per the encode_other.go code:
//
// We could immediately etc.
// x := load64(src, s-1)
MOVQ -1(SI), R14
// prevHash := hash(uint32(x>>0), shift)
MOVL R14, R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// table[prevHash] = uint16(s-1)
MOVQ SI, AX
SUBQ DX, AX
SUBQ $1, AX
// XXX: MOVW AX, table-32768(SP)(R11*2)
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
BYTE $0x66
BYTE $0x42
BYTE $0x89
BYTE $0x44
BYTE $0x5c
BYTE $0x78
// currHash := hash(uint32(x>>8), shift)
SHRQ $8, R14
MOVL R14, R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// candidate = int(table[currHash])
// XXX: MOVWQZX table-32768(SP)(R11*2), R15
// XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
BYTE $0x4e
BYTE $0x0f
BYTE $0xb7
BYTE $0x7c
BYTE $0x5c
BYTE $0x78
// table[currHash] = uint16(s)
ADDQ $1, AX
// XXX: MOVW AX, table-32768(SP)(R11*2)
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
BYTE $0x66
BYTE $0x42
BYTE $0x89
BYTE $0x44
BYTE $0x5c
BYTE $0x78
// if uint32(x>>8) == load32(src, candidate) { continue }
MOVL (DX)(R15*1), BX
CMPL R14, BX
JEQ inner1
// nextHash = hash(uint32(x>>16), shift)
SHRQ $8, R14
MOVL R14, R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// s++
ADDQ $1, SI
// break out of the inner1 for loop, i.e. continue the outer loop.
JMP outer
emitRemainder:
// if nextEmit < len(src) { etc }
MOVQ src_len+32(FP), AX
ADDQ DX, AX
CMPQ R10, AX
JEQ encodeBlockEnd
// d += emitLiteral(dst[d:], src[nextEmit:])
//
// Push args.
MOVQ DI, 0(SP)
MOVQ $0, 8(SP) // Unnecessary, as the callee ignores it, but conservative.
MOVQ $0, 16(SP) // Unnecessary, as the callee ignores it, but conservative.
MOVQ R10, 24(SP)
SUBQ R10, AX
MOVQ AX, 32(SP)
MOVQ AX, 40(SP) // Unnecessary, as the callee ignores it, but conservative.
// Spill local variables (registers) onto the stack; call; unspill.
MOVQ DI, 80(SP)
CALL ·emitLiteral(SB)
MOVQ 80(SP), DI
// Finish the "d +=" part of "d += emitLiteral(etc)".
ADDQ 48(SP), DI
encodeBlockEnd:
MOVQ dst_base+0(FP), AX
SUBQ AX, DI
MOVQ DI, d+48(FP)
RET
+238
View File
@@ -0,0 +1,238 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64 appengine !gc noasm
package snappy
func load32(b []byte, i int) uint32 {
b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func load64(b []byte, i int) uint64 {
b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
// emitLiteral writes a literal chunk and returns the number of bytes written.
//
// It assumes that:
// dst is long enough to hold the encoded bytes
// 1 <= len(lit) && len(lit) <= 65536
func emitLiteral(dst, lit []byte) int {
i, n := 0, uint(len(lit)-1)
switch {
case n < 60:
dst[0] = uint8(n)<<2 | tagLiteral
i = 1
case n < 1<<8:
dst[0] = 60<<2 | tagLiteral
dst[1] = uint8(n)
i = 2
default:
dst[0] = 61<<2 | tagLiteral
dst[1] = uint8(n)
dst[2] = uint8(n >> 8)
i = 3
}
return i + copy(dst[i:], lit)
}
// emitCopy writes a copy chunk and returns the number of bytes written.
//
// It assumes that:
// dst is long enough to hold the encoded bytes
// 1 <= offset && offset <= 65535
// 4 <= length && length <= 65535
func emitCopy(dst []byte, offset, length int) int {
i := 0
// The maximum length for a single tagCopy1 or tagCopy2 op is 64 bytes. The
// threshold for this loop is a little higher (at 68 = 64 + 4), and the
// length emitted down below is is a little lower (at 60 = 64 - 4), because
// it's shorter to encode a length 67 copy as a length 60 tagCopy2 followed
// by a length 7 tagCopy1 (which encodes as 3+2 bytes) than to encode it as
// a length 64 tagCopy2 followed by a length 3 tagCopy2 (which encodes as
// 3+3 bytes). The magic 4 in the 64±4 is because the minimum length for a
// tagCopy1 op is 4 bytes, which is why a length 3 copy has to be an
// encodes-as-3-bytes tagCopy2 instead of an encodes-as-2-bytes tagCopy1.
for length >= 68 {
// Emit a length 64 copy, encoded as 3 bytes.
dst[i+0] = 63<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
i += 3
length -= 64
}
if length > 64 {
// Emit a length 60 copy, encoded as 3 bytes.
dst[i+0] = 59<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
i += 3
length -= 60
}
if length >= 12 || offset >= 2048 {
// Emit the remaining copy, encoded as 3 bytes.
dst[i+0] = uint8(length-1)<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
return i + 3
}
// Emit the remaining copy, encoded as 2 bytes.
dst[i+0] = uint8(offset>>8)<<5 | uint8(length-4)<<2 | tagCopy1
dst[i+1] = uint8(offset)
return i + 2
}
// extendMatch returns the largest k such that k <= len(src) and that
// src[i:i+k-j] and src[j:k] have the same contents.
//
// It assumes that:
// 0 <= i && i < j && j <= len(src)
func extendMatch(src []byte, i, j int) int {
for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 {
}
return j
}
func hash(u, shift uint32) uint32 {
return (u * 0x1e35a7bd) >> shift
}
// encodeBlock encodes a non-empty src to a guaranteed-large-enough dst. It
// assumes that the varint-encoded length of the decompressed bytes has already
// been written.
//
// It also assumes that:
// len(dst) >= MaxEncodedLen(len(src)) &&
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
func encodeBlock(dst, src []byte) (d int) {
// Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive.
// The table element type is uint16, as s < sLimit and sLimit < len(src)
// and len(src) <= maxBlockSize and maxBlockSize == 65536.
const (
maxTableSize = 1 << 14
// tableMask is redundant, but helps the compiler eliminate bounds
// checks.
tableMask = maxTableSize - 1
)
shift := uint32(32 - 8)
for tableSize := 1 << 8; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 {
shift--
}
// In Go, all array elements are zero-initialized, so there is no advantage
// to a smaller tableSize per se. However, it matches the C++ algorithm,
// and in the asm versions of this code, we can get away with zeroing only
// the first tableSize elements.
var table [maxTableSize]uint16
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
nextHash := hash(load32(src, s), shift)
for {
// Copied from the C++ snappy implementation:
//
// Heuristic match skipping: If 32 bytes are scanned with no matches
// found, start looking only at every other byte. If 32 more bytes are
// scanned (or skipped), look at every third byte, etc.. When a match
// is found, immediately go back to looking at every byte. This is a
// small loss (~5% performance, ~0.1% density) for compressible data
// due to more bookkeeping, but for non-compressible data (such as
// JPEG) it's a huge win since the compressor quickly "realizes" the
// data is incompressible and doesn't bother looking for matches
// everywhere.
//
// The "skip" variable keeps track of how many bytes there are since
// the last match; dividing it by 32 (ie. right-shifting by five) gives
// the number of bytes to move ahead for each iteration.
skip := 32
nextS := s
candidate := 0
for {
s = nextS
bytesBetweenHashLookups := skip >> 5
nextS = s + bytesBetweenHashLookups
skip += bytesBetweenHashLookups
if nextS > sLimit {
goto emitRemainder
}
candidate = int(table[nextHash&tableMask])
table[nextHash&tableMask] = uint16(s)
nextHash = hash(load32(src, nextS), shift)
if load32(src, s) == load32(src, candidate) {
break
}
}
// A 4-byte match has been found. We'll later see if more than 4 bytes
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
// them as literal bytes.
d += emitLiteral(dst[d:], src[nextEmit:s])
// Call emitCopy, and then see if another emitCopy could be our next
// move. Repeat until we find no match for the input immediately after
// what was consumed by the last emitCopy call.
//
// If we exit this loop normally then we need to call emitLiteral next,
// though we don't yet know how big the literal will be. We handle that
// by proceeding to the next iteration of the main loop. We also can
// exit this loop via goto if we get close to exhausting the input.
for {
// Invariant: we have a 4-byte match at s, and no need to emit any
// literal bytes prior to s.
base := s
// Extend the 4-byte match as long as possible.
//
// This is an inlined version of:
// s = extendMatch(src, candidate+4, s+4)
s += 4
for i := candidate + 4; s < len(src) && src[i] == src[s]; i, s = i+1, s+1 {
}
d += emitCopy(dst[d:], base-candidate, s-base)
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
// We could immediately start working at s now, but to improve
// compression we first update the hash table at s-1 and at s. If
// another emitCopy is not our next move, also calculate nextHash
// at s+1. At least on GOARCH=amd64, these three hash calculations
// are faster as one load64 call (with some shifts) instead of
// three load32 calls.
x := load64(src, s-1)
prevHash := hash(uint32(x>>0), shift)
table[prevHash&tableMask] = uint16(s - 1)
currHash := hash(uint32(x>>8), shift)
candidate = int(table[currHash&tableMask])
table[currHash&tableMask] = uint16(s)
if uint32(x>>8) != load32(src, candidate) {
nextHash = hash(uint32(x>>16), shift)
s++
break
}
}
}
emitRemainder:
if nextEmit < len(src) {
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}
+1
View File
@@ -0,0 +1 @@
module github.com/golang/snappy
+98
View File
@@ -0,0 +1,98 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package snappy implements the Snappy compression format. It aims for very
// high speeds and reasonable compression.
//
// There are actually two Snappy formats: block and stream. They are related,
// but different: trying to decompress block-compressed data as a Snappy stream
// will fail, and vice versa. The block format is the Decode and Encode
// functions and the stream format is the Reader and Writer types.
//
// The block format, the more common case, is used when the complete size (the
// number of bytes) of the original data is known upfront, at the time
// compression starts. The stream format, also known as the framing format, is
// for when that isn't always true.
//
// The canonical, C++ implementation is at https://github.com/google/snappy and
// it only implements the block format.
package snappy // import "github.com/golang/snappy"
import (
"hash/crc32"
)
/*
Each encoded block begins with the varint-encoded length of the decoded data,
followed by a sequence of chunks. Chunks begin and end on byte boundaries. The
first byte of each chunk is broken into its 2 least and 6 most significant bits
called l and m: l ranges in [0, 4) and m ranges in [0, 64). l is the chunk tag.
Zero means a literal tag. All other values mean a copy tag.
For literal tags:
- If m < 60, the next 1 + m bytes are literal bytes.
- Otherwise, let n be the little-endian unsigned integer denoted by the next
m - 59 bytes. The next 1 + n bytes after that are literal bytes.
For copy tags, length bytes are copied from offset bytes ago, in the style of
Lempel-Ziv compression algorithms. In particular:
- For l == 1, the offset ranges in [0, 1<<11) and the length in [4, 12).
The length is 4 + the low 3 bits of m. The high 3 bits of m form bits 8-10
of the offset. The next byte is bits 0-7 of the offset.
- For l == 2, the offset ranges in [0, 1<<16) and the length in [1, 65).
The length is 1 + m. The offset is the little-endian unsigned integer
denoted by the next 2 bytes.
- For l == 3, this tag is a legacy format that is no longer issued by most
encoders. Nonetheless, the offset ranges in [0, 1<<32) and the length in
[1, 65). The length is 1 + m. The offset is the little-endian unsigned
integer denoted by the next 4 bytes.
*/
const (
tagLiteral = 0x00
tagCopy1 = 0x01
tagCopy2 = 0x02
tagCopy4 = 0x03
)
const (
checksumSize = 4
chunkHeaderSize = 4
magicChunk = "\xff\x06\x00\x00" + magicBody
magicBody = "sNaPpY"
// maxBlockSize is the maximum size of the input to encodeBlock. It is not
// part of the wire format per se, but some parts of the encoder assume
// that an offset fits into a uint16.
//
// Also, for the framing format (Writer type instead of Encode function),
// https://github.com/google/snappy/blob/master/framing_format.txt says
// that "the uncompressed data in a chunk must be no longer than 65536
// bytes".
maxBlockSize = 65536
// maxEncodedLenOfMaxBlockSize equals MaxEncodedLen(maxBlockSize), but is
// hard coded to be a const instead of a variable, so that obufLen can also
// be a const. Their equivalence is confirmed by
// TestMaxEncodedLenOfMaxBlockSize.
maxEncodedLenOfMaxBlockSize = 76490
obufHeaderLen = len(magicChunk) + checksumSize + chunkHeaderSize
obufLen = obufHeaderLen + maxEncodedLenOfMaxBlockSize
)
const (
chunkTypeCompressedData = 0x00
chunkTypeUncompressedData = 0x01
chunkTypePadding = 0xfe
chunkTypeStreamIdentifier = 0xff
)
var crcTable = crc32.MakeTable(crc32.Castagnoli)
// crc implements the checksum specified in section 3 of
// https://github.com/google/snappy/blob/master/framing_format.txt
func crc(b []byte) uint32 {
c := crc32.Update(0, crcTable, b)
return uint32(c>>15|c<<17) + 0xa282ead8
}
View File
+11
View File
@@ -0,0 +1,11 @@
language: go
sudo: false
go:
- "1.7"
- "1.8"
- "1.9"
- "1.10"
- master
matrix:
allow_failures:
- go: master
+175
View File
@@ -0,0 +1,175 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
+71
View File
@@ -0,0 +1,71 @@
[![GoDoc](https://godoc.org/github.com/xdg/scram?status.svg)](https://godoc.org/github.com/xdg/scram)
[![Build Status](https://travis-ci.org/xdg/scram.svg?branch=master)](https://travis-ci.org/xdg/scram)
# scram  Go implementation of RFC-5802
## Description
Package scram provides client and server implementations of the Salted
Challenge Response Authentication Mechanism (SCRAM) described in
[RFC-5802](https://tools.ietf.org/html/rfc5802) and
[RFC-7677](https://tools.ietf.org/html/rfc7677).
It includes both client and server side support.
Channel binding and extensions are not (yet) supported.
## Examples
### Client side
package main
import "github.com/xdg/scram"
func main() {
// Get Client with username, password and (optional) authorization ID.
clientSHA1, err := scram.SHA1.NewClient("mulder", "trustno1", "")
if err != nil {
panic(err)
}
// Prepare the authentication conversation. Use the empty string as the
// initial server message argument to start the conversation.
conv := clientSHA1.NewConversation()
var serverMsg string
// Get the first message, send it and read the response.
firstMsg, err := conv.Step(serverMsg)
if err != nil {
panic(err)
}
serverMsg = sendClientMsg(firstMsg)
// Get the second message, send it, and read the response.
secondMsg, err := conv.Step(serverMsg)
if err != nil {
panic(err)
}
serverMsg = sendClientMsg(secondMsg)
// Validate the server's final message. We have no further message to
// send so ignore that return value.
_, err = conv.Step(serverMsg)
if err != nil {
panic(err)
}
return
}
func sendClientMsg(s string) string {
// A real implementation would send this to a server and read a reply.
return ""
}
## Copyright and License
Copyright 2018 by David A. Golden. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"). You may
obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+130
View File
@@ -0,0 +1,130 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package scram
import (
"sync"
"golang.org/x/crypto/pbkdf2"
)
// Client implements the client side of SCRAM authentication. It holds
// configuration values needed to initialize new client-side conversations for
// a specific username, password and authorization ID tuple. Client caches
// the computationally-expensive parts of a SCRAM conversation as described in
// RFC-5802. If repeated authentication conversations may be required for a
// user (e.g. disconnect/reconnect), the user's Client should be preserved.
//
// For security reasons, Clients have a default minimum PBKDF2 iteration count
// of 4096. If a server requests a smaller iteration count, an authentication
// conversation will error.
//
// A Client can also be used by a server application to construct the hashed
// authentication values to be stored for a new user. See StoredCredentials()
// for more.
type Client struct {
sync.RWMutex
username string
password string
authzID string
minIters int
nonceGen NonceGeneratorFcn
hashGen HashGeneratorFcn
cache map[KeyFactors]derivedKeys
}
func newClient(username, password, authzID string, fcn HashGeneratorFcn) *Client {
return &Client{
username: username,
password: password,
authzID: authzID,
minIters: 4096,
nonceGen: defaultNonceGenerator,
hashGen: fcn,
cache: make(map[KeyFactors]derivedKeys),
}
}
// WithMinIterations changes minimum required PBKDF2 iteration count.
func (c *Client) WithMinIterations(n int) *Client {
c.Lock()
defer c.Unlock()
c.minIters = n
return c
}
// WithNonceGenerator replaces the default nonce generator (base64 encoding of
// 24 bytes from crypto/rand) with a custom generator. This is provided for
// testing or for users with custom nonce requirements.
func (c *Client) WithNonceGenerator(ng NonceGeneratorFcn) *Client {
c.Lock()
defer c.Unlock()
c.nonceGen = ng
return c
}
// NewConversation constructs a client-side authentication conversation.
// Conversations cannot be reused, so this must be called for each new
// authentication attempt.
func (c *Client) NewConversation() *ClientConversation {
c.RLock()
defer c.RUnlock()
return &ClientConversation{
client: c,
nonceGen: c.nonceGen,
hashGen: c.hashGen,
minIters: c.minIters,
}
}
func (c *Client) getDerivedKeys(kf KeyFactors) derivedKeys {
dk, ok := c.getCache(kf)
if !ok {
dk = c.computeKeys(kf)
c.setCache(kf, dk)
}
return dk
}
// GetStoredCredentials takes a salt and iteration count structure and
// provides the values that must be stored by a server to authentication a
// user. These values are what the Server credential lookup function must
// return for a given username.
func (c *Client) GetStoredCredentials(kf KeyFactors) StoredCredentials {
dk := c.getDerivedKeys(kf)
return StoredCredentials{
KeyFactors: kf,
StoredKey: dk.StoredKey,
ServerKey: dk.ServerKey,
}
}
func (c *Client) computeKeys(kf KeyFactors) derivedKeys {
h := c.hashGen()
saltedPassword := pbkdf2.Key([]byte(c.password), []byte(kf.Salt), kf.Iters, h.Size(), c.hashGen)
clientKey := computeHMAC(c.hashGen, saltedPassword, []byte("Client Key"))
return derivedKeys{
ClientKey: clientKey,
StoredKey: computeHash(c.hashGen, clientKey),
ServerKey: computeHMAC(c.hashGen, saltedPassword, []byte("Server Key")),
}
}
func (c *Client) getCache(kf KeyFactors) (derivedKeys, bool) {
c.RLock()
defer c.RUnlock()
dk, ok := c.cache[kf]
return dk, ok
}
func (c *Client) setCache(kf KeyFactors, dk derivedKeys) {
c.Lock()
defer c.Unlock()
c.cache[kf] = dk
return
}
+149
View File
@@ -0,0 +1,149 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package scram
import (
"crypto/hmac"
"encoding/base64"
"errors"
"fmt"
"strings"
)
type clientState int
const (
clientStarting clientState = iota
clientFirst
clientFinal
clientDone
)
// ClientConversation implements the client-side of an authentication
// conversation with a server. A new conversation must be created for
// each authentication attempt.
type ClientConversation struct {
client *Client
nonceGen NonceGeneratorFcn
hashGen HashGeneratorFcn
minIters int
state clientState
valid bool
gs2 string
nonce string
c1b string
serveSig []byte
}
// Step takes a string provided from a server (or just an empty string for the
// very first conversation step) and attempts to move the authentication
// conversation forward. It returns a string to be sent to the server or an
// error if the server message is invalid. Calling Step after a conversation
// completes is also an error.
func (cc *ClientConversation) Step(challenge string) (response string, err error) {
switch cc.state {
case clientStarting:
cc.state = clientFirst
response, err = cc.firstMsg()
case clientFirst:
cc.state = clientFinal
response, err = cc.finalMsg(challenge)
case clientFinal:
cc.state = clientDone
response, err = cc.validateServer(challenge)
default:
response, err = "", errors.New("Conversation already completed")
}
return
}
// Done returns true if the conversation is completed or has errored.
func (cc *ClientConversation) Done() bool {
return cc.state == clientDone
}
// Valid returns true if the conversation successfully authenticated with the
// server, including counter-validation that the server actually has the
// user's stored credentials.
func (cc *ClientConversation) Valid() bool {
return cc.valid
}
func (cc *ClientConversation) firstMsg() (string, error) {
// Values are cached for use in final message parameters
cc.gs2 = cc.gs2Header()
cc.nonce = cc.client.nonceGen()
cc.c1b = fmt.Sprintf("n=%s,r=%s", encodeName(cc.client.username), cc.nonce)
return cc.gs2 + cc.c1b, nil
}
func (cc *ClientConversation) finalMsg(s1 string) (string, error) {
msg, err := parseServerFirst(s1)
if err != nil {
return "", err
}
// Check nonce prefix and update
if !strings.HasPrefix(msg.nonce, cc.nonce) {
return "", errors.New("server nonce did not extend client nonce")
}
cc.nonce = msg.nonce
// Check iteration count vs minimum
if msg.iters < cc.minIters {
return "", fmt.Errorf("server requested too few iterations (%d)", msg.iters)
}
// Create client-final-message-without-proof
c2wop := fmt.Sprintf(
"c=%s,r=%s",
base64.StdEncoding.EncodeToString([]byte(cc.gs2)),
cc.nonce,
)
// Create auth message
authMsg := cc.c1b + "," + s1 + "," + c2wop
// Get derived keys from client cache
dk := cc.client.getDerivedKeys(KeyFactors{Salt: string(msg.salt), Iters: msg.iters})
// Create proof as clientkey XOR clientsignature
clientSignature := computeHMAC(cc.hashGen, dk.StoredKey, []byte(authMsg))
clientProof := xorBytes(dk.ClientKey, clientSignature)
proof := base64.StdEncoding.EncodeToString(clientProof)
// Cache ServerSignature for later validation
cc.serveSig = computeHMAC(cc.hashGen, dk.ServerKey, []byte(authMsg))
return fmt.Sprintf("%s,p=%s", c2wop, proof), nil
}
func (cc *ClientConversation) validateServer(s2 string) (string, error) {
msg, err := parseServerFinal(s2)
if err != nil {
return "", err
}
if len(msg.err) > 0 {
return "", fmt.Errorf("server error: %s", msg.err)
}
if !hmac.Equal(msg.verifier, cc.serveSig) {
return "", errors.New("server validation failed")
}
cc.valid = true
return "", nil
}
func (cc *ClientConversation) gs2Header() string {
if cc.client.authzID == "" {
return "n,,"
}
return fmt.Sprintf("n,%s,", encodeName(cc.client.authzID))
}
+97
View File
@@ -0,0 +1,97 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package scram
import (
"crypto/hmac"
"crypto/rand"
"encoding/base64"
"strings"
)
// NonceGeneratorFcn defines a function that returns a string of high-quality
// random printable ASCII characters EXCLUDING the comma (',') character. The
// default nonce generator provides Base64 encoding of 24 bytes from
// crypto/rand.
type NonceGeneratorFcn func() string
// derivedKeys collects the three cryptographically derived values
// into one struct for caching.
type derivedKeys struct {
ClientKey []byte
StoredKey []byte
ServerKey []byte
}
// KeyFactors represent the two server-provided factors needed to compute
// client credentials for authentication. Salt is decoded bytes (i.e. not
// base64), but in string form so that KeyFactors can be used as a map key for
// cached credentials.
type KeyFactors struct {
Salt string
Iters int
}
// StoredCredentials are the values that a server must store for a given
// username to allow authentication. They include the salt and iteration
// count, plus the derived values to authenticate a client and for the server
// to authenticate itself back to the client.
//
// NOTE: these are specific to a given hash function. To allow a user to
// authenticate with either SCRAM-SHA-1 or SCRAM-SHA-256, two sets of
// StoredCredentials must be created and stored, one for each hash function.
type StoredCredentials struct {
KeyFactors
StoredKey []byte
ServerKey []byte
}
// CredentialLookup is a callback to provide StoredCredentials for a given
// username. This is used to configure Server objects.
//
// NOTE: these are specific to a given hash function. The callback provided
// to a Server with a given hash function must provide the corresponding
// StoredCredentials.
type CredentialLookup func(string) (StoredCredentials, error)
func defaultNonceGenerator() string {
raw := make([]byte, 24)
nonce := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
rand.Read(raw)
base64.StdEncoding.Encode(nonce, raw)
return string(nonce)
}
func encodeName(s string) string {
return strings.Replace(strings.Replace(s, "=", "=3D", -1), ",", "=2C", -1)
}
func decodeName(s string) (string, error) {
// TODO Check for = not followed by 2C or 3D
return strings.Replace(strings.Replace(s, "=2C", ",", -1), "=3D", "=", -1), nil
}
func computeHash(hg HashGeneratorFcn, b []byte) []byte {
h := hg()
h.Write(b)
return h.Sum(nil)
}
func computeHMAC(hg HashGeneratorFcn, key, data []byte) []byte {
mac := hmac.New(hg, key)
mac.Write(data)
return mac.Sum(nil)
}
func xorBytes(a, b []byte) []byte {
// TODO check a & b are same length, or just xor to smallest
xor := make([]byte, len(a))
for i := range a {
xor[i] = a[i] ^ b[i]
}
return xor
}
+24
View File
@@ -0,0 +1,24 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package scram provides client and server implementations of the Salted
// Challenge Response Authentication Mechanism (SCRAM) described in RFC-5802
// and RFC-7677.
//
// Usage
//
// The scram package provides two variables, `SHA1` and `SHA256`, that are
// used to construct Client or Server objects.
//
// clientSHA1, err := scram.SHA1.NewClient(username, password, authID)
// clientSHA256, err := scram.SHA256.NewClient(username, password, authID)
//
// serverSHA1, err := scram.SHA1.NewServer(credentialLookupFcn)
// serverSHA256, err := scram.SHA256.NewServer(credentialLookupFcn)
//
// These objects are used to construct ClientConversation or
// ServerConversation objects that are used to carry out authentication.
package scram
+205
View File
@@ -0,0 +1,205 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package scram
import (
"encoding/base64"
"errors"
"fmt"
"strconv"
"strings"
)
type c1Msg struct {
gs2Header string
authzID string
username string
nonce string
c1b string
}
type c2Msg struct {
cbind []byte
nonce string
proof []byte
c2wop string
}
type s1Msg struct {
nonce string
salt []byte
iters int
}
type s2Msg struct {
verifier []byte
err string
}
func parseField(s, k string) (string, error) {
t := strings.TrimPrefix(s, k+"=")
if t == s {
return "", fmt.Errorf("error parsing '%s' for field '%s'", s, k)
}
return t, nil
}
func parseGS2Flag(s string) (string, error) {
if s[0] == 'p' {
return "", fmt.Errorf("channel binding requested but not supported")
}
if s == "n" || s == "y" {
return s, nil
}
return "", fmt.Errorf("error parsing '%s' for gs2 flag", s)
}
func parseFieldBase64(s, k string) ([]byte, error) {
raw, err := parseField(s, k)
if err != nil {
return nil, err
}
dec, err := base64.StdEncoding.DecodeString(raw)
if err != nil {
return nil, err
}
return dec, nil
}
func parseFieldInt(s, k string) (int, error) {
raw, err := parseField(s, k)
if err != nil {
return 0, err
}
num, err := strconv.Atoi(raw)
if err != nil {
return 0, fmt.Errorf("error parsing field '%s': %v", k, err)
}
return num, nil
}
func parseClientFirst(c1 string) (msg c1Msg, err error) {
fields := strings.Split(c1, ",")
if len(fields) < 4 {
err = errors.New("not enough fields in first server message")
return
}
gs2flag, err := parseGS2Flag(fields[0])
if err != nil {
return
}
// 'a' field is optional
if len(fields[1]) > 0 {
msg.authzID, err = parseField(fields[1], "a")
if err != nil {
return
}
}
// Recombine and save the gs2 header
msg.gs2Header = gs2flag + "," + msg.authzID + ","
// Check for unsupported extensions field "m".
if strings.HasPrefix(fields[2], "m=") {
err = errors.New("SCRAM message extensions are not supported")
return
}
msg.username, err = parseField(fields[2], "n")
if err != nil {
return
}
msg.nonce, err = parseField(fields[3], "r")
if err != nil {
return
}
msg.c1b = strings.Join(fields[2:], ",")
return
}
func parseClientFinal(c2 string) (msg c2Msg, err error) {
fields := strings.Split(c2, ",")
if len(fields) < 3 {
err = errors.New("not enough fields in first server message")
return
}
msg.cbind, err = parseFieldBase64(fields[0], "c")
if err != nil {
return
}
msg.nonce, err = parseField(fields[1], "r")
if err != nil {
return
}
// Extension fields may come between nonce and proof, so we
// grab the *last* fields as proof.
msg.proof, err = parseFieldBase64(fields[len(fields)-1], "p")
if err != nil {
return
}
msg.c2wop = c2[:strings.LastIndex(c2, ",")]
return
}
func parseServerFirst(s1 string) (msg s1Msg, err error) {
// Check for unsupported extensions field "m".
if strings.HasPrefix(s1, "m=") {
err = errors.New("SCRAM message extensions are not supported")
return
}
fields := strings.Split(s1, ",")
if len(fields) < 3 {
err = errors.New("not enough fields in first server message")
return
}
msg.nonce, err = parseField(fields[0], "r")
if err != nil {
return
}
msg.salt, err = parseFieldBase64(fields[1], "s")
if err != nil {
return
}
msg.iters, err = parseFieldInt(fields[2], "i")
return
}
func parseServerFinal(s2 string) (msg s2Msg, err error) {
fields := strings.Split(s2, ",")
msg.verifier, err = parseFieldBase64(fields[0], "v")
if err == nil {
return
}
msg.err, err = parseField(fields[0], "e")
return
}
+66
View File
@@ -0,0 +1,66 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package scram
import (
"crypto/sha1"
"crypto/sha256"
"fmt"
"hash"
"github.com/xdg/stringprep"
)
// HashGeneratorFcn abstracts a factory function that returns a hash.Hash
// value to be used for SCRAM operations. Generally, one would use the
// provided package variables, `scram.SHA1` and `scram.SHA256`, for the most
// common forms of SCRAM.
type HashGeneratorFcn func() hash.Hash
// SHA1 is a function that returns a crypto/sha1 hasher and should be used to
// create Client objects configured for SHA-1 hashing.
var SHA1 HashGeneratorFcn = func() hash.Hash { return sha1.New() }
// SHA256 is a function that returns a crypto/sha256 hasher and should be used
// to create Client objects configured for SHA-256 hashing.
var SHA256 HashGeneratorFcn = func() hash.Hash { return sha256.New() }
// NewClient constructs a SCRAM client component based on a given hash.Hash
// factory receiver. This constructor will normalize the username, password
// and authzID via the SASLprep algorithm, as recommended by RFC-5802. If
// SASLprep fails, the method returns an error.
func (f HashGeneratorFcn) NewClient(username, password, authzID string) (*Client, error) {
var userprep, passprep, authprep string
var err error
if userprep, err = stringprep.SASLprep.Prepare(username); err != nil {
return nil, fmt.Errorf("Error SASLprepping username '%s': %v", username, err)
}
if passprep, err = stringprep.SASLprep.Prepare(password); err != nil {
return nil, fmt.Errorf("Error SASLprepping password '%s': %v", password, err)
}
if authprep, err = stringprep.SASLprep.Prepare(authzID); err != nil {
return nil, fmt.Errorf("Error SASLprepping authzID '%s': %v", authzID, err)
}
return newClient(userprep, passprep, authprep, f), nil
}
// NewClientUnprepped acts like NewClient, except none of the arguments will
// be normalized via SASLprep. This is not generally recommended, but is
// provided for users that may have custom normalization needs.
func (f HashGeneratorFcn) NewClientUnprepped(username, password, authzID string) (*Client, error) {
return newClient(username, password, authzID, f), nil
}
// NewServer constructs a SCRAM server component based on a given hash.Hash
// factory receiver. To be maximally generic, it uses dependency injection to
// handle credential lookup, which is the process of turning a username string
// into a struct with stored credentials for authentication.
func (f HashGeneratorFcn) NewServer(cl CredentialLookup) (*Server, error) {
return newServer(cl, f)
}
+50
View File
@@ -0,0 +1,50 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package scram
import "sync"
// Server implements the server side of SCRAM authentication. It holds
// configuration values needed to initialize new server-side conversations.
// Generally, this can be persistent within an application.
type Server struct {
sync.RWMutex
credentialCB CredentialLookup
nonceGen NonceGeneratorFcn
hashGen HashGeneratorFcn
}
func newServer(cl CredentialLookup, fcn HashGeneratorFcn) (*Server, error) {
return &Server{
credentialCB: cl,
nonceGen: defaultNonceGenerator,
hashGen: fcn,
}, nil
}
// WithNonceGenerator replaces the default nonce generator (base64 encoding of
// 24 bytes from crypto/rand) with a custom generator. This is provided for
// testing or for users with custom nonce requirements.
func (s *Server) WithNonceGenerator(ng NonceGeneratorFcn) *Server {
s.Lock()
defer s.Unlock()
s.nonceGen = ng
return s
}
// NewConversation constructs a server-side authentication conversation.
// Conversations cannot be reused, so this must be called for each new
// authentication attempt.
func (s *Server) NewConversation() *ServerConversation {
s.RLock()
defer s.RUnlock()
return &ServerConversation{
nonceGen: s.nonceGen,
hashGen: s.hashGen,
credentialCB: s.credentialCB,
}
}
+151
View File
@@ -0,0 +1,151 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package scram
import (
"crypto/hmac"
"encoding/base64"
"errors"
"fmt"
)
type serverState int
const (
serverFirst serverState = iota
serverFinal
serverDone
)
// ServerConversation implements the server-side of an authentication
// conversation with a client. A new conversation must be created for
// each authentication attempt.
type ServerConversation struct {
nonceGen NonceGeneratorFcn
hashGen HashGeneratorFcn
credentialCB CredentialLookup
state serverState
credential StoredCredentials
valid bool
gs2Header string
username string
authzID string
nonce string
c1b string
s1 string
}
// Step takes a string provided from a client and attempts to move the
// authentication conversation forward. It returns a string to be sent to the
// client or an error if the client message is invalid. Calling Step after a
// conversation completes is also an error.
func (sc *ServerConversation) Step(challenge string) (response string, err error) {
switch sc.state {
case serverFirst:
sc.state = serverFinal
response, err = sc.firstMsg(challenge)
case serverFinal:
sc.state = serverDone
response, err = sc.finalMsg(challenge)
default:
response, err = "", errors.New("Conversation already completed")
}
return
}
// Done returns true if the conversation is completed or has errored.
func (sc *ServerConversation) Done() bool {
return sc.state == serverDone
}
// Valid returns true if the conversation successfully authenticated the
// client.
func (sc *ServerConversation) Valid() bool {
return sc.valid
}
// Username returns the client-provided username. This is valid to call
// if the first conversation Step() is successful.
func (sc *ServerConversation) Username() string {
return sc.username
}
// AuthzID returns the (optional) client-provided authorization identity, if
// any. If one was not provided, it returns the empty string. This is valid
// to call if the first conversation Step() is successful.
func (sc *ServerConversation) AuthzID() string {
return sc.authzID
}
func (sc *ServerConversation) firstMsg(c1 string) (string, error) {
msg, err := parseClientFirst(c1)
if err != nil {
sc.state = serverDone
return "", err
}
sc.gs2Header = msg.gs2Header
sc.username = msg.username
sc.authzID = msg.authzID
sc.credential, err = sc.credentialCB(msg.username)
if err != nil {
sc.state = serverDone
return "e=unknown-user", err
}
sc.nonce = msg.nonce + sc.nonceGen()
sc.c1b = msg.c1b
sc.s1 = fmt.Sprintf("r=%s,s=%s,i=%d",
sc.nonce,
base64.StdEncoding.EncodeToString([]byte(sc.credential.Salt)),
sc.credential.Iters,
)
return sc.s1, nil
}
// For errors, returns server error message as well as non-nil error. Callers
// can choose whether to send server error or not.
func (sc *ServerConversation) finalMsg(c2 string) (string, error) {
msg, err := parseClientFinal(c2)
if err != nil {
return "", err
}
// Check channel binding matches what we expect; in this case, we expect
// just the gs2 header we received as we don't support channel binding
// with a data payload. If we add binding, we need to independently
// compute the header to match here.
if string(msg.cbind) != sc.gs2Header {
return "e=channel-bindings-dont-match", fmt.Errorf("channel binding received '%s' doesn't match expected '%s'", msg.cbind, sc.gs2Header)
}
// Check nonce received matches what we sent
if msg.nonce != sc.nonce {
return "e=other-error", errors.New("nonce received did not match nonce sent")
}
// Create auth message
authMsg := sc.c1b + "," + sc.s1 + "," + msg.c2wop
// Retrieve ClientKey from proof and verify it
clientSignature := computeHMAC(sc.hashGen, sc.credential.StoredKey, []byte(authMsg))
clientKey := xorBytes([]byte(msg.proof), clientSignature)
storedKey := computeHash(sc.hashGen, clientKey)
// Compare with constant-time function
if !hmac.Equal(storedKey, sc.credential.StoredKey) {
return "e=invalid-proof", errors.New("challenge proof invalid")
}
sc.valid = true
// Compute and return server verifier
serverSignature := computeHMAC(sc.hashGen, sc.credential.ServerKey, []byte(authMsg))
return "v=" + base64.StdEncoding.EncodeToString(serverSignature), nil
}
View File
+10
View File
@@ -0,0 +1,10 @@
language: go
sudo: false
go:
- 1.7
- 1.8
- 1.9
- master
matrix:
allow_failures:
- go: master
+175
View File
@@ -0,0 +1,175 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
+27
View File
@@ -0,0 +1,27 @@
[![GoDoc](https://godoc.org/github.com/xdg/stringprep?status.svg)](https://godoc.org/github.com/xdg/stringprep)
[![Build Status](https://travis-ci.org/xdg/stringprep.svg?branch=master)](https://travis-ci.org/xdg/stringprep)
# stringprep  Go implementation of RFC-3454 stringprep and RFC-4013 SASLprep
## Synopsis
```
import "github.com/xdg/stringprep"
prepped := stringprep.SASLprep.Prepare("TrustNô1")
```
## Description
This library provides an implementation of the stringprep algorithm
(RFC-3454) in Go, including all data tables.
A pre-built SASLprep (RFC-4013) profile is provided as well.
## Copyright and License
Copyright 2018 by David A. Golden. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"). You may
obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+73
View File
@@ -0,0 +1,73 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package stringprep
var errHasLCat = "BiDi string can't have runes from category L"
var errFirstRune = "BiDi string first rune must have category R or AL"
var errLastRune = "BiDi string last rune must have category R or AL"
// Check for prohibited characters from table C.8
func checkBiDiProhibitedRune(s string) error {
for _, r := range s {
if TableC8.Contains(r) {
return Error{Msg: errProhibited, Rune: r}
}
}
return nil
}
// Check for LCat characters from table D.2
func checkBiDiLCat(s string) error {
for _, r := range s {
if TableD2.Contains(r) {
return Error{Msg: errHasLCat, Rune: r}
}
}
return nil
}
// Check first and last characters are in table D.1; requires non-empty string
func checkBadFirstAndLastRandALCat(s string) error {
rs := []rune(s)
if !TableD1.Contains(rs[0]) {
return Error{Msg: errFirstRune, Rune: rs[0]}
}
n := len(rs) - 1
if !TableD1.Contains(rs[n]) {
return Error{Msg: errLastRune, Rune: rs[n]}
}
return nil
}
// Look for RandALCat characters from table D.1
func hasBiDiRandALCat(s string) bool {
for _, r := range s {
if TableD1.Contains(r) {
return true
}
}
return false
}
// Check that BiDi rules are satisfied ; let empty string pass this rule
func passesBiDiRules(s string) error {
if len(s) == 0 {
return nil
}
if err := checkBiDiProhibitedRune(s); err != nil {
return err
}
if hasBiDiRandALCat(s) {
if err := checkBiDiLCat(s); err != nil {
return err
}
if err := checkBadFirstAndLastRandALCat(s); err != nil {
return err
}
}
return nil
}
+10
View File
@@ -0,0 +1,10 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package stringprep provides data tables and algorithms for RFC-3454,
// including errata (as of 2018-02). It also provides a profile for
// SASLprep as defined in RFC-4013.
package stringprep
+14
View File
@@ -0,0 +1,14 @@
package stringprep
import "fmt"
// Error describes problems encountered during stringprep, including what rune
// was problematic.
type Error struct {
Msg string
Rune rune
}
func (e Error) Error() string {
return fmt.Sprintf("%s (rune: '\\u%04x')", e.Msg, e.Rune)
}
+21
View File
@@ -0,0 +1,21 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package stringprep
// Mapping represents a stringprep mapping, from a single rune to zero or more
// runes.
type Mapping map[rune][]rune
// Map maps a rune to a (possibly empty) rune slice via a stringprep Mapping.
// The ok return value is false if the rune was not found.
func (m Mapping) Map(r rune) (replacement []rune, ok bool) {
rs, ok := m[r]
if !ok {
return nil, false
}
return rs, true
}
+75
View File
@@ -0,0 +1,75 @@
package stringprep
import (
"golang.org/x/text/unicode/norm"
)
// Profile represents a stringprep profile.
type Profile struct {
Mappings []Mapping
Normalize bool
Prohibits []Set
CheckBiDi bool
}
var errProhibited = "prohibited character"
// Prepare transforms an input string to an output string following
// the rules defined in the profile as defined by RFC-3454.
func (p Profile) Prepare(s string) (string, error) {
// Optimistically, assume output will be same length as input
temp := make([]rune, 0, len(s))
// Apply maps
for _, r := range s {
rs, ok := p.applyMaps(r)
if ok {
temp = append(temp, rs...)
} else {
temp = append(temp, r)
}
}
// Normalize
var out string
if p.Normalize {
out = norm.NFKC.String(string(temp))
} else {
out = string(temp)
}
// Check prohibited
for _, r := range out {
if p.runeIsProhibited(r) {
return "", Error{Msg: errProhibited, Rune: r}
}
}
// Check BiDi allowed
if p.CheckBiDi {
if err := passesBiDiRules(out); err != nil {
return "", err
}
}
return out, nil
}
func (p Profile) applyMaps(r rune) ([]rune, bool) {
for _, m := range p.Mappings {
rs, ok := m.Map(r)
if ok {
return rs, true
}
}
return nil, false
}
func (p Profile) runeIsProhibited(r rune) bool {
for _, s := range p.Prohibits {
if s.Contains(r) {
return true
}
}
return false
}
+52
View File
@@ -0,0 +1,52 @@
package stringprep
var mapNonASCIISpaceToASCIISpace = Mapping{
0x00A0: []rune{0x0020},
0x1680: []rune{0x0020},
0x2000: []rune{0x0020},
0x2001: []rune{0x0020},
0x2002: []rune{0x0020},
0x2003: []rune{0x0020},
0x2004: []rune{0x0020},
0x2005: []rune{0x0020},
0x2006: []rune{0x0020},
0x2007: []rune{0x0020},
0x2008: []rune{0x0020},
0x2009: []rune{0x0020},
0x200A: []rune{0x0020},
0x200B: []rune{0x0020},
0x202F: []rune{0x0020},
0x205F: []rune{0x0020},
0x3000: []rune{0x0020},
}
// SASLprep is a pre-defined stringprep profile for user names and passwords
// as described in RFC-4013.
//
// Because the stringprep distinction between query and stored strings was
// intended for compatibility across profile versions, but SASLprep was never
// updated and is now deprecated, this profile only operates in stored
// strings mode, prohibiting unassigned code points.
var SASLprep Profile = saslprep
var saslprep = Profile{
Mappings: []Mapping{
TableB1,
mapNonASCIISpaceToASCIISpace,
},
Normalize: true,
Prohibits: []Set{
TableA1,
TableC1_2,
TableC2_1,
TableC2_2,
TableC3,
TableC4,
TableC5,
TableC6,
TableC7,
TableC8,
TableC9,
},
CheckBiDi: true,
}
+36
View File
@@ -0,0 +1,36 @@
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package stringprep
import "sort"
// RuneRange represents a close-ended range of runes: [N,M]. For a range
// consisting of a single rune, N and M will be equal.
type RuneRange [2]rune
// Contains returns true if a rune is within the bounds of the RuneRange.
func (rr RuneRange) Contains(r rune) bool {
return rr[0] <= r && r <= rr[1]
}
func (rr RuneRange) isAbove(r rune) bool {
return r <= rr[0]
}
// Set represents a stringprep data table used to identify runes of a
// particular type.
type Set []RuneRange
// Contains returns true if a rune is within any of the RuneRanges in the
// Set.
func (s Set) Contains(r rune) bool {
i := sort.Search(len(s), func(i int) bool { return s[i].Contains(r) || s[i].isAbove(r) })
if i < len(s) && s[i].Contains(r) {
return true
}
return false
}
+3215
View File
File diff suppressed because it is too large Load Diff
+201
View File
@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+60
View File
@@ -0,0 +1,60 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
// +build go1.9
package bson // import "go.mongodb.org/mongo-driver/bson"
import (
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Zeroer allows custom struct types to implement a report of zero
// state. All struct types that don't implement Zeroer or where IsZero
// returns false are considered to be not zero.
type Zeroer interface {
IsZero() bool
}
// D represents a BSON Document. This type can be used to represent BSON in a concise and readable
// manner. It should generally be used when serializing to BSON. For deserializing, the Raw or
// Document types should be used.
//
// Example usage:
//
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
//
// This type should be used in situations where order matters, such as MongoDB commands. If the
// order is not important, a map is more comfortable and concise.
type D = primitive.D
// E represents a BSON element for a D. It is usually used inside a D.
type E = primitive.E
// M is an unordered, concise representation of a BSON Document. It should generally be used to
// serialize BSON when the order of the elements of a BSON document do not matter. If the element
// order matters, use a D instead.
//
// Example usage:
//
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
//
// This type is handled in the encoders as a regular map[string]interface{}. The elements will be
// serialized in an undefined, random order, and the order will be different each time.
type M = primitive.M
// An A represents a BSON array. This type can be used to represent a BSON array in a concise and
// readable manner. It should generally be used when serializing to BSON. For deserializing, the
// RawArray or Array types should be used.
//
// Example usage:
//
// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}}
//
type A = primitive.A
+91
View File
@@ -0,0 +1,91 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// +build !go1.9
package bson // import "go.mongodb.org/mongo-driver/bson"
import (
"math"
"strconv"
"strings"
)
// Zeroer allows custom struct types to implement a report of zero
// state. All struct types that don't implement Zeroer or where IsZero
// returns false are considered to be not zero.
type Zeroer interface {
IsZero() bool
}
// D represents a BSON Document. This type can be used to represent BSON in a concise and readable
// manner. It should generally be used when serializing to BSON. For deserializing, the Raw or
// Document types should be used.
//
// Example usage:
//
// primitive.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
//
// This type should be used in situations where order matters, such as MongoDB commands. If the
// order is not important, a map is more comfortable and concise.
type D []E
// Map creates a map from the elements of the D.
func (d D) Map() M {
m := make(M, len(d))
for _, e := range d {
m[e.Key] = e.Value
}
return m
}
// E represents a BSON element for a D. It is usually used inside a D.
type E struct {
Key string
Value interface{}
}
// M is an unordered, concise representation of a BSON Document. It should generally be used to
// serialize BSON when the order of the elements of a BSON document do not matter. If the element
// order matters, use a D instead.
//
// Example usage:
//
// primitive.M{"foo": "bar", "hello": "world", "pi": 3.14159}
//
// This type is handled in the encoders as a regular map[string]interface{}. The elements will be
// serialized in an undefined, random order, and the order will be different each time.
type M map[string]interface{}
// An A represents a BSON array. This type can be used to represent a BSON array in a concise and
// readable manner. It should generally be used when serializing to BSON. For deserializing, the
// RawArray or Array types should be used.
//
// Example usage:
//
// primitive.A{"bar", "world", 3.14159, primitive.D{{"qux", 12345}}}
//
type A []interface{}
func formatDouble(f float64) string {
var s string
if math.IsInf(f, 1) {
s = "Infinity"
} else if math.IsInf(f, -1) {
s = "-Infinity"
} else if math.IsNaN(f) {
s = "NaN"
} else {
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
+163
View File
@@ -0,0 +1,163 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec // import "go.mongodb.org/mongo-driver/bson/bsoncodec"
import (
"fmt"
"reflect"
"strings"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// Marshaler is an interface implemented by types that can marshal themselves
// into a BSON document represented as bytes. The bytes returned must be a valid
// BSON document if the error is nil.
type Marshaler interface {
MarshalBSON() ([]byte, error)
}
// ValueMarshaler is an interface implemented by types that can marshal
// themselves into a BSON value as bytes. The type must be the valid type for
// the bytes returned. The bytes and byte type together must be valid if the
// error is nil.
type ValueMarshaler interface {
MarshalBSONValue() (bsontype.Type, []byte, error)
}
// Unmarshaler is an interface implemented by types that can unmarshal a BSON
// document representation of themselves. The BSON bytes can be assumed to be
// valid. UnmarshalBSON must copy the BSON bytes if it wishes to retain the data
// after returning.
type Unmarshaler interface {
UnmarshalBSON([]byte) error
}
// ValueUnmarshaler is an interface implemented by types that can unmarshal a
// BSON value representaiton of themselves. The BSON bytes and type can be
// assumed to be valid. UnmarshalBSONValue must copy the BSON value bytes if it
// wishes to retain the data after returning.
type ValueUnmarshaler interface {
UnmarshalBSONValue(bsontype.Type, []byte) error
}
// ValueEncoderError is an error returned from a ValueEncoder when the provided value can't be
// encoded by the ValueEncoder.
type ValueEncoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vee ValueEncoderError) Error() string {
typeKinds := make([]string, 0, len(vee.Types)+len(vee.Kinds))
for _, t := range vee.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vee.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vee.Received.Kind().String()
if vee.Received.IsValid() {
received = vee.Received.Type().String()
}
return fmt.Sprintf("%s can only encode valid %s, but got %s", vee.Name, strings.Join(typeKinds, ", "), received)
}
// ValueDecoderError is an error returned from a ValueDecoder when the provided value can't be
// decoded by the ValueDecoder.
type ValueDecoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vde ValueDecoderError) Error() string {
typeKinds := make([]string, 0, len(vde.Types)+len(vde.Kinds))
for _, t := range vde.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vde.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vde.Received.Kind().String()
if vde.Received.IsValid() {
received = vde.Received.Type().String()
}
return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received)
}
// EncodeContext is the contextual information required for a Codec to encode a
// value.
type EncodeContext struct {
*Registry
MinSize bool
}
// DecodeContext is the contextual information required for a Codec to decode a
// value.
type DecodeContext struct {
*Registry
Truncate bool
// Ancestor is the type of a containing document. This is mainly used to determine what type
// should be used when decoding an embedded document into an empty interface. For example, if
// Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface
// will be decoded into a bson.M.
Ancestor reflect.Type
}
// ValueCodec is the interface that groups the methods to encode and decode
// values.
type ValueCodec interface {
ValueEncoder
ValueDecoder
}
// ValueEncoder is the interface implemented by types that can handle the encoding of a value.
type ValueEncoder interface {
EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error
}
// ValueEncoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueEncoder.
type ValueEncoderFunc func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error
// EncodeValue implements the ValueEncoder interface.
func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
return fn(ec, vw, val)
}
// ValueDecoder is the interface implemented by types that can handle the decoding of a value.
type ValueDecoder interface {
DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error
}
// ValueDecoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueDecoder.
type ValueDecoderFunc func(DecodeContext, bsonrw.ValueReader, reflect.Value) error
// DecodeValue implements the ValueDecoder interface.
func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
return fn(dc, vr, val)
}
// CodecZeroer is the interface implemented by Codecs that can also determine if
// a value of the type that would be encoded is zero.
type CodecZeroer interface {
IsTypeZero(interface{}) bool
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,648 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"encoding/json"
"errors"
"fmt"
"math"
"net/url"
"reflect"
"sync"
"time"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var defaultValueEncoders DefaultValueEncoders
var bvwPool = bsonrw.NewBSONValueWriterPool()
var sliceWriterPool = sync.Pool{
New: func() interface{} {
sw := make(bsonrw.SliceWriter, 0, 0)
return &sw
},
}
func encodeElement(ec EncodeContext, dw bsonrw.DocumentWriter, e primitive.E) error {
vw, err := dw.WriteDocumentElement(e.Key)
if err != nil {
return err
}
if e.Value == nil {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value))
if err != nil {
return err
}
return nil
}
// DefaultValueEncoders is a namespace type for the default ValueEncoders used
// when creating a registry.
type DefaultValueEncoders struct{}
// RegisterDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with
// the provided RegistryBuilder.
func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) {
if rb == nil {
panic(errors.New("argument to RegisterDefaultEncoders must not be nil"))
}
rb.
RegisterEncoder(tByteSlice, ValueEncoderFunc(dve.ByteSliceEncodeValue)).
RegisterEncoder(tTime, ValueEncoderFunc(dve.TimeEncodeValue)).
RegisterEncoder(tEmpty, ValueEncoderFunc(dve.EmptyInterfaceEncodeValue)).
RegisterEncoder(tOID, ValueEncoderFunc(dve.ObjectIDEncodeValue)).
RegisterEncoder(tDecimal, ValueEncoderFunc(dve.Decimal128EncodeValue)).
RegisterEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)).
RegisterEncoder(tURL, ValueEncoderFunc(dve.URLEncodeValue)).
RegisterEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)).
RegisterEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)).
RegisterEncoder(tProxy, ValueEncoderFunc(dve.ProxyEncodeValue)).
RegisterEncoder(tJavaScript, ValueEncoderFunc(dve.JavaScriptEncodeValue)).
RegisterEncoder(tSymbol, ValueEncoderFunc(dve.SymbolEncodeValue)).
RegisterEncoder(tBinary, ValueEncoderFunc(dve.BinaryEncodeValue)).
RegisterEncoder(tUndefined, ValueEncoderFunc(dve.UndefinedEncodeValue)).
RegisterEncoder(tDateTime, ValueEncoderFunc(dve.DateTimeEncodeValue)).
RegisterEncoder(tNull, ValueEncoderFunc(dve.NullEncodeValue)).
RegisterEncoder(tRegex, ValueEncoderFunc(dve.RegexEncodeValue)).
RegisterEncoder(tDBPointer, ValueEncoderFunc(dve.DBPointerEncodeValue)).
RegisterEncoder(tTimestamp, ValueEncoderFunc(dve.TimestampEncodeValue)).
RegisterEncoder(tMinKey, ValueEncoderFunc(dve.MinKeyEncodeValue)).
RegisterEncoder(tMaxKey, ValueEncoderFunc(dve.MaxKeyEncodeValue)).
RegisterEncoder(tCoreDocument, ValueEncoderFunc(dve.CoreDocumentEncodeValue)).
RegisterEncoder(tCodeWithScope, ValueEncoderFunc(dve.CodeWithScopeEncodeValue)).
RegisterDefaultEncoder(reflect.Bool, ValueEncoderFunc(dve.BooleanEncodeValue)).
RegisterDefaultEncoder(reflect.Int, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int8, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int16, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int32, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int64, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Uint, ValueEncoderFunc(dve.UintEncodeValue)).
RegisterDefaultEncoder(reflect.Uint8, ValueEncoderFunc(dve.UintEncodeValue)).
RegisterDefaultEncoder(reflect.Uint16, ValueEncoderFunc(dve.UintEncodeValue)).
RegisterDefaultEncoder(reflect.Uint32, ValueEncoderFunc(dve.UintEncodeValue)).
RegisterDefaultEncoder(reflect.Uint64, ValueEncoderFunc(dve.UintEncodeValue)).
RegisterDefaultEncoder(reflect.Float32, ValueEncoderFunc(dve.FloatEncodeValue)).
RegisterDefaultEncoder(reflect.Float64, ValueEncoderFunc(dve.FloatEncodeValue)).
RegisterDefaultEncoder(reflect.Array, ValueEncoderFunc(dve.ArrayEncodeValue)).
RegisterDefaultEncoder(reflect.Map, ValueEncoderFunc(dve.MapEncodeValue)).
RegisterDefaultEncoder(reflect.Slice, ValueEncoderFunc(dve.SliceEncodeValue)).
RegisterDefaultEncoder(reflect.String, ValueEncoderFunc(dve.StringEncodeValue)).
RegisterDefaultEncoder(reflect.Struct, &StructCodec{cache: make(map[reflect.Type]*structDescription), parser: DefaultStructTagParser}).
RegisterDefaultEncoder(reflect.Ptr, NewPointerCodec())
}
// BooleanEncodeValue is the ValueEncoderFunc for bool types.
func (dve DefaultValueEncoders) BooleanEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Bool {
return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val}
}
return vw.WriteBoolean(val.Bool())
}
func fitsIn32Bits(i int64) bool {
return math.MinInt32 <= i && i <= math.MaxInt32
}
// IntEncodeValue is the ValueEncoderFunc for int types.
func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32:
return vw.WriteInt32(int32(val.Int()))
case reflect.Int:
i64 := val.Int()
if fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
case reflect.Int64:
i64 := val.Int()
if ec.MinSize && fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
}
return ValueEncoderError{
Name: "IntEncodeValue",
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
Received: val,
}
}
// UintEncodeValue is the ValueEncoderFunc for uint types.
func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Uint8, reflect.Uint16:
return vw.WriteInt32(int32(val.Uint()))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
u64 := val.Uint()
if ec.MinSize && u64 <= math.MaxInt32 {
return vw.WriteInt32(int32(u64))
}
if u64 > math.MaxInt64 {
return fmt.Errorf("%d overflows int64", u64)
}
return vw.WriteInt64(int64(u64))
}
return ValueEncoderError{
Name: "UintEncodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: val,
}
}
// FloatEncodeValue is the ValueEncoderFunc for float types.
func (dve DefaultValueEncoders) FloatEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Float32, reflect.Float64:
return vw.WriteDouble(val.Float())
}
return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val}
}
// StringEncodeValue is the ValueEncoderFunc for string types.
func (dve DefaultValueEncoders) StringEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{
Name: "StringEncodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: val,
}
}
return vw.WriteString(val.String())
}
// ObjectIDEncodeValue is the ValueEncoderFunc for primitive.ObjectID.
func (dve DefaultValueEncoders) ObjectIDEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tOID {
return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val}
}
return vw.WriteObjectID(val.Interface().(primitive.ObjectID))
}
// Decimal128EncodeValue is the ValueEncoderFunc for primitive.Decimal128.
func (dve DefaultValueEncoders) Decimal128EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDecimal {
return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val}
}
return vw.WriteDecimal128(val.Interface().(primitive.Decimal128))
}
// JSONNumberEncodeValue is the ValueEncoderFunc for json.Number.
func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJSONNumber {
return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val}
}
jsnum := val.Interface().(json.Number)
// Attempt int first, then float64
if i64, err := jsnum.Int64(); err == nil {
return dve.IntEncodeValue(ec, vw, reflect.ValueOf(i64))
}
f64, err := jsnum.Float64()
if err != nil {
return err
}
return dve.FloatEncodeValue(ec, vw, reflect.ValueOf(f64))
}
// URLEncodeValue is the ValueEncoderFunc for url.URL.
func (dve DefaultValueEncoders) URLEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tURL {
return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val}
}
u := val.Interface().(url.URL)
return vw.WriteString(u.String())
}
// TimeEncodeValue is the ValueEncoderFunc for time.TIme.
func (dve DefaultValueEncoders) TimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTime {
return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val}
}
tt := val.Interface().(time.Time)
return vw.WriteDateTime(tt.Unix()*1000 + int64(tt.Nanosecond()/1e6))
}
// ByteSliceEncodeValue is the ValueEncoderFunc for []byte.
func (dve DefaultValueEncoders) ByteSliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tByteSlice {
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
return vw.WriteBinary(val.Interface().([]byte))
}
// MapEncodeValue is the ValueEncoderFunc for map[string]* types.
func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String {
return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
if val.IsNil() {
// If we have a nill map but we can't WriteNull, that means we're probably trying to encode
// to a TopLevel document. We can't currently tell if this is what actually happened, but if
// there's a deeper underlying problem, the error will also be returned from WriteDocument,
// so just continue. The operations on a map reflection value are valid, so we can call
// MapKeys within mapEncodeValue without a problem.
err := vw.WriteNull()
if err == nil {
return nil
}
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
return dve.mapEncodeValue(ec, dw, val, nil)
}
// mapEncodeValue handles encoding of the values of a map. The collisionFn returns
// true if the provided key exists, this is mainly used for inline maps in the
// struct codec.
func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
encoder, err := ec.LookupEncoder(val.Type().Elem())
if err != nil {
return err
}
keys := val.MapKeys()
for _, key := range keys {
if collisionFn != nil && collisionFn(key.String()) {
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
}
vw, err := dw.WriteDocumentElement(key.String())
if err != nil {
return err
}
if enc, ok := encoder.(ValueEncoder); ok {
err = enc.EncodeValue(ec, vw, val.MapIndex(key))
if err != nil {
return err
}
continue
}
err = encoder.EncodeValue(ec, vw, val.MapIndex(key))
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// ArrayEncodeValue is the ValueEncoderFunc for array types.
func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Array {
return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val}
}
// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().Elem() == tE {
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for idx := 0; idx < val.Len(); idx++ {
e := val.Index(idx).Interface().(primitive.E)
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
encoder, err := ec.LookupEncoder(val.Type().Elem())
if err != nil {
return err
}
for idx := 0; idx < val.Len(); idx++ {
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = encoder.EncodeValue(ec, vw, val.Index(idx))
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// SliceEncodeValue is the ValueEncoderFunc for slice types.
func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Slice {
return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().ConvertibleTo(tD) {
d := val.Convert(tD).Interface().(primitive.D)
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for _, e := range d {
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
encoder, err := ec.LookupEncoder(val.Type().Elem())
if err != nil {
return err
}
for idx := 0; idx < val.Len(); idx++ {
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = encoder.EncodeValue(ec, vw, val.Index(idx))
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}.
func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tEmpty {
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(val.Elem().Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, val.Elem())
}
// ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations.
func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || !val.Type().Implements(tValueMarshaler) {
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
}
fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue")
returns := fn.Call(nil)
if !returns[2].IsNil() {
return returns[2].Interface().(error)
}
t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data)
}
// MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations.
func (dve DefaultValueEncoders) MarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || !val.Type().Implements(tMarshaler) {
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
}
fn := val.Convert(tMarshaler).MethodByName("MarshalBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
}
data := returns[0].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data)
}
// ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations.
func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || !val.Type().Implements(tProxy) {
return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val}
}
fn := val.Convert(tProxy).MethodByName("ProxyBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
}
data := returns[0]
var encoder ValueEncoder
var err error
if data.Elem().IsValid() {
encoder, err = ec.LookupEncoder(data.Elem().Type())
} else {
encoder, err = ec.LookupEncoder(nil)
}
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, data.Elem())
}
// JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type.
func (DefaultValueEncoders) JavaScriptEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJavaScript {
return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val}
}
return vw.WriteJavascript(val.String())
}
// SymbolEncodeValue is the ValueEncoderFunc for the primitive.Symbol type.
func (DefaultValueEncoders) SymbolEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tSymbol {
return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val}
}
return vw.WriteSymbol(val.String())
}
// BinaryEncodeValue is the ValueEncoderFunc for Binary.
func (DefaultValueEncoders) BinaryEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tBinary {
return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val}
}
b := val.Interface().(primitive.Binary)
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}
// UndefinedEncodeValue is the ValueEncoderFunc for Undefined.
func (DefaultValueEncoders) UndefinedEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tUndefined {
return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val}
}
return vw.WriteUndefined()
}
// DateTimeEncodeValue is the ValueEncoderFunc for DateTime.
func (DefaultValueEncoders) DateTimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDateTime {
return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val}
}
return vw.WriteDateTime(val.Int())
}
// NullEncodeValue is the ValueEncoderFunc for Null.
func (DefaultValueEncoders) NullEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tNull {
return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val}
}
return vw.WriteNull()
}
// RegexEncodeValue is the ValueEncoderFunc for Regex.
func (DefaultValueEncoders) RegexEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRegex {
return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val}
}
regex := val.Interface().(primitive.Regex)
return vw.WriteRegex(regex.Pattern, regex.Options)
}
// DBPointerEncodeValue is the ValueEncoderFunc for DBPointer.
func (DefaultValueEncoders) DBPointerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDBPointer {
return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val}
}
dbp := val.Interface().(primitive.DBPointer)
return vw.WriteDBPointer(dbp.DB, dbp.Pointer)
}
// TimestampEncodeValue is the ValueEncoderFunc for Timestamp.
func (DefaultValueEncoders) TimestampEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTimestamp {
return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val}
}
ts := val.Interface().(primitive.Timestamp)
return vw.WriteTimestamp(ts.T, ts.I)
}
// MinKeyEncodeValue is the ValueEncoderFunc for MinKey.
func (DefaultValueEncoders) MinKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMinKey {
return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val}
}
return vw.WriteMinKey()
}
// MaxKeyEncodeValue is the ValueEncoderFunc for MaxKey.
func (DefaultValueEncoders) MaxKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMaxKey {
return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val}
}
return vw.WriteMaxKey()
}
// CoreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document.
func (DefaultValueEncoders) CoreDocumentEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreDocument {
return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val}
}
cdoc := val.Interface().(bsoncore.Document)
return bsonrw.Copier{}.CopyDocumentFromBytes(vw, cdoc)
}
// CodeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope.
func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCodeWithScope {
return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val}
}
cws := val.Interface().(primitive.CodeWithScope)
dw, err := vw.WriteCodeWithScope(string(cws.Code))
if err != nil {
return err
}
sw := sliceWriterPool.Get().(*bsonrw.SliceWriter)
defer sliceWriterPool.Put(sw)
*sw = (*sw)[:0]
scopeVW := bvwPool.Get(sw)
defer bvwPool.Put(scopeVW)
encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope))
if err != nil {
return err
}
err = bsonrw.Copier{}.CopyBytesToDocumentWriter(dw, *sw)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
+61
View File
@@ -0,0 +1,61 @@
// Package bsoncodec provides a system for encoding values to BSON representations and decoding
// values from BSON representations. This package considers both binary BSON and ExtendedJSON as
// BSON representations. The types in this package enable a flexible system for handling this
// encoding and decoding.
//
// The codec system is composed of two parts:
//
// 1) ValueEncoders and ValueDecoders that handle encoding and decoding Go values to and from BSON
// representations.
//
// 2) A Registry that holds these ValueEncoders and ValueDecoders and provides methods for
// retrieving them.
//
// ValueEncoders and ValueDecoders
//
// The ValueEncoder interface is implemented by types that can encode a provided Go type to BSON.
// The value to encode is provided as a reflect.Value and a bsonrw.ValueWriter is used within the
// EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc
// is provided to allow use of a function with the correct signature as a ValueEncoder. An
// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and
// to provide configuration information.
//
// The ValueDecoder interface is the inverse of the ValueEncoder. Implementations should ensure that
// the value they receive is settable. Similar to ValueEncoderFunc, ValueDecoderFunc is provided to
// allow the use of a function with the correct signature as a ValueDecoder. A DecodeContext
// instance is provided and serves similar functionality to the EncodeContext.
//
// Registry and RegistryBuilder
//
// A Registry is an immutable store for ValueEncoders, ValueDecoders, and a type map. For looking up
// ValueEncoders and Decoders the Registry first attempts to find a ValueEncoder or ValueDecoder for
// the type provided; if one cannot be found it then checks to see if a registered ValueEncoder or
// ValueDecoder exists for an interface the type implements. Finally, the reflect.Kind of the type
// is used to lookup a default ValueEncoder or ValueDecoder for that kind. If no ValueEncoder or
// ValueDecoder can be found, an error is returned.
//
// The Registry also holds a type map. This allows users to retrieve the Go type that should be used
// when decoding a BSON value into an empty interface. This is primarily only used for the empty
// interface ValueDecoder.
//
// A RegistryBuilder is used to construct a Registry. The Register methods are used to associate
// either a reflect.Type or a reflect.Kind with a ValueEncoder or ValueDecoder. A RegistryBuilder
// returned from NewRegistryBuilder contains no registered ValueEncoders nor ValueDecoders and
// contains an empty type map.
//
// The RegisterTypeMapEntry method handles associating a BSON type with a Go type. For example, if
// you want to decode BSON int64 and int32 values into Go int instances, you would do the following:
//
// var regbuilder *RegistryBuilder = ... intType := reflect.TypeOf(int(0))
// regbuilder.RegisterTypeMapEntry(bsontype.Int64, intType).RegisterTypeMapEntry(bsontype.Int32,
// intType)
//
// DefaultValueEncoders and DefaultValueDecoders
//
// The DefaultValueEncoders and DefaultValueDecoders types provide a full set of ValueEncoders and
// ValueDecoders for handling a wide range of Go types, including all of the types within the
// primitive package. To make registering these codecs easier, a helper method on each type is
// provided. For the DefaultValueEncoders type the method is called RegisterDefaultEncoders and for
// the DefaultValueDecoders type the method is called RegisterDefaultDecoders, this method also
// handles registering type map entries for each BSON type.
package bsoncodec
+65
View File
@@ -0,0 +1,65 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import "fmt"
type mode int
const (
_ mode = iota
mTopLevel
mDocument
mArray
mValue
mElement
mCodeWithScope
mSpacer
)
func (m mode) String() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "DocumentMode"
case mArray:
str = "ArrayMode"
case mValue:
str = "ValueMode"
case mElement:
str = "ElementMode"
case mCodeWithScope:
str = "CodeWithScopeMode"
case mSpacer:
str = "CodeWithScopeSpacerFrame"
default:
str = "UnknownMode"
}
return str
}
// TransitionError is an error returned when an invalid progressing a
// ValueReader or ValueWriter state machine occurs.
type TransitionError struct {
parent mode
current mode
destination mode
}
func (te TransitionError) Error() string {
if te.destination == mode(0) {
return fmt.Sprintf("invalid state transition: cannot read/write value while in %s", te.current)
}
if te.parent == mode(0) {
return fmt.Sprintf("invalid state transition: %s -> %s", te.current, te.destination)
}
return fmt.Sprintf("invalid state transition: %s -> %s; parent %s", te.current, te.destination, te.parent)
}
+110
View File
@@ -0,0 +1,110 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"sync"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
var defaultPointerCodec = &PointerCodec{
ecache: make(map[reflect.Type]ValueEncoder),
dcache: make(map[reflect.Type]ValueDecoder),
}
var _ ValueEncoder = &PointerCodec{}
var _ ValueDecoder = &PointerCodec{}
// PointerCodec is the Codec used for pointers.
type PointerCodec struct {
ecache map[reflect.Type]ValueEncoder
dcache map[reflect.Type]ValueDecoder
l sync.RWMutex
}
// NewPointerCodec returns a PointerCodec that has been initialized.
func NewPointerCodec() *PointerCodec {
return &PointerCodec{
ecache: make(map[reflect.Type]ValueEncoder),
dcache: make(map[reflect.Type]ValueDecoder),
}
}
// EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil
// or looking up an encoder for the type of value the pointer points to.
func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.Ptr {
if !val.IsValid() {
return vw.WriteNull()
}
return ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
pc.l.RLock()
enc, ok := pc.ecache[val.Type()]
pc.l.RUnlock()
if ok {
if enc == nil {
return ErrNoEncoder{Type: val.Type()}
}
return enc.EncodeValue(ec, vw, val.Elem())
}
enc, err := ec.LookupEncoder(val.Type().Elem())
pc.l.Lock()
pc.ecache[val.Type()] = enc
pc.l.Unlock()
if err != nil {
return err
}
return enc.EncodeValue(ec, vw, val.Elem())
}
// DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and
// using that to decode. If the BSON value is Null, this method will set the pointer to nil.
func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Ptr {
return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
if vr.Type() == bsontype.Null {
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
}
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
pc.l.RLock()
dec, ok := pc.dcache[val.Type()]
pc.l.RUnlock()
if ok {
if dec == nil {
return ErrNoDecoder{Type: val.Type()}
}
return dec.DecodeValue(dc, vr, val.Elem())
}
dec, err := dc.LookupDecoder(val.Type().Elem())
pc.l.Lock()
pc.dcache[val.Type()] = dec
pc.l.Unlock()
if err != nil {
return err
}
return dec.DecodeValue(dc, vr, val.Elem())
}
+14
View File
@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
// Proxy is an interface implemented by types that cannot themselves be directly encoded. Types
// that implement this interface with have ProxyBSON called during the encoding process and that
// value will be encoded in place for the implementer.
type Proxy interface {
ProxyBSON() (interface{}, error)
}
+384
View File
@@ -0,0 +1,384 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"errors"
"reflect"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder.
var ErrNilType = errors.New("cannot perform a decoder lookup on <nil>")
// ErrNotPointer is returned when a non-pointer type is provided to LookupDecoder.
var ErrNotPointer = errors.New("non-pointer provided to LookupDecoder")
// ErrNoEncoder is returned when there wasn't an encoder available for a type.
type ErrNoEncoder struct {
Type reflect.Type
}
func (ene ErrNoEncoder) Error() string {
if ene.Type == nil {
return "no encoder found for <nil>"
}
return "no encoder found for " + ene.Type.String()
}
// ErrNoDecoder is returned when there wasn't a decoder available for a type.
type ErrNoDecoder struct {
Type reflect.Type
}
func (end ErrNoDecoder) Error() string {
return "no decoder found for " + end.Type.String()
}
// ErrNoTypeMapEntry is returned when there wasn't a type available for the provided BSON type.
type ErrNoTypeMapEntry struct {
Type bsontype.Type
}
func (entme ErrNoTypeMapEntry) Error() string {
return "no type map entry found for " + entme.Type.String()
}
// ErrNotInterface is returned when the provided type is not an interface.
var ErrNotInterface = errors.New("The provided type is not an interface")
var defaultRegistry *Registry
func init() {
defaultRegistry = buildDefaultRegistry()
}
// A RegistryBuilder is used to build a Registry. This type is not goroutine
// safe.
type RegistryBuilder struct {
typeEncoders map[reflect.Type]ValueEncoder
interfaceEncoders []interfaceValueEncoder
kindEncoders map[reflect.Kind]ValueEncoder
typeDecoders map[reflect.Type]ValueDecoder
interfaceDecoders []interfaceValueDecoder
kindDecoders map[reflect.Kind]ValueDecoder
typeMap map[bsontype.Type]reflect.Type
}
// A Registry is used to store and retrieve codecs for types and interfaces. This type is the main
// typed passed around and Encoders and Decoders are constructed from it.
type Registry struct {
typeEncoders map[reflect.Type]ValueEncoder
typeDecoders map[reflect.Type]ValueDecoder
interfaceEncoders []interfaceValueEncoder
interfaceDecoders []interfaceValueDecoder
kindEncoders map[reflect.Kind]ValueEncoder
kindDecoders map[reflect.Kind]ValueDecoder
typeMap map[bsontype.Type]reflect.Type
mu sync.RWMutex
}
// NewRegistryBuilder creates a new empty RegistryBuilder.
func NewRegistryBuilder() *RegistryBuilder {
return &RegistryBuilder{
typeEncoders: make(map[reflect.Type]ValueEncoder),
typeDecoders: make(map[reflect.Type]ValueDecoder),
interfaceEncoders: make([]interfaceValueEncoder, 0),
interfaceDecoders: make([]interfaceValueDecoder, 0),
kindEncoders: make(map[reflect.Kind]ValueEncoder),
kindDecoders: make(map[reflect.Kind]ValueDecoder),
typeMap: make(map[bsontype.Type]reflect.Type),
}
}
func buildDefaultRegistry() *Registry {
rb := NewRegistryBuilder()
defaultValueEncoders.RegisterDefaultEncoders(rb)
defaultValueDecoders.RegisterDefaultDecoders(rb)
return rb.Build()
}
// RegisterCodec will register the provided ValueCodec for the provided type.
func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *RegistryBuilder {
rb.RegisterEncoder(t, codec)
rb.RegisterDecoder(t, codec)
return rb
}
// RegisterEncoder will register the provided ValueEncoder to the provided type.
//
// The type registered will be used directly, so an encoder can be registered for a type and a
// different encoder can be registered for a pointer to that type.
func (rb *RegistryBuilder) RegisterEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder {
if t == tEmpty {
rb.typeEncoders[t] = enc
return rb
}
switch t.Kind() {
case reflect.Interface:
for idx, ir := range rb.interfaceEncoders {
if ir.i == t {
rb.interfaceEncoders[idx].ve = enc
return rb
}
}
rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc})
default:
rb.typeEncoders[t] = enc
}
return rb
}
// RegisterDecoder will register the provided ValueDecoder to the provided type.
//
// The type registered will be used directly, so a decoder can be registered for a type and a
// different decoder can be registered for a pointer to that type.
func (rb *RegistryBuilder) RegisterDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder {
if t == nil {
rb.typeDecoders[nil] = dec
return rb
}
if t == tEmpty {
rb.typeDecoders[t] = dec
return rb
}
switch t.Kind() {
case reflect.Interface:
for idx, ir := range rb.interfaceDecoders {
if ir.i == t {
rb.interfaceDecoders[idx].vd = dec
return rb
}
}
rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec})
default:
rb.typeDecoders[t] = dec
}
return rb
}
// RegisterDefaultEncoder will registr the provided ValueEncoder to the provided
// kind.
func (rb *RegistryBuilder) RegisterDefaultEncoder(kind reflect.Kind, enc ValueEncoder) *RegistryBuilder {
rb.kindEncoders[kind] = enc
return rb
}
// RegisterDefaultDecoder will register the provided ValueDecoder to the
// provided kind.
func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder {
rb.kindDecoders[kind] = dec
return rb
}
// RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this
// mapping is decoding situations where an empty interface is used and a default type needs to be
// created and decoded into.
//
// NOTE: It is unlikely that registering a type for BSON Embedded Document is actually desired. By
// registering a type map entry for BSON Embedded Document the type registered will be used in any
// case where a BSON Embedded Document will be decoded into an empty interface. For example, if you
// register primitive.M, the EmptyInterface decoder will always use primitive.M, even if an ancestor
// was a primitive.D.
func (rb *RegistryBuilder) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) *RegistryBuilder {
rb.typeMap[bt] = rt
return rb
}
// Build creates a Registry from the current state of this RegistryBuilder.
func (rb *RegistryBuilder) Build() *Registry {
registry := new(Registry)
registry.typeEncoders = make(map[reflect.Type]ValueEncoder)
for t, enc := range rb.typeEncoders {
registry.typeEncoders[t] = enc
}
registry.typeDecoders = make(map[reflect.Type]ValueDecoder)
for t, dec := range rb.typeDecoders {
registry.typeDecoders[t] = dec
}
registry.interfaceEncoders = make([]interfaceValueEncoder, len(rb.interfaceEncoders))
copy(registry.interfaceEncoders, rb.interfaceEncoders)
registry.interfaceDecoders = make([]interfaceValueDecoder, len(rb.interfaceDecoders))
copy(registry.interfaceDecoders, rb.interfaceDecoders)
registry.kindEncoders = make(map[reflect.Kind]ValueEncoder)
for kind, enc := range rb.kindEncoders {
registry.kindEncoders[kind] = enc
}
registry.kindDecoders = make(map[reflect.Kind]ValueDecoder)
for kind, dec := range rb.kindDecoders {
registry.kindDecoders[kind] = dec
}
registry.typeMap = make(map[bsontype.Type]reflect.Type)
for bt, rt := range rb.typeMap {
registry.typeMap[bt] = rt
}
return registry
}
// LookupEncoder will inspect the registry for an encoder that satisfies the
// type provided. An encoder registered for a specific type will take
// precedence over an encoder registered for an interface the type satisfies,
// which takes precedence over an encoder for the reflect.Kind of the value. If
// no encoder can be found, an error is returned.
func (r *Registry) LookupEncoder(t reflect.Type) (ValueEncoder, error) {
encodererr := ErrNoEncoder{Type: t}
r.mu.RLock()
enc, found := r.lookupTypeEncoder(t)
r.mu.RUnlock()
if found {
if enc == nil {
return nil, ErrNoEncoder{Type: t}
}
return enc, nil
}
enc, found = r.lookupInterfaceEncoder(t)
if found {
r.mu.Lock()
r.typeEncoders[t] = enc
r.mu.Unlock()
return enc, nil
}
if t == nil {
r.mu.Lock()
r.typeEncoders[t] = nil
r.mu.Unlock()
return nil, encodererr
}
enc, found = r.kindEncoders[t.Kind()]
if !found {
r.mu.Lock()
r.typeEncoders[t] = nil
r.mu.Unlock()
return nil, encodererr
}
r.mu.Lock()
r.typeEncoders[t] = enc
r.mu.Unlock()
return enc, nil
}
func (r *Registry) lookupTypeEncoder(t reflect.Type) (ValueEncoder, bool) {
enc, found := r.typeEncoders[t]
return enc, found
}
func (r *Registry) lookupInterfaceEncoder(t reflect.Type) (ValueEncoder, bool) {
if t == nil {
return nil, false
}
for _, ienc := range r.interfaceEncoders {
if !t.Implements(ienc.i) {
continue
}
return ienc.ve, true
}
return nil, false
}
// LookupDecoder will inspect the registry for a decoder that satisfies the
// type provided. A decoder registered for a specific type will take
// precedence over a decoder registered for an interface the type satisfies,
// which takes precedence over a decoder for the reflect.Kind of the value. If
// no decoder can be found, an error is returned.
func (r *Registry) LookupDecoder(t reflect.Type) (ValueDecoder, error) {
if t == nil {
return nil, ErrNilType
}
decodererr := ErrNoDecoder{Type: t}
r.mu.RLock()
dec, found := r.lookupTypeDecoder(t)
r.mu.RUnlock()
if found {
if dec == nil {
return nil, ErrNoDecoder{Type: t}
}
return dec, nil
}
dec, found = r.lookupInterfaceDecoder(t)
if found {
r.mu.Lock()
r.typeDecoders[t] = dec
r.mu.Unlock()
return dec, nil
}
dec, found = r.kindDecoders[t.Kind()]
if !found {
r.mu.Lock()
r.typeDecoders[t] = nil
r.mu.Unlock()
return nil, decodererr
}
r.mu.Lock()
r.typeDecoders[t] = dec
r.mu.Unlock()
return dec, nil
}
func (r *Registry) lookupTypeDecoder(t reflect.Type) (ValueDecoder, bool) {
dec, found := r.typeDecoders[t]
return dec, found
}
func (r *Registry) lookupInterfaceDecoder(t reflect.Type) (ValueDecoder, bool) {
for _, idec := range r.interfaceDecoders {
if !t.Implements(idec.i) && !reflect.PtrTo(t).Implements(idec.i) {
continue
}
return idec.vd, true
}
return nil, false
}
// LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON
// type. If no type is found, ErrNoTypeMapEntry is returned.
func (r *Registry) LookupTypeMapEntry(bt bsontype.Type) (reflect.Type, error) {
t, ok := r.typeMap[bt]
if !ok || t == nil {
return nil, ErrNoTypeMapEntry{Type: bt}
}
return t, nil
}
type interfaceValueEncoder struct {
i reflect.Type
ve ValueEncoder
}
type interfaceValueDecoder struct {
i reflect.Type
vd ValueDecoder
}
+359
View File
@@ -0,0 +1,359 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"errors"
"fmt"
"reflect"
"sync"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
var defaultStructCodec = &StructCodec{
cache: make(map[reflect.Type]*structDescription),
parser: DefaultStructTagParser,
}
// Zeroer allows custom struct types to implement a report of zero
// state. All struct types that don't implement Zeroer or where IsZero
// returns false are considered to be not zero.
type Zeroer interface {
IsZero() bool
}
// StructCodec is the Codec used for struct values.
type StructCodec struct {
cache map[reflect.Type]*structDescription
l sync.RWMutex
parser StructTagParser
}
var _ ValueEncoder = &StructCodec{}
var _ ValueDecoder = &StructCodec{}
// NewStructCodec returns a StructCodec that uses p for struct tag parsing.
func NewStructCodec(p StructTagParser) (*StructCodec, error) {
if p == nil {
return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
}
return &StructCodec{
cache: make(map[reflect.Type]*structDescription),
parser: p,
}, nil
}
// EncodeValue handles encoding generic struct types.
func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Struct {
return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
sd, err := sc.describeStruct(r.Registry, val.Type())
if err != nil {
return err
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
var rv reflect.Value
for _, desc := range sd.fl {
if desc.inline == nil {
rv = val.Field(desc.idx)
} else {
rv = val.FieldByIndex(desc.inline)
}
if desc.encoder == nil {
return ErrNoEncoder{Type: rv.Type()}
}
encoder := desc.encoder
iszero := sc.isZero
if iz, ok := encoder.(CodecZeroer); ok {
iszero = iz.IsTypeZero
}
if desc.omitEmpty && iszero(rv.Interface()) {
continue
}
vw2, err := dw.WriteDocumentElement(desc.name)
if err != nil {
return err
}
ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
err = encoder.EncodeValue(ectx, vw2, rv)
if err != nil {
return err
}
}
if sd.inlineMap >= 0 {
rv := val.Field(sd.inlineMap)
collisionFn := func(key string) bool {
_, exists := sd.fm[key]
return exists
}
return defaultValueEncoders.mapEncodeValue(r, dw, rv, collisionFn)
}
return dw.WriteDocumentEnd()
}
// DecodeValue implements the Codec interface.
// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Struct {
return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
switch vr.Type() {
case bsontype.Type(0), bsontype.EmbeddedDocument:
default:
return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type())
}
sd, err := sc.describeStruct(r.Registry, val.Type())
if err != nil {
return err
}
var decoder ValueDecoder
var inlineMap reflect.Value
if sd.inlineMap >= 0 {
inlineMap = val.Field(sd.inlineMap)
if inlineMap.IsNil() {
inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
}
decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
if err != nil {
return err
}
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
for {
name, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
break
}
if err != nil {
return err
}
fd, exists := sd.fm[name]
if !exists {
if sd.inlineMap < 0 {
// The encoding/json package requires a flag to return on error for non-existent fields.
// This functionality seems appropriate for the struct codec.
err = vr.Skip()
if err != nil {
return err
}
continue
}
elem := reflect.New(inlineMap.Type().Elem()).Elem()
err = decoder.DecodeValue(r, vr, elem)
if err != nil {
return err
}
inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
continue
}
var field reflect.Value
if fd.inline == nil {
field = val.Field(fd.idx)
} else {
field = val.FieldByIndex(fd.inline)
}
if !field.CanSet() { // Being settable is a super set of being addressable.
return fmt.Errorf("cannot decode element '%s' into field %v; it is not settable", name, field)
}
if field.Kind() == reflect.Ptr && field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Addr()
dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate}
if fd.decoder == nil {
return ErrNoDecoder{Type: field.Elem().Type()}
}
if decoder, ok := fd.decoder.(ValueDecoder); ok {
err = decoder.DecodeValue(dctx, vr, field.Elem())
if err != nil {
return err
}
continue
}
err = fd.decoder.DecodeValue(dctx, vr, field)
if err != nil {
return err
}
}
return nil
}
func (sc *StructCodec) isZero(i interface{}) bool {
v := reflect.ValueOf(i)
// check the value validity
if !v.IsValid() {
return true
}
if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
return z.IsZero()
}
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
}
return false
}
type structDescription struct {
fm map[string]fieldDescription
fl []fieldDescription
inlineMap int
}
type fieldDescription struct {
name string
idx int
omitEmpty bool
minSize bool
truncate bool
inline []int
encoder ValueEncoder
decoder ValueDecoder
}
func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
// We need to analyze the struct, including getting the tags, collecting
// information about inlining, and create a map of the field name to the field.
sc.l.RLock()
ds, exists := sc.cache[t]
sc.l.RUnlock()
if exists {
return ds, nil
}
numFields := t.NumField()
sd := &structDescription{
fm: make(map[string]fieldDescription, numFields),
fl: make([]fieldDescription, 0, numFields),
inlineMap: -1,
}
for i := 0; i < numFields; i++ {
sf := t.Field(i)
if sf.PkgPath != "" {
// unexported, ignore
continue
}
encoder, err := r.LookupEncoder(sf.Type)
if err != nil {
encoder = nil
}
decoder, err := r.LookupDecoder(sf.Type)
if err != nil {
decoder = nil
}
description := fieldDescription{idx: i, encoder: encoder, decoder: decoder}
stags, err := sc.parser.ParseStructTags(sf)
if err != nil {
return nil, err
}
if stags.Skip {
continue
}
description.name = stags.Name
description.omitEmpty = stags.OmitEmpty
description.minSize = stags.MinSize
description.truncate = stags.Truncate
if stags.Inline {
switch sf.Type.Kind() {
case reflect.Map:
if sd.inlineMap >= 0 {
return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
}
if sf.Type.Key() != tString {
return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
}
sd.inlineMap = description.idx
case reflect.Struct:
inlinesf, err := sc.describeStruct(r, sf.Type)
if err != nil {
return nil, err
}
for _, fd := range inlinesf.fl {
if _, exists := sd.fm[fd.name]; exists {
return nil, fmt.Errorf("(struct %s) duplicated key %s", t.String(), fd.name)
}
if fd.inline == nil {
fd.inline = []int{i, fd.idx}
} else {
fd.inline = append([]int{i}, fd.inline...)
}
sd.fm[fd.name] = fd
sd.fl = append(sd.fl, fd)
}
default:
return nil, fmt.Errorf("(struct %s) inline fields must be either a struct or a map", t.String())
}
continue
}
if _, exists := sd.fm[description.name]; exists {
return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), description.name)
}
sd.fm[description.name] = description
sd.fl = append(sd.fl, description)
}
sc.l.Lock()
sc.cache[t] = sd
sc.l.Unlock()
return sd, nil
}
+119
View File
@@ -0,0 +1,119 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"strings"
)
// StructTagParser returns the struct tags for a given struct field.
type StructTagParser interface {
ParseStructTags(reflect.StructField) (StructTags, error)
}
// StructTagParserFunc is an adapter that allows a generic function to be used
// as a StructTagParser.
type StructTagParserFunc func(reflect.StructField) (StructTags, error)
// ParseStructTags implements the StructTagParser interface.
func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructTags, error) {
return stpf(sf)
}
// StructTags represents the struct tag fields that the StructCodec uses during
// the encoding and decoding process.
//
// In the case of a struct, the lowercased field name is used as the key for each exported
// field but this behavior may be changed using a struct tag. The tag may also contain flags to
// adjust the marshalling behavior for the field.
//
// The properties are defined below:
//
// OmitEmpty Only include the field if it's not set to the zero value for the type or to
// empty slices or maps.
//
// MinSize Marshal an integer of a type larger than 32 bits value as an int32, if that's
// feasible while preserving the numeric value.
//
// Truncate When unmarshaling a BSON double, it is permitted to lose precision to fit within
// a float32.
//
// Inline Inline the field, which must be a struct or a map, causing all of its fields
// or keys to be processed as if they were part of the outer struct. For maps,
// keys must not conflict with the bson keys of other struct fields.
//
// Skip This struct field should be skipped. This is usually denoted by parsing a "-"
// for the name.
//
// TODO(skriptble): Add tags for undefined as nil and for null as nil.
type StructTags struct {
Name string
OmitEmpty bool
MinSize bool
Truncate bool
Inline bool
Skip bool
}
// DefaultStructTagParser is the StructTagParser used by the StructCodec by default.
// It will handle the bson struct tag. See the documentation for StructTags to see
// what each of the returned fields means.
//
// If there is no name in the struct tag fields, the struct field name is lowercased.
// The tag formats accepted are:
//
// "[<key>][,<flag1>[,<flag2>]]"
//
// `(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)`
//
// An example:
//
// type T struct {
// A bool
// B int "myb"
// C string "myc,omitempty"
// D string `bson:",omitempty" json:"jsonkey"`
// E int64 ",minsize"
// F int64 "myf,omitempty,minsize"
// }
//
// A struct tag either consisting entirely of '-' or with a bson key with a
// value consisting entirely of '-' will return a StructTags with Skip true and
// the remaining fields will be their default values.
var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) {
key := strings.ToLower(sf.Name)
tag, ok := sf.Tag.Lookup("bson")
if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 {
tag = string(sf.Tag)
}
var st StructTags
if tag == "-" {
st.Skip = true
return st, nil
}
for idx, str := range strings.Split(tag, ",") {
if idx == 0 && str != "" {
key = str
}
switch str {
case "omitempty":
st.OmitEmpty = true
case "minsize":
st.MinSize = true
case "truncate":
st.Truncate = true
case "inline":
st.Inline = true
}
}
st.Name = key
return st, nil
}
+80
View File
@@ -0,0 +1,80 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"encoding/json"
"net/url"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var ptBool = reflect.TypeOf((*bool)(nil))
var ptInt8 = reflect.TypeOf((*int8)(nil))
var ptInt16 = reflect.TypeOf((*int16)(nil))
var ptInt32 = reflect.TypeOf((*int32)(nil))
var ptInt64 = reflect.TypeOf((*int64)(nil))
var ptInt = reflect.TypeOf((*int)(nil))
var ptUint8 = reflect.TypeOf((*uint8)(nil))
var ptUint16 = reflect.TypeOf((*uint16)(nil))
var ptUint32 = reflect.TypeOf((*uint32)(nil))
var ptUint64 = reflect.TypeOf((*uint64)(nil))
var ptUint = reflect.TypeOf((*uint)(nil))
var ptFloat32 = reflect.TypeOf((*float32)(nil))
var ptFloat64 = reflect.TypeOf((*float64)(nil))
var ptString = reflect.TypeOf((*string)(nil))
var tBool = reflect.TypeOf(false)
var tFloat32 = reflect.TypeOf(float32(0))
var tFloat64 = reflect.TypeOf(float64(0))
var tInt = reflect.TypeOf(int(0))
var tInt8 = reflect.TypeOf(int8(0))
var tInt16 = reflect.TypeOf(int16(0))
var tInt32 = reflect.TypeOf(int32(0))
var tInt64 = reflect.TypeOf(int64(0))
var tString = reflect.TypeOf("")
var tTime = reflect.TypeOf(time.Time{})
var tUint = reflect.TypeOf(uint(0))
var tUint8 = reflect.TypeOf(uint8(0))
var tUint16 = reflect.TypeOf(uint16(0))
var tUint32 = reflect.TypeOf(uint32(0))
var tUint64 = reflect.TypeOf(uint64(0))
var tEmpty = reflect.TypeOf((*interface{})(nil)).Elem()
var tByteSlice = reflect.TypeOf([]byte(nil))
var tByte = reflect.TypeOf(byte(0x00))
var tURL = reflect.TypeOf(url.URL{})
var tJSONNumber = reflect.TypeOf(json.Number(""))
var tValueMarshaler = reflect.TypeOf((*ValueMarshaler)(nil)).Elem()
var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem()
var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem()
var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem()
var tBinary = reflect.TypeOf(primitive.Binary{})
var tUndefined = reflect.TypeOf(primitive.Undefined{})
var tOID = reflect.TypeOf(primitive.ObjectID{})
var tDateTime = reflect.TypeOf(primitive.DateTime(0))
var tNull = reflect.TypeOf(primitive.Null{})
var tRegex = reflect.TypeOf(primitive.Regex{})
var tCodeWithScope = reflect.TypeOf(primitive.CodeWithScope{})
var tDBPointer = reflect.TypeOf(primitive.DBPointer{})
var tJavaScript = reflect.TypeOf(primitive.JavaScript(""))
var tSymbol = reflect.TypeOf(primitive.Symbol(""))
var tTimestamp = reflect.TypeOf(primitive.Timestamp{})
var tDecimal = reflect.TypeOf(primitive.Decimal128{})
var tMinKey = reflect.TypeOf(primitive.MinKey{})
var tMaxKey = reflect.TypeOf(primitive.MaxKey{})
var tD = reflect.TypeOf(primitive.D{})
var tA = reflect.TypeOf(primitive.A{})
var tE = reflect.TypeOf(primitive.E{})
var tCoreDocument = reflect.TypeOf(bsoncore.Document{})
+389
View File
@@ -0,0 +1,389 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
"io"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// Copier is a type that allows copying between ValueReaders, ValueWriters, and
// []byte values.
type Copier struct{}
// NewCopier creates a new copier with the given registry. If a nil registry is provided
// a default registry is used.
func NewCopier() Copier {
return Copier{}
}
// CopyDocument handles copying a document from src to dst.
func CopyDocument(dst ValueWriter, src ValueReader) error {
return Copier{}.CopyDocument(dst, src)
}
// CopyDocument handles copying one document from the src to the dst.
func (c Copier) CopyDocument(dst ValueWriter, src ValueReader) error {
dr, err := src.ReadDocument()
if err != nil {
return err
}
dw, err := dst.WriteDocument()
if err != nil {
return err
}
return c.copyDocumentCore(dw, dr)
}
// CopyDocumentFromBytes copies the values from a BSON document represented as a
// []byte to a ValueWriter.
func (c Copier) CopyDocumentFromBytes(dst ValueWriter, src []byte) error {
dw, err := dst.WriteDocument()
if err != nil {
return err
}
err = c.CopyBytesToDocumentWriter(dw, src)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
// CopyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a
// DocumentWriter.
func (c Copier) CopyBytesToDocumentWriter(dst DocumentWriter, src []byte) error {
// TODO(skriptble): Create errors types here. Anything thats a tag should be a property.
length, rem, ok := bsoncore.ReadLength(src)
if !ok {
return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src))
}
if len(src) < int(length) {
return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length)
}
rem = rem[:length-4]
var t bsontype.Type
var key string
var val bsoncore.Value
for {
t, rem, ok = bsoncore.ReadType(rem)
if !ok {
return io.EOF
}
if t == bsontype.Type(0) {
if len(rem) != 0 {
return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem)
}
break
}
key, rem, ok = bsoncore.ReadKey(rem)
if !ok {
return fmt.Errorf("invalid key found. remaining bytes=%v", rem)
}
dvw, err := dst.WriteDocumentElement(key)
if err != nil {
return err
}
val, rem, ok = bsoncore.ReadValue(rem, t)
if !ok {
return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t)
}
err = c.CopyValueFromBytes(dvw, t, val.Data)
if err != nil {
return err
}
}
return nil
}
// CopyDocumentToBytes copies an entire document from the ValueReader and
// returns it as bytes.
func (c Copier) CopyDocumentToBytes(src ValueReader) ([]byte, error) {
return c.AppendDocumentBytes(nil, src)
}
// AppendDocumentBytes functions the same as CopyDocumentToBytes, but will
// append the result to dst.
func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(BytesReader); ok {
_, dst, err := br.ReadValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
vw.reset(dst)
err := c.CopyDocument(vw, src)
dst = vw.buf
return dst, err
}
// CopyValueFromBytes will write the value represtend by t and src to dst.
func (c Copier) CopyValueFromBytes(dst ValueWriter, t bsontype.Type, src []byte) error {
if wvb, ok := dst.(BytesWriter); ok {
return wvb.WriteValueBytes(t, src)
}
vr := vrPool.Get().(*valueReader)
defer vrPool.Put(vr)
vr.reset(src)
vr.pushElement(t)
return c.CopyValue(dst, vr)
}
// CopyValueToBytes copies a value from src and returns it as a bsontype.Type and a
// []byte.
func (c Copier) CopyValueToBytes(src ValueReader) (bsontype.Type, []byte, error) {
return c.AppendValueBytes(nil, src)
}
// AppendValueBytes functions the same as CopyValueToBytes, but will append the
// result to dst.
func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, []byte, error) {
if br, ok := src.(BytesReader); ok {
return br.ReadValueBytes(dst)
}
vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
start := len(dst)
vw.reset(dst)
vw.push(mElement)
err := c.CopyValue(vw, src)
if err != nil {
return 0, dst, err
}
return bsontype.Type(vw.buf[start]), vw.buf[start+2:], nil
}
// CopyValue will copy a single value from src to dst.
func (c Copier) CopyValue(dst ValueWriter, src ValueReader) error {
var err error
switch src.Type() {
case bsontype.Double:
var f64 float64
f64, err = src.ReadDouble()
if err != nil {
break
}
err = dst.WriteDouble(f64)
case bsontype.String:
var str string
str, err = src.ReadString()
if err != nil {
return err
}
err = dst.WriteString(str)
case bsontype.EmbeddedDocument:
err = c.CopyDocument(dst, src)
case bsontype.Array:
err = c.copyArray(dst, src)
case bsontype.Binary:
var data []byte
var subtype byte
data, subtype, err = src.ReadBinary()
if err != nil {
break
}
err = dst.WriteBinaryWithSubtype(data, subtype)
case bsontype.Undefined:
err = src.ReadUndefined()
if err != nil {
break
}
err = dst.WriteUndefined()
case bsontype.ObjectID:
var oid primitive.ObjectID
oid, err = src.ReadObjectID()
if err != nil {
break
}
err = dst.WriteObjectID(oid)
case bsontype.Boolean:
var b bool
b, err = src.ReadBoolean()
if err != nil {
break
}
err = dst.WriteBoolean(b)
case bsontype.DateTime:
var dt int64
dt, err = src.ReadDateTime()
if err != nil {
break
}
err = dst.WriteDateTime(dt)
case bsontype.Null:
err = src.ReadNull()
if err != nil {
break
}
err = dst.WriteNull()
case bsontype.Regex:
var pattern, options string
pattern, options, err = src.ReadRegex()
if err != nil {
break
}
err = dst.WriteRegex(pattern, options)
case bsontype.DBPointer:
var ns string
var pointer primitive.ObjectID
ns, pointer, err = src.ReadDBPointer()
if err != nil {
break
}
err = dst.WriteDBPointer(ns, pointer)
case bsontype.JavaScript:
var js string
js, err = src.ReadJavascript()
if err != nil {
break
}
err = dst.WriteJavascript(js)
case bsontype.Symbol:
var symbol string
symbol, err = src.ReadSymbol()
if err != nil {
break
}
err = dst.WriteSymbol(symbol)
case bsontype.CodeWithScope:
var code string
var srcScope DocumentReader
code, srcScope, err = src.ReadCodeWithScope()
if err != nil {
break
}
var dstScope DocumentWriter
dstScope, err = dst.WriteCodeWithScope(code)
if err != nil {
break
}
err = c.copyDocumentCore(dstScope, srcScope)
case bsontype.Int32:
var i32 int32
i32, err = src.ReadInt32()
if err != nil {
break
}
err = dst.WriteInt32(i32)
case bsontype.Timestamp:
var t, i uint32
t, i, err = src.ReadTimestamp()
if err != nil {
break
}
err = dst.WriteTimestamp(t, i)
case bsontype.Int64:
var i64 int64
i64, err = src.ReadInt64()
if err != nil {
break
}
err = dst.WriteInt64(i64)
case bsontype.Decimal128:
var d128 primitive.Decimal128
d128, err = src.ReadDecimal128()
if err != nil {
break
}
err = dst.WriteDecimal128(d128)
case bsontype.MinKey:
err = src.ReadMinKey()
if err != nil {
break
}
err = dst.WriteMinKey()
case bsontype.MaxKey:
err = src.ReadMaxKey()
if err != nil {
break
}
err = dst.WriteMaxKey()
default:
err = fmt.Errorf("Cannot copy unknown BSON type %s", src.Type())
}
return err
}
func (c Copier) copyArray(dst ValueWriter, src ValueReader) error {
ar, err := src.ReadArray()
if err != nil {
return err
}
aw, err := dst.WriteArray()
if err != nil {
return err
}
for {
vr, err := ar.ReadValue()
if err == ErrEOA {
break
}
if err != nil {
return err
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = c.CopyValue(vw, vr)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
for {
key, vr, err := dr.ReadElement()
if err == ErrEOD {
break
}
if err != nil {
return err
}
vw, err := dw.WriteDocumentElement(key)
if err != nil {
return err
}
err = c.CopyValue(vw, vr)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
+9
View File
@@ -0,0 +1,9 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsonrw contains abstractions for reading and writing
// BSON and BSON like types from sources.
package bsonrw // import "go.mongodb.org/mongo-driver/bson/bsonrw"
+731
View File
@@ -0,0 +1,731 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"errors"
"fmt"
"io"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
const maxNestingDepth = 200
// ErrInvalidJSON indicates the JSON input is invalid
var ErrInvalidJSON = errors.New("invalid JSON input")
type jsonParseState byte
const (
jpsStartState jsonParseState = iota
jpsSawBeginObject
jpsSawEndObject
jpsSawBeginArray
jpsSawEndArray
jpsSawColon
jpsSawComma
jpsSawKey
jpsSawValue
jpsDoneState
jpsInvalidState
)
type jsonParseMode byte
const (
jpmInvalidMode jsonParseMode = iota
jpmObjectMode
jpmArrayMode
)
type extJSONValue struct {
t bsontype.Type
v interface{}
}
type extJSONObject struct {
keys []string
values []*extJSONValue
}
type extJSONParser struct {
js *jsonScanner
s jsonParseState
m []jsonParseMode
k string
v *extJSONValue
err error
canonical bool
depth int
maxDepth int
emptyObject bool
}
// newExtJSONParser returns a new extended JSON parser, ready to to begin
// parsing from the first character of the argued json input. It will not
// perform any read-ahead and will therefore not report any errors about
// malformed JSON at this point.
func newExtJSONParser(r io.Reader, canonical bool) *extJSONParser {
return &extJSONParser{
js: &jsonScanner{r: r},
s: jpsStartState,
m: []jsonParseMode{},
canonical: canonical,
maxDepth: maxNestingDepth,
}
}
// peekType examines the next value and returns its BSON Type
func (ejp *extJSONParser) peekType() (bsontype.Type, error) {
var t bsontype.Type
var err error
ejp.advanceState()
switch ejp.s {
case jpsSawValue:
t = ejp.v.t
case jpsSawBeginArray:
t = bsontype.Array
case jpsInvalidState:
err = ejp.err
case jpsSawComma:
// in array mode, seeing a comma means we need to progress again to actually observe a type
if ejp.peekMode() == jpmArrayMode {
return ejp.peekType()
}
case jpsSawEndArray:
// this would only be a valid state if we were in array mode, so return end-of-array error
err = ErrEOA
case jpsSawBeginObject:
// peek key to determine type
ejp.advanceState()
switch ejp.s {
case jpsSawEndObject: // empty embedded document
t = bsontype.EmbeddedDocument
ejp.emptyObject = true
case jpsInvalidState:
err = ejp.err
case jpsSawKey:
t = wrapperKeyBSONType(ejp.k)
if t == bsontype.JavaScript {
// just saw $code, need to check for $scope at same level
_, err := ejp.readValue(bsontype.JavaScript)
if err != nil {
break
}
switch ejp.s {
case jpsSawEndObject: // type is TypeJavaScript
case jpsSawComma:
ejp.advanceState()
if ejp.s == jpsSawKey && ejp.k == "$scope" {
t = bsontype.CodeWithScope
} else {
err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
}
case jpsInvalidState:
err = ejp.err
default:
err = ErrInvalidJSON
}
}
}
}
return t, err
}
// readKey parses the next key and its type and returns them
func (ejp *extJSONParser) readKey() (string, bsontype.Type, error) {
if ejp.emptyObject {
ejp.emptyObject = false
return "", 0, ErrEOD
}
// advance to key (or return with error)
switch ejp.s {
case jpsStartState:
ejp.advanceState()
if ejp.s == jpsSawBeginObject {
ejp.advanceState()
}
case jpsSawBeginObject:
ejp.advanceState()
case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject, jpsSawComma:
ejp.advanceState()
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsDoneState:
return "", 0, io.EOF
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, ErrInvalidJSON
}
case jpsSawKey: // do nothing (key was peeked before)
default:
return "", 0, invalidRequestError("key")
}
// read key
var key string
switch ejp.s {
case jpsSawKey:
key = ejp.k
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, invalidRequestError("key")
}
// check for colon
ejp.advanceState()
if err := ensureColon(ejp.s, key); err != nil {
return "", 0, err
}
// peek at the value to determine type
t, err := ejp.peekType()
if err != nil {
return "", 0, err
}
return key, t, nil
}
// readValue returns the value corresponding to the Type returned by peekType
func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) {
if ejp.s == jpsInvalidState {
return nil, ejp.err
}
var v *extJSONValue
switch t {
case bsontype.Null, bsontype.Boolean, bsontype.String:
if ejp.s != jpsSawValue {
return nil, invalidRequestError(t.String())
}
v = ejp.v
case bsontype.Int32, bsontype.Int64, bsontype.Double:
// relaxed version allows these to be literal number values
if ejp.s == jpsSawValue {
v = ejp.v
break
}
fallthrough
case bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID, bsontype.MinKey, bsontype.MaxKey, bsontype.Undefined:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("} after value", t)
}
default:
return nil, invalidRequestError(t.String())
}
case bsontype.Binary, bsontype.Regex, bsontype.Timestamp, bsontype.DBPointer:
if ejp.s != jpsSawKey {
return nil, invalidRequestError(t.String())
}
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
if t == bsontype.Binary && ejp.s == jpsSawValue {
// convert legacy $binary format
base64 := ejp.v
ejp.advanceState()
if ejp.s != jpsSawComma {
return nil, invalidJSONErrorForType(",", bsontype.Binary)
}
ejp.advanceState()
key, t, err := ejp.readKey()
if err != nil {
return nil, err
}
if key != "$type" {
return nil, invalidJSONErrorForType("$type", bsontype.Binary)
}
subType, err := ejp.readValue(t)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", bsontype.Binary)
}
v = &extJSONValue{
t: bsontype.EmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// read KV pairs
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONErrorForType("{", t)
}
keys, vals, err := ejp.readObject(2, true)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
}
v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case bsontype.DateTime:
switch ejp.s {
case jpsSawValue:
v = ejp.v
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject:
keys, vals, err := ejp.readObject(1, true)
if err != nil {
return nil, err
}
v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case jpsSawValue:
if ejp.canonical {
return nil, invalidJSONError("{")
}
v = ejp.v
default:
if ejp.canonical {
return nil, invalidJSONErrorForType("object", t)
}
return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as decribed in RFC-3339", t)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("value and then }", t)
}
default:
return nil, invalidRequestError(t.String())
}
case bsontype.JavaScript:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object or comma and just return
ejp.advanceState()
case jpsSawEndObject:
v = ejp.v
default:
return nil, invalidRequestError(t.String())
}
case bsontype.CodeWithScope:
if ejp.s == jpsSawKey && ejp.k == "$scope" {
v = ejp.v // this is the $code string from earlier
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONError("$scope to be embedded document")
}
} else {
return nil, invalidRequestError(t.String())
}
case bsontype.EmbeddedDocument, bsontype.Array:
return nil, invalidRequestError(t.String())
}
return v, nil
}
// readObject is a utility method for reading full objects of known (or expected) size
// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
keys := make([]string, numKeys)
vals := make([]*extJSONValue, numKeys)
if !started {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, nil, invalidJSONError("{")
}
}
for i := 0; i < numKeys; i++ {
key, t, err := ejp.readKey()
if err != nil {
return nil, nil, err
}
switch ejp.s {
case jpsSawKey:
v, err := ejp.readValue(t)
if err != nil {
return nil, nil, err
}
keys[i] = key
vals[i] = v
case jpsSawValue:
keys[i] = key
vals[i] = ejp.v
default:
return nil, nil, invalidJSONError("value")
}
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, nil, invalidJSONError("}")
}
return keys, vals, nil
}
// advanceState reads the next JSON token from the scanner and transitions
// from the current state based on that token's type
func (ejp *extJSONParser) advanceState() {
if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
return
}
jt, err := ejp.js.nextToken()
if err != nil {
ejp.err = err
ejp.s = jpsInvalidState
return
}
valid := ejp.validateToken(jt.t)
if !valid {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
return
}
switch jt.t {
case jttBeginObject:
ejp.s = jpsSawBeginObject
ejp.pushMode(jpmObjectMode)
ejp.depth++
if ejp.depth > ejp.maxDepth {
ejp.err = nestingDepthError(jt.p, ejp.depth)
ejp.s = jpsInvalidState
}
case jttEndObject:
ejp.s = jpsSawEndObject
ejp.depth--
if ejp.popMode() != jpmObjectMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttBeginArray:
ejp.s = jpsSawBeginArray
ejp.pushMode(jpmArrayMode)
case jttEndArray:
ejp.s = jpsSawEndArray
if ejp.popMode() != jpmArrayMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttColon:
ejp.s = jpsSawColon
case jttComma:
ejp.s = jpsSawComma
case jttEOF:
ejp.s = jpsDoneState
if len(ejp.m) != 0 {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttString:
switch ejp.s {
case jpsSawComma:
if ejp.peekMode() == jpmArrayMode {
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
return
}
fallthrough
case jpsSawBeginObject:
ejp.s = jpsSawKey
ejp.k = jt.v.(string)
return
}
fallthrough
default:
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
}
}
var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
jpsStartState: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
jttEOF: true,
},
jpsSawBeginObject: {
jttEndObject: true,
jttString: true,
},
jpsSawEndObject: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawBeginArray: {
jttBeginObject: true,
jttBeginArray: true,
jttEndArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawEndArray: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawColon: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawComma: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawKey: {
jttColon: true,
},
jpsSawValue: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsDoneState: {},
jpsInvalidState: {},
}
func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
switch ejp.s {
case jpsSawEndObject:
// if we are at depth zero and the next token is a '{',
// we can consider it valid only if we are not in array mode.
if jtt == jttBeginObject && ejp.depth == 0 {
return ejp.peekMode() != jpmArrayMode
}
case jpsSawComma:
switch ejp.peekMode() {
// the only valid next token after a comma inside a document is a string (a key)
case jpmObjectMode:
return jtt == jttString
case jpmInvalidMode:
return false
}
}
_, ok := jpsValidTransitionTokens[ejp.s][jtt]
return ok
}
// ensureExtValueType returns true if the current value has the expected
// value type for single-key extended JSON types. For example,
// {"$numberInt": v} v must be TypeString
func (ejp *extJSONParser) ensureExtValueType(t bsontype.Type) bool {
switch t {
case bsontype.MinKey, bsontype.MaxKey:
return ejp.v.t == bsontype.Int32
case bsontype.Undefined:
return ejp.v.t == bsontype.Boolean
case bsontype.Int32, bsontype.Int64, bsontype.Double, bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID:
return ejp.v.t == bsontype.String
default:
return false
}
}
func (ejp *extJSONParser) pushMode(m jsonParseMode) {
ejp.m = append(ejp.m, m)
}
func (ejp *extJSONParser) popMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
m := ejp.m[l-1]
ejp.m = ejp.m[:l-1]
return m
}
func (ejp *extJSONParser) peekMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
return ejp.m[l-1]
}
func extendJSONToken(jt *jsonToken) *extJSONValue {
var t bsontype.Type
switch jt.t {
case jttInt32:
t = bsontype.Int32
case jttInt64:
t = bsontype.Int64
case jttDouble:
t = bsontype.Double
case jttString:
t = bsontype.String
case jttBool:
t = bsontype.Boolean
case jttNull:
t = bsontype.Null
default:
return nil
}
return &extJSONValue{t: t, v: jt.v}
}
func ensureColon(s jsonParseState, key string) error {
if s != jpsSawColon {
return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
}
return nil
}
func invalidRequestError(s string) error {
return fmt.Errorf("invalid request to read %s", s)
}
func invalidJSONError(expected string) error {
return fmt.Errorf("invalid JSON input; expected %s", expected)
}
func invalidJSONErrorForType(expected string, t bsontype.Type) error {
return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
}
func unexpectedTokenError(jt *jsonToken) error {
switch jt.t {
case jttInt32, jttInt64, jttDouble:
return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
case jttString:
return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
case jttBool:
return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
case jttNull:
return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
case jttEOF:
return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
default:
return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
}
}
func nestingDepthError(p, depth int) error {
return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
}
+659
View File
@@ -0,0 +1,659 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
"io"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ExtJSONValueReaderPool is a pool for ValueReaders that read ExtJSON.
type ExtJSONValueReaderPool struct {
pool sync.Pool
}
// NewExtJSONValueReaderPool instantiates a new ExtJSONValueReaderPool.
func NewExtJSONValueReaderPool() *ExtJSONValueReaderPool {
return &ExtJSONValueReaderPool{
pool: sync.Pool{
New: func() interface{} {
return new(extJSONValueReader)
},
},
}
}
// Get retrieves a ValueReader from the pool and uses src as the underlying ExtJSON.
func (bvrp *ExtJSONValueReaderPool) Get(r io.Reader, canonical bool) (ValueReader, error) {
vr := bvrp.pool.Get().(*extJSONValueReader)
return vr.reset(r, canonical)
}
// Put inserts a ValueReader into the pool. If the ValueReader is not a ExtJSON ValueReader nothing
// is inserted into the pool and ok will be false.
func (bvrp *ExtJSONValueReaderPool) Put(vr ValueReader) (ok bool) {
bvr, ok := vr.(*extJSONValueReader)
if !ok {
return false
}
bvr, _ = bvr.reset(nil, false)
bvrp.pool.Put(bvr)
return true
}
type ejvrState struct {
mode mode
vType bsontype.Type
depth int
}
// extJSONValueReader is for reading extended JSON.
type extJSONValueReader struct {
p *extJSONParser
stack []ejvrState
frame int
}
// NewExtJSONValueReader creates a new ValueReader from a given io.Reader
// It will interpret the JSON of r as canonical or relaxed according to the
// given canonical flag
func NewExtJSONValueReader(r io.Reader, canonical bool) (ValueReader, error) {
return newExtJSONValueReader(r, canonical)
}
func newExtJSONValueReader(r io.Reader, canonical bool) (*extJSONValueReader, error) {
ejvr := new(extJSONValueReader)
return ejvr.reset(r, canonical)
}
func (ejvr *extJSONValueReader) reset(r io.Reader, canonical bool) (*extJSONValueReader, error) {
p := newExtJSONParser(r, canonical)
typ, err := p.peekType()
if err != nil {
return nil, ErrInvalidJSON
}
var m mode
switch typ {
case bsontype.EmbeddedDocument:
m = mTopLevel
case bsontype.Array:
m = mArray
default:
m = mValue
}
stack := make([]ejvrState, 1, 5)
stack[0] = ejvrState{
mode: m,
vType: typ,
}
return &extJSONValueReader{
p: p,
stack: stack,
}, nil
}
func (ejvr *extJSONValueReader) advanceFrame() {
if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack
length := len(ejvr.stack)
if length+1 >= cap(ejvr.stack) {
// double it
buf := make([]ejvrState, 2*cap(ejvr.stack)+1)
copy(buf, ejvr.stack)
ejvr.stack = buf
}
ejvr.stack = ejvr.stack[:length+1]
}
ejvr.frame++
// Clean the stack
ejvr.stack[ejvr.frame].mode = 0
ejvr.stack[ejvr.frame].vType = 0
ejvr.stack[ejvr.frame].depth = 0
}
func (ejvr *extJSONValueReader) pushDocument() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mDocument
ejvr.stack[ejvr.frame].depth = ejvr.p.depth
}
func (ejvr *extJSONValueReader) pushCodeWithScope() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mCodeWithScope
}
func (ejvr *extJSONValueReader) pushArray() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mArray
}
func (ejvr *extJSONValueReader) push(m mode, t bsontype.Type) {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = m
ejvr.stack[ejvr.frame].vType = t
}
func (ejvr *extJSONValueReader) pop() {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
ejvr.frame--
case mDocument, mArray, mCodeWithScope:
ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
}
func (ejvr *extJSONValueReader) skipDocument() error {
// read entire document until ErrEOD (using readKey and readValue)
_, typ, err := ejvr.p.readKey()
for err == nil {
_, err = ejvr.p.readValue(typ)
if err != nil {
break
}
_, typ, err = ejvr.p.readKey()
}
return err
}
func (ejvr *extJSONValueReader) skipArray() error {
// read entire array until ErrEOA (using peekType)
_, err := ejvr.p.peekType()
for err == nil {
_, err = ejvr.p.peekType()
}
return err
}
func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvr.stack[ejvr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if ejvr.frame != 0 {
te.parent = ejvr.stack[ejvr.frame-1].mode
}
return te
}
func (ejvr *extJSONValueReader) typeError(t bsontype.Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t)
}
func (ejvr *extJSONValueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string, addModes ...mode) error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != t {
return ejvr.typeError(t)
}
default:
modes := []mode{mElement, mValue}
if addModes != nil {
modes = append(modes, addModes...)
}
return ejvr.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvr *extJSONValueReader) Type() bsontype.Type {
return ejvr.stack[ejvr.frame].vType
}
func (ejvr *extJSONValueReader) Skip() error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
default:
return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
defer ejvr.pop()
t := ejvr.stack[ejvr.frame].vType
switch t {
case bsontype.Array:
// read entire array until ErrEOA
err := ejvr.skipArray()
if err != ErrEOA {
return err
}
case bsontype.EmbeddedDocument:
// read entire doc until ErrEOD
err := ejvr.skipDocument()
if err != ErrEOD {
return err
}
case bsontype.CodeWithScope:
// read the code portion and set up parser in document mode
_, err := ejvr.p.readValue(t)
if err != nil {
return err
}
// read until ErrEOD
err = ejvr.skipDocument()
if err != ErrEOD {
return err
}
default:
_, err := ejvr.p.readValue(t)
if err != nil {
return err
}
}
return nil
}
func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel: // allow reading array from top level
case mArray:
return ejvr, nil
default:
if err := ejvr.ensureElementValue(bsontype.Array, mArray, "ReadArray", mTopLevel, mArray); err != nil {
return nil, err
}
}
ejvr.pushArray()
return ejvr, nil
}
func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := ejvr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
v, err := ejvr.p.readValue(bsontype.Binary)
if err != nil {
return nil, 0, err
}
b, btype, err = v.parseBinary()
ejvr.pop()
return b, btype, err
}
func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) {
if err := ejvr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil {
return false, err
}
v, err := ejvr.p.readValue(bsontype.Boolean)
if err != nil {
return false, err
}
if v.t != bsontype.Boolean {
return false, fmt.Errorf("expected type bool, but got type %s", v.t)
}
ejvr.pop()
return v.v.(bool), nil
}
func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel:
return ejvr, nil
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != bsontype.EmbeddedDocument {
return nil, ejvr.typeError(bsontype.EmbeddedDocument)
}
ejvr.pushDocument()
return ejvr, nil
default:
return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
}
func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err = ejvr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
v, err := ejvr.p.readValue(bsontype.CodeWithScope)
if err != nil {
return "", nil, err
}
code, err = v.parseJavascript()
ejvr.pushCodeWithScope()
return code, ejvr, err
}
func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) {
if err = ejvr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil {
return "", primitive.NilObjectID, err
}
v, err := ejvr.p.readValue(bsontype.DBPointer)
if err != nil {
return "", primitive.NilObjectID, err
}
ns, oid, err = v.parseDBPointer()
ejvr.pop()
return ns, oid, err
}
func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) {
if err := ejvr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.DateTime)
if err != nil {
return 0, err
}
d, err := v.parseDateTime()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDecimal128() (primitive.Decimal128, error) {
if err := ejvr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil {
return primitive.Decimal128{}, err
}
v, err := ejvr.p.readValue(bsontype.Decimal128)
if err != nil {
return primitive.Decimal128{}, err
}
d, err := v.parseDecimal128()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDouble() (float64, error) {
if err := ejvr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.Double)
if err != nil {
return 0, err
}
d, err := v.parseDouble()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadInt32() (int32, error) {
if err := ejvr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.Int32)
if err != nil {
return 0, err
}
i, err := v.parseInt32()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadInt64() (int64, error) {
if err := ejvr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.Int64)
if err != nil {
return 0, err
}
i, err := v.parseInt64()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) {
if err = ejvr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(bsontype.JavaScript)
if err != nil {
return "", err
}
code, err = v.parseJavascript()
ejvr.pop()
return code, err
}
func (ejvr *extJSONValueReader) ReadMaxKey() error {
if err := ejvr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.MaxKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("max")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadMinKey() error {
if err := ejvr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.MinKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("min")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadNull() error {
if err := ejvr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.Null)
if err != nil {
return err
}
if v.t != bsontype.Null {
return fmt.Errorf("expected type null but got type %s", v.t)
}
ejvr.pop()
return nil
}
func (ejvr *extJSONValueReader) ReadObjectID() (primitive.ObjectID, error) {
if err := ejvr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil {
return primitive.ObjectID{}, err
}
v, err := ejvr.p.readValue(bsontype.ObjectID)
if err != nil {
return primitive.ObjectID{}, err
}
oid, err := v.parseObjectID()
ejvr.pop()
return oid, err
}
func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) {
if err = ejvr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil {
return "", "", err
}
v, err := ejvr.p.readValue(bsontype.Regex)
if err != nil {
return "", "", err
}
pattern, options, err = v.parseRegex()
ejvr.pop()
return pattern, options, err
}
func (ejvr *extJSONValueReader) ReadString() (string, error) {
if err := ejvr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(bsontype.String)
if err != nil {
return "", err
}
if v.t != bsontype.String {
return "", fmt.Errorf("expected type string but got type %s", v.t)
}
ejvr.pop()
return v.v.(string), nil
}
func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) {
if err = ejvr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(bsontype.Symbol)
if err != nil {
return "", err
}
symbol, err = v.parseSymbol()
ejvr.pop()
return symbol, err
}
func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err = ejvr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
v, err := ejvr.p.readValue(bsontype.Timestamp)
if err != nil {
return 0, 0, err
}
t, i, err = v.parseTimestamp()
ejvr.pop()
return t, i, err
}
func (ejvr *extJSONValueReader) ReadUndefined() error {
if err := ejvr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.Undefined)
if err != nil {
return err
}
err = v.parseUndefined()
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
name, t, err := ejvr.p.readKey()
if err != nil {
if err == ErrEOD {
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
_, err := ejvr.p.peekType()
if err != nil {
return "", nil, err
}
}
ejvr.pop()
}
return "", nil, err
}
ejvr.push(mElement, t)
return name, ejvr, nil
}
func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mArray:
default:
return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := ejvr.p.peekType()
if err != nil {
if err == ErrEOA {
ejvr.pop()
}
return nil, err
}
ejvr.push(mValue, t)
return ejvr, nil
}
+223
View File
@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/golang/go by The Go Authors
// See THIRD-PARTY-NOTICES for original license terms.
package bsonrw
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
+481
View File
@@ -0,0 +1,481 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"encoding/base64"
"errors"
"fmt"
"math"
"strconv"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func wrapperKeyBSONType(key string) bsontype.Type {
switch string(key) {
case "$numberInt":
return bsontype.Int32
case "$numberLong":
return bsontype.Int64
case "$oid":
return bsontype.ObjectID
case "$symbol":
return bsontype.Symbol
case "$numberDouble":
return bsontype.Double
case "$numberDecimal":
return bsontype.Decimal128
case "$binary":
return bsontype.Binary
case "$code":
return bsontype.JavaScript
case "$scope":
return bsontype.CodeWithScope
case "$timestamp":
return bsontype.Timestamp
case "$regularExpression":
return bsontype.Regex
case "$dbPointer":
return bsontype.DBPointer
case "$date":
return bsontype.DateTime
case "$ref":
fallthrough
case "$id":
fallthrough
case "$db":
return bsontype.EmbeddedDocument // dbrefs aren't bson types
case "$minKey":
return bsontype.MinKey
case "$maxKey":
return bsontype.MaxKey
case "$undefined":
return bsontype.Undefined
}
return bsontype.EmbeddedDocument
}
func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t)
}
binObj := ejv.v.(*extJSONObject)
bFound := false
stFound := false
for i, key := range binObj.keys {
val := binObj.values[i]
switch key {
case "base64":
if bFound {
return nil, 0, errors.New("duplicate base64 key in $binary")
}
if val.t != bsontype.String {
return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t)
}
base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string))
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string))
}
b = base64Bytes
bFound = true
case "subType":
if stFound {
return nil, 0, errors.New("duplicate subType key in $binary")
}
if val.t != bsontype.String {
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
}
i, err := strconv.ParseInt(val.v.(string), 16, 64)
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string))
}
subType = byte(i)
stFound = true
default:
return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key)
}
}
if !bFound {
return nil, 0, errors.New("missing base64 field in $binary object")
}
if !stFound {
return nil, 0, errors.New("missing subType field in $binary object")
}
return b, subType, nil
}
func (ejv *extJSONValue) parseDBPointer() (ns string, oid primitive.ObjectID, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return "", primitive.NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t)
}
dbpObj := ejv.v.(*extJSONObject)
oidFound := false
nsFound := false
for i, key := range dbpObj.keys {
val := dbpObj.values[i]
switch key {
case "$ref":
if nsFound {
return "", primitive.NilObjectID, errors.New("duplicate $ref key in $dbPointer")
}
if val.t != bsontype.String {
return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t)
}
ns = val.v.(string)
nsFound = true
case "$id":
if oidFound {
return "", primitive.NilObjectID, errors.New("duplicate $id key in $dbPointer")
}
if val.t != bsontype.String {
return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t)
}
oid, err = primitive.ObjectIDFromHex(val.v.(string))
if err != nil {
return "", primitive.NilObjectID, err
}
oidFound = true
default:
return "", primitive.NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key)
}
}
if !nsFound {
return "", oid, errors.New("missing $ref field in $dbPointer object")
}
if !oidFound {
return "", oid, errors.New("missing $id field in $dbPointer object")
}
return ns, oid, nil
}
const rfc3339Milli = "2006-01-02T15:04:05.999Z07:00"
func (ejv *extJSONValue) parseDateTime() (int64, error) {
switch ejv.t {
case bsontype.Int32:
return int64(ejv.v.(int32)), nil
case bsontype.Int64:
return ejv.v.(int64), nil
case bsontype.String:
return parseDatetimeString(ejv.v.(string))
case bsontype.EmbeddedDocument:
return parseDatetimeObject(ejv.v.(*extJSONObject))
default:
return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t)
}
}
func parseDatetimeString(data string) (int64, error) {
t, err := time.Parse(rfc3339Milli, data)
if err != nil {
return 0, fmt.Errorf("invalid $date value string: %s", data)
}
return t.UnixNano() / 1e6, nil
}
func parseDatetimeObject(data *extJSONObject) (d int64, err error) {
dFound := false
for i, key := range data.keys {
val := data.values[i]
switch key {
case "$numberLong":
if dFound {
return 0, errors.New("duplicate $numberLong key in $date")
}
if val.t != bsontype.String {
return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t)
}
d, err = val.parseInt64()
if err != nil {
return 0, err
}
dFound = true
default:
return 0, fmt.Errorf("invalid key in $date object: %s", key)
}
}
if !dFound {
return 0, errors.New("missing $numberLong field in $date object")
}
return d, nil
}
func (ejv *extJSONValue) parseDecimal128() (primitive.Decimal128, error) {
if ejv.t != bsontype.String {
return primitive.Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t)
}
d, err := primitive.ParseDecimal128(ejv.v.(string))
if err != nil {
return primitive.Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string))
}
return d, nil
}
func (ejv *extJSONValue) parseDouble() (float64, error) {
if ejv.t == bsontype.Double {
return ejv.v.(float64), nil
}
if ejv.t != bsontype.String {
return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t)
}
switch string(ejv.v.(string)) {
case "Infinity":
return math.Inf(1), nil
case "-Infinity":
return math.Inf(-1), nil
case "NaN":
return math.NaN(), nil
}
f, err := strconv.ParseFloat(ejv.v.(string), 64)
if err != nil {
return 0, err
}
return f, nil
}
func (ejv *extJSONValue) parseInt32() (int32, error) {
if ejv.t == bsontype.Int32 {
return ejv.v.(int32), nil
}
if ejv.t != bsontype.String {
return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
if i < math.MinInt32 || i > math.MaxInt32 {
return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i)
}
return int32(i), nil
}
func (ejv *extJSONValue) parseInt64() (int64, error) {
if ejv.t == bsontype.Int64 {
return ejv.v.(int64), nil
}
if ejv.t != bsontype.String {
return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
return i, nil
}
func (ejv *extJSONValue) parseJavascript() (code string, err error) {
if ejv.t != bsontype.String {
return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseMinMaxKey(minmax string) error {
if ejv.t != bsontype.Int32 {
return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t)
}
if ejv.v.(int32) != 1 {
return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32))
}
return nil
}
func (ejv *extJSONValue) parseObjectID() (primitive.ObjectID, error) {
if ejv.t != bsontype.String {
return primitive.NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t)
}
return primitive.ObjectIDFromHex(ejv.v.(string))
}
func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t)
}
regexObj := ejv.v.(*extJSONObject)
patFound := false
optFound := false
for i, key := range regexObj.keys {
val := regexObj.values[i]
switch string(key) {
case "pattern":
if patFound {
return "", "", errors.New("duplicate pattern key in $regularExpression")
}
if val.t != bsontype.String {
return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t)
}
pattern = val.v.(string)
patFound = true
case "options":
if optFound {
return "", "", errors.New("duplicate options key in $regularExpression")
}
if val.t != bsontype.String {
return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t)
}
options = val.v.(string)
optFound = true
default:
return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key)
}
}
if !patFound {
return "", "", errors.New("missing pattern field in $regularExpression object")
}
if !optFound {
return "", "", errors.New("missing options field in $regularExpression object")
}
return pattern, options, nil
}
func (ejv *extJSONValue) parseSymbol() (string, error) {
if ejv.t != bsontype.String {
return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t)
}
handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) {
if flag {
return 0, fmt.Errorf("duplicate %s key in $timestamp", key)
}
switch val.t {
case bsontype.Int32:
if val.v.(int32) < 0 {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %s", key, string(val.v.(int32)))
}
return uint32(val.v.(int32)), nil
case bsontype.Int64:
if val.v.(int64) < 0 || uint32(val.v.(int64)) > math.MaxUint32 {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %s", key, string(val.v.(int32)))
}
return uint32(val.v.(int64)), nil
default:
return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t)
}
}
tsObj := ejv.v.(*extJSONObject)
tFound := false
iFound := false
for j, key := range tsObj.keys {
val := tsObj.values[j]
switch key {
case "t":
if t, err = handleKey(key, val, tFound); err != nil {
return 0, 0, err
}
tFound = true
case "i":
if i, err = handleKey(key, val, iFound); err != nil {
return 0, 0, err
}
iFound = true
default:
return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key)
}
}
if !tFound {
return 0, 0, errors.New("missing t field in $timestamp object")
}
if !iFound {
return 0, 0, errors.New("missing i field in $timestamp object")
}
return t, i, nil
}
func (ejv *extJSONValue) parseUndefined() error {
if ejv.t != bsontype.Boolean {
return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t)
}
if !ejv.v.(bool) {
return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool))
}
return nil
}
+734
View File
@@ -0,0 +1,734 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"encoding/base64"
"fmt"
"go.mongodb.org/mongo-driver/bson/primitive"
"io"
"math"
"sort"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
var ejvwPool = sync.Pool{
New: func() interface{} {
return new(extJSONValueWriter)
},
}
// ExtJSONValueWriterPool is a pool for ExtJSON ValueWriters.
type ExtJSONValueWriterPool struct {
pool sync.Pool
}
// NewExtJSONValueWriterPool creates a new pool for ValueWriter instances that write to ExtJSON.
func NewExtJSONValueWriterPool() *ExtJSONValueWriterPool {
return &ExtJSONValueWriterPool{
pool: sync.Pool{
New: func() interface{} {
return new(extJSONValueWriter)
},
},
}
}
// Get retrieves a ExtJSON ValueWriter from the pool and resets it to use w as the destination.
func (bvwp *ExtJSONValueWriterPool) Get(w io.Writer, canonical, escapeHTML bool) ValueWriter {
vw := bvwp.pool.Get().(*extJSONValueWriter)
if writer, ok := w.(*SliceWriter); ok {
vw.reset(*writer, canonical, escapeHTML)
vw.w = writer
return vw
}
vw.buf = vw.buf[:0]
vw.w = w
return vw
}
// Put inserts a ValueWriter into the pool. If the ValueWriter is not a ExtJSON ValueWriter, nothing
// happens and ok will be false.
func (bvwp *ExtJSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
bvw, ok := vw.(*extJSONValueWriter)
if !ok {
return false
}
if _, ok := bvw.w.(*SliceWriter); ok {
bvw.buf = nil
}
bvw.w = nil
bvwp.pool.Put(bvw)
return true
}
type ejvwState struct {
mode mode
}
type extJSONValueWriter struct {
w io.Writer
buf []byte
stack []ejvwState
frame int64
canonical bool
escapeHTML bool
}
// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) (ValueWriter, error) {
if w == nil {
return nil, errNilWriter
}
return newExtJSONWriter(w, canonical, escapeHTML), nil
}
func newExtJSONWriter(w io.Writer, canonical, escapeHTML bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
w: w,
buf: []byte{},
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
}
}
func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
buf: buf,
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
}
}
func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
if ejvw.stack == nil {
ejvw.stack = make([]ejvwState, 1, 5)
}
ejvw.stack = ejvw.stack[:1]
ejvw.stack[0] = ejvwState{mode: mTopLevel}
ejvw.canonical = canonical
ejvw.escapeHTML = escapeHTML
ejvw.frame = 0
ejvw.buf = buf
ejvw.w = nil
}
func (ejvw *extJSONValueWriter) advanceFrame() {
if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
length := len(ejvw.stack)
if length+1 >= cap(ejvw.stack) {
// double it
buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
copy(buf, ejvw.stack)
ejvw.stack = buf
}
ejvw.stack = ejvw.stack[:length+1]
}
ejvw.frame++
}
func (ejvw *extJSONValueWriter) push(m mode) {
ejvw.advanceFrame()
ejvw.stack[ejvw.frame].mode = m
}
func (ejvw *extJSONValueWriter) pop() {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
ejvw.frame--
case mDocument, mArray, mCodeWithScope:
ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvw.stack[ejvw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if ejvw.frame != 0 {
te.parent = ejvw.stack[ejvw.frame-1].mode
}
return te
}
func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return ejvw.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
var s string
if quotes {
s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
} else {
s = fmt.Sprintf(`{"$%s":%s}`, key, value)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '[')
ejvw.push(mArray)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
return ejvw.WriteBinaryWithSubtype(b, 0x00)
}
func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$binary":{"base64":"`)
buf.WriteString(base64.StdEncoding.EncodeToString(b))
buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
var buf bytes.Buffer
buf.WriteString(`{"$code":`)
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
buf.WriteString(`,"$scope":{`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.push(mCodeWithScope)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$dbPointer":{"$ref":"`)
buf.WriteString(ns)
buf.WriteString(`","$id":{"$oid":"`)
buf.WriteString(oid.Hex())
buf.WriteString(`"}}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
return err
}
t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
ejvw.writeExtendedSingleValue("date", s, false)
} else {
ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDecimal128(d primitive.Decimal128) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
if ejvw.stack[ejvw.frame].mode == mTopLevel {
ejvw.buf = append(ejvw.buf, '{')
return ejvw, nil
}
if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '{')
ejvw.push(mDocument)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
return err
}
s := formatDouble(f)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberDouble", s, true)
} else {
switch s {
case "Infinity":
fallthrough
case "-Infinity":
fallthrough
case "NaN":
s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
return err
}
s := strconv.FormatInt(int64(i), 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberInt", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
return err
}
s := strconv.FormatInt(i, 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberLong", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("code", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMaxKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("maxKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMinKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("minKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteNull() error {
if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte("null")...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteObjectID(oid primitive.ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$regularExpression":{"pattern":`)
writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
buf.WriteString(`,"options":"`)
buf.WriteString(sortStringAlphebeticAscending(options))
buf.WriteString(`"}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteString(s string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$timestamp":{"t":`)
buf.WriteString(strconv.FormatUint(uint64(t), 10))
buf.WriteString(`,"i":`)
buf.WriteString(strconv.FormatUint(uint64(i), 10))
buf.WriteString(`}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteUndefined() error {
if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("undefined", "true", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`"%s":`, key))...)
ejvw.push(mElement)
default:
return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
default:
return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
}
// close the document
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = '}'
} else {
ejvw.buf = append(ejvw.buf, '}')
}
switch ejvw.stack[ejvw.frame].mode {
case mCodeWithScope:
ejvw.buf = append(ejvw.buf, '}')
fallthrough
case mDocument:
ejvw.buf = append(ejvw.buf, ',')
case mTopLevel:
if ejvw.w != nil {
if _, err := ejvw.w.Write(ejvw.buf); err != nil {
return err
}
ejvw.buf = ejvw.buf[:0]
}
}
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
ejvw.push(mValue)
default:
return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
// close the array
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = ']'
} else {
ejvw.buf = append(ejvw.buf, ']')
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
default:
return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
}
return nil
}
func formatDouble(f float64) string {
var s string
if math.IsInf(f, 1) {
s = "Infinity"
} else if math.IsInf(f, -1) {
s = "-Infinity"
} else if math.IsNaN(f) {
s = "NaN"
} else {
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
var hexChars = "0123456789abcdef"
func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
buf.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
i++
continue
}
if start < i {
buf.WriteString(s[start:i])
}
switch b {
case '\\', '"':
buf.WriteByte('\\')
buf.WriteByte(b)
case '\n':
buf.WriteByte('\\')
buf.WriteByte('n')
case '\r':
buf.WriteByte('\\')
buf.WriteByte('r')
case '\t':
buf.WriteByte('\\')
buf.WriteByte('t')
case '\b':
buf.WriteByte('\\')
buf.WriteByte('b')
case '\f':
buf.WriteByte('\\')
buf.WriteByte('f')
default:
// This encodes bytes < 0x20 except for \t, \n and \r.
// If escapeHTML is set, it also escapes <, >, and &
// because they can lead to security holes when
// user-controlled strings are rendered into JSON
// and served to some browsers.
buf.WriteString(`\u00`)
buf.WriteByte(hexChars[b>>4])
buf.WriteByte(hexChars[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\u202`)
buf.WriteByte(hexChars[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
buf.WriteString(s[start:])
}
buf.WriteByte('"')
}
type sortableString []rune
func (ss sortableString) Len() int {
return len(ss)
}
func (ss sortableString) Less(i, j int) bool {
return ss[i] < ss[j]
}
func (ss sortableString) Swap(i, j int) {
oldI := ss[i]
ss[i] = ss[j]
ss[j] = oldI
}
func sortStringAlphebeticAscending(s string) string {
ss := sortableString([]rune(s))
sort.Sort(ss)
return string([]rune(ss))
}
+439
View File
@@ -0,0 +1,439 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"strconv"
"strings"
"unicode"
)
type jsonTokenType byte
const (
jttBeginObject jsonTokenType = iota
jttEndObject
jttBeginArray
jttEndArray
jttColon
jttComma
jttInt32
jttInt64
jttDouble
jttString
jttBool
jttNull
jttEOF
)
type jsonToken struct {
t jsonTokenType
v interface{}
p int
}
type jsonScanner struct {
r io.Reader
buf []byte
pos int
lastReadErr error
}
// nextToken returns the next JSON token if one exists. A token is a character
// of the JSON grammar, a number, a string, or a literal.
func (js *jsonScanner) nextToken() (*jsonToken, error) {
c, err := js.readNextByte()
// keep reading until a non-space is encountered (break on read error or EOF)
for isWhiteSpace(c) && err == nil {
c, err = js.readNextByte()
}
if err == io.EOF {
return &jsonToken{t: jttEOF}, nil
} else if err != nil {
return nil, err
}
// switch on the character
switch c {
case '{':
return &jsonToken{t: jttBeginObject, v: byte('{'), p: js.pos - 1}, nil
case '}':
return &jsonToken{t: jttEndObject, v: byte('}'), p: js.pos - 1}, nil
case '[':
return &jsonToken{t: jttBeginArray, v: byte('['), p: js.pos - 1}, nil
case ']':
return &jsonToken{t: jttEndArray, v: byte(']'), p: js.pos - 1}, nil
case ':':
return &jsonToken{t: jttColon, v: byte(':'), p: js.pos - 1}, nil
case ',':
return &jsonToken{t: jttComma, v: byte(','), p: js.pos - 1}, nil
case '"': // RFC-8259 only allows for double quotes (") not single (')
return js.scanString()
default:
// check if it's a number
if c == '-' || isDigit(c) {
return js.scanNumber(c)
} else if c == 't' || c == 'f' || c == 'n' {
// maybe a literal
return js.scanLiteral(c)
} else {
return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c)
}
}
}
// readNextByte attempts to read the next byte from the buffer. If the buffer
// has been exhausted, this function calls readIntoBuf, thus refilling the
// buffer and resetting the read position to 0
func (js *jsonScanner) readNextByte() (byte, error) {
if js.pos >= len(js.buf) {
err := js.readIntoBuf()
if err != nil {
return 0, err
}
}
b := js.buf[js.pos]
js.pos++
return b, nil
}
// readNNextBytes reads n bytes into dst, starting at offset
func (js *jsonScanner) readNNextBytes(dst []byte, n, offset int) error {
var err error
for i := 0; i < n; i++ {
dst[i+offset], err = js.readNextByte()
if err != nil {
return err
}
}
return nil
}
// readIntoBuf reads up to 512 bytes from the scanner's io.Reader into the buffer
func (js *jsonScanner) readIntoBuf() error {
if js.lastReadErr != nil {
js.buf = js.buf[:0]
js.pos = 0
return js.lastReadErr
}
if cap(js.buf) == 0 {
js.buf = make([]byte, 0, 512)
}
n, err := js.r.Read(js.buf[:cap(js.buf)])
if err != nil {
js.lastReadErr = err
if n > 0 {
err = nil
}
}
js.buf = js.buf[:n]
js.pos = 0
return err
}
func isWhiteSpace(c byte) bool {
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
}
func isDigit(c byte) bool {
return unicode.IsDigit(rune(c))
}
func isValueTerminator(c byte) bool {
return c == ',' || c == '}' || c == ']' || isWhiteSpace(c)
}
// scanString reads from an opening '"' to a closing '"' and handles escaped characters
func (js *jsonScanner) scanString() (*jsonToken, error) {
var b bytes.Buffer
var c byte
var err error
p := js.pos - 1
for {
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
switch c {
case '\\':
c, err = js.readNextByte()
switch c {
case '"', '\\', '/', '\'':
b.WriteByte(c)
case 'b':
b.WriteByte('\b')
case 'f':
b.WriteByte('\f')
case 'n':
b.WriteByte('\n')
case 'r':
b.WriteByte('\r')
case 't':
b.WriteByte('\t')
case 'u':
us := make([]byte, 4)
err = js.readNNextBytes(us, 4, 0)
if err != nil {
return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
}
s := fmt.Sprintf(`\u%s`, us)
s, err = strconv.Unquote(strings.Replace(strconv.Quote(s), `\\u`, `\u`, 1))
if err != nil {
return nil, err
}
b.WriteString(s)
default:
return nil, fmt.Errorf("invalid escape sequence in JSON string '\\%c'", c)
}
case '"':
return &jsonToken{t: jttString, v: b.String(), p: p}, nil
default:
b.WriteByte(c)
}
}
}
// scanLiteral reads an unquoted sequence of characters and determines if it is one of
// three valid JSON literals (true, false, null); if so, it returns the appropriate
// jsonToken; otherwise, it returns an error
func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) {
p := js.pos - 1
lit := make([]byte, 4)
lit[0] = first
err := js.readNNextBytes(lit, 3, 1)
if err != nil {
return nil, err
}
c5, err := js.readNextByte()
if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || err == io.EOF) {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: true, p: p}, nil
} else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || err == io.EOF) {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttNull, v: nil, p: p}, nil
} else if bytes.Equal([]byte("fals"), lit) {
if c5 == 'e' {
c5, err = js.readNextByte()
if isValueTerminator(c5) || err == io.EOF {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: false, p: p}, nil
}
}
}
return nil, fmt.Errorf("invalid JSON literal. Position: %d, literal: %s", p, lit)
}
type numberScanState byte
const (
nssSawLeadingMinus numberScanState = iota
nssSawLeadingZero
nssSawIntegerDigits
nssSawDecimalPoint
nssSawFractionDigits
nssSawExponentLetter
nssSawExponentSign
nssSawExponentDigits
nssDone
nssInvalid
)
// scanNumber reads a JSON number (according to RFC-8259)
func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) {
var b bytes.Buffer
var s numberScanState
var c byte
var err error
t := jttInt64 // assume it's an int64 until the type can be determined
start := js.pos - 1
b.WriteByte(first)
switch first {
case '-':
s = nssSawLeadingMinus
case '0':
s = nssSawLeadingZero
default:
s = nssSawIntegerDigits
}
for {
c, err = js.readNextByte()
if err != nil && err != io.EOF {
return nil, err
}
switch s {
case nssSawLeadingMinus:
switch c {
case '0':
s = nssSawLeadingZero
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawIntegerDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawLeadingZero:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else {
s = nssInvalid
}
}
case nssSawIntegerDigits:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else if isDigit(c) {
s = nssSawIntegerDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawDecimalPoint:
t = jttDouble
if isDigit(c) {
s = nssSawFractionDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawFractionDigits:
switch c {
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else if isDigit(c) {
s = nssSawFractionDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawExponentLetter:
t = jttDouble
switch c {
case '+', '-':
s = nssSawExponentSign
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawExponentSign:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawExponentDigits:
switch c {
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
}
switch s {
case nssInvalid:
return nil, fmt.Errorf("invalid JSON number. Position: %d", start)
case nssDone:
js.pos = int(math.Max(0, float64(js.pos-1)))
if t != jttDouble {
v, err := strconv.ParseInt(b.String(), 10, 64)
if err == nil {
if v < math.MinInt32 || v > math.MaxInt32 {
return &jsonToken{t: jttInt64, v: v, p: start}, nil
}
return &jsonToken{t: jttInt32, v: int32(v), p: start}, nil
}
}
v, err := strconv.ParseFloat(b.String(), 64)
if err != nil {
return nil, err
}
return &jsonToken{t: jttDouble, v: v, p: start}, nil
}
}
}
+108
View File
@@ -0,0 +1,108 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
)
type mode int
const (
_ mode = iota
mTopLevel
mDocument
mArray
mValue
mElement
mCodeWithScope
mSpacer
)
func (m mode) String() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "DocumentMode"
case mArray:
str = "ArrayMode"
case mValue:
str = "ValueMode"
case mElement:
str = "ElementMode"
case mCodeWithScope:
str = "CodeWithScopeMode"
case mSpacer:
str = "CodeWithScopeSpacerFrame"
default:
str = "UnknownMode"
}
return str
}
func (m mode) TypeString() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "Document"
case mArray:
str = "Array"
case mValue:
str = "Value"
case mElement:
str = "Element"
case mCodeWithScope:
str = "CodeWithScope"
case mSpacer:
str = "CodeWithScopeSpacer"
default:
str = "Unknown"
}
return str
}
// TransitionError is an error returned when an invalid progressing a
// ValueReader or ValueWriter state machine occurs.
// If read is false, the error is for writing
type TransitionError struct {
name string
parent mode
current mode
destination mode
modes []mode
action string
}
func (te TransitionError) Error() string {
errString := fmt.Sprintf("%s can only %s", te.name, te.action)
if te.destination != mode(0) {
errString = fmt.Sprintf("%s a %s", errString, te.destination.TypeString())
}
errString = fmt.Sprintf("%s while positioned on a", errString)
for ind, m := range te.modes {
if ind != 0 && len(te.modes) > 2 {
errString = fmt.Sprintf("%s,", errString)
}
if ind == len(te.modes)-1 && len(te.modes) > 1 {
errString = fmt.Sprintf("%s or", errString)
}
errString = fmt.Sprintf("%s %s", errString, m.TypeString())
}
errString = fmt.Sprintf("%s but is positioned on a %s", errString, te.current.TypeString())
if te.parent != mode(0) {
errString = fmt.Sprintf("%s with parent %s", errString, te.parent.TypeString())
}
return errString
}
+63
View File
@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ArrayReader is implemented by types that allow reading values from a BSON
// array.
type ArrayReader interface {
ReadValue() (ValueReader, error)
}
// DocumentReader is implemented by types that allow reading elements from a
// BSON document.
type DocumentReader interface {
ReadElement() (string, ValueReader, error)
}
// ValueReader is a generic interface used to read values from BSON. This type
// is implemented by several types with different underlying representations of
// BSON, such as a bson.Document, raw BSON bytes, or extended JSON.
type ValueReader interface {
Type() bsontype.Type
Skip() error
ReadArray() (ArrayReader, error)
ReadBinary() (b []byte, btype byte, err error)
ReadBoolean() (bool, error)
ReadDocument() (DocumentReader, error)
ReadCodeWithScope() (code string, dr DocumentReader, err error)
ReadDBPointer() (ns string, oid primitive.ObjectID, err error)
ReadDateTime() (int64, error)
ReadDecimal128() (primitive.Decimal128, error)
ReadDouble() (float64, error)
ReadInt32() (int32, error)
ReadInt64() (int64, error)
ReadJavascript() (code string, err error)
ReadMaxKey() error
ReadMinKey() error
ReadNull() error
ReadObjectID() (primitive.ObjectID, error)
ReadRegex() (pattern, options string, err error)
ReadString() (string, error)
ReadSymbol() (symbol string, err error)
ReadTimestamp() (t, i uint32, err error)
ReadUndefined() error
}
// BytesReader is a generic interface used to read BSON bytes from a
// ValueReader. This imterface is meant to be a superset of ValueReader, so that
// types that implement ValueReader may also implement this interface.
//
// The bytes of the value will be appended to dst.
type BytesReader interface {
ReadValueBytes(dst []byte) (bsontype.Type, []byte, error)
}
+882
View File
@@ -0,0 +1,882 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"sync"
"unicode"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var _ ValueReader = (*valueReader)(nil)
var vrPool = sync.Pool{
New: func() interface{} {
return new(valueReader)
},
}
// BSONValueReaderPool is a pool for ValueReaders that read BSON.
type BSONValueReaderPool struct {
pool sync.Pool
}
// NewBSONValueReaderPool instantiates a new BSONValueReaderPool.
func NewBSONValueReaderPool() *BSONValueReaderPool {
return &BSONValueReaderPool{
pool: sync.Pool{
New: func() interface{} {
return new(valueReader)
},
},
}
}
// Get retrieves a ValueReader from the pool and uses src as the underlying BSON.
func (bvrp *BSONValueReaderPool) Get(src []byte) ValueReader {
vr := bvrp.pool.Get().(*valueReader)
vr.reset(src)
return vr
}
// Put inserts a ValueReader into the pool. If the ValueReader is not a BSON ValueReader nothing
// is inserted into the pool and ok will be false.
func (bvrp *BSONValueReaderPool) Put(vr ValueReader) (ok bool) {
bvr, ok := vr.(*valueReader)
if !ok {
return false
}
bvr.reset(nil)
bvrp.pool.Put(bvr)
return true
}
// ErrEOA is the error returned when the end of a BSON array has been reached.
var ErrEOA = errors.New("end of array")
// ErrEOD is the error returned when the end of a BSON document has been reached.
var ErrEOD = errors.New("end of document")
type vrState struct {
mode mode
vType bsontype.Type
end int64
}
// valueReader is for reading BSON values.
type valueReader struct {
offset int64
d []byte
stack []vrState
frame int64
}
// NewBSONDocumentReader returns a ValueReader using b for the underlying BSON
// representation. Parameter b must be a BSON Document.
//
// TODO(skriptble): There's a lack of symmetry between the reader and writer, since the reader takes
// a []byte while the writer takes an io.Writer. We should have two versions of each, one that takes
// a []byte and one that takes an io.Reader or io.Writer. The []byte version will need to return a
// thing that can return the finished []byte since it might be reallocated when appended to.
func NewBSONDocumentReader(b []byte) ValueReader {
return newValueReader(b)
}
// NewBSONValueReader returns a ValueReader that starts in the Value mode instead of in top
// level document mode. This enables the creation of a ValueReader for a single BSON value.
func NewBSONValueReader(t bsontype.Type, val []byte) ValueReader {
stack := make([]vrState, 1, 5)
stack[0] = vrState{
mode: mValue,
vType: t,
}
return &valueReader{
d: val,
stack: stack,
}
}
func newValueReader(b []byte) *valueReader {
stack := make([]vrState, 1, 5)
stack[0] = vrState{
mode: mTopLevel,
}
return &valueReader{
d: b,
stack: stack,
}
}
func (vr *valueReader) reset(b []byte) {
if vr.stack == nil {
vr.stack = make([]vrState, 1, 5)
}
vr.stack = vr.stack[:1]
vr.stack[0] = vrState{mode: mTopLevel}
vr.d = b
vr.offset = 0
vr.frame = 0
}
func (vr *valueReader) advanceFrame() {
if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack
length := len(vr.stack)
if length+1 >= cap(vr.stack) {
// double it
buf := make([]vrState, 2*cap(vr.stack)+1)
copy(buf, vr.stack)
vr.stack = buf
}
vr.stack = vr.stack[:length+1]
}
vr.frame++
// Clean the stack
vr.stack[vr.frame].mode = 0
vr.stack[vr.frame].vType = 0
vr.stack[vr.frame].end = 0
}
func (vr *valueReader) pushDocument() error {
vr.advanceFrame()
vr.stack[vr.frame].mode = mDocument
size, err := vr.readLength()
if err != nil {
return err
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return nil
}
func (vr *valueReader) pushArray() error {
vr.advanceFrame()
vr.stack[vr.frame].mode = mArray
size, err := vr.readLength()
if err != nil {
return err
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return nil
}
func (vr *valueReader) pushElement(t bsontype.Type) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mElement
vr.stack[vr.frame].vType = t
}
func (vr *valueReader) pushValue(t bsontype.Type) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mValue
vr.stack[vr.frame].vType = t
}
func (vr *valueReader) pushCodeWithScope() (int64, error) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mCodeWithScope
size, err := vr.readLength()
if err != nil {
return 0, err
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return int64(size), nil
}
func (vr *valueReader) pop() {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
vr.frame--
case mDocument, mArray, mCodeWithScope:
vr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
}
func (vr *valueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vr.stack[vr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if vr.frame != 0 {
te.parent = vr.stack[vr.frame-1].mode
}
return te
}
func (vr *valueReader) typeError(t bsontype.Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", vr.stack[vr.frame].vType, t)
}
func (vr *valueReader) invalidDocumentLengthError() error {
return fmt.Errorf("document is invalid, end byte is at %d, but null byte found at %d", vr.stack[vr.frame].end, vr.offset)
}
func (vr *valueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string) error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
if vr.stack[vr.frame].vType != t {
return vr.typeError(t)
}
default:
return vr.invalidTransitionErr(destination, callerName, []mode{mElement, mValue})
}
return nil
}
func (vr *valueReader) Type() bsontype.Type {
return vr.stack[vr.frame].vType
}
func (vr *valueReader) nextElementLength() (int32, error) {
var length int32
var err error
switch vr.stack[vr.frame].vType {
case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope:
length, err = vr.peekLength()
case bsontype.Binary:
length, err = vr.peekLength()
length += 4 + 1 // binary length + subtype byte
case bsontype.Boolean:
length = 1
case bsontype.DBPointer:
length, err = vr.peekLength()
length += 4 + 12 // string length + ObjectID length
case bsontype.DateTime, bsontype.Double, bsontype.Int64, bsontype.Timestamp:
length = 8
case bsontype.Decimal128:
length = 16
case bsontype.Int32:
length = 4
case bsontype.JavaScript, bsontype.String, bsontype.Symbol:
length, err = vr.peekLength()
length += 4
case bsontype.MaxKey, bsontype.MinKey, bsontype.Null, bsontype.Undefined:
length = 0
case bsontype.ObjectID:
length = 12
case bsontype.Regex:
regex := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if regex < 0 {
err = io.EOF
break
}
pattern := bytes.IndexByte(vr.d[regex+1:], 0x00)
if pattern < 0 {
err = io.EOF
break
}
length = int32(int64(regex) + 1 + int64(pattern) + 1 - vr.offset)
default:
return 0, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType)
}
return length, err
}
func (vr *valueReader) ReadValueBytes(dst []byte) (bsontype.Type, []byte, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
length, err := vr.peekLength()
if err != nil {
return bsontype.Type(0), nil, err
}
dst, err = vr.appendBytes(dst, length)
if err != nil {
return bsontype.Type(0), nil, err
}
return bsontype.Type(0), dst, nil
case mElement, mValue:
length, err := vr.nextElementLength()
if err != nil {
return bsontype.Type(0), dst, err
}
dst, err = vr.appendBytes(dst, length)
t := vr.stack[vr.frame].vType
vr.pop()
return t, dst, err
default:
return bsontype.Type(0), nil, vr.invalidTransitionErr(0, "ReadValueBytes", []mode{mElement, mValue})
}
}
func (vr *valueReader) Skip() error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
default:
return vr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
length, err := vr.nextElementLength()
if err != nil {
return err
}
err = vr.skipBytes(length)
vr.pop()
return err
}
func (vr *valueReader) ReadArray() (ArrayReader, error) {
if err := vr.ensureElementValue(bsontype.Array, mArray, "ReadArray"); err != nil {
return nil, err
}
err := vr.pushArray()
if err != nil {
return nil, err
}
return vr, nil
}
func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := vr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
length, err := vr.readLength()
if err != nil {
return nil, 0, err
}
btype, err = vr.readByte()
if err != nil {
return nil, 0, err
}
if btype == 0x02 {
length, err = vr.readLength()
if err != nil {
return nil, 0, err
}
}
b, err = vr.readBytes(length)
if err != nil {
return nil, 0, err
}
vr.pop()
return b, btype, nil
}
func (vr *valueReader) ReadBoolean() (bool, error) {
if err := vr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil {
return false, err
}
b, err := vr.readByte()
if err != nil {
return false, err
}
if b > 1 {
return false, fmt.Errorf("invalid byte for boolean, %b", b)
}
vr.pop()
return b == 1, nil
}
func (vr *valueReader) ReadDocument() (DocumentReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
// read size
size, err := vr.readLength()
if err != nil {
return nil, err
}
if int(size) != len(vr.d) {
return nil, fmt.Errorf("invalid document length")
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return vr, nil
case mElement, mValue:
if vr.stack[vr.frame].vType != bsontype.EmbeddedDocument {
return nil, vr.typeError(bsontype.EmbeddedDocument)
}
default:
return nil, vr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
err := vr.pushDocument()
if err != nil {
return nil, err
}
return vr, nil
}
func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err := vr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
totalLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
strLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
strBytes, err := vr.readBytes(strLength)
if err != nil {
return "", nil, err
}
code = string(strBytes[:len(strBytes)-1])
size, err := vr.pushCodeWithScope()
if err != nil {
return "", nil, err
}
// The total length should equal:
// 4 (total length) + strLength + 4 (the length of str itself) + (document length)
componentsLength := int64(4+strLength+4) + size
if int64(totalLength) != componentsLength {
return "", nil, fmt.Errorf(
"length of CodeWithScope does not match lengths of components; total: %d; components: %d",
totalLength, componentsLength,
)
}
return code, vr, nil
}
func (vr *valueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) {
if err := vr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil {
return "", oid, err
}
ns, err = vr.readString()
if err != nil {
return "", oid, err
}
oidbytes, err := vr.readBytes(12)
if err != nil {
return "", oid, err
}
copy(oid[:], oidbytes)
vr.pop()
return ns, oid, nil
}
func (vr *valueReader) ReadDateTime() (int64, error) {
if err := vr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
i, err := vr.readi64()
if err != nil {
return 0, err
}
vr.pop()
return i, nil
}
func (vr *valueReader) ReadDecimal128() (primitive.Decimal128, error) {
if err := vr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil {
return primitive.Decimal128{}, err
}
b, err := vr.readBytes(16)
if err != nil {
return primitive.Decimal128{}, err
}
l := binary.LittleEndian.Uint64(b[0:8])
h := binary.LittleEndian.Uint64(b[8:16])
vr.pop()
return primitive.NewDecimal128(h, l), nil
}
func (vr *valueReader) ReadDouble() (float64, error) {
if err := vr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil {
return 0, err
}
u, err := vr.readu64()
if err != nil {
return 0, err
}
vr.pop()
return math.Float64frombits(u), nil
}
func (vr *valueReader) ReadInt32() (int32, error) {
if err := vr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil {
return 0, err
}
vr.pop()
return vr.readi32()
}
func (vr *valueReader) ReadInt64() (int64, error) {
if err := vr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil {
return 0, err
}
vr.pop()
return vr.readi64()
}
func (vr *valueReader) ReadJavascript() (code string, err error) {
if err := vr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
vr.pop()
return vr.readString()
}
func (vr *valueReader) ReadMaxKey() error {
if err := vr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadMinKey() error {
if err := vr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadNull() error {
if err := vr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadObjectID() (primitive.ObjectID, error) {
if err := vr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil {
return primitive.ObjectID{}, err
}
oidbytes, err := vr.readBytes(12)
if err != nil {
return primitive.ObjectID{}, err
}
var oid primitive.ObjectID
copy(oid[:], oidbytes)
vr.pop()
return oid, nil
}
func (vr *valueReader) ReadRegex() (string, string, error) {
if err := vr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil {
return "", "", err
}
pattern, err := vr.readCString()
if err != nil {
return "", "", err
}
options, err := vr.readCString()
if err != nil {
return "", "", err
}
vr.pop()
return pattern, options, nil
}
func (vr *valueReader) ReadString() (string, error) {
if err := vr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil {
return "", err
}
vr.pop()
return vr.readString()
}
func (vr *valueReader) ReadSymbol() (symbol string, err error) {
if err := vr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil {
return "", err
}
vr.pop()
return vr.readString()
}
func (vr *valueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err := vr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
i, err = vr.readu32()
if err != nil {
return 0, 0, err
}
t, err = vr.readu32()
if err != nil {
return 0, 0, err
}
vr.pop()
return t, i, nil
}
func (vr *valueReader) ReadUndefined() error {
if err := vr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadElement() (string, ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, vr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
t, err := vr.readByte()
if err != nil {
return "", nil, err
}
if t == 0 {
if vr.offset != vr.stack[vr.frame].end {
return "", nil, vr.invalidDocumentLengthError()
}
vr.pop()
return "", nil, ErrEOD
}
name, err := vr.readCString()
if err != nil {
return "", nil, err
}
vr.pushElement(bsontype.Type(t))
return name, vr, nil
}
func (vr *valueReader) ReadValue() (ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mArray:
default:
return nil, vr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := vr.readByte()
if err != nil {
return nil, err
}
if t == 0 {
if vr.offset != vr.stack[vr.frame].end {
return nil, vr.invalidDocumentLengthError()
}
vr.pop()
return nil, ErrEOA
}
_, err = vr.readCString()
if err != nil {
return nil, err
}
vr.pushValue(bsontype.Type(t))
return vr, nil
}
func (vr *valueReader) readBytes(length int32) ([]byte, error) {
if length < 0 {
return nil, fmt.Errorf("invalid length: %d", length)
}
if vr.offset+int64(length) > int64(len(vr.d)) {
return nil, io.EOF
}
start := vr.offset
vr.offset += int64(length)
return vr.d[start : start+int64(length)], nil
}
func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) {
if vr.offset+int64(length) > int64(len(vr.d)) {
return nil, io.EOF
}
start := vr.offset
vr.offset += int64(length)
return append(dst, vr.d[start:start+int64(length)]...), nil
}
func (vr *valueReader) skipBytes(length int32) error {
if vr.offset+int64(length) > int64(len(vr.d)) {
return io.EOF
}
vr.offset += int64(length)
return nil
}
func (vr *valueReader) readByte() (byte, error) {
if vr.offset+1 > int64(len(vr.d)) {
return 0x0, io.EOF
}
vr.offset++
return vr.d[vr.offset-1], nil
}
func (vr *valueReader) readCString() (string, error) {
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if idx < 0 {
return "", io.EOF
}
start := vr.offset
// idx does not include the null byte
vr.offset += int64(idx) + 1
return string(vr.d[start : start+int64(idx)]), nil
}
func (vr *valueReader) skipCString() error {
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if idx < 0 {
return io.EOF
}
// idx does not include the null byte
vr.offset += int64(idx) + 1
return nil
}
func (vr *valueReader) readString() (string, error) {
length, err := vr.readLength()
if err != nil {
return "", err
}
if int64(length)+vr.offset > int64(len(vr.d)) {
return "", io.EOF
}
if length <= 0 {
return "", fmt.Errorf("invalid string length: %d", length)
}
if vr.d[vr.offset+int64(length)-1] != 0x00 {
return "", fmt.Errorf("string does not end with null byte, but with %v", vr.d[vr.offset+int64(length)-1])
}
start := vr.offset
vr.offset += int64(length)
if length == 2 {
asciiByte := vr.d[start]
if asciiByte > unicode.MaxASCII {
return "", fmt.Errorf("invalid ascii byte")
}
}
return string(vr.d[start : start+int64(length)-1]), nil
}
func (vr *valueReader) peekLength() (int32, error) {
if vr.offset+4 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
}
func (vr *valueReader) readLength() (int32, error) { return vr.readi32() }
func (vr *valueReader) readi32() (int32, error) {
if vr.offset+4 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 4
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
}
func (vr *valueReader) readu32() (uint32, error) {
if vr.offset+4 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 4
return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil
}
func (vr *valueReader) readi64() (int64, error) {
if vr.offset+8 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 8
return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 |
int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil
}
func (vr *valueReader) readu64() (uint64, error) {
if vr.offset+8 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 8
return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 |
uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil
}
+589
View File
@@ -0,0 +1,589 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"errors"
"fmt"
"io"
"math"
"strconv"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var _ ValueWriter = (*valueWriter)(nil)
var vwPool = sync.Pool{
New: func() interface{} {
return new(valueWriter)
},
}
// BSONValueWriterPool is a pool for BSON ValueWriters.
type BSONValueWriterPool struct {
pool sync.Pool
}
// NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON.
func NewBSONValueWriterPool() *BSONValueWriterPool {
return &BSONValueWriterPool{
pool: sync.Pool{
New: func() interface{} {
return new(valueWriter)
},
},
}
}
// Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination.
func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter {
vw := bvwp.pool.Get().(*valueWriter)
if writer, ok := w.(*SliceWriter); ok {
vw.reset(*writer)
vw.w = writer
return vw
}
vw.buf = vw.buf[:0]
vw.w = w
return vw
}
// Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing
// happens and ok will be false.
func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
bvw, ok := vw.(*valueWriter)
if !ok {
return false
}
if _, ok := bvw.w.(*SliceWriter); ok {
bvw.buf = nil
}
bvw.w = nil
bvwp.pool.Put(bvw)
return true
}
// This is here so that during testing we can change it and not require
// allocating a 4GB slice.
var maxSize = math.MaxInt32
var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer")
type errMaxDocumentSizeExceeded struct {
size int64
}
func (mdse errMaxDocumentSizeExceeded) Error() string {
return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size)
}
type vwMode int
const (
_ vwMode = iota
vwTopLevel
vwDocument
vwArray
vwValue
vwElement
vwCodeWithScope
)
func (vm vwMode) String() string {
var str string
switch vm {
case vwTopLevel:
str = "TopLevel"
case vwDocument:
str = "DocumentMode"
case vwArray:
str = "ArrayMode"
case vwValue:
str = "ValueMode"
case vwElement:
str = "ElementMode"
case vwCodeWithScope:
str = "CodeWithScopeMode"
default:
str = "UnknownMode"
}
return str
}
type vwState struct {
mode mode
key string
arrkey int
start int32
}
type valueWriter struct {
w io.Writer
buf []byte
stack []vwState
frame int64
}
func (vw *valueWriter) advanceFrame() {
if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack
length := len(vw.stack)
if length+1 >= cap(vw.stack) {
// double it
buf := make([]vwState, 2*cap(vw.stack)+1)
copy(buf, vw.stack)
vw.stack = buf
}
vw.stack = vw.stack[:length+1]
}
vw.frame++
}
func (vw *valueWriter) push(m mode) {
vw.advanceFrame()
// Clean the stack
vw.stack[vw.frame].mode = m
vw.stack[vw.frame].key = ""
vw.stack[vw.frame].arrkey = 0
vw.stack[vw.frame].start = 0
vw.stack[vw.frame].mode = m
switch m {
case mDocument, mArray, mCodeWithScope:
vw.reserveLength()
}
}
func (vw *valueWriter) reserveLength() {
vw.stack[vw.frame].start = int32(len(vw.buf))
vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00)
}
func (vw *valueWriter) pop() {
switch vw.stack[vw.frame].mode {
case mElement, mValue:
vw.frame--
case mDocument, mArray, mCodeWithScope:
vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
// NewBSONValueWriter creates a ValueWriter that writes BSON to w.
//
// This ValueWriter will only write entire documents to the io.Writer and it
// will buffer the document as it is built.
func NewBSONValueWriter(w io.Writer) (ValueWriter, error) {
if w == nil {
return nil, errNilWriter
}
return newValueWriter(w), nil
}
func newValueWriter(w io.Writer) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.w = w
vw.stack = stack
return vw
}
func newValueWriterFromSlice(buf []byte) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.stack = stack
vw.buf = buf
return vw
}
func (vw *valueWriter) reset(buf []byte) {
if vw.stack == nil {
vw.stack = make([]vwState, 1, 5)
}
vw.stack = vw.stack[:1]
vw.stack[0] = vwState{mode: mTopLevel}
vw.buf = buf
vw.frame = 0
vw.w = nil
}
func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vw.stack[vw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if vw.frame != 0 {
te.parent = vw.stack[vw.frame-1].mode
}
return te
}
func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
switch vw.stack[vw.frame].mode {
case mElement:
vw.buf = bsoncore.AppendHeader(vw.buf, t, vw.stack[vw.frame].key)
case mValue:
// TODO: Do this with a cache of the first 1000 or so array keys.
vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return vw.invalidTransitionError(destination, callerName, modes)
}
return nil
}
func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error {
if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil {
return err
}
vw.buf = append(vw.buf, b...)
vw.pop()
return nil
}
func (vw *valueWriter) WriteArray() (ArrayWriter, error) {
if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil {
return nil, err
}
vw.push(mArray)
return vw, nil
}
func (vw *valueWriter) WriteBinary(b []byte) error {
return vw.WriteBinaryWithSubtype(b, 0x00)
}
func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
vw.buf = bsoncore.AppendBinary(vw.buf, btype, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteBoolean(b bool) error {
if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil {
return err
}
vw.buf = bsoncore.AppendBoolean(vw.buf, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
// CodeWithScope is a different than other types because we need an extra
// frame on the stack. In the EndDocument code, we write the document
// length, pop, write the code with scope length, and pop. To simplify the
// pop code, we push a spacer frame that we'll always jump over.
vw.push(mCodeWithScope)
vw.buf = bsoncore.AppendString(vw.buf, code)
vw.push(mSpacer)
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil {
return err
}
vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDateTime(dt int64) error {
if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil {
return err
}
vw.buf = bsoncore.AppendDateTime(vw.buf, dt)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error {
if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil {
return err
}
vw.buf = bsoncore.AppendDecimal128(vw.buf, d128)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDouble(f float64) error {
if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil {
return err
}
vw.buf = bsoncore.AppendDouble(vw.buf, f)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt32(i32 int32) error {
if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt32(vw.buf, i32)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt64(i64 int64) error {
if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt64(vw.buf, i64)
vw.pop()
return nil
}
func (vw *valueWriter) WriteJavascript(code string) error {
if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil {
return err
}
vw.buf = bsoncore.AppendJavaScript(vw.buf, code)
vw.pop()
return nil
}
func (vw *valueWriter) WriteMaxKey() error {
if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteMinKey() error {
if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteNull() error {
if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error {
if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil {
return err
}
vw.buf = bsoncore.AppendObjectID(vw.buf, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteRegex(pattern string, options string) error {
if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil {
return err
}
vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options))
vw.pop()
return nil
}
func (vw *valueWriter) WriteString(s string) error {
if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil {
return err
}
vw.buf = bsoncore.AppendString(vw.buf, s)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocument() (DocumentWriter, error) {
if vw.stack[vw.frame].mode == mTopLevel {
vw.reserveLength()
return vw, nil
}
if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteSymbol(symbol string) error {
if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil {
return err
}
vw.buf = bsoncore.AppendSymbol(vw.buf, symbol)
vw.pop()
return nil
}
func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil {
return err
}
vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i)
vw.pop()
return nil
}
func (vw *valueWriter) WriteUndefined() error {
if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument})
}
vw.push(mElement)
vw.stack[vw.frame].key = key
return vw, nil
}
func (vw *valueWriter) WriteDocumentEnd() error {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
if vw.stack[vw.frame].mode == mTopLevel {
if vw.w != nil {
if sw, ok := vw.w.(*SliceWriter); ok {
*sw = vw.buf
} else {
_, err = vw.w.Write(vw.buf)
if err != nil {
return err
}
// reset buffer
vw.buf = vw.buf[:0]
}
}
}
vw.pop()
if vw.stack[vw.frame].mode == mCodeWithScope {
// We ignore the error here because of the gaurantee of writeLength.
// See the docs for writeLength for more info.
_ = vw.writeLength()
vw.pop()
}
return nil
}
func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
if vw.stack[vw.frame].mode != mArray {
return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})
}
arrkey := vw.stack[vw.frame].arrkey
vw.stack[vw.frame].arrkey++
vw.push(mValue)
vw.stack[vw.frame].arrkey = arrkey
return vw, nil
}
func (vw *valueWriter) WriteArrayEnd() error {
if vw.stack[vw.frame].mode != mArray {
return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
vw.pop()
return nil
}
// NOTE: We assume that if we call writeLength more than once the same function
// within the same function without altering the vw.buf that this method will
// not return an error. If this changes ensure that the following methods are
// updated:
//
// - WriteDocumentEnd
func (vw *valueWriter) writeLength() error {
length := len(vw.buf)
if length > maxSize {
return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
}
length = length - int(vw.stack[vw.frame].start)
start := vw.stack[vw.frame].start
vw.buf[start+0] = byte(length)
vw.buf[start+1] = byte(length >> 8)
vw.buf[start+2] = byte(length >> 16)
vw.buf[start+3] = byte(length >> 24)
return nil
}
+96
View File
@@ -0,0 +1,96 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ArrayWriter is the interface used to create a BSON or BSON adjacent array.
// Callers must ensure they call WriteArrayEnd when they have finished creating
// the array.
type ArrayWriter interface {
WriteArrayElement() (ValueWriter, error)
WriteArrayEnd() error
}
// DocumentWriter is the interface used to create a BSON or BSON adjacent
// document. Callers must ensure they call WriteDocumentEnd when they have
// finished creating the document.
type DocumentWriter interface {
WriteDocumentElement(string) (ValueWriter, error)
WriteDocumentEnd() error
}
// ValueWriter is the interface used to write BSON values. Implementations of
// this interface handle creating BSON or BSON adjacent representations of the
// values.
type ValueWriter interface {
WriteArray() (ArrayWriter, error)
WriteBinary(b []byte) error
WriteBinaryWithSubtype(b []byte, btype byte) error
WriteBoolean(bool) error
WriteCodeWithScope(code string) (DocumentWriter, error)
WriteDBPointer(ns string, oid primitive.ObjectID) error
WriteDateTime(dt int64) error
WriteDecimal128(primitive.Decimal128) error
WriteDouble(float64) error
WriteInt32(int32) error
WriteInt64(int64) error
WriteJavascript(code string) error
WriteMaxKey() error
WriteMinKey() error
WriteNull() error
WriteObjectID(primitive.ObjectID) error
WriteRegex(pattern, options string) error
WriteString(string) error
WriteDocument() (DocumentWriter, error)
WriteSymbol(symbol string) error
WriteTimestamp(t, i uint32) error
WriteUndefined() error
}
// BytesWriter is the interface used to write BSON bytes to a ValueWriter.
// This interface is meant to be a superset of ValueWriter, so that types that
// implement ValueWriter may also implement this interface.
type BytesWriter interface {
WriteValueBytes(t bsontype.Type, b []byte) error
}
// SliceWriter allows a pointer to a slice of bytes to be used as an io.Writer.
type SliceWriter []byte
func (sw *SliceWriter) Write(p []byte) (int, error) {
written := len(p)
*sw = append(*sw, p...)
return written, nil
}
type writer []byte
func (w *writer) Write(p []byte) (int, error) {
index := len(*w)
return w.WriteAt(p, int64(index))
}
func (w *writer) WriteAt(p []byte, off int64) (int, error) {
newend := off + int64(len(p))
if newend < int64(len(*w)) {
newend = int64(len(*w))
}
if newend > int64(cap(*w)) {
buf := make([]byte, int64(2*cap(*w))+newend)
copy(buf, *w)
*w = buf
}
*w = []byte(*w)[:newend]
copy([]byte(*w)[off:], p)
return len(p), nil
}
+87
View File
@@ -0,0 +1,87 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsontype is a utility package that contains types for each BSON type and the
// a stringifier for the Type to enable easier debugging when working with BSON.
package bsontype // import "go.mongodb.org/mongo-driver/bson/bsontype"
// These constants uniquely refer to each BSON type.
const (
Double Type = 0x01
String Type = 0x02
EmbeddedDocument Type = 0x03
Array Type = 0x04
Binary Type = 0x05
Undefined Type = 0x06
ObjectID Type = 0x07
Boolean Type = 0x08
DateTime Type = 0x09
Null Type = 0x0A
Regex Type = 0x0B
DBPointer Type = 0x0C
JavaScript Type = 0x0D
Symbol Type = 0x0E
CodeWithScope Type = 0x0F
Int32 Type = 0x10
Timestamp Type = 0x11
Int64 Type = 0x12
Decimal128 Type = 0x13
MinKey Type = 0xFF
MaxKey Type = 0x7F
)
// Type represents a BSON type.
type Type byte
// String returns the string representation of the BSON type's name.
func (bt Type) String() string {
switch bt {
case '\x01':
return "double"
case '\x02':
return "string"
case '\x03':
return "embedded document"
case '\x04':
return "array"
case '\x05':
return "binary"
case '\x06':
return "undefined"
case '\x07':
return "objectID"
case '\x08':
return "boolean"
case '\x09':
return "UTC datetime"
case '\x0A':
return "null"
case '\x0B':
return "regex"
case '\x0C':
return "dbPointer"
case '\x0D':
return "javascript"
case '\x0E':
return "symbol"
case '\x0F':
return "code with scope"
case '\x10':
return "32-bit integer"
case '\x11':
return "timestamp"
case '\x12':
return "64-bit integer"
case '\x13':
return "128-bit decimal"
case '\xFF':
return "min key"
case '\x7F':
return "max key"
default:
return "invalid"
}
}
+106
View File
@@ -0,0 +1,106 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
"sync"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
)
// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Decoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var decPool = sync.Pool{
New: func() interface{} {
return new(Decoder)
},
}
// A Decoder reads and decodes BSON documents from a stream. It reads from a bsonrw.ValueReader as
// the source of BSON data.
type Decoder struct {
dc bsoncodec.DecodeContext
vr bsonrw.ValueReader
}
// NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr.
func NewDecoder(vr bsonrw.ValueReader) (*Decoder, error) {
if vr == nil {
return nil, errors.New("cannot create a new Decoder with a nil ValueReader")
}
return &Decoder{
dc: bsoncodec.DecodeContext{Registry: DefaultRegistry},
vr: vr,
}, nil
}
// NewDecoderWithContext returns a new decoder that uses DecodeContext dc to read from vr.
func NewDecoderWithContext(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader) (*Decoder, error) {
if dc.Registry == nil {
dc.Registry = DefaultRegistry
}
if vr == nil {
return nil, errors.New("cannot create a new Decoder with a nil ValueReader")
}
return &Decoder{
dc: dc,
vr: vr,
}, nil
}
// Decode reads the next BSON document from the stream and decodes it into the
// value pointed to by val.
//
// The documentation for Unmarshal contains details about of BSON into a Go
// value.
func (d *Decoder) Decode(val interface{}) error {
if unmarshaler, ok := val.(Unmarshaler); ok {
// TODO(skriptble): Reuse a []byte here and use the AppendDocumentBytes method.
buf, err := bsonrw.Copier{}.CopyDocumentToBytes(d.vr)
if err != nil {
return err
}
return unmarshaler.UnmarshalBSON(buf)
}
rval := reflect.ValueOf(val)
if rval.Kind() != reflect.Ptr {
return fmt.Errorf("argument to Decode must be a pointer to a type, but got %v", rval)
}
rval = rval.Elem()
decoder, err := d.dc.LookupDecoder(rval.Type())
if err != nil {
return err
}
return decoder.DecodeValue(d.dc, d.vr, rval)
}
// Reset will reset the state of the decoder, using the same *DecodeContext used in
// the original construction but using vr for reading.
func (d *Decoder) Reset(vr bsonrw.ValueReader) error {
d.vr = vr
return nil
}
// SetRegistry replaces the current registry of the decoder with r.
func (d *Decoder) SetRegistry(r *bsoncodec.Registry) error {
d.dc.Registry = r
return nil
}
// SetContext replaces the current registry of the decoder with dc.
func (d *Decoder) SetContext(dc bsoncodec.DecodeContext) error {
d.dc = dc
return nil
}
+42
View File
@@ -0,0 +1,42 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bson is a library for reading, writing, and manipulating BSON. The
// library has two families of types for representing BSON.
//
// The Raw family of types is used to validate and retrieve elements from a slice of bytes. This
// type is most useful when you want do lookups on BSON bytes without unmarshaling it into another
// type.
//
// Example:
// var raw bson.Raw = ... // bytes from somewhere
// err := raw.Validate()
// if err != nil { return err }
// val := raw.Lookup("foo")
// i32, ok := val.Int32OK()
// // do something with i32...
//
// The D family of types is used to build concise representations of BSON using native Go types.
// These types do not support automatic lookup.
//
// Example:
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
//
//
// Marshaling and Unmarshaling are handled with the Marshal and Unmarshal family of functions. If
// you need to write or read BSON from a non-slice source, an Encoder or Decoder can be used with a
// bsonrw.ValueWriter or bsonrw.ValueReader.
//
// Example:
// b, err := bson.Marshal(bson.D{{"foo", "bar"}})
// if err != nil { return err }
// var fooer struct {
// Foo string
// }
// err = bson.Unmarshal(b, &fooer)
// if err != nil { return err }
// // do something with fooer...
package bson
+99
View File
@@ -0,0 +1,99 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"reflect"
"sync"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
)
// This pool is used to keep the allocations of Encoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Encoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var encPool = sync.Pool{
New: func() interface{} {
return new(Encoder)
},
}
// An Encoder writes a serialization format to an output stream. It writes to a bsonrw.ValueWriter
// as the destination of BSON data.
type Encoder struct {
ec bsoncodec.EncodeContext
vw bsonrw.ValueWriter
}
// NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw.
func NewEncoder(vw bsonrw.ValueWriter) (*Encoder, error) {
if vw == nil {
return nil, errors.New("cannot create a new Encoder with a nil ValueWriter")
}
return &Encoder{
ec: bsoncodec.EncodeContext{Registry: DefaultRegistry},
vw: vw,
}, nil
}
// NewEncoderWithContext returns a new encoder that uses EncodeContext ec to write to vw.
func NewEncoderWithContext(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter) (*Encoder, error) {
if ec.Registry == nil {
ec = bsoncodec.EncodeContext{Registry: DefaultRegistry}
}
if vw == nil {
return nil, errors.New("cannot create a new Encoder with a nil ValueWriter")
}
return &Encoder{
ec: ec,
vw: vw,
}, nil
}
// Encode writes the BSON encoding of val to the stream.
//
// The documentation for Marshal contains details about the conversion of Go
// values to BSON.
func (e *Encoder) Encode(val interface{}) error {
if marshaler, ok := val.(Marshaler); ok {
// TODO(skriptble): Should we have a MarshalAppender interface so that we can have []byte reuse?
buf, err := marshaler.MarshalBSON()
if err != nil {
return err
}
return bsonrw.Copier{}.CopyDocumentFromBytes(e.vw, buf)
}
encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val))
if err != nil {
return err
}
return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val))
}
// Reset will reset the state of the encoder, using the same *EncodeContext used in
// the original construction but using vw.
func (e *Encoder) Reset(vw bsonrw.ValueWriter) error {
e.vw = vw
return nil
}
// SetRegistry replaces the current registry of the encoder with r.
func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error {
e.ec.Registry = r
return nil
}
// SetContext replaces the current EncodeContext of the encoder with er.
func (e *Encoder) SetContext(ec bsoncodec.EncodeContext) error {
e.ec = ec
return nil
}
+156
View File
@@ -0,0 +1,156 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
const defaultDstCap = 256
var bvwPool = bsonrw.NewBSONValueWriterPool()
var extjPool = bsonrw.NewExtJSONValueWriterPool()
// Marshaler is an interface implemented by types that can marshal themselves
// into a BSON document represented as bytes. The bytes returned must be a valid
// BSON document if the error is nil.
type Marshaler interface {
MarshalBSON() ([]byte, error)
}
// ValueMarshaler is an interface implemented by types that can marshal
// themselves into a BSON value as bytes. The type must be the valid type for
// the bytes returned. The bytes and byte type together must be valid if the
// error is nil.
type ValueMarshaler interface {
MarshalBSONValue() (bsontype.Type, []byte, error)
}
// Marshal returns the BSON encoding of val.
//
// Marshal will use the default registry created by NewRegistry to recursively
// marshal val into a []byte. Marshal will inspect struct tags and alter the
// marshaling process accordingly.
func Marshal(val interface{}) ([]byte, error) {
return MarshalWithRegistry(DefaultRegistry, val)
}
// MarshalAppend will append the BSON encoding of val to dst. If dst is not
// large enough to hold the BSON encoding of val, dst will be grown.
func MarshalAppend(dst []byte, val interface{}) ([]byte, error) {
return MarshalAppendWithRegistry(DefaultRegistry, dst, val)
}
// MarshalWithRegistry returns the BSON encoding of val using Registry r.
func MarshalWithRegistry(r *bsoncodec.Registry, val interface{}) ([]byte, error) {
dst := make([]byte, 0, 256) // TODO: make the default cap a constant
return MarshalAppendWithRegistry(r, dst, val)
}
// MarshalWithContext returns the BSON encoding of val using EncodeContext ec.
func MarshalWithContext(ec bsoncodec.EncodeContext, val interface{}) ([]byte, error) {
dst := make([]byte, 0, 256) // TODO: make the default cap a constant
return MarshalAppendWithContext(ec, dst, val)
}
// MarshalAppendWithRegistry will append the BSON encoding of val to dst using
// Registry r. If dst is not large enough to hold the BSON encoding of val, dst
// will be grown.
func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) ([]byte, error) {
return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val)
}
// MarshalAppendWithContext will append the BSON encoding of val to dst using
// EncodeContext ec. If dst is not large enough to hold the BSON encoding of val, dst
// will be grown.
func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) {
sw := new(bsonrw.SliceWriter)
*sw = dst
vw := bvwPool.Get(sw)
defer bvwPool.Put(vw)
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
err := enc.Reset(vw)
if err != nil {
return nil, err
}
err = enc.SetContext(ec)
if err != nil {
return nil, err
}
err = enc.Encode(val)
if err != nil {
return nil, err
}
return *sw, nil
}
// MarshalExtJSON returns the extended JSON encoding of val.
func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) {
return MarshalExtJSONWithRegistry(DefaultRegistry, val, canonical, escapeHTML)
}
// MarshalExtJSONAppend will append the extended JSON encoding of val to dst.
// If dst is not large enough to hold the extended JSON encoding of val, dst
// will be grown.
func MarshalExtJSONAppend(dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) {
return MarshalExtJSONAppendWithRegistry(DefaultRegistry, dst, val, canonical, escapeHTML)
}
// MarshalExtJSONWithRegistry returns the extended JSON encoding of val using Registry r.
func MarshalExtJSONWithRegistry(r *bsoncodec.Registry, val interface{}, canonical, escapeHTML bool) ([]byte, error) {
dst := make([]byte, 0, defaultDstCap)
return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML)
}
// MarshalExtJSONWithContext returns the extended JSON encoding of val using Registry r.
func MarshalExtJSONWithContext(ec bsoncodec.EncodeContext, val interface{}, canonical, escapeHTML bool) ([]byte, error) {
dst := make([]byte, 0, defaultDstCap)
return MarshalExtJSONAppendWithContext(ec, dst, val, canonical, escapeHTML)
}
// MarshalExtJSONAppendWithRegistry will append the extended JSON encoding of
// val to dst using Registry r. If dst is not large enough to hold the BSON
// encoding of val, dst will be grown.
func MarshalExtJSONAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) {
return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML)
}
// MarshalExtJSONAppendWithContext will append the extended JSON encoding of
// val to dst using Registry r. If dst is not large enough to hold the BSON
// encoding of val, dst will be grown.
func MarshalExtJSONAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) {
sw := new(bsonrw.SliceWriter)
*sw = dst
ejvw := extjPool.Get(sw, canonical, escapeHTML)
defer extjPool.Put(ejvw)
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
err := enc.Reset(ejvw)
if err != nil {
return nil, err
}
err = enc.SetContext(ec)
if err != nil {
return nil, err
}
err = enc.Encode(val)
if err != nil {
return nil, err
}
return *sw, nil
}
+307
View File
@@ -0,0 +1,307 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package primitive
import (
"fmt"
"strconv"
"strings"
)
// Decimal128 holds decimal128 BSON values.
type Decimal128 struct {
h, l uint64
}
// NewDecimal128 creates a Decimal128 using the provide high and low uint64s.
func NewDecimal128(h, l uint64) Decimal128 {
return Decimal128{h: h, l: l}
}
// GetBytes returns the underlying bytes of the BSON decimal value as two uint16 values. The first
// contains the most first 8 bytes of the value and the second contains the latter.
func (d Decimal128) GetBytes() (uint64, uint64) {
return d.h, d.l
}
// String returns a string representation of the decimal value.
func (d Decimal128) String() string {
var pos int // positive sign
var e int // exponent
var h, l uint64 // significand high/low
if d.h>>63&1 == 0 {
pos = 1
}
switch d.h >> 58 & (1<<5 - 1) {
case 0x1F:
return "NaN"
case 0x1E:
return "-Infinity"[pos:]
}
l = d.l
if d.h>>61&3 == 3 {
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
// Implicit 0b100 prefix in significand.
e = int(d.h>>47&(1<<14-1)) - 6176
//h = 4<<47 | d.h&(1<<47-1)
// Spec says all of these values are out of range.
h, l = 0, 0
} else {
// Bits: 1*sign 14*exponent 113*significand
e = int(d.h>>49&(1<<14-1)) - 6176
h = d.h & (1<<49 - 1)
}
// Would be handled by the logic below, but that's trivial and common.
if h == 0 && l == 0 && e == 0 {
return "-0"[pos:]
}
var repr [48]byte // Loop 5 times over 9 digits plus dot, negative sign, and leading zero.
var last = len(repr)
var i = len(repr)
var dot = len(repr) + e
var rem uint32
Loop:
for d9 := 0; d9 < 5; d9++ {
h, l, rem = divmod(h, l, 1e9)
for d1 := 0; d1 < 9; d1++ {
// Handle "-0.0", "0.00123400", "-1.00E-6", "1.050E+3", etc.
if i < len(repr) && (dot == i || l == 0 && h == 0 && rem > 0 && rem < 10 && (dot < i-6 || e > 0)) {
e += len(repr) - i
i--
repr[i] = '.'
last = i - 1
dot = len(repr) // Unmark.
}
c := '0' + byte(rem%10)
rem /= 10
i--
repr[i] = c
// Handle "0E+3", "1E+3", etc.
if l == 0 && h == 0 && rem == 0 && i == len(repr)-1 && (dot < i-5 || e > 0) {
last = i
break Loop
}
if c != '0' {
last = i
}
// Break early. Works without it, but why.
if dot > i && l == 0 && h == 0 && rem == 0 {
break Loop
}
}
}
repr[last-1] = '-'
last--
if e > 0 {
return string(repr[last+pos:]) + "E+" + strconv.Itoa(e)
}
if e < 0 {
return string(repr[last+pos:]) + "E" + strconv.Itoa(e)
}
return string(repr[last+pos:])
}
func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) {
div64 := uint64(div)
a := h >> 32
aq := a / div64
ar := a % div64
b := ar<<32 + h&(1<<32-1)
bq := b / div64
br := b % div64
c := br<<32 + l>>32
cq := c / div64
cr := c % div64
d := cr<<32 + l&(1<<32-1)
dq := d / div64
dr := d % div64
return (aq<<32 | bq), (cq<<32 | dq), uint32(dr)
}
var dNaN = Decimal128{0x1F << 58, 0}
var dPosInf = Decimal128{0x1E << 58, 0}
var dNegInf = Decimal128{0x3E << 58, 0}
func dErr(s string) (Decimal128, error) {
return dNaN, fmt.Errorf("cannot parse %q as a decimal128", s)
}
//ParseDecimal128 takes the given string and attempts to parse it into a valid
// Decimal128 value.
func ParseDecimal128(s string) (Decimal128, error) {
orig := s
if s == "" {
return dErr(orig)
}
neg := s[0] == '-'
if neg || s[0] == '+' {
s = s[1:]
}
if (len(s) == 3 || len(s) == 8) && (s[0] == 'N' || s[0] == 'n' || s[0] == 'I' || s[0] == 'i') {
if s == "NaN" || s == "nan" || strings.EqualFold(s, "nan") {
return dNaN, nil
}
if s == "Inf" || s == "inf" || strings.EqualFold(s, "inf") || strings.EqualFold(s, "infinity") {
if neg {
return dNegInf, nil
}
return dPosInf, nil
}
return dErr(orig)
}
var h, l uint64
var e int
var add, ovr uint32
var mul uint32 = 1
var dot = -1
var digits = 0
var i = 0
for i < len(s) {
c := s[i]
if mul == 1e9 {
h, l, ovr = muladd(h, l, mul, add)
mul, add = 1, 0
if ovr > 0 || h&((1<<15-1)<<49) > 0 {
return dErr(orig)
}
}
if c >= '0' && c <= '9' {
i++
if c > '0' || digits > 0 {
digits++
}
if digits > 34 {
if c == '0' {
// Exact rounding.
e++
continue
}
return dErr(orig)
}
mul *= 10
add *= 10
add += uint32(c - '0')
continue
}
if c == '.' {
i++
if dot >= 0 || i == 1 && len(s) == 1 {
return dErr(orig)
}
if i == len(s) {
break
}
if s[i] < '0' || s[i] > '9' || e > 0 {
return dErr(orig)
}
dot = i
continue
}
break
}
if i == 0 {
return dErr(orig)
}
if mul > 1 {
h, l, ovr = muladd(h, l, mul, add)
if ovr > 0 || h&((1<<15-1)<<49) > 0 {
return dErr(orig)
}
}
if dot >= 0 {
e += dot - i
}
if i+1 < len(s) && (s[i] == 'E' || s[i] == 'e') {
i++
eneg := s[i] == '-'
if eneg || s[i] == '+' {
i++
if i == len(s) {
return dErr(orig)
}
}
n := 0
for i < len(s) && n < 1e4 {
c := s[i]
i++
if c < '0' || c > '9' {
return dErr(orig)
}
n *= 10
n += int(c - '0')
}
if eneg {
n = -n
}
e += n
for e < -6176 {
// Subnormal.
var div uint32 = 1
for div < 1e9 && e < -6176 {
div *= 10
e++
}
var rem uint32
h, l, rem = divmod(h, l, div)
if rem > 0 {
return dErr(orig)
}
}
for e > 6111 {
// Clamped.
var mul uint32 = 1
for mul < 1e9 && e > 6111 {
mul *= 10
e--
}
h, l, ovr = muladd(h, l, mul, 0)
if ovr > 0 || h&((1<<15-1)<<49) > 0 {
return dErr(orig)
}
}
if e < -6176 || e > 6111 {
return dErr(orig)
}
}
if i < len(s) {
return dErr(orig)
}
h |= uint64(e+6176) & uint64(1<<14-1) << 49
if neg {
h |= 1 << 63
}
return Decimal128{h, l}, nil
}
func muladd(h, l uint64, mul uint32, add uint32) (resh, resl uint64, overflow uint32) {
mul64 := uint64(mul)
a := mul64 * (l & (1<<32 - 1))
b := a>>32 + mul64*(l>>32)
c := b>>32 + mul64*(h&(1<<32-1))
d := c>>32 + mul64*(h>>32)
a = a&(1<<32-1) + uint64(add)
b = b&(1<<32-1) + a>>32
c = c&(1<<32-1) + b>>32
d = d&(1<<32-1) + c>>32
return (d<<32 | c&(1<<32-1)), (b<<32 | a&(1<<32-1)), uint32(d >> 32)
}
+154
View File
@@ -0,0 +1,154 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package primitive
import (
"bytes"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"sync/atomic"
"time"
)
// ErrInvalidHex indicates that a hex string cannot be converted to an ObjectID.
var ErrInvalidHex = errors.New("the provided hex string is not a valid ObjectID")
// ObjectID is the BSON ObjectID type.
type ObjectID [12]byte
// NilObjectID is the zero value for ObjectID.
var NilObjectID ObjectID
var objectIDCounter = readRandomUint32()
var processUnique = processUniqueBytes()
// NewObjectID generates a new ObjectID.
func NewObjectID() ObjectID {
var b [12]byte
binary.BigEndian.PutUint32(b[0:4], uint32(time.Now().Unix()))
copy(b[4:9], processUnique[:])
putUint24(b[9:12], atomic.AddUint32(&objectIDCounter, 1))
return b
}
// Hex returns the hex encoding of the ObjectID as a string.
func (id ObjectID) Hex() string {
return hex.EncodeToString(id[:])
}
func (id ObjectID) String() string {
return fmt.Sprintf("ObjectID(%q)", id.Hex())
}
// IsZero returns true if id is the empty ObjectID.
func (id ObjectID) IsZero() bool {
return bytes.Equal(id[:], NilObjectID[:])
}
// ObjectIDFromHex creates a new ObjectID from a hex string. It returns an error if the hex string is not a
// valid ObjectID.
func ObjectIDFromHex(s string) (ObjectID, error) {
b, err := hex.DecodeString(s)
if err != nil {
return NilObjectID, err
}
if len(b) != 12 {
return NilObjectID, ErrInvalidHex
}
var oid [12]byte
copy(oid[:], b[:])
return oid, nil
}
// MarshalJSON returns the ObjectID as a string
func (id ObjectID) MarshalJSON() ([]byte, error) {
return json.Marshal(id.Hex())
}
// UnmarshalJSON populates the byte slice with the ObjectID. If the byte slice is 64 bytes long, it
// will be populated with the hex representation of the ObjectID. If the byte slice is twelve bytes
// long, it will be populated with the BSON representation of the ObjectID. Otherwise, it will
// return an error.
func (id *ObjectID) UnmarshalJSON(b []byte) error {
var err error
switch len(b) {
case 12:
copy(id[:], b)
default:
// Extended JSON
var res interface{}
err := json.Unmarshal(b, &res)
if err != nil {
return err
}
str, ok := res.(string)
if !ok {
m, ok := res.(map[string]interface{})
if !ok {
return errors.New("not an extended JSON ObjectID")
}
oid, ok := m["$oid"]
if !ok {
return errors.New("not an extended JSON ObjectID")
}
str, ok = oid.(string)
if !ok {
return errors.New("not an extended JSON ObjectID")
}
}
if len(str) != 24 {
return fmt.Errorf("cannot unmarshal into an ObjectID, the length must be 12 but it is %d", len(str))
}
_, err = hex.Decode(id[:], []byte(str))
if err != nil {
return err
}
}
return err
}
func processUniqueBytes() [5]byte {
var b [5]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
}
return b
}
func readRandomUint32() uint32 {
var b [4]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
}
return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
}
func putUint24(b []byte, v uint32) {
b[0] = byte(v >> 16)
b[1] = byte(v >> 8)
b[2] = byte(v)
}
+156
View File
@@ -0,0 +1,156 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package primitive contains types similar to Go primitives for BSON types can do not have direct
// Go primitive representations.
package primitive // import "go.mongodb.org/mongo-driver/bson/primitive"
import (
"bytes"
"encoding/json"
"fmt"
"time"
)
// Binary represents a BSON binary value.
type Binary struct {
Subtype byte
Data []byte
}
// Equal compaes bp to bp2 and returns true is the are equal.
func (bp Binary) Equal(bp2 Binary) bool {
if bp.Subtype != bp2.Subtype {
return false
}
return bytes.Equal(bp.Data, bp2.Data)
}
// Undefined represents the BSON undefined value type.
type Undefined struct{}
// DateTime represents the BSON datetime value.
type DateTime int64
// MarshalJSON marshal to time type
func (d DateTime) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Unix(int64(d)/1000, int64(d)%1000*1000000))
}
// Null repreesnts the BSON null value.
type Null struct{}
// Regex represents a BSON regex value.
type Regex struct {
Pattern string
Options string
}
func (rp Regex) String() string {
return fmt.Sprintf(`{"pattern": "%s", "options": "%s"}`, rp.Pattern, rp.Options)
}
// Equal compaes rp to rp2 and returns true is the are equal.
func (rp Regex) Equal(rp2 Regex) bool {
return rp.Pattern == rp2.Pattern && rp.Options == rp.Options
}
// DBPointer represents a BSON dbpointer value.
type DBPointer struct {
DB string
Pointer ObjectID
}
func (d DBPointer) String() string {
return fmt.Sprintf(`{"db": "%s", "pointer": "%s"}`, d.DB, d.Pointer)
}
// Equal compaes d to d2 and returns true is the are equal.
func (d DBPointer) Equal(d2 DBPointer) bool {
return d.DB == d2.DB && bytes.Equal(d.Pointer[:], d2.Pointer[:])
}
// JavaScript represents a BSON JavaScript code value.
type JavaScript string
// Symbol represents a BSON symbol value.
type Symbol string
// CodeWithScope represents a BSON JavaScript code with scope value.
type CodeWithScope struct {
Code JavaScript
Scope interface{}
}
func (cws CodeWithScope) String() string {
return fmt.Sprintf(`{"code": "%s", "scope": %v}`, cws.Code, cws.Scope)
}
// Timestamp represents a BSON timestamp value.
type Timestamp struct {
T uint32
I uint32
}
// Equal compaes tp to tp2 and returns true is the are equal.
func (tp Timestamp) Equal(tp2 Timestamp) bool {
return tp.T == tp2.T && tp.I == tp2.I
}
// MinKey represents the BSON minkey value.
type MinKey struct{}
// MaxKey represents the BSON maxkey value.
type MaxKey struct{}
// D represents a BSON Document. This type can be used to represent BSON in a concise and readable
// manner. It should generally be used when serializing to BSON. For deserializing, the Raw or
// Document types should be used.
//
// Example usage:
//
// primitive.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
//
// This type should be used in situations where order matters, such as MongoDB commands. If the
// order is not important, a map is more comfortable and concise.
type D []E
// Map creates a map from the elements of the D.
func (d D) Map() M {
m := make(M, len(d))
for _, e := range d {
m[e.Key] = e.Value
}
return m
}
// E represents a BSON element for a D. It is usually used inside a D.
type E struct {
Key string
Value interface{}
}
// M is an unordered, concise representation of a BSON Document. It should generally be used to
// serialize BSON when the order of the elements of a BSON document do not matter. If the element
// order matters, use a D instead.
//
// Example usage:
//
// primitive.M{"foo": "bar", "hello": "world", "pi": 3.14159}
//
// This type is handled in the encoders as a regular map[string]interface{}. The elements will be
// serialized in an undefined, random order, and the order will be different each time.
type M map[string]interface{}
// An A represents a BSON array. This type can be used to represent a BSON array in a concise and
// readable manner. It should generally be used when serializing to BSON. For deserializing, the
// RawArray or Array types should be used.
//
// Example usage:
//
// primitive.A{"bar", "world", 3.14159, primitive.D{{"qux", 12345}}}
//
type A []interface{}
+111
View File
@@ -0,0 +1,111 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
)
var primitiveCodecs PrimitiveCodecs
// PrimitiveCodecs is a namespace for all of the default bsoncodec.Codecs for the primitive types
// defined in this package.
type PrimitiveCodecs struct{}
// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs
// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created.
func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) {
if rb == nil {
panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil"))
}
rb.
RegisterEncoder(tRawValue, bsoncodec.ValueEncoderFunc(pc.RawValueEncodeValue)).
RegisterEncoder(tRaw, bsoncodec.ValueEncoderFunc(pc.RawEncodeValue)).
RegisterDecoder(tRawValue, bsoncodec.ValueDecoderFunc(pc.RawValueDecodeValue)).
RegisterDecoder(tRaw, bsoncodec.ValueDecoderFunc(pc.RawDecodeValue))
}
// RawValueEncodeValue is the ValueEncoderFunc for RawValue.
func (PrimitiveCodecs) RawValueEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRawValue {
return bsoncodec.ValueEncoderError{Name: "RawValueEncodeValue", Types: []reflect.Type{tRawValue}, Received: val}
}
rawvalue := val.Interface().(RawValue)
return bsonrw.Copier{}.CopyValueFromBytes(vw, rawvalue.Type, rawvalue.Value)
}
// RawValueDecodeValue is the ValueDecoderFunc for RawValue.
func (PrimitiveCodecs) RawValueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tRawValue {
return bsoncodec.ValueDecoderError{Name: "RawValueDecodeValue", Types: []reflect.Type{tRawValue}, Received: val}
}
t, value, err := bsonrw.Copier{}.CopyValueToBytes(vr)
if err != nil {
return err
}
val.Set(reflect.ValueOf(RawValue{Type: t, Value: value}))
return nil
}
// RawEncodeValue is the ValueEncoderFunc for Reader.
func (PrimitiveCodecs) RawEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRaw {
return bsoncodec.ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val}
}
rdr := val.Interface().(Raw)
return bsonrw.Copier{}.CopyDocumentFromBytes(vw, rdr)
}
// RawDecodeValue is the ValueDecoderFunc for Reader.
func (PrimitiveCodecs) RawDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tRaw {
return bsoncodec.ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
rdr, err := bsonrw.Copier{}.AppendDocumentBytes(val.Interface().(Raw), vr)
val.Set(reflect.ValueOf(rdr))
return err
}
func (pc PrimitiveCodecs) encodeRaw(ec bsoncodec.EncodeContext, dw bsonrw.DocumentWriter, raw Raw) error {
var copier bsonrw.Copier
elems, err := raw.Elements()
if err != nil {
return err
}
for _, elem := range elems {
dvw, err := dw.WriteDocumentElement(elem.Key())
if err != nil {
return err
}
val := elem.Value()
err = copier.CopyValueFromBytes(dvw, val.Type, val.Value)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
+92
View File
@@ -0,0 +1,92 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"io"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ErrNilReader indicates that an operation was attempted on a nil bson.Reader.
var ErrNilReader = errors.New("nil reader")
var errValidateDone = errors.New("validation loop complete")
// Raw is a wrapper around a byte slice. It will interpret the slice as a
// BSON document. This type is a wrapper around a bsoncore.Document. Errors returned from the
// methods on this type and associated types come from the bsoncore package.
type Raw []byte
// NewFromIOReader reads in a document from the given io.Reader and constructs a Raw from
// it.
func NewFromIOReader(r io.Reader) (Raw, error) {
doc, err := bsoncore.NewDocumentFromReader(r)
return Raw(doc), err
}
// Validate validates the document. This method only validates the first document in
// the slice, to validate other documents, the slice must be resliced.
func (r Raw) Validate() (err error) { return bsoncore.Document(r).Validate() }
// Lookup search the document, potentially recursively, for the given key. If
// there are multiple keys provided, this method will recurse down, as long as
// the top and intermediate nodes are either documents or arrays.If an error
// occurs or if the value doesn't exist, an empty RawValue is returned.
func (r Raw) Lookup(key ...string) RawValue {
return convertFromCoreValue(bsoncore.Document(r).Lookup(key...))
}
// LookupErr searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
func (r Raw) LookupErr(key ...string) (RawValue, error) {
val, err := bsoncore.Document(r).LookupErr(key...)
return convertFromCoreValue(val), err
}
// Elements returns this document as a slice of elements. The returned slice will contain valid
// elements. If the document is not valid, the elements up to the invalid point will be returned
// along with an error.
func (r Raw) Elements() ([]RawElement, error) {
elems, err := bsoncore.Document(r).Elements()
relems := make([]RawElement, 0, len(elems))
for _, elem := range elems {
relems = append(relems, RawElement(elem))
}
return relems, err
}
// Values returns this document as a slice of values. The returned slice will contain valid values.
// If the document is not valid, the values up to the invalid point will be returned along with an
// error.
func (r Raw) Values() ([]RawValue, error) {
vals, err := bsoncore.Document(r).Values()
rvals := make([]RawValue, 0, len(vals))
for _, val := range vals {
rvals = append(rvals, convertFromCoreValue(val))
}
return rvals, err
}
// Index searches for and retrieves the element at the given index. This method will panic if
// the document is invalid or if the index is out of bounds.
func (r Raw) Index(index uint) RawElement { return RawElement(bsoncore.Document(r).Index(index)) }
// IndexErr searches for and retrieves the element at the given index.
func (r Raw) IndexErr(index uint) (RawElement, error) {
elem, err := bsoncore.Document(r).IndexErr(index)
return RawElement(elem), err
}
// String implements the fmt.Stringer interface.
func (r Raw) String() string { return bsoncore.Document(r).String() }
// readi32 is a helper function for reading an int32 from slice of bytes.
func readi32(b []byte) int32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return int32(b[0]) | int32(b[1])<<8 | int32(b[2])<<16 | int32(b[3])<<24
}
+51
View File
@@ -0,0 +1,51 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// RawElement represents a BSON element in byte form. This type provides a simple way to
// transform a slice of bytes into a BSON element and extract information from it.
//
// RawElement is a thin wrapper around a bsoncore.Element.
type RawElement []byte
// Key returns the key for this element. If the element is not valid, this method returns an empty
// string. If knowing if the element is valid is important, use KeyErr.
func (re RawElement) Key() string { return bsoncore.Element(re).Key() }
// KeyErr returns the key for this element, returning an error if the element is not valid.
func (re RawElement) KeyErr() (string, error) { return bsoncore.Element(re).KeyErr() }
// Value returns the value of this element. If the element is not valid, this method returns an
// empty Value. If knowing if the element is valid is important, use ValueErr.
func (re RawElement) Value() RawValue { return convertFromCoreValue(bsoncore.Element(re).Value()) }
// ValueErr returns the value for this element, returning an error if the element is not valid.
func (re RawElement) ValueErr() (RawValue, error) {
val, err := bsoncore.Element(re).ValueErr()
return convertFromCoreValue(val), err
}
// Validate ensures re is a valid BSON element.
func (re RawElement) Validate() error { return bsoncore.Element(re).Validate() }
// String implements the fmt.Stringer interface. The output will be in extended JSON format.
func (re RawElement) String() string {
doc := bsoncore.BuildDocument(nil, re)
j, err := MarshalExtJSON(Raw(doc), true, false)
if err != nil {
return "<malformed>"
}
return string(j)
}
// DebugString outputs a human readable version of RawElement. It will attempt to stringify the
// valid components of the element even if the entire element is not valid.
func (re RawElement) DebugString() string { return bsoncore.Element(re).DebugString() }
+287
View File
@@ -0,0 +1,287 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ErrNilContext is returned when the provided DecodeContext is nil.
var ErrNilContext = errors.New("DecodeContext cannot be nil")
// ErrNilRegistry is returned when the provided registry is nil.
var ErrNilRegistry = errors.New("Registry cannot be nil")
// RawValue represents a BSON value in byte form. It can be used to hold unprocessed BSON or to
// defer processing of BSON. Type is the BSON type of the value and Value are the raw bytes that
// represent the element.
//
// This type wraps bsoncore.Value for most of it's functionality.
type RawValue struct {
Type bsontype.Type
Value []byte
r *bsoncodec.Registry
}
// Unmarshal deserializes BSON into the provided val. If RawValue cannot be unmarshaled into val, an
// error is returned. This method will use the registry used to create the RawValue, if the RawValue
// was created from partial BSON processing, or it will use the default registry. Users wishing to
// specify the registry to use should use UnmarshalWithRegistry.
func (rv RawValue) Unmarshal(val interface{}) error {
reg := rv.r
if reg == nil {
reg = DefaultRegistry
}
return rv.UnmarshalWithRegistry(reg, val)
}
// Equal compares rv and rv2 and returns true if they are equal.
func (rv RawValue) Equal(rv2 RawValue) bool {
if rv.Type != rv2.Type {
return false
}
if !bytes.Equal(rv.Value, rv2.Value) {
return false
}
return true
}
// UnmarshalWithRegistry performs the same unmarshalling as Unmarshal but uses the provided registry
// instead of the one attached or the default registry.
func (rv RawValue) UnmarshalWithRegistry(r *bsoncodec.Registry, val interface{}) error {
if r == nil {
return ErrNilRegistry
}
vr := bsonrw.NewBSONValueReader(rv.Type, rv.Value)
rval := reflect.ValueOf(val)
if rval.Kind() != reflect.Ptr {
return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval)
}
rval = rval.Elem()
dec, err := r.LookupDecoder(rval.Type())
if err != nil {
return err
}
return dec.DecodeValue(bsoncodec.DecodeContext{Registry: r}, vr, rval)
}
// UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses the provided DecodeContext
// instead of the one attached or the default registry.
func (rv RawValue) UnmarshalWithContext(dc *bsoncodec.DecodeContext, val interface{}) error {
if dc == nil {
return ErrNilContext
}
vr := bsonrw.NewBSONValueReader(rv.Type, rv.Value)
rval := reflect.ValueOf(val)
if rval.Kind() != reflect.Ptr {
return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval)
}
rval = rval.Elem()
dec, err := dc.LookupDecoder(rval.Type())
if err != nil {
return err
}
return dec.DecodeValue(*dc, vr, rval)
}
func convertFromCoreValue(v bsoncore.Value) RawValue { return RawValue{Type: v.Type, Value: v.Data} }
func convertToCoreValue(v RawValue) bsoncore.Value { return bsoncore.Value{Type: v.Type, Data: v.Value} }
// Validate ensures the value is a valid BSON value.
func (rv RawValue) Validate() error { return convertToCoreValue(rv).Validate() }
// IsNumber returns true if the type of v is a numeric BSON type.
func (rv RawValue) IsNumber() bool { return convertToCoreValue(rv).IsNumber() }
// String implements the fmt.String interface. This method will return values in extended JSON
// format. If the value is not valid, this returns an empty string
func (rv RawValue) String() string { return convertToCoreValue(rv).String() }
// DebugString outputs a human readable version of Document. It will attempt to stringify the
// valid components of the document even if the entire document is not valid.
func (rv RawValue) DebugString() string { return convertToCoreValue(rv).DebugString() }
// Double returns the float64 value for this element.
// It panics if e's BSON type is not bsontype.Double.
func (rv RawValue) Double() float64 { return convertToCoreValue(rv).Double() }
// DoubleOK is the same as Double, but returns a boolean instead of panicking.
func (rv RawValue) DoubleOK() (float64, bool) { return convertToCoreValue(rv).DoubleOK() }
// StringValue returns the string value for this element.
// It panics if e's BSON type is not bsontype.String.
//
// NOTE: This method is called StringValue to avoid a collision with the String method which
// implements the fmt.Stringer interface.
func (rv RawValue) StringValue() string { return convertToCoreValue(rv).StringValue() }
// StringValueOK is the same as StringValue, but returns a boolean instead of
// panicking.
func (rv RawValue) StringValueOK() (string, bool) { return convertToCoreValue(rv).StringValueOK() }
// Document returns the BSON document the Value represents as a Document. It panics if the
// value is a BSON type other than document.
func (rv RawValue) Document() Raw { return Raw(convertToCoreValue(rv).Document()) }
// DocumentOK is the same as Document, except it returns a boolean
// instead of panicking.
func (rv RawValue) DocumentOK() (Raw, bool) {
doc, ok := convertToCoreValue(rv).DocumentOK()
return Raw(doc), ok
}
// Array returns the BSON array the Value represents as an Array. It panics if the
// value is a BSON type other than array.
func (rv RawValue) Array() Raw { return Raw(convertToCoreValue(rv).Array()) }
// ArrayOK is the same as Array, except it returns a boolean instead
// of panicking.
func (rv RawValue) ArrayOK() (Raw, bool) {
doc, ok := convertToCoreValue(rv).ArrayOK()
return Raw(doc), ok
}
// Binary returns the BSON binary value the Value represents. It panics if the value is a BSON type
// other than binary.
func (rv RawValue) Binary() (subtype byte, data []byte) { return convertToCoreValue(rv).Binary() }
// BinaryOK is the same as Binary, except it returns a boolean instead of
// panicking.
func (rv RawValue) BinaryOK() (subtype byte, data []byte, ok bool) {
return convertToCoreValue(rv).BinaryOK()
}
// ObjectID returns the BSON objectid value the Value represents. It panics if the value is a BSON
// type other than objectid.
func (rv RawValue) ObjectID() primitive.ObjectID { return convertToCoreValue(rv).ObjectID() }
// ObjectIDOK is the same as ObjectID, except it returns a boolean instead of
// panicking.
func (rv RawValue) ObjectIDOK() (primitive.ObjectID, bool) { return convertToCoreValue(rv).ObjectIDOK() }
// Boolean returns the boolean value the Value represents. It panics if the
// value is a BSON type other than boolean.
func (rv RawValue) Boolean() bool { return convertToCoreValue(rv).Boolean() }
// BooleanOK is the same as Boolean, except it returns a boolean instead of
// panicking.
func (rv RawValue) BooleanOK() (bool, bool) { return convertToCoreValue(rv).BooleanOK() }
// DateTime returns the BSON datetime value the Value represents as a
// unix timestamp. It panics if the value is a BSON type other than datetime.
func (rv RawValue) DateTime() int64 { return convertToCoreValue(rv).DateTime() }
// DateTimeOK is the same as DateTime, except it returns a boolean instead of
// panicking.
func (rv RawValue) DateTimeOK() (int64, bool) { return convertToCoreValue(rv).DateTimeOK() }
// Time returns the BSON datetime value the Value represents. It panics if the value is a BSON
// type other than datetime.
func (rv RawValue) Time() time.Time { return convertToCoreValue(rv).Time() }
// TimeOK is the same as Time, except it returns a boolean instead of
// panicking.
func (rv RawValue) TimeOK() (time.Time, bool) { return convertToCoreValue(rv).TimeOK() }
// Regex returns the BSON regex value the Value represents. It panics if the value is a BSON
// type other than regex.
func (rv RawValue) Regex() (pattern, options string) { return convertToCoreValue(rv).Regex() }
// RegexOK is the same as Regex, except it returns a boolean instead of
// panicking.
func (rv RawValue) RegexOK() (pattern, options string, ok bool) {
return convertToCoreValue(rv).RegexOK()
}
// DBPointer returns the BSON dbpointer value the Value represents. It panics if the value is a BSON
// type other than DBPointer.
func (rv RawValue) DBPointer() (string, primitive.ObjectID) { return convertToCoreValue(rv).DBPointer() }
// DBPointerOK is the same as DBPoitner, except that it returns a boolean
// instead of panicking.
func (rv RawValue) DBPointerOK() (string, primitive.ObjectID, bool) {
return convertToCoreValue(rv).DBPointerOK()
}
// JavaScript returns the BSON JavaScript code value the Value represents. It panics if the value is
// a BSON type other than JavaScript code.
func (rv RawValue) JavaScript() string { return convertToCoreValue(rv).JavaScript() }
// JavaScriptOK is the same as Javascript, excepti that it returns a boolean
// instead of panicking.
func (rv RawValue) JavaScriptOK() (string, bool) { return convertToCoreValue(rv).JavaScriptOK() }
// Symbol returns the BSON symbol value the Value represents. It panics if the value is a BSON
// type other than symbol.
func (rv RawValue) Symbol() string { return convertToCoreValue(rv).Symbol() }
// SymbolOK is the same as Symbol, excepti that it returns a boolean
// instead of panicking.
func (rv RawValue) SymbolOK() (string, bool) { return convertToCoreValue(rv).SymbolOK() }
// CodeWithScope returns the BSON JavaScript code with scope the Value represents.
// It panics if the value is a BSON type other than JavaScript code with scope.
func (rv RawValue) CodeWithScope() (string, Raw) {
code, scope := convertToCoreValue(rv).CodeWithScope()
return code, Raw(scope)
}
// CodeWithScopeOK is the same as CodeWithScope, except that it returns a boolean instead of
// panicking.
func (rv RawValue) CodeWithScopeOK() (string, Raw, bool) {
code, scope, ok := convertToCoreValue(rv).CodeWithScopeOK()
return code, Raw(scope), ok
}
// Int32 returns the int32 the Value represents. It panics if the value is a BSON type other than
// int32.
func (rv RawValue) Int32() int32 { return convertToCoreValue(rv).Int32() }
// Int32OK is the same as Int32, except that it returns a boolean instead of
// panicking.
func (rv RawValue) Int32OK() (int32, bool) { return convertToCoreValue(rv).Int32OK() }
// Timestamp returns the BSON timestamp value the Value represents. It panics if the value is a
// BSON type other than timestamp.
func (rv RawValue) Timestamp() (t, i uint32) { return convertToCoreValue(rv).Timestamp() }
// TimestampOK is the same as Timestamp, except that it returns a boolean
// instead of panicking.
func (rv RawValue) TimestampOK() (t, i uint32, ok bool) { return convertToCoreValue(rv).TimestampOK() }
// Int64 returns the int64 the Value represents. It panics if the value is a BSON type other than
// int64.
func (rv RawValue) Int64() int64 { return convertToCoreValue(rv).Int64() }
// Int64OK is the same as Int64, except that it returns a boolean instead of
// panicking.
func (rv RawValue) Int64OK() (int64, bool) { return convertToCoreValue(rv).Int64OK() }
// Decimal128 returns the decimal the Value represents. It panics if the value is a BSON type other than
// decimal.
func (rv RawValue) Decimal128() primitive.Decimal128 { return convertToCoreValue(rv).Decimal128() }
// Decimal128OK is the same as Decimal128, except that it returns a boolean
// instead of panicking.
func (rv RawValue) Decimal128OK() (primitive.Decimal128, bool) {
return convertToCoreValue(rv).Decimal128OK()
}
+24
View File
@@ -0,0 +1,24 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import "go.mongodb.org/mongo-driver/bson/bsoncodec"
// DefaultRegistry is the default bsoncodec.Registry. It contains the default codecs and the
// primitive codecs.
var DefaultRegistry = NewRegistryBuilder().Build()
// NewRegistryBuilder creates a new RegistryBuilder configured with the default encoders and
// deocders from the bsoncodec.DefaultValueEncoders and bsoncodec.DefaultValueDecoders types and the
// PrimitiveCodecs type in this package.
func NewRegistryBuilder() *bsoncodec.RegistryBuilder {
rb := bsoncodec.NewRegistryBuilder()
bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb)
bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb)
primitiveCodecs.RegisterPrimitiveCodecs(rb)
return rb
}
+85
View File
@@ -0,0 +1,85 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// These constants uniquely refer to each BSON type.
const (
TypeDouble = bsontype.Double
TypeString = bsontype.String
TypeEmbeddedDocument = bsontype.EmbeddedDocument
TypeArray = bsontype.Array
TypeBinary = bsontype.Binary
TypeUndefined = bsontype.Undefined
TypeObjectID = bsontype.ObjectID
TypeBoolean = bsontype.Boolean
TypeDateTime = bsontype.DateTime
TypeNull = bsontype.Null
TypeRegex = bsontype.Regex
TypeDBPointer = bsontype.DBPointer
TypeJavaScript = bsontype.JavaScript
TypeSymbol = bsontype.Symbol
TypeCodeWithScope = bsontype.CodeWithScope
TypeInt32 = bsontype.Int32
TypeTimestamp = bsontype.Timestamp
TypeInt64 = bsontype.Int64
TypeDecimal128 = bsontype.Decimal128
TypeMinKey = bsontype.MinKey
TypeMaxKey = bsontype.MaxKey
)
var tBinary = reflect.TypeOf(primitive.Binary{})
var tBool = reflect.TypeOf(false)
var tCodeWithScope = reflect.TypeOf(primitive.CodeWithScope{})
var tDBPointer = reflect.TypeOf(primitive.DBPointer{})
var tDecimal = reflect.TypeOf(primitive.Decimal128{})
var tD = reflect.TypeOf(D{})
var tA = reflect.TypeOf(A{})
var tDateTime = reflect.TypeOf(primitive.DateTime(0))
var tUndefined = reflect.TypeOf(primitive.Undefined{})
var tNull = reflect.TypeOf(primitive.Null{})
var tRawValue = reflect.TypeOf(RawValue{})
var tFloat32 = reflect.TypeOf(float32(0))
var tFloat64 = reflect.TypeOf(float64(0))
var tInt = reflect.TypeOf(int(0))
var tInt8 = reflect.TypeOf(int8(0))
var tInt16 = reflect.TypeOf(int16(0))
var tInt32 = reflect.TypeOf(int32(0))
var tInt64 = reflect.TypeOf(int64(0))
var tJavaScript = reflect.TypeOf(primitive.JavaScript(""))
var tOID = reflect.TypeOf(primitive.ObjectID{})
var tRaw = reflect.TypeOf(Raw(nil))
var tRegex = reflect.TypeOf(primitive.Regex{})
var tString = reflect.TypeOf("")
var tSymbol = reflect.TypeOf(primitive.Symbol(""))
var tTime = reflect.TypeOf(time.Time{})
var tTimestamp = reflect.TypeOf(primitive.Timestamp{})
var tUint = reflect.TypeOf(uint(0))
var tUint8 = reflect.TypeOf(uint8(0))
var tUint16 = reflect.TypeOf(uint16(0))
var tUint32 = reflect.TypeOf(uint32(0))
var tUint64 = reflect.TypeOf(uint64(0))
var tMinKey = reflect.TypeOf(primitive.MinKey{})
var tMaxKey = reflect.TypeOf(primitive.MaxKey{})
var tEmpty = reflect.TypeOf((*interface{})(nil)).Elem()
var tEmptySlice = reflect.TypeOf([]interface{}(nil))
var zeroVal reflect.Value
// this references the quantity of milliseconds between zero time and
// the unix epoch. useful for making sure that we convert time.Time
// objects correctly to match the legacy bson library's handling of
// time.Time values.
const zeroEpochMs = int64(62135596800000)
+101
View File
@@ -0,0 +1,101 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// Unmarshaler is an interface implemented by types that can unmarshal a BSON
// document representation of themselves. The BSON bytes can be assumed to be
// valid. UnmarshalBSON must copy the BSON bytes if it wishes to retain the data
// after returning.
type Unmarshaler interface {
UnmarshalBSON([]byte) error
}
// ValueUnmarshaler is an interface implemented by types that can unmarshal a
// BSON value representaiton of themselves. The BSON bytes and type can be
// assumed to be valid. UnmarshalBSONValue must copy the BSON value bytes if it
// wishes to retain the data after returning.
type ValueUnmarshaler interface {
UnmarshalBSONValue(bsontype.Type, []byte) error
}
// Unmarshal parses the BSON-encoded data and stores the result in the value
// pointed to by val. If val is nil or not a pointer, Unmarshal returns
// InvalidUnmarshalError.
func Unmarshal(data []byte, val interface{}) error {
return UnmarshalWithRegistry(DefaultRegistry, data, val)
}
// UnmarshalWithRegistry parses the BSON-encoded data using Registry r and
// stores the result in the value pointed to by val. If val is nil or not
// a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError.
func UnmarshalWithRegistry(r *bsoncodec.Registry, data []byte, val interface{}) error {
vr := bsonrw.NewBSONDocumentReader(data)
return unmarshalFromReader(bsoncodec.DecodeContext{Registry: r}, vr, val)
}
// UnmarshalWithContext parses the BSON-encoded data using DecodeContext dc and
// stores the result in the value pointed to by val. If val is nil or not
// a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError.
func UnmarshalWithContext(dc bsoncodec.DecodeContext, data []byte, val interface{}) error {
vr := bsonrw.NewBSONDocumentReader(data)
return unmarshalFromReader(dc, vr, val)
}
// UnmarshalExtJSON parses the extended JSON-encoded data and stores the result
// in the value pointed to by val. If val is nil or not a pointer, Unmarshal
// returns InvalidUnmarshalError.
func UnmarshalExtJSON(data []byte, canonical bool, val interface{}) error {
return UnmarshalExtJSONWithRegistry(DefaultRegistry, data, canonical, val)
}
// UnmarshalExtJSONWithRegistry parses the extended JSON-encoded data using
// Registry r and stores the result in the value pointed to by val. If val is
// nil or not a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError.
func UnmarshalExtJSONWithRegistry(r *bsoncodec.Registry, data []byte, canonical bool, val interface{}) error {
ejvr, err := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), canonical)
if err != nil {
return err
}
return unmarshalFromReader(bsoncodec.DecodeContext{Registry: r}, ejvr, val)
}
// UnmarshalExtJSONWithContext parses the extended JSON-encoded data using
// DecodeContext dc and stores the result in the value pointed to by val. If val is
// nil or not a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError.
func UnmarshalExtJSONWithContext(dc bsoncodec.DecodeContext, data []byte, canonical bool, val interface{}) error {
ejvr, err := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), canonical)
if err != nil {
return err
}
return unmarshalFromReader(dc, ejvr, val)
}
func unmarshalFromReader(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val interface{}) error {
dec := decPool.Get().(*Decoder)
defer decPool.Put(dec)
err := dec.Reset(vr)
if err != nil {
return err
}
err = dec.SetContext(dc)
if err != nil {
return err
}
return dec.Decode(val)
}
+49
View File
@@ -0,0 +1,49 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package event // import "go.mongodb.org/mongo-driver/event"
import (
"context"
"go.mongodb.org/mongo-driver/bson"
)
// CommandStartedEvent represents an event generated when a command is sent to a server.
type CommandStartedEvent struct {
Command bson.Raw
DatabaseName string
CommandName string
RequestID int64
ConnectionID string
}
// CommandFinishedEvent represents a generic command finishing.
type CommandFinishedEvent struct {
DurationNanos int64
CommandName string
RequestID int64
ConnectionID string
}
// CommandSucceededEvent represents an event generated when a command's execution succeeds.
type CommandSucceededEvent struct {
CommandFinishedEvent
Reply bson.Raw
}
// CommandFailedEvent represents an event generated when a command's execution fails.
type CommandFailedEvent struct {
CommandFinishedEvent
Failure string
}
// CommandMonitor represents a monitor that is triggered for different events.
type CommandMonitor struct {
Started func(context.Context, *CommandStartedEvent)
Succeeded func(context.Context, *CommandSucceededEvent)
Failed func(context.Context, *CommandFailedEvent)
}
+74
View File
@@ -0,0 +1,74 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package internal
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
"errors"
"fmt"
)
// Implements the connection.Connection interface by reading and writing wire messages
// to a channel
type ChannelConn struct {
WriteErr error
Written chan wiremessage.WireMessage
ReadResp chan wiremessage.WireMessage
ReadErr chan error
}
func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
select {
case c.Written <- wm:
default:
c.WriteErr = errors.New("could not write wiremessage to written channel")
}
return c.WriteErr
}
func (c *ChannelConn) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) {
var wm wiremessage.WireMessage
var err error
select {
case wm = <-c.ReadResp:
case err = <-c.ReadErr:
case <-ctx.Done():
}
return wm, err
}
func (c *ChannelConn) Close() error {
return nil
}
func (c *ChannelConn) Expired() bool {
return false
}
func (c *ChannelConn) Alive() bool {
return true
}
func (c *ChannelConn) ID() string {
return "faked"
}
// Create a OP_REPLY wiremessage from a BSON document
func MakeReply(doc bsonx.Doc) (wiremessage.WireMessage, error) {
rdr, err := doc.MarshalBSON()
if err != nil {
return nil, errors.New(fmt.Sprintf("could not create document: %v", err))
}
return wiremessage.Reply{
NumberReturned: 1,
Documents: []bson.Raw{rdr},
}, nil
}

Some files were not shown because too many files have changed in this diff Show More